Test Failed
Push — main ( 5c6504...8c852d )
by Adriano
01:51
created

pkg/csvql/main.go   B

Size/Duplication

Total Lines 284
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
cc 46
eloc 174
dl 0
loc 284
rs 8.72
c 0
b 0
f 0

11 Methods

Rating   Name   Duplication   Size   Complexity  
A csvql.New 0 20 2
B csvql.*csvql.printResult 0 32 6
B csvql.*csvql.initializePrompt 0 36 8
A csvql.*csvql.buildTable 0 16 4
B csvql.*csvql.Run 0 24 6
A csvql.*csvql.openConnection 0 9 2
B csvql.*csvql.loadDataFromFile 0 34 7
A csvql.*csvql.buildInsert 0 12 2
A csvql.*csvql.execute 0 8 2
A csvql.*csvql.loadTotalRows 0 21 5
A csvql.*csvql.openFile 0 9 2
1
package csvql
2
3
import (
4
	"bytes"
5
	"database/sql"
6
	"encoding/csv"
7
	"fmt"
8
	"github.com/chzyer/readline"
9
	"github.com/fatih/color"
10
	_ "github.com/mattn/go-sqlite3"
11
	"github.com/rodaine/table"
12
	"github.com/schollz/progressbar/v3"
13
	"io"
14
	"os"
15
	"strings"
16
)
17
18
const (
19
	cliPrompt              = "csvql> "
20
	cliInterruptPrompt     = "^C"
21
	cliEOFPrompt           = "exit"
22
	dataSourceNameDefault  = ":memory:"
23
	sqlCreateTableTemplate = "CREATE TABLE rows (%s\n);"
24
	sqlInsertTemplate      = "INSERT INTO rows (%s) VALUES (%s);"
25
)
26
27
type Csvql interface {
28
	Run() error
29
}
30
31
type csvql struct {
32
	db      *sql.DB
33
	file    *os.File
34
	bar     *progressbar.ProgressBar
35
	params  CsvqlParams
36
	columns []string
37
	lines   int
38
}
39
40
func New(params CsvqlParams) Csvql {
41
	bar := progressbar.NewOptions(0,
42
		progressbar.OptionSetWriter(os.Stdout),
43
		progressbar.OptionEnableColorCodes(true),
44
		progressbar.OptionShowBytes(true),
45
		progressbar.OptionFullWidth(),
46
		progressbar.OptionSetDescription("[cyan][1/1][reset] loading data..."),
47
		progressbar.OptionSetTheme(progressbar.Theme{
48
			Saucer:        "[green]=[reset]",
49
			SaucerHead:    "[green]>[reset]",
50
			SaucerPadding: " ",
51
			BarStart:      "[",
52
			BarEnd:        "]",
53
		}))
54
55
	if params.DataSourceName == "" {
56
		params.DataSourceName = dataSourceNameDefault
57
	}
58
59
	return &csvql{params: params, bar: bar}
60
}
61
62
func (c *csvql) Run() error {
63
	if err := c.loadTotalRows(); err != nil {
64
		return err
65
	}
66
67
	if err := c.openFile(); err != nil {
68
		return err
69
	}
70
	defer c.file.Close()
71
72
	if err := c.openConnection(); err != nil {
73
		return err
74
	}
75
	defer c.db.Close()
76
77
	if err := c.loadDataFromFile(); err != nil {
78
		return err
79
	}
80
81
	if err := c.initializePrompt(); err != nil {
82
		return err
83
	}
84
85
	return nil
86
}
87
88
func (c *csvql) initializePrompt() error {
89
	l, err := readline.NewEx(&readline.Config{
90
		Prompt:          cliPrompt,
91
		InterruptPrompt: cliInterruptPrompt,
92
		EOFPrompt:       cliEOFPrompt,
93
		AutoComplete: readline.SegmentFunc(func(i [][]rune, i2 int) [][]rune {
94
			return nil
95
		}),
96
	})
97
	if err != nil {
98
		return err
99
	}
100
101
	l.CaptureExitSignal()
102
103
	for {
104
		line, err := l.Readline()
105
		if err == readline.ErrInterrupt {
106
			if len(line) == 0 {
107
				break
108
			}
109
110
			continue
111
		}
112
113
		if err == io.EOF {
114
			break
115
		}
116
117
		line = strings.TrimSpace(line)
118
		if err := c.execute(line); err != nil {
119
			fmt.Fprintf(os.Stderr, "%s\n", err.Error())
120
		}
121
	}
122
123
	return nil
124
}
125
126
func (c *csvql) execute(line string) error {
127
	rows, err := c.db.Query(line)
128
	if err != nil {
129
		return err
130
	}
131
	defer rows.Close()
132
133
	return c.printResult(rows)
134
}
135
136
func (c *csvql) printResult(rows *sql.Rows) error {
137
	columns, err := rows.Columns()
138
	if err != nil {
139
		return err
140
	}
141
142
	cols := make([]interface{}, 0)
143
	for _, c := range columns {
144
		cols = append(cols, c)
145
	}
146
147
	tbl := table.New(cols...).
148
		WithHeaderFormatter(color.New(color.FgGreen, color.Underline).SprintfFunc()).
149
		WithFirstColumnFormatter(color.New(color.FgYellow).SprintfFunc())
150
151
	for rows.Next() {
152
		values := make([]interface{}, len(columns))
153
		pointers := make([]interface{}, len(columns))
154
		for i := range values {
155
			pointers[i] = &values[i]
156
		}
157
158
		if err := rows.Scan(pointers...); err != nil {
159
			return err
160
		}
161
162
		tbl.AddRow(values...)
163
	}
164
165
	tbl.Print()
166
167
	return nil
168
}
169
170
func (c *csvql) openConnection() error {
171
	db, err := sql.Open("sqlite3", c.params.DataSourceName)
172
	if err != nil {
173
		return err
174
	}
175
176
	c.db = db
177
178
	return nil
179
}
180
181
func (c *csvql) openFile() error {
182
	f, err := os.Open(c.params.FileInput)
183
	if err != nil {
184
		return err
185
	}
186
187
	c.file = f
188
189
	return nil
190
}
191
192
func (c *csvql) loadTotalRows() error {
193
	r, err := os.Open(c.params.FileInput)
194
	if err != nil {
195
		return err
196
	}
197
	defer r.Close()
198
199
	buf := make([]byte, 32*1024)
200
	c.lines = 0
201
	lineSep := []byte{'\n'}
202
203
	for {
204
		r, err := r.Read(buf)
205
		c.lines += bytes.Count(buf[:r], lineSep)
206
207
		switch {
208
		case err == io.EOF:
209
			return nil
210
211
		case err != nil:
212
			return err
213
		}
214
	}
215
}
216
217
func (c *csvql) loadDataFromFile() error {
218
	c.bar.ChangeMax(c.lines)
219
	defer c.bar.Finish()
220
221
	r := csv.NewReader(c.file)
222
	r.Comma = c.params.Comma
223
224
	headers, err := r.Read()
225
	if err != nil {
226
		return err
227
	}
228
229
	c.columns = headers
230
	if err := c.buildTable(); err != nil {
231
		return err
232
	}
233
234
	for {
235
		records, err := r.Read()
236
		if err == io.EOF {
237
			break
238
		}
239
240
		var values []any
241
		for _, r := range records {
242
			values = append(values, r)
243
		}
244
245
		if err := c.buildInsert(values); err != nil {
246
			return err
247
		}
248
	}
249
250
	return nil
251
}
252
253
// build table creation statement
254
func (c *csvql) buildTable() error {
255
	defer c.bar.Add(1)
256
257
	var tableAttrsRaw strings.Builder
258
	for ln, v := range c.columns {
259
		tableAttrsRaw.WriteString(fmt.Sprintf("\n\t%s text", v))
260
		if len(c.columns)-1 > ln {
261
			tableAttrsRaw.WriteString(",")
262
		}
263
	}
264
265
	if _, err := c.db.Exec(fmt.Sprintf(sqlCreateTableTemplate, tableAttrsRaw.String())); err != nil {
266
		return err
267
	}
268
269
	return nil
270
}
271
272
// build insert create statement
273
func (c *csvql) buildInsert(values []any) error {
274
	defer c.bar.Add(1)
275
276
	columnsRaw := strings.Join(c.columns, ", ")
277
	paramsRaw := strings.Repeat("?, ", len(c.columns))
278
	insertRaw := fmt.Sprintf(sqlInsertTemplate, columnsRaw, paramsRaw[:len(paramsRaw)-2])
279
280
	if _, err := c.db.Exec(insertRaw, values...); err != nil {
281
		return err
282
	}
283
284
	return nil
285
}
286