Passed
Pull Request — master (#17)
by Frank
02:58
created

cli.columnInfo.hasTrue   A

Complexity

Conditions 4

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 2
nop 0
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
package cli
2
3
import (
4
	"fmt"
5
	"strings"
6
7
	"github.com/fraenky8/tables-to-go/pkg/config"
8
	"github.com/fraenky8/tables-to-go/pkg/database"
9
	"github.com/fraenky8/tables-to-go/pkg/output"
10
	"github.com/fraenky8/tables-to-go/pkg/tagger"
11
)
12
13
var (
14
	taggers tagger.Tagger
15
16
	// some strings for idiomatic go in column names
17
	// see https://github.com/golang/go/wiki/CodeReviewComments#initialisms
18
	initialisms = []string{"ID", "JSON", "XML", "HTTP", "URL"}
19
)
20
21
// Run runs the transformations by creating the concrete Database by the provided settings
22
func Run(settings *config.Settings, db database.Database, out output.Writer) (err error) {
23
24
	taggers = tagger.NewTaggers(settings)
25
26
	fmt.Printf("running for %q...\r\n", settings.DbType)
27
28
	tables, err := db.GetTables()
29
	if err != nil {
30
		return fmt.Errorf("could not get tables: %v", err)
31
	}
32
33
	if settings.Verbose {
34
		fmt.Printf("> number of tables: %v\r\n", len(tables))
35
	}
36
37
	if err = db.PrepareGetColumnsOfTableStmt(); err != nil {
38
		return fmt.Errorf("could not prepare the get-column-statement: %v", err)
39
	}
40
41
	for _, table := range tables {
42
43
		if settings.Verbose {
44
			fmt.Printf("> processing table %q\r\n", table.Name)
45
		}
46
47
		if err = db.GetColumnsOfTable(table); err != nil {
48
			return fmt.Errorf("could not get columns of table %s: %v", table.Name, err)
49
		}
50
51
		if settings.Verbose {
52
			fmt.Printf("\t> number of columns: %v\r\n", len(table.Columns))
53
		}
54
55
		tableName, content := createTableStructString(settings, db, table)
56
57
		err = out.Write(tableName, content)
58
		if err != nil {
59
			return fmt.Errorf("could not write struct for table %s: %v", table.Name, err)
60
		}
61
	}
62
63
	fmt.Println("done!")
64
65
	return nil
66
}
67
68
type columnInfo struct {
69
	isNullable          bool
70
	isTemporal          bool
71
	isNullablePrimitive bool
72
	isNullableTemporal  bool
73
}
74
75
func (c columnInfo) hasTrue() bool {
76
	return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
77
}
78
79
func createTableStructString(settings *config.Settings, db database.Database, table *database.Table) (string, string) {
80
81
	var structFields strings.Builder
82
83
	columnInfo := columnInfo{}
84
	columns := map[string]struct{}{}
85
86
	for _, column := range table.Columns {
87
88
		columnName := strings.Title(column.Name)
89
		if settings.IsOutputFormatCamelCase() {
90
			columnName = camelCaseString(column.Name)
91
		}
92
		if settings.ShouldInitialism() {
93
			columnName = toInitialisms(columnName)
94
		}
95
96
		// ISSUE-4: if columns are part of multiple constraints
97
		// then the sql returns multiple rows per column name.
98
		// Therefore we check if we already added a column with
99
		// that name to the struct, if so, skip.
100
		if _, ok := columns[columnName]; ok {
101
			continue
102
		}
103
		columns[columnName] = struct{}{}
104
105
		if settings.VVerbose {
106
			fmt.Printf("\t\t> %v\r\n", column.Name)
107
		}
108
109
		columnType, col := mapDbColumnTypeToGoType(settings, db, column)
110
111
		// save that we saw types of columns at least once
112
		if !columnInfo.isTemporal {
113
			columnInfo.isTemporal = col.isTemporal
114
		}
115
		if !columnInfo.isNullableTemporal {
116
			columnInfo.isNullableTemporal = col.isNullableTemporal
117
		}
118
		if !columnInfo.isNullablePrimitive {
119
			columnInfo.isNullablePrimitive = col.isNullablePrimitive
120
		}
121
122
		structFields.WriteString(columnName)
123
		structFields.WriteString(" ")
124
		structFields.WriteString(columnType)
125
		structFields.WriteString(" ")
126
		structFields.WriteString(taggers.GenerateTag(db, column))
127
		structFields.WriteString("\n")
128
	}
129
130
	if settings.IsMastermindStructableRecorder {
131
		structFields.WriteString("\t\nstructable.Recorder\n")
132
	}
133
134
	var fileContent strings.Builder
135
136
	// write header infos
137
	fileContent.WriteString("package ")
138
	fileContent.WriteString(settings.PackageName)
139
	fileContent.WriteString("\n\n")
140
141
	// write imports
142
	generateImports(&fileContent, settings, db, columnInfo)
143
144
	tableName := strings.Title(settings.Prefix + table.Name + settings.Suffix)
145
	if settings.IsOutputFormatCamelCase() {
146
		tableName = camelCaseString(tableName)
147
	}
148
149
	// write struct with fields
150
	fileContent.WriteString("type ")
151
	fileContent.WriteString(tableName)
152
	fileContent.WriteString(" struct {\n")
153
	fileContent.WriteString(structFields.String())
154
	fileContent.WriteString("}")
155
156
	return tableName, fileContent.String()
157
}
158
159
func generateImports(content *strings.Builder, settings *config.Settings, db database.Database, columnInfo columnInfo) {
160
161
	if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
162
		return
163
	}
164
165
	content.WriteString("import (\n")
166
167
	if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
168
		content.WriteString("\t\"database/sql\"\n")
169
	}
