Passed
Pull Request — master (#20)
by
unknown
01:54
created

cli.indexCaseInsensitive   A

Complexity

Conditions 1

Size

Total Lines 3
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 2
CRAP Score 1

Importance

Changes 0
Metric Value
cc 1
eloc 3
nop 2
dl 0
loc 3
ccs 2
cts 2
cp 1
crap 1
rs 10
c 0
b 0
f 0
1
package cli
2
3
import (
4
	"fmt"
5
	"strings"
6
7
	"github.com/fraenky8/tables-to-go/pkg/database"
8
	"github.com/fraenky8/tables-to-go/pkg/output"
9
	"github.com/fraenky8/tables-to-go/pkg/settings"
10
	"github.com/fraenky8/tables-to-go/pkg/tagger"
11
)
12
13
var (
14
	taggers tagger.Tagger
15
16
	// some strings for idiomatic go in column names
17
	// see https://github.com/golang/go/wiki/CodeReviewComments#initialisms
18
	initialisms = []string{"ID", "JSON", "XML", "HTTP", "URL"}
19
)
20
21
// Run runs the transformations by creating the concrete Database by the provided settings
22
func Run(settings *settings.Settings, db database.Database, out output.Writer) (err error) {
23
24 1
	taggers = tagger.NewTaggers(settings)
25
26 1
	fmt.Printf("running for %q...\r\n", settings.DbType)
27
28 1
	tables, err := db.GetTables()
29 1
	if err != nil {
30
		return fmt.Errorf("could not get tables: %v", err)
31
	}
32
33 1
	if settings.Verbose {
34
		fmt.Printf("> number of tables: %v\r\n", len(tables))
35
	}
36
37 1
	if err = db.PrepareGetColumnsOfTableStmt(); err != nil {
38
		return fmt.Errorf("could not prepare the get-column-statement: %v", err)
39
	}
40
41 1
	for _, table := range tables {
42
43 1
		if settings.Verbose {
44
			fmt.Printf("> processing table %q\r\n", table.Name)
45
		}
46
47 1
		if err = db.GetColumnsOfTable(table); err != nil {
48
			if !settings.Force {
49
				return fmt.Errorf("could not get columns of table %q: %v", table.Name, err)
50
			}
51
			fmt.Printf("could not get columns of table %q: %v\n", table.Name, err)
52
			continue
53
		}
54
55 1
		if settings.Verbose {
56
			fmt.Printf("\t> number of columns: %v\r\n", len(table.Columns))
57
		}
58
59 1
		tableName, content := createTableStructString(settings, db, table)
60
61 1
		err = out.Write(tableName, content)
62 1
		if err != nil {
63
			if !settings.Force {
64
				return fmt.Errorf("could not write struct for table %q: %v", table.Name, err)
65
			}
66
			fmt.Printf("could not write struct for table %q: %v\n", table.Name, err)
67
		}
68
	}
69
70 1
	fmt.Println("done!")
71
72 1
	return nil
73
}
74
75
type columnInfo struct {
76
	isNullable          bool
77
	isTemporal          bool
78
	isNullablePrimitive bool
79
	isNullableTemporal  bool
80
}
81
82
func (c columnInfo) hasTrue() bool {
83 1
	return c.isNullable || c.isTemporal || c.isNullableTemporal || c.isNullablePrimitive
84
}
85
86
func createTableStructString(settings *settings.Settings, db database.Database, table *database.Table) (string, string) {
87
88 1
	var structFields strings.Builder
89
90 1
	columnInfo := columnInfo{}
91 1
	columns := map[string]struct{}{}
92
93 1
	for _, column := range table.Columns {
94
95 1
		columnName := strings.Title(column.Name)
96 1
		if settings.IsOutputFormatCamelCase() {
97 1
			columnName = camelCaseString(column.Name)
98
		}
99 1
		if settings.ShouldInitialism() {
100 1
			columnName = toInitialisms(columnName)
101
		}
102
103
		// ISSUE-4: if columns are part of multiple constraints
104
		// then the sql returns multiple rows per column name.
105
		// Therefore we check if we already added a column with
106
		// that name to the struct, if so, skip.
107 1
		if _, ok := columns[columnName]; ok {
108
			continue
109
		}
110 1
		columns[columnName] = struct{}{}
111
112 1
		if settings.VVerbose {
113
			fmt.Printf("\t\t> %v\r\n", column.Name)
114
		}
115
116 1
		columnType, col := mapDbColumnTypeToGoType(settings, db, column)
117
118
		// save that we saw types of columns at least once
119 1
		if !columnInfo.isTemporal {
120 1
			columnInfo.isTemporal = col.isTemporal
121
		}
122 1
		if !columnInfo.isNullableTemporal {
123 1
			columnInfo.isNullableTemporal = col.isNullableTemporal
124
		}
125 1
		if !columnInfo.isNullablePrimitive {
126 1
			columnInfo.isNullablePrimitive = col.isNullablePrimitive
127
		}
128
129 1
		structFields.WriteString(columnName)
130 1
		structFields.WriteString(" ")
131 1
		structFields.WriteString(columnType)
132 1
		structFields.WriteString(" ")
133 1
		structFields.WriteString(taggers.GenerateTag(db, column))
134 1
		structFields.WriteString("\n")
135
	}
136
137 1
	if settings.IsMastermindStructableRecorder {
138
		structFields.WriteString("\t\nstructable.Recorder\n")
139
	}
140
141 1
	var fileContent strings.Builder
142
143
	// write header infos
144 1
	fileContent.WriteString("package ")
145 1
	fileContent.WriteString(settings.PackageName)
146 1
	fileContent.WriteString("\n\n")
147
148
	// write imports
149 1
	generateImports(&fileContent, settings, db, columnInfo)
150
151 1
	tableName := strings.Title(settings.Prefix + table.Name + settings.Suffix)
152 1
	if settings.IsOutputFormatCamelCase() {
153 1
		tableName = camelCaseString(tableName)
154
	}
155
156
	// write struct with fields
157 1
	fileContent.WriteString("type ")
158 1
	fileContent.WriteString(tableName)
159 1
	fileContent.WriteString(" struct {\n")
160 1
	fileContent.WriteString(structFields.String())
161 1
	fileContent.WriteString("}")
162
163 1
	return tableName, fileContent.String()
164
}
165
166
func generateImports(content *strings.Builder, settings *settings.Settings, db database.Database, columnInfo columnInfo) {
167
168 1
	if !columnInfo.hasTrue() && !settings.IsMastermindStructableRecorder {
169 1
		return
170
	}
171
172 1
	content.WriteString("import (\n")
173
174 1
	if columnInfo.isNullablePrimitive && settings.IsNullTypeSQL() {
175 1
		content.WriteString("\t\"database/sql\"\n")
176
	}
177
178 1
	if columnInfo.isTemporal {
179 1
		content.WriteString("\t\"time\"\n")
180
	}
181
182 1
	if columnInfo.isNullableTemporal && settings.IsNullTypeSQL() {
183 1
		content.WriteString("\t\n")
184 1
		content.WriteString(db.GetDriverImportLibrary())
185 1
		content.WriteString("\n")
186
	}
187
188 1
	if settings.IsMastermindStructableRecorder {
189
		content.WriteString("\t\n\"github.com/Masterminds/structable\"\n")
190
	}
191
192 1
	content.WriteString(")\n\n")
193
}
194
195
func mapDbColumnTypeToGoType(s *settings.Settings, db database.Database, column database.Column) (goType string, columnInfo columnInfo) {
196 1
	if db.IsString(column) || db.IsText(column) {
197 1
		goType = "string"
198 1
		if db.IsNullable(column) {
199 1
			goType = getNullType(s, "*string", "sql.NullString")
200 1
			columnInfo.isNullable = true
201
		}
202 1
	} else if db.IsInteger(column) {
203 1
		goType = "int"
204 1
		if db.IsNullable(column) {
205 1
			goType = getNullType(s, "*int", "sql.NullInt64")
206 1
			columnInfo.isNullable = true
207
		}
208 1
	} else if db.IsFloat(column) {
209 1
		goType = "float64"
210 1
		if db.IsNullable(column) {
211 1
			goType = getNullType(s, "*float64", "sql.NullFloat64")
212 1
			columnInfo.isNullable = true
213
		}
214 1
	} else if db.IsTemporal(column) {
215 1
		if !db.IsNullable(column) {
216 1
			goType = "time.Time"
217 1
			columnInfo.isTemporal = true
218
		} else {
219 1
			goType = getNullType(s, "*time.Time", db.GetTemporalDriverDataType())
220 1
			columnInfo.isTemporal = s.Null == settings.NullTypeNative
221 1
			columnInfo.isNullableTemporal = true
222 1
			columnInfo.isNullable = true
223
		}
224
	} else {
225
		// TODO handle special data types
226 1
		switch column.DataType {
227
		case "boolean":
228 1
			goType = "bool"
229 1
			if db.IsNullable(column) {
230 1
				goType = getNullType(s, "*bool", "sql.NullBool")
231 1
				columnInfo.isNullable = true
232
			}
233
		default:
234
			goType = getNullType(s, "*string", "sql.NullString")
235
		}
236
	}
237
238 1
	columnInfo.isNullablePrimitive = columnInfo.isNullable && !db.IsTemporal(column)
239
240 1
	return goType, columnInfo
241
}
242
243
func getNullType(settings *settings.Settings, primitive string, sql string) string {
244 1
	if settings.IsNullTypeSQL() {
245 1
		return sql
246
	}
247 1
	return primitive
248
}
249
250
func camelCaseString(s string) string {
251 1
	if s == "" {
252 1
		return s
253
	}
254
255 1
	splitted := strings.Split(s, "_")
256
257 1
	if len(splitted) == 1 {
258 1
		return strings.Title(s)
259
	}
260
261 1
	var cc string
262 1
	for _, part := range splitted {
263 1
		cc += strings.Title(strings.ToLower(part))
264
	}
265 1
	return cc
266
}
267
268
func toInitialisms(s string) string {
269 1
	for _, substr := range initialisms {
270 1
		idx := indexCaseInsensitive(s, substr)
271 1
		if idx == -1 {
272 1
			continue
273
		}
274 1
		toReplace := s[idx : idx+len(substr)]
275 1
		s = strings.ReplaceAll(s, toReplace, substr)
276
	}
277 1
	return s
278
}
279
280
func indexCaseInsensitive(s, substr string) int {
281 1
	s, substr = strings.ToLower(s), strings.ToLower(substr)
282 1
	return strings.Index(s, substr)
283
}
284