internal/cli/tables-to-go-cli.go   F
last analyzed

Size/Duplication

Total Lines 355
Duplicated Lines 0 %

Test Coverage

Coverage 85.21%

Importance

Changes 0
Metric Value
cc 77
eloc 202
dl 0
loc 355
ccs 144
cts 169
cp 0.8521
crap 96.1816
rs 2.24
c 0
b 0
f 0

12 Methods

Rating   Name   Duplication   Size   Complexity  
A cli.columnInfo.hasTrue 0 2 4
C cli.generateImports 0 27 9
A cli.replaceSpace 0 5 3
A cli.indexCaseInsensitive 0 3 1
F cli.mapDbColumnTypeToGoType 0 46 14
C cli.createTableStructString 0 80 11
B cli.formatColumnName 0 31 7
A cli.toInitialisms 0 10 3
A cli.validVariableName 0 7 5
F cli.Run 0 64 14
A cli.getNullType 0 5 2
A cli.camelCaseString 0 16 4
1
package cli
2
3
import (
4
	"fmt"
5
	"strings"
6
	"unicode"
7
8
	"github.com/iancoleman/strcase"
9
10
	"github.com/fraenky8/tables-to-go/pkg/database"
11
	"github.com/fraenky8/tables-to-go/pkg/output"
12
	"github.com/fraenky8/tables-to-go/pkg/settings"
13
	"github.com/fraenky8/tables-to-go/pkg/tagger"
14
)
15
16
var (
17
	taggers tagger.Tagger
18
19
	// some strings for idiomatic go in column names
20
	// see https://github.com/golang/go/wiki/CodeReviewComments#initialisms
21
	initialisms = []string{"ID", "JSON", "XML", "HTTP", "URL"}
22
)
23
24
// Run runs the transformations by creating the concrete Database by the provided settings
25
func Run(settings *settings.Settings, db database.Database, out output.Writer) (err error) {
26
27 1
	taggers = tagger.NewTaggers(settings)
28
29 1
	fmt.Printf("running for %q...\r\n", settings.DbType)
30
31 1
	tables, err := db.GetTables()
32 1
	if err != nil {
33
		return fmt.Errorf("could not get tables: %v", err)
34
	}
35
36 1
	if settings.Verbose {
37
		fmt.Printf("> number of tables: %v\r\n", len(tables))
38
	}
39
40 1
	if err = db.PrepareGetColumnsOfTableStmt(); err != nil {
41
		return fmt.Errorf("could not prepare the get-column-statement: %v", err)
42
	}
43
44 1
	for _, table := range tables {
45
46 1
		if settings.Verbose {
47
			fmt.Printf("> processing table %q\r\n", table.Name)
48
		}
49
50 1
		if err = db.GetColumnsOfTable(table); err != nil {
51
			if !settings.Force {
52
				return fmt.Errorf("could not get columns of table %q: %v", table.Name, err)
53
			}
54
			fmt.Printf("could not get columns of table %q: %v\n", table.Name, err)
55
			continue
56
		}
57
58 1
		if settings.Verbose {
59
			fmt.Printf("\t> number of columns: %v\r\n", len(table.Columns))
60
		}
61
62 1
		tableName, content, err := createTableStructString(settings, db, table)
63
64 1
		if err != nil {
65
			if !settings.Force {
66
				return fmt.Errorf("could not create string for table %q: %v", table.Name, err)
67
			}
68
			fmt.Printf("could not create string for table %q: %v\n", table.Name, err)
69
			continue
70
		}
71
72 1
		fileName := camelCaseString(tableName)
73 1
		if settings.IsFileNameFormatSnakeCase() {
74
			fileName = strcase.ToSnake(fileName)
75
		}
76
77 1
		err = out.Write(fileName, content)
78 1
		if err != nil {
79
			if !settings.Force {
80
				return fmt.Errorf("could not write struct for table %q: %v", table.Name, err)
81
			}
82
			fmt.Printf("could not write struct for table %q: %v\n", table.Name, err)
83
		}
84
	}
85
86 1
	fmt.Println("done!")
87
88 1
	return nil
89
}
90
91
type columnInfo struct {
92
	isNullable          bool
93
	isTemporal          bool
94
	isNullablePrimitive bool
95
	isNullableTemporal  bool
96
}
97
98
func (c columnInfo) hasTrue() bool {
99 1
	return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
100
}
101
102
func createTableStructString(settings *settings.Settings, db database.Database, table *database.Table) (string, string, error) {
103
104 1
	var structFields strings.Builder
105 1
	tableName := strings.Title(settings.Prefix + table.Name + settings.Suffix)
106
	// Replace any whitespace with underscores
107 1
	tableName = strings.Map(replaceSpace, tableName)
108 1
	if settings.IsOutputFormatCamelCase() {
109 1
		tableName = camelCaseString(tableName)
110
	}
111
112
	// Check that the table name doesn't contain any invalid characters for Go variables
113 1
	if !validVariableName(tableName) {
114
		return "", "", fmt.Errorf("table name %q contains invalid characters", table.Name)
115
	}
116
117 1
	columnInfo := columnInfo{}
118 1
	columns := map[string]struct{}{}
119
120 1
	for _, column := range table.Columns {
121 1
		columnName, err := formatColumnName(settings, column.Name, table.Name)
122 1
		if err != nil {
123
			return "", "", err
124
		}
125
126
		// ISSUE-4: if columns are part of multiple constraints
127
		// then the sql returns multiple rows per column name.
128
		// Therefore we check if we already added a column with
129
		// that name to the struct, if so, skip.
130 1
		if _, ok := columns[columnName]; ok {
131
			continue
132
		}
133 1
		columns[columnName] = struct{}{}
134
135 1
		if settings.VVerbose {
136
			fmt.Printf("\t\t> %v\r\n", column.Name)
137
		}
138
139 1
		columnType, col := mapDbColumnTypeToGoType(settings, db, column)
140
141
		// save that we saw types of columns at least once
142 1
		if !columnInfo.isTemporal {
143 1
			columnInfo.isTemporal = col.isTemporal
144
		}
145 1
		if !columnInfo.isNullableTemporal {
146 1
			columnInfo.isNullableTemporal = col.isNullableTemporal
147
		}
148 1
		if !columnInfo.isNullablePrimitive {
149 1
			columnInfo.isNullablePrimitive = col.isNullablePrimitive
150
		}
151
152 1
		structFields.WriteString(columnName)
153 1
		structFields.WriteString(" ")
154 1
		structFields.WriteString(columnType)
155 1
		structFields.WriteString(" ")
156 1
		structFields.WriteString(taggers.GenerateTag(db, column))
157 1
		structFields.WriteString("\n")
158
	}
159
160 1
	if settings.IsMastermindStructableRecorder {
161
		structFields.WriteString("\t\nstructable.Recorder\n")
162
	}
163
164 1
	var fileContent strings.Builder
165
166
	// write header infos
167 1
	fileContent.WriteString("package ")
168 1
	fileContent.WriteString(settings.PackageName)
169 1
	fileContent.WriteString("\n\n")
170
171
	// write imports
172 1
	generateImports(&fileContent, settings, db, columnInfo)
173
174
	// write struct with fields
175 1
	fileContent.WriteString("type ")
176 1
	fileContent.WriteString(tableName)
177 1
	fileContent.WriteString(" struct {\n")
178 1
	fileContent.WriteString(structFields.String())
179 1
	fileContent.WriteString("}")
180
181 1
	return tableName, fileContent.String(), nil
182
}
183
184
func generateImports(content *strings.Builder, settings *settings.Settings, db database.Database, columnInfo columnInfo) {
185
186 1
	if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
187 1
		return
188
	}
