Passed
Pull Request — master (#20)
by
unknown
02:07
created

cli.camelCaseString   A

Complexity

Conditions 4

Size

Total Lines 16
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 9
CRAP Score 4

Importance

Changes 0
Metric Value
cc 4
eloc 10
nop 1
dl 0
loc 16
ccs 9
cts 9
cp 1
crap 4
rs 9.9
c 0
b 0
f 0
1
package cli
2
3
import (
4
	"fmt"
5
	"strings"
6
7
	"github.com/fraenky8/tables-to-go/pkg/database"
8
	"github.com/fraenky8/tables-to-go/pkg/output"
9
	"github.com/fraenky8/tables-to-go/pkg/settings"
10
	"github.com/fraenky8/tables-to-go/pkg/tagger"
11
)
12
13
var (
14
	taggers tagger.Tagger
15
16
	// some strings for idiomatic go in column names
17
	// see https://github.com/golang/go/wiki/CodeReviewComments#initialisms
18
	initialisms = []string{"ID", "JSON", "XML", "HTTP", "URL"}
19
)
20
21
// Run runs the transformations by creating the concrete Database by the provided settings
22
func Run(settings *settings.Settings, db database.Database, out output.Writer) (err error) {
23
24 1
	taggers = tagger.NewTaggers(settings)
25
26 1
	fmt.Printf("running for %q...\r\n", settings.DbType)
27
28 1
	tables, err := db.GetTables()
29 1
	if err != nil {
30
		return fmt.Errorf("could not get tables: %v", err)
31
	}
32
33 1
	if settings.Verbose {
34
		fmt.Printf("> number of tables: %v\r\n", len(tables))
35
	}
36
37 1
	if err = db.PrepareGetColumnsOfTableStmt(); err != nil {
38
		return fmt.Errorf("could not prepare the get-column-statement: %v", err)
39
	}
40
41 1
	for _, table := range tables {
42
43 1
		if settings.Verbose {
44
			fmt.Printf("> processing table %q\r\n", table.Name)
45
		}
46
47 1
		if err = db.GetColumnsOfTable(table); err != nil {
48
			if !settings.Force {
49
				return fmt.Errorf("could not get columns of table %s: %v", table.Name, err)
50
			}
51
			fmt.Printf("could not get columns of table %s: %v\n", table.Name, err)
52
		}
53
54 1
		if settings.Verbose {
55
			fmt.Printf("\t> number of columns: %v\r\n", len(table.Columns))
56
		}
57
58 1
		tableName, content := createTableStructString(settings, db, table)
59
60 1
		err = out.Write(tableName, content)
61 1
		if err != nil {
62
			if !settings.Force {
63
				return fmt.Errorf("could not write struct for table %s: %v", table.Name, err)
64
			}
65
			fmt.Printf("could not write struct for table %s: %v\n", table.Name, err)
66
		}
67
	}
68
69 1
	fmt.Println("done!")
70
71 1
	return nil
72
}
73
74
type columnInfo struct {
75
	isNullable          bool
76
	isTemporal          bool
77
	isNullablePrimitive bool
78
	isNullableTemporal  bool
79
}
80
81
func (c columnInfo) hasTrue() bool {
82 1
	return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
83
}
84
85
func createTableStructString(settings *settings.Settings, db database.Database, table *database.Table) (string, string) {
86
87 1
	var structFields strings.Builder
88
89 1
	columnInfo := columnInfo{}
90 1
	columns := map[string]struct{}{}
91
92 1
	for _, column := range table.Columns {
93
94 1
		columnName := strings.Title(column.Name)
95 1
		if settings.IsOutputFormatCamelCase() {
96 1
			columnName = camelCaseString(column.Name)
97
		}
98 1
		if settings.ShouldInitialism() {
99 1
			columnName = toInitialisms(columnName)
100
		}
101
102
		// ISSUE-4: if columns are part of multiple constraints
103
		// then the sql returns multiple rows per column name.
104
		// Therefore we check if we already added a column with
105
		// that name to the struct, if so, skip.
106 1
		if _, ok := columns[columnName]; ok {
107
			continue
108
		}
109 1
		columns[columnName] = struct{}{}
110
111 1
		if settings.VVerbose {
112
			fmt.Printf("\t\t> %v\r\n", column.Name)
113
		}
114
115 1
		columnType, col := mapDbColumnTypeToGoType(settings, db, column)
116
117
		// save that we saw types of columns at least once
118 1
		if !columnInfo.isTemporal {
119 1
			columnInfo.isTemporal = col.isTemporal
120
		}
121 1
		if !columnInfo.isNullableTemporal {
122 1
			columnInfo.isNullableTemporal = col.isNullableTemporal
123
		}
124 1
		if !columnInfo.isNullablePrimitive {
125 1
			columnInfo.isNullablePrimitive = col.isNullablePrimitive
126
		}
127
128 1
		structFields.WriteString(columnName)
129 1
		structFields.WriteString(" ")
130 1
		structFields.WriteString(columnType)
131 1
		structFields.WriteString(" ")
132 1
		structFields.WriteString(taggers.GenerateTag(db, column))
133 1
		structFields.WriteString("\n")
134
	}
135
136 1
	if settings.IsMastermindStructableRecorder {
137
		structFields.WriteString("\t\nstructable.Recorder\n")
138
	}
139
140 1
	var fileContent strings.Builder
141
142
	// write header infos
143 1
	fileContent.WriteString("package ")
144 1
	fileContent.WriteString(settings.PackageName)
145 1
	fileContent.WriteString("\n\n")
146
147
	// write imports
148 1
	generateImports(&fileContent, settings, db, columnInfo)
149
150 1
	tableName := strings.Title(settings.Prefix + table.Name + settings.Suffix)
151 1
	if settings.IsOutputFormatCamelCase() {
152 1
		tableName = camelCaseString(tableName)
153
	}
154
155
	// write struct with fields
156 1
	fileContent.WriteString("type ")
157 1
	fileContent.WriteString(tableName)
158 1
	fileContent.WriteString(" struct {\n")
159 1
	fileContent.WriteString(structFields.String())
160 1
	fileContent.WriteString("}")
161
162 1
	return tableName, fileContent.String()
163
}
164
165
func generateImports(content *strings.Builder, settings *settings.Settings, db database.Database, columnInfo columnInfo) {
166
167 1
	if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
168 1
		return
169
	}
170
171 1
	content.WriteString("import (\n")
172
173 1
	if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
174 1
		content.WriteString("\t\"database/sql\"\n")
175
	}
