Passed
Pull Request — master (#21)
by
unknown
02:06
created

cli.replaceSpace   A

Complexity

Conditions 2

Size

Total Lines 5
Code Lines 4

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 2
CRAP Score 2.1481

Importance

Changes 0
Metric Value
cc 2
eloc 4
nop 1
dl 0
loc 5
ccs 2
cts 3
cp 0.6667
crap 2.1481
rs 10
c 0
b 0
f 0
1
package cli
2
3
import (
4
	"fmt"
5
	"strings"
6
	"unicode"
7
8
	"github.com/fraenky8/tables-to-go/pkg/database"
9
	"github.com/fraenky8/tables-to-go/pkg/output"
10
	"github.com/fraenky8/tables-to-go/pkg/settings"
11
	"github.com/fraenky8/tables-to-go/pkg/tagger"
12
)
13
14
var (
15
	taggers tagger.Tagger
16
17
	// some strings for idiomatic go in column names
18
	// see https://github.com/golang/go/wiki/CodeReviewComments#initialisms
19
	initialisms = []string{"ID", "JSON", "XML", "HTTP", "URL"}
20
)
21
22
// Run runs the transformations by creating the concrete Database by the provided settings
23
func Run(settings *settings.Settings, db database.Database, out output.Writer) (err error) {
24
25 1
	taggers = tagger.NewTaggers(settings)
26
27 1
	fmt.Printf("running for %q...\r\n", settings.DbType)
28
29 1
	tables, err := db.GetTables()
30 1
	if err != nil {
31
		return fmt.Errorf("could not get tables: %v", err)
32
	}
33
34 1
	if settings.Verbose {
35
		fmt.Printf("> number of tables: %v\r\n", len(tables))
36
	}
37
38 1
	if err = db.PrepareGetColumnsOfTableStmt(); err != nil {
39
		return fmt.Errorf("could not prepare the get-column-statement: %v", err)
40
	}
41
42 1
	for _, table := range tables {
43
44 1
		if settings.Verbose {
45
			fmt.Printf("> processing table %q\r\n", table.Name)
46
		}
47
48 1
		if err = db.GetColumnsOfTable(table); err != nil {
49
			if !settings.Force {
50
				return fmt.Errorf("could not get columns of table %s: %v", table.Name, err)
51
			}
52
			fmt.Printf("could not get columns of table %s: %v\n", table.Name, err)
53
			continue
54
		}
55
56 1
		if settings.Verbose {
57
			fmt.Printf("\t> number of columns: %v\r\n", len(table.Columns))
58
		}
59
60 1
		tableName, content, err := createTableStructString(settings, db, table)
61 1
		if err != nil {
62
			if !settings.Force {
63
				return fmt.Errorf("could not create string for table %s: %v", table.Name, err)
64
			}
65
			fmt.Printf("could not create string for table %s: %v\n", table.Name, err)
66
			continue
67
		}
68
69 1
		err = out.Write(tableName, content)
70 1
		if err != nil {
71
			if !settings.Force {
72
				return fmt.Errorf("could not write struct for table %s: %v", table.Name, err)
73
			}
74
			fmt.Printf("could not write struct for table %s: %v\n", table.Name, err)
75
		}
76
	}
77
78 1
	fmt.Println("done!")
79
80 1
	return nil
81
}
82
83
type columnInfo struct {
84
	isNullable          bool
85
	isTemporal          bool
86
	isNullablePrimitive bool
87
	isNullableTemporal  bool
88
}
89
90
func (c columnInfo) hasTrue() bool {
91 1
	return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
92
}
93
94
func createTableStructString(settings *settings.Settings, db database.Database, table *database.Table) (string, string, error) {
95
96 1
	var structFields strings.Builder
97 1
	tableName := strings.Title(settings.Prefix + table.Name + settings.Suffix)
98
	// Replace any whitespace with underscores
99 1
	tableName = strings.Map(replaceSpace, tableName)
100 1
	if settings.IsOutputFormatCamelCase() {
101 1
		tableName = camelCaseString(tableName)
102
	}
103
104
	// Check that the table name doesn't contain any invalid characters for Go variables
105 1
	if !validVariableName(tableName) {
106
		return "", "", fmt.Errorf("Table name %q contains invalid characters", table.Name)
107
	}
108
109 1
	columnInfo := columnInfo{}
110 1
	columns := map[string]struct{}{}
111
112 1
	for _, column := range table.Columns {
113
114 1
		columnName := strings.Title(column.Name)
115
		// Replace any whitespace with underscores
116 1
		columnName = strings.Map(replaceSpace, columnName)
117 1
		if settings.IsOutputFormatCamelCase() {
118 1
			columnName = camelCaseString(columnName)
119
		}
120 1
		if settings.ShouldInitialism() {
121 1
			columnName = toInitialisms(columnName)
122
		}
123
124
		// Check that the column name doesn't contain any invalid characters for Go variables
125 1
		if !validVariableName(columnName) {
126
			return "", "", fmt.Errorf("Column name %q in table %q contains invalid characters", column.Name, table.Name)
127
		}
128
		// First character of an identifier in Go must be letter or _
129
		// We want it to be an uppercase letter to be a public field
130 1
		if !unicode.IsLetter([]rune(columnName)[0]) {
131
			if settings.Verbose {
132
				fmt.Printf("\t\t>Column %q in table %q doesn't start with a letter; prepending with X_\n", column.Name, table.Name)
133
			}
134
			columnName = "X_" + columnName
135
		}
136
		// ISSUE-4: if columns are part of multiple constraints
137
		// then the sql returns multiple rows per column name.
138
		// Therefore we check if we already added a column with
139
		// that name to the struct, if so, skip.
140 1
		if _, ok := columns[columnName]; ok {
141
			continue
142
		}
143 1
		columns[columnName] = struct{}{}
144
145 1
		if settings.VVerbose {
146
			fmt.Printf("\t\t> %v\r\n", column.Name)
147
		}
148
149 1
		columnType, col := mapDbColumnTypeToGoType(settings, db, column)
150
151
		// save that we saw types of columns at least once
152 1
		if !columnInfo.isTemporal {
153 1
			columnInfo.isTemporal = col.isTemporal
154
		}
155 1
		if !columnInfo.isNullableTemporal {
156 1
			columnInfo.isNullableTemporal = col.isNullableTemporal
157
		}
158 1
		if !columnInfo.isNullablePrimitive {
159 1
			columnInfo.isNullablePrimitive = col.isNullablePrimitive
160
		}
161
162 1
		structFields.WriteString(columnName)
163 1
		structFields.WriteString(" ")
164 1
		structFields.WriteString(columnType)
165 1
		structFields.WriteString(" ")
166 1
		structFields.WriteString(taggers.GenerateTag(db, column))
167 1
		structFields.WriteString("\n")
168
	}
169
170 1
	if settings.IsMastermindStructableRecorder {
171
		structFields.WriteString("\t\nstructable.Recorder\n")
172
	}
173
174 1
	var fileContent strings.Builder
175
176
	// write header infos
177 1
	fileContent.WriteString("package ")
178 1
	fileContent.WriteString(settings.PackageName)
179 1
	fileContent.WriteString("\n\n")
180
181
	// write imports
182 1
	generateImports(&fileContent, settings, db, columnInfo)
183
184
	// write struct with fields
185 1
	fileContent.WriteString("type ")
186 1
	fileContent.WriteString(tableName)
187 1
	fileContent.WriteString(" struct {\n")
188 1
	fileContent.WriteString(structFields.String())
189 1
	fileContent.WriteString("}")
190
191 1
	return tableName, fileContent.String(), nil
192
}
193
194
func generateImports(content *strings.Builder, settings *settings.Settings, db database.Database, columnInfo columnInfo) {
195
196 1
	if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
197 1
		return
198
	}
199
200 1
	content.WriteString("import (\n")
201
202 1
	if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
203 1
		content.WriteString("\t\"database/sql\"\n")
204
	}
