Passed
Pull Request — main (#5)
by Adriano
02:14
created

csv.*csvHandler.Close   A

Complexity

Conditions 4

Size

Total Lines 12
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 4
eloc 9
nop 0
dl 0
loc 12
rs 9.95
c 0
b 0
f 0
1
package csv
2
3
import (
4
	"adrianolaselva.github.io/csvql/pkg/filehandler"
5
	"adrianolaselva.github.io/csvql/pkg/storage"
6
	"bytes"
7
	"database/sql"
8
	"encoding/csv"
9
	"errors"
10
	"fmt"
11
	"github.com/schollz/progressbar/v3"
12
	"io"
13
	"os"
14
	"path/filepath"
15
	"strings"
16
	"sync"
17
)
18
19
const (
20
	bufferMaxLength = 32 * 1024
21
)
22
23
type csvHandler struct {
24
	mx          sync.Mutex
25
	bar         *progressbar.ProgressBar
26
	storage     storage.Storage
27
	files       []*os.File
28
	fileInputs  []string
29
	totalLines  int
30
	limitLines  int
31
	currentLine int
32
	delimiter   rune
33
}
34
35
func NewCsvHandler(fileInputs []string, delimiter rune, bar *progressbar.ProgressBar, storage storage.Storage, limitLines int) filehandler.FileHandler {
36
	return &csvHandler{fileInputs: fileInputs, delimiter: delimiter, storage: storage, bar: bar, limitLines: limitLines}
37
}
38
39
// Import import data
40
func (c *csvHandler) Import() error {
41
	if err := c.openFiles(); err != nil {
42
		return err
43
	}
44
45
	wg := new(sync.WaitGroup)
46
	wg.Add(len(c.fileInputs))
47
	errChannels := make(chan error, len(c.fileInputs))
48
49
	for _, file := range c.fileInputs {
50
		go func(wg *sync.WaitGroup, file string, errChan chan error) {
51
			defer wg.Done()
52
			err := c.loadTotalRows(file)
53
			errChan <- err
54
		}(wg, file, errChannels)
55
	}
56
57
	wg.Wait()
58
	if err := <-errChannels; err != nil {
59
		return err
60
	}
61
62
	if c.limitLines > 0 && c.totalLines > c.limitLines {
63
		c.totalLines = c.limitLines
64
	}
65
66
	wg.Add(len(c.files))
67
	errChannels = make(chan error, len(c.files))
68
	for _, file := range c.files {
69
		tableName := strings.ReplaceAll(strings.ToLower(filepath.Base(file.Name())), filepath.Ext(file.Name()), "")
70
		go func(wg *sync.WaitGroup, file *os.File, tableName string, errChan chan error) {
71
			defer wg.Done()
72
			errChan <- c.loadDataFromFile(tableName, file)
73
		}(wg, file, tableName, errChannels)
74
	}
75
76
	wg.Wait()
77
	if err := <-errChannels; err != nil {
78
		return err
79
	}
80
81
	return nil
82
}
83
84
// Query execute statements
85
func (c *csvHandler) Query(cmd string) (*sql.Rows, error) {
86
	rows, err := c.storage.Query(cmd)
87
	if err != nil {
88
		return nil, fmt.Errorf("failed to execute query: %w", err)
89
	}
90
91
	return rows, nil
92
}
93
94
// Lines return total lines
95
func (c *csvHandler) Lines() int {
96
	return c.totalLines
97
}
98
99
// Close execute in defer
100
func (c *csvHandler) Close() error {
101
	defer func(storage storage.Storage) {
102
		_ = storage.Close()
103
	}(c.storage)
104
105
	defer func(files []*os.File) {
106
		for _, file := range files {
107
			_ = file.Close()
108
		}
109
	}(c.files)
110
111
	return nil
112
}
113
114
// loadDataFromFile load data from file
115
func (c *csvHandler) loadDataFromFile(tableName string, file *os.File) error {
116
	c.mx.Lock()
117
	defer c.mx.Unlock()
118
119
	c.bar.ChangeMax(c.totalLines)
120
121
	r := csv.NewReader(file)
122
	r.Comma = c.delimiter
123
124
	columns, err := c.readHeader(tableName, r)
125
	if err != nil {
126
		return fmt.Errorf("failed to load headers and build structure: %w", err)
127
	}
128
129
	c.currentLine = 0
130
	for {
131
		err := c.readline(tableName, columns, r)
132
		if errors.Is(err, io.EOF) {
133
			break
134
		}
135
136
		if err != nil {
137
			return err
138
		}
139
	}
140
141
	return nil
142
}
143
144
// readHeader read header
145
func (c *csvHandler) readHeader(tableName string, r *csv.Reader) ([]string, error) {
146
	columns, err := r.Read()
147
	if err != nil {
148
		return nil, fmt.Errorf("failed to load headers: %w", err)
149
	}
150
151
	if err := c.storage.BuildStructure(tableName, columns); err != nil {
152
		return nil, fmt.Errorf("failed to load headers and build structure: %w", err)
153
	}
154
155
	return columns, nil
156
}
157
158
// readline read line
159
func (c *csvHandler) readline(tableName string, columns []string, r *csv.Reader) error {
160
	records, err := r.Read()
161
	if err != nil {
162
		return fmt.Errorf("failed to read line: %w", err)
163
	}
164
165
	if c.totalLines == c.currentLine {
166
		return io.EOF
167
	}
168
169
	_ = c.bar.Add(1)
170
	c.currentLine++
171
172
	if err := c.storage.InsertRow(tableName, columns, c.convertToAnyArray(records)); err != nil {
173
		return fmt.Errorf("failed to process row number %d: %w", c.currentLine, err)
174
	}
175
176
	return nil
177
}
178
179
// convertToAnyArray convert string array to any array
180
func (c *csvHandler) convertToAnyArray(records []string) []any {
181
	values := make([]any, 0, len(records))
182
	for _, r := range records {
183
		values = append(values, r)
184
	}
185
186
	return values
187
}
188
189
// openFile open file
190
func (c *csvHandler) openFiles() error {
191
	wg := new(sync.WaitGroup)
192
	wg.Add(len(c.fileInputs))
193
	errChannels := make(chan error, len(c.fileInputs))
194
195
	for _, file := range c.fileInputs {
196
		go func(wg *sync.WaitGroup, file string, errChan chan error) {
197
			defer wg.Done()
198
199
			f, err := os.Open(file)
200
			if err != nil {
201
				errChan <- fmt.Errorf("failed to open file: %w", err)
202
				return
203
			}
204
205
			c.files = append(c.files, f)
206
			errChan <- nil
207
		}(wg, file, errChannels)
208
	}
209
210
	wg.Wait()
211
	if err := <-errChannels; err != nil {
212
		return fmt.Errorf("failed to open file: %w", err)
213
	}
214
215
	return nil
216
}
217
218
// loadTotalRows load total rows in file
219
func (c *csvHandler) loadTotalRows(file string) error {
220
	r, err := os.Open(file)
221
	if err != nil {
222
		return fmt.Errorf("failed to open file %s: %w", file, err)
223
	}
224
	defer func(r *os.File) {
225
		_ = r.Close()
226
	}(r)
227
228
	buf := make([]byte, bufferMaxLength)
229
	c.totalLines = 0
230
	lineSep := []byte{'\n'}
231
232
	for {
233
		r, err := r.Read(buf)
234
		c.totalLines += bytes.Count(buf[:r], lineSep)
235
236
		switch {
237
		case err == io.EOF:
238
			return nil
239
240
		case err != nil:
241
			return fmt.Errorf("failed to totalize rows: %w", err)
242
		}
243
	}
244
}
245