170
171
	if columnInfo.isTemporal {
172
		content.WriteString("\t\"time\"\n")
173
	}
174
175
	if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
176
		content.WriteString("\t\n")
177
		content.WriteString(db.GetDriverImportLibrary())
178
		content.WriteString("\n")
179
	}
180
181
	if settings.IsMastermindStructableRecorder {
182
		content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
183
	}
184
185
	content.WriteString(")\n\n")
186
}
187
188
func mapDbColumnTypeToGoType(settings *config.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
189
	if db.IsString(column) || db.IsText(column) {
190
		goType = "string"
191
		if db.IsNullable(column) {
192
			goType = getNullType(settings, "*string", "sql.NullString")
193
			columnInfo.isNullable = true
194
		}
195
	} else if db.IsInteger(column) {
196
		goType = "int"
197
		if db.IsNullable(column) {
198
			goType = getNullType(settings, "*int", "sql.NullInt64")
199
			columnInfo.isNullable = true
200
		}
201
	} else if db.IsFloat(column) {
202
		goType = "float64"
203
		if db.IsNullable(column) {
204
			goType = getNullType(settings, "*float64", "sql.NullFloat64")
205
			columnInfo.isNullable = true
206
		}
207
	} else if db.IsTemporal(column) {
208
		if !db.IsNullable(column) {
209
			goType = "time.Time"
210
			columnInfo.isTemporal = true
211
		} else {
212
			goType = getNullType(settings, "*time.Time", db.GetTemporalDriverDataType())
213
			columnInfo.isTemporal = settings.Null == config.NullTypeNative
214
			columnInfo.isNullableTemporal = true
215
			columnInfo.isNullable = true
216
		}
217
	} else {
218
		// TODO handle special data types
219
		switch column.DataType {
220
		case "boolean":
221
			goType = "bool"
222
			if db.IsNullable(column) {
223
				goType = getNullType(settings, "*bool", "sql.NullBool")
224
				columnInfo.isNullable = true
225
			}
226
		default:
227
			goType = getNullType(settings, "*string", "sql.NullString")
228
		}
229
	}
230
231
	columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)
232
233
	return goType, columnInfo
234
}
235
236
func getNullType(settings *config.Settings, primitive string, sql string) string {
237
	if settings.IsNullTypeSQL() {
238
		return sql
239
	}
240
	return primitive
241
}
242
243
func camelCaseString(s string) (cc string) {
244
	splitted := strings.Split(s, "_")
245
246
	if len(splitted) == 1 {
247
		return strings.Title(s)
248
	}
249
250
	for _, part := range splitted {
251
		cc += strings.Title(strings.ToLower(part))
252
	}
253
	return cc
254
}
255
256
func toInitialisms(s string) string {
257
	for _, substr := range initialisms {
258
		idx := indexCaseInsensitive(s, substr)
259
		if idx == -1 {
260
			continue
261
		}
262
		toReplace := s[idx : idx+len(substr)]
263
		s = strings.ReplaceAll(s, toReplace, substr)
264
	}
265
	return s
266
}
267
268
func indexCaseInsensitive(s, substr string) int {
269
	s, substr = strings.ToLower(s), strings.ToLower(substr)
270
	return strings.Index(s, substr)
271
}
272