205
206 1
	if columnInfo.isTemporal {
207 1
		content.WriteString("\t\"time\"\n")
208
	}
209
210 1
	if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
211 1
		content.WriteString("\t\n")
212 1
		content.WriteString(db.GetDriverImportLibrary())
213 1
		content.WriteString("\n")
214
	}
215
216 1
	if settings.IsMastermindStructableRecorder {
217
		content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
218
	}
219
220 1
	content.WriteString(")\n\n")
221
}
222
223
func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
224 1
	if db.IsString(column) || db.IsText(column) {
225 1
		goType = "string"
226 1
		if db.IsNullable(column) {
227 1
			goType = getNullType(s, "*string", "sql.NullString")
228 1
			columnInfo.isNullable = true
229
		}
230 1
	} else if db.IsInteger(column) {
231 1
		goType = "int"
232 1
		if db.IsNullable(column) {
233 1
			goType = getNullType(s, "*int", "sql.NullInt64")
234 1
			columnInfo.isNullable = true
235
		}
236 1
	} else if db.IsFloat(column) {
237 1
		goType = "float64"
238 1
		if db.IsNullable(column) {
239 1
			goType = getNullType(s, "*float64", "sql.NullFloat64")
240 1
			columnInfo.isNullable = true
241
		}
242 1
	} else if db.IsTemporal(column) {
243 1
		if !db.IsNullable(column) {
244 1
			goType = "time.Time"
245 1
			columnInfo.isTemporal = true
246
		} else {
247 1
			goType = getNullType(s, "*time.Time", db.GetTemporalDriverDataType())
248 1
			columnInfo.isTemporal = s.Null == settings.NullTypeNative
249 1
			columnInfo.isNullableTemporal = true
250 1
			columnInfo.isNullable = true
251
		}
252
	} else {
253
		// TODO handle special data types
254 1
		switch column.DataType {
255
		case "boolean":
256 1
			goType = "bool"
257 1
			if db.IsNullable(column) {
258 1
				goType = getNullType(s, "*bool", "sql.NullBool")
259 1
				columnInfo.isNullable = true
260
			}
261
		default:
262
			goType = getNullType(s, "*string", "sql.NullString")
263
		}
264
	}