189
190 1
	content.WriteString("import (\n")
191
192 1
	if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
193 1
		content.WriteString("\t\"database/sql\"\n")
194
	}
195
196 1
	if columnInfo.isTemporal {
197 1
		content.WriteString("\t\"time\"\n")
198
	}
199
200 1
	if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
201 1
		content.WriteString("\t\n")
202 1
		content.WriteString(db.GetDriverImportLibrary())
203 1
		content.WriteString("\n")
204
	}
205
206 1
	if settings.IsMastermindStructableRecorder {
207
		content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
208
	}
209
210 1
	content.WriteString(")\n\n")
211
}
212
213
func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
214 1
	if db.IsString(column) || db.IsText(column) {
215 1
		goType = "string"
216 1
		if db.IsNullable(column) {
217 1
			goType = getNullType(s, "*string", "sql.NullString")
218 1
			columnInfo.isNullable = true
219
		}
220 1
	} else if db.IsInteger(column) {
221 1
		goType = "int"
222 1
		if db.IsNullable(column) {
223 1
			goType = getNullType(s, "*int", "sql.NullInt64")
224 1
			columnInfo.isNullable = true
225
		}
226 1
	} else if db.IsFloat(column) {
227 1
		goType = "float64"
228 1
		if db.IsNullable(column) {
229 1
			goType = getNullType(s, "*float64", "sql.NullFloat64")
230 1
			columnInfo.isNullable = true
231
		}
232 1
	} else if db.IsTemporal(column) {
233 1
		if !db.IsNullable(column) {
234 1
			goType = "time.Time"
235 1
			columnInfo.isTemporal = true
236
		} else {
237 1
			goType = getNullType(s, "*time.Time", db.GetTemporalDriverDataType())
238 1
			columnInfo.isTemporal = s.Null == settings.NullTypeNative
239 1
			columnInfo.isNullableTemporal = true
240 1
			columnInfo.isNullable = true
241
		}
242
	} else {
243
		// TODO handle special data types
244 1
		switch column.DataType {
245
		case "boolean":
246 1
			goType = "bool"
247 1
			if db.IsNullable(column) {
248 1
				goType = getNullType(s, "*bool", "sql.NullBool")
249 1
				columnInfo.isNullable = true
250
			}
251
		default:
252
			goType = getNullType(s, "*string", "sql.NullString")
253
		}
254
	}