176
177 1
	if columnInfo.isTemporal {
178 1
		content.WriteString("\t\"time\"\n")
179
	}
180
181 1
	if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
182 1
		content.WriteString("\t\n")
183 1
		content.WriteString(db.GetDriverImportLibrary())
184 1
		content.WriteString("\n")
185
	}
186
187 1
	if settings.IsMastermindStructableRecorder {
188
		content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
189
	}
190
191 1
	content.WriteString(")\n\n")
192
}
193
194
func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
195 1
	if db.IsString(column) || db.IsText(column) {
196 1
		goType = "string"
197 1
		if db.IsNullable(column) {
198 1
			goType = getNullType(s, "*string", "sql.NullString")
199 1
			columnInfo.isNullable = true
200
		}
201 1
	} else if db.IsInteger(column) {
202 1
		goType = "int"
203 1
		if db.IsNullable(column) {
204 1
			goType = getNullType(s, "*int", "sql.NullInt64")
205 1
			columnInfo.isNullable = true
206
		}
207 1
	} else if db.IsFloat(column) {
208 1
		goType = "float64"
209 1
		if db.IsNullable(column) {
210 1
			goType = getNullType(s, "*float64", "sql.NullFloat64")
211 1
			columnInfo.isNullable = true
212
		}
213 1
	} else if db.IsTemporal(column) {
214 1
		if !db.IsNullable(column) {
215 1
			goType = "time.Time"
216 1
			columnInfo.isTemporal = true
217
		} else {
218 1
			goType = getNullType(s, "*time.Time", db.GetTemporalDriverDataType())
219 1
			columnInfo.isTemporal = s.Null == settings.NullTypeNative
220 1
			columnInfo.isNullableTemporal = true
221 1
			columnInfo.isNullable = true
222
		}
223
	} else {
224
		// TODO handle special data types
225 1
		switch column.DataType {
226
		case "boolean":
227 1
			goType = "bool"
228 1
			if db.IsNullable(column) {
229 1
				goType = getNullType(s, "*bool", "sql.NullBool")
230 1
				columnInfo.isNullable = true
231
			}
232
		default:
233
			goType = getNullType(s, "*string", "sql.NullString")
234
		}
235
	}
236
237 1
	columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)
238
239 1
	return goType, columnInfo
240
}
241
242
func getNullType(settings *settings.Settings, primitive string, sql string) string {
243 1
	if settings.IsNullTypeSQL() {
244 1
		return sql
245
	}
246 1
	return primitive
247
}
248
249
func camelCaseString(s string) string {
250 1
	if s == "" {
251 1
		return s
252
	}
253
254 1
	splitted := strings.Split(s, "_")
255
256 1
	if len(splitted) == 1 {
257 1
		return strings.Title(s)
258
	}
259
260 1
	var cc string
261 1
	for _, part := range splitted {
262 1
		cc += strings.Title(strings.ToLower(part))
263
	}
264 1
	return cc
265
}
266
267
func toInitialisms(s string) string {
268 1
	for _, substr := range initialisms {
269 1
		idx := indexCaseInsensitive(s, substr)
270 1
		if idx == -1 {
271 1
			continue
272
		}
273 1
		toReplace := s[idx : idx+len(substr)]
274 1
		s = strings.ReplaceAll(s, toReplace, substr)
275
	}
276 1
	return s
277
}
278
279
func indexCaseInsensitive(s, substr string) int {
280 1
	s, substr = strings.ToLower(s), strings.ToLower(substr)
281 1
	return strings.Index(s, substr)
282
}
283