265
266 1
	columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)
267
268 1
	return goType, columnInfo
269
}
270
271
func getNullType(settings *settings.Settings, primitive string, sql string) string {
272 1
	if settings.IsNullTypeSQL() {
273 1
		return sql
274
	}
275 1
	return primitive
276
}
277
278
func camelCaseString(s string) string {
279 1
	if s == "" {
280 1
		return s
281
	}
282
283 1
	splitted := strings.Split(s, "_")
284
285 1
	if len(splitted) == 1 {
286 1
		return strings.Title(s)
287
	}
288
289 1
	var cc string
290 1
	for _, part := range splitted {
291 1
		cc += strings.Title(strings.ToLower(part))
292
	}
293 1
	return cc
294
}
295
296
func toInitialisms(s string) string {
297 1
	for _, substr := range initialisms {
298 1
		idx := indexCaseInsensitive(s, substr)
299 1
		if idx == -1 {
300 1
			continue
301
		}
302 1
		toReplace := s[idx : idx+len(substr)]
303 1
		s = strings.ReplaceAll(s, toReplace, substr)
304
	}
305 1
	return s
306
}
307
308
func indexCaseInsensitive(s, substr string) int {
309 1
	s, substr = strings.ToLower(s), strings.ToLower(substr)
310 1
	return strings.Index(s, substr)
311
}
312
313
// ValidVariableName checks for the existence of any characters
314
// outside of Unicode letters, numbers and underscore.
315
func validVariableName(str string) bool {
316 1
	for _, r := range str {
317 1
		if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_') {
318
			return false
319
		}
320
	}
321 1
	return true
322
}
323
324
// ReplaceSpace swaps any Unicode space characters for underscores
325
// to create valid Go identifiers
326
func replaceSpace(r rune) rune {
327 1
	if unicode.IsSpace(r) {
328
		return '_'
329
	}
330 1
	return r
331
}
332