255
256 1
	columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)
257
258 1
	return goType, columnInfo
259
}
260
261
func camelCaseString(s string) string {
262 1
	if s == "" {
263 1
		return s
264
	}
265
266 1
	splitted := strings.Split(s, "_")
267
268 1
	if len(splitted) == 1 {
269 1
		return strings.Title(s)
270
	}
271
272 1
	var cc string
273 1
	for _, part := range splitted {
274 1
		cc += strings.Title(strings.ToLower(part))
275
	}
276 1
	return cc
277
}
278
279
func getNullType(settings *settings.Settings, primitive string, sql string) string {
280 1
	if settings.IsNullTypeSQL() {
281 1
		return sql
282
	}
283 1
	return primitive
284
}
285
286
func toInitialisms(s string) string {
287 1
	for _, substr := range initialisms {
288 1
		idx := indexCaseInsensitive(s, substr)
289 1
		if idx == -1 {
290 1
			continue
291
		}
292 1
		toReplace := s[idx : idx+len(substr)]
293 1
		s = strings.ReplaceAll(s, toReplace, substr)
294
	}
295 1
	return s
296
}
297
298
func indexCaseInsensitive(s, substr string) int {
299 1
	s, substr = strings.ToLower(s), strings.ToLower(substr)
300 1
	return strings.Index(s, substr)
301
}
302
303
// ValidVariableName checks for the existence of any characters
304
// outside of Unicode letters, numbers and underscore.
305
func validVariableName(s string) bool {
306 1
	for _, r := range s {
307 1
		if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_') {
308 1
			return false
309
		}
310
	}
311 1
	return true
312
}
313
314
// ReplaceSpace swaps any Unicode space characters for underscores
315
// to create valid Go identifiers
316
func replaceSpace(r rune) rune {
317 1
	if unicode.IsSpace(r) || r == '\u200B' {
318 1
		return '_'
319
	}
320 1
	return r
321
}
322
323
// FormatColumnName checks for invalid characters and transforms a column name
324
// according to the provided settings.
325
func formatColumnName(settings *settings.Settings, column, table string) (string, error) {
326
327
	// Replace any whitespace with underscores
328 1
	columnName := strings.Map(replaceSpace, column)
329 1
	columnName = strings.Title(columnName)
330
331 1
	if settings.IsOutputFormatCamelCase() {
332 1
		columnName = camelCaseString(columnName)
333
	}
334 1
	if settings.ShouldInitialism() {
335 1
		columnName = toInitialisms(columnName)
336
	}
337
338
	// Check that the column name doesn't contain any invalid characters for Go variables
339 1
	if !validVariableName(columnName) {
340 1
		return "", fmt.Errorf("column name %q in table %q contains invalid characters", column, table)
341
	}
342
	// First character of an identifier in Go must be letter or _
343
	// We want it to be an uppercase letter to be a public field
344 1
	if !unicode.IsLetter([]rune(columnName)[0]) {
345 1
		prefix := "X_"
346 1
		if settings.IsOutputFormatCamelCase() {
347 1
			prefix = "X"
348
		}
349 1
		if settings.Verbose {
350
			fmt.Printf("\t\t>column %q in table %q doesn't start with a letter; prepending with %q\n", column, table, prefix)
351
		}
352 1
		columnName = prefix + columnName
353
	}
354
355 1
	return columnName, nil
356
}
357