Passed
Pull Request — master (#23)
by Frank
06:05 queued 03:00
created

sqlx.*Row.Err   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 0
dl 0
loc 2
rs 10
c 0
b 0
f 0
1
package sqlx
2
3
import (
4
	"database/sql"
5
	"database/sql/driver"
6
	"errors"
7
	"fmt"
8
9
	"io/ioutil"
10
	"path/filepath"
11
	"reflect"
12
	"strings"
13
	"sync"
14
15
	"github.com/jmoiron/sqlx/reflectx"
16
)
17
18
// Although the NameMapper is convenient, in practice it should not
19
// be relied on except for application code.  If you are writing a library
20
// that uses sqlx, you should be aware that the name mappings you expect
21
// can be overridden by your user's application.
22
23
// NameMapper is used to map column names to struct field names.  By default,
24
// it uses strings.ToLower to lowercase struct field names.  It can be set
25
// to whatever you want, but it is encouraged to be set before sqlx is used
26
// as name-to-field mappings are cached after first use on a type.
27
var NameMapper = strings.ToLower
28
var origMapper = reflect.ValueOf(NameMapper)
29
30
// Rather than creating on init, this is created when necessary so that
31
// importers have time to customize the NameMapper.
32
var mpr *reflectx.Mapper
33
34
// mprMu protects mpr.
35
var mprMu sync.Mutex
36
37
// mapper returns a valid mapper using the configured NameMapper func.
38
func mapper() *reflectx.Mapper {
39
	mprMu.Lock()
40
	defer mprMu.Unlock()
41
42
	if mpr == nil {
43
		mpr = reflectx.NewMapperFunc("db", NameMapper)
44
	} else if origMapper != reflect.ValueOf(NameMapper) {
45
		// if NameMapper has changed, create a new mapper
46
		mpr = reflectx.NewMapperFunc("db", NameMapper)
47
		origMapper = reflect.ValueOf(NameMapper)
48
	}
49
	return mpr
50
}
51
52
// isScannable takes the reflect.Type and the actual dest value and returns
53
// whether or not it's Scannable.  Something is scannable if:
54
//   * it is not a struct
55
//   * it implements sql.Scanner
56
//   * it has no exported fields
57
func isScannable(t reflect.Type) bool {
58
	if reflect.PtrTo(t).Implements(_scannerInterface) {
59
		return true
60
	}
61
	if t.Kind() != reflect.Struct {
62
		return true
63
	}
64
65
	// it's not important that we use the right mapper for this particular object,
66
	// we're only concerned on how many exported fields this struct has
67
	m := mapper()
68
	if len(m.TypeMap(t).Index) == 0 {
69
		return true
70
	}
71
	return false
72
}
73
74
// ColScanner is an interface used by MapScan and SliceScan
75
type ColScanner interface {
76
	Columns() ([]string, error)
77
	Scan(dest ...interface{}) error
78
	Err() error
79
}
80
81
// Queryer is an interface used by Get and Select
82
type Queryer interface {
83
	Query(query string, args ...interface{}) (*sql.Rows, error)
84
	Queryx(query string, args ...interface{}) (*Rows, error)
85
	QueryRowx(query string, args ...interface{}) *Row
86
}
87
88
// Execer is an interface used by MustExec and LoadFile
89
type Execer interface {
90
	Exec(query string, args ...interface{}) (sql.Result, error)
91
}
92
93
// Binder is an interface for something which can bind queries (Tx, DB)
94
type binder interface {
95
	DriverName() string
96
	Rebind(string) string
97
	BindNamed(string, interface{}) (string, []interface{}, error)
98
}
99
100
// Ext is a union interface which can bind, query, and exec, used by
101
// NamedQuery and NamedExec.
102
type Ext interface {
103
	binder
104
	Queryer
105
	Execer
106
}
107
108
// Preparer is an interface used by Preparex.
109
type Preparer interface {
110
	Prepare(query string) (*sql.Stmt, error)
111
}
112
113
// determine if any of our extensions are unsafe
114
func isUnsafe(i interface{}) bool {
115
	switch v := i.(type) {
116
	case Row:
117
		return v.unsafe
118
	case *Row:
119
		return v.unsafe
120
	case Rows:
121
		return v.unsafe
122
	case *Rows:
123
		return v.unsafe
124
	case NamedStmt:
125
		return v.Stmt.unsafe
126
	case *NamedStmt:
127
		return v.Stmt.unsafe
128
	case Stmt:
129
		return v.unsafe
130
	case *Stmt:
131
		return v.unsafe
132
	case qStmt:
133
		return v.unsafe
134
	case *qStmt:
135
		return v.unsafe
136
	case DB:
137
		return v.unsafe
138
	case *DB:
139
		return v.unsafe
140
	case Tx:
141
		return v.unsafe
142
	case *Tx:
143
		return v.unsafe
144
	case sql.Rows, *sql.Rows:
145
		return false
146
	default:
147
		return false
148
	}
149
}
150
151
func mapperFor(i interface{}) *reflectx.Mapper {
152
	switch i.(type) {
153
	case DB:
154
		return i.(DB).Mapper
155
	case *DB:
156
		return i.(*DB).Mapper
157
	case Tx:
158
		return i.(Tx).Mapper
159
	case *Tx:
160
		return i.(*Tx).Mapper
161
	default:
162
		return mapper()
163
	}
164
}
165
166
var _scannerInterface = reflect.TypeOf((*sql.Scanner)(nil)).Elem()
167
var _valuerInterface = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
168
169
// Row is a reimplementation of sql.Row in order to gain access to the underlying
170
// sql.Rows.Columns() data, necessary for StructScan.
171
type Row struct {
172
	err    error
173
	unsafe bool
174
	rows   *sql.Rows
175
	Mapper *reflectx.Mapper
176
}
177
178
// Scan is a fixed implementation of sql.Row.Scan, which does not discard the
179
// underlying error from the internal rows object if it exists.
180
func (r *Row) Scan(dest ...interface{}) error {
181
	if r.err != nil {
182
		return r.err
183
	}
184
185
	// TODO(bradfitz): for now we need to defensively clone all
186
	// []byte that the driver returned (not permitting
187
	// *RawBytes in Rows.Scan), since we're about to close
188
	// the Rows in our defer, when we return from this function.
189
	// the contract with the driver.Next(...) interface is that it
190
	// can return slices into read-only temporary memory that's
191
	// only valid until the next Scan/Close.  But the TODO is that
192
	// for a lot of drivers, this copy will be unnecessary.  We
193
	// should provide an optional interface for drivers to
194
	// implement to say, "don't worry, the []bytes that I return
195
	// from Next will not be modified again." (for instance, if
196
	// they were obtained from the network anyway) But for now we
197
	// don't care.
198
	defer r.rows.Close()
199
	for _, dp := range dest {
200
		if _, ok := dp.(*sql.RawBytes); ok {
201
			return errors.New("sql: RawBytes isn't allowed on Row.Scan")
202
		}
203
	}
204
205
	if !r.rows.Next() {
206
		if err := r.rows.Err(); err != nil {
207
			return err
208
		}
209
		return sql.ErrNoRows
210
	}
211
	err := r.rows.Scan(dest...)
212
	if err != nil {
213
		return err
214
	}
215
	// Make sure the query can be processed to completion with no errors.
216
	if err := r.rows.Close(); err != nil {
217
		return err
218
	}
219
	return nil
220
}
221
222
// Columns returns the underlying sql.Rows.Columns(), or the deferred error usually
223
// returned by Row.Scan()
224
func (r *Row) Columns() ([]string, error) {
225
	if r.err != nil {
226
		return []string{}, r.err
227
	}
228
	return r.rows.Columns()
229
}
230
231
// ColumnTypes returns the underlying sql.Rows.ColumnTypes(), or the deferred error
232
func (r *Row) ColumnTypes() ([]*sql.ColumnType, error) {
233
	if r.err != nil {
234
		return []*sql.ColumnType{}, r.err
235
	}
236
	return r.rows.ColumnTypes()
237
}
238
239
// Err returns the error encountered while scanning.
240
func (r *Row) Err() error {
241
	return r.err
242
}
243
244
// DB is a wrapper around sql.DB which keeps track of the driverName upon Open,
245
// used mostly to automatically bind named queries using the right bindvars.
246
type DB struct {
247
	*sql.DB
248
	driverName string
249
	unsafe     bool
250
	Mapper     *reflectx.Mapper
251
}
252
253
// NewDb returns a new sqlx DB wrapper for a pre-existing *sql.DB.  The
254
// driverName of the original database is required for named query support.
255
func NewDb(db *sql.DB, driverName string) *DB {
256
	return &DB{DB: db, driverName: driverName, Mapper: mapper()}
257
}
258
259
// DriverName returns the driverName passed to the Open function for this DB.
260
func (db *DB) DriverName() string {
261
	return db.driverName
262
}
263
264
// Open is the same as sql.Open, but returns an *sqlx.DB instead.
265
func Open(driverName, dataSourceName string) (*DB, error) {
266
	db, err := sql.Open(driverName, dataSourceName)
267
	if err != nil {
268
		return nil, err
269
	}
270
	return &DB{DB: db, driverName: driverName, Mapper: mapper()}, err
271
}
272
273
// MustOpen is the same as sql.Open, but returns an *sqlx.DB instead and panics on error.
274
func MustOpen(driverName, dataSourceName string) *DB {
275
	db, err := Open(driverName, dataSourceName)
276
	if err != nil {
277
		panic(err)
278
	}
279
	return db
280
}
281
282
// MapperFunc sets a new mapper for this db using the default sqlx struct tag
283
// and the provided mapper function.
284
func (db *DB) MapperFunc(mf func(string) string) {
285
	db.Mapper = reflectx.NewMapperFunc("db", mf)
286
}
287
288
// Rebind transforms a query from QUESTION to the DB driver's bindvar type.
289
func (db *DB) Rebind(query string) string {
290
	return Rebind(BindType(db.driverName), query)
291
}
292
293
// Unsafe returns a version of DB which will silently succeed to scan when
294
// columns in the SQL result have no fields in the destination struct.
295
// sqlx.Stmt and sqlx.Tx which are created from this DB will inherit its
296
// safety behavior.
297
func (db *DB) Unsafe() *DB {
298
	return &DB{DB: db.DB, driverName: db.driverName, unsafe: true, Mapper: db.Mapper}
299
}
300
301
// BindNamed binds a query using the DB driver's bindvar type.
302
func (db *DB) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
303
	return bindNamedMapper(BindType(db.driverName), query, arg, db.Mapper)
304
}
305
306
// NamedQuery using this DB.
307
// Any named placeholder parameters are replaced with fields from arg.
308
func (db *DB) NamedQuery(query string, arg interface{}) (*Rows, error) {
309
	return NamedQuery(db, query, arg)
310
}
311
312
// NamedExec using this DB.
313
// Any named placeholder parameters are replaced with fields from arg.
314
func (db *DB) NamedExec(query string, arg interface{}) (sql.Result, error) {
315
	return NamedExec(db, query, arg)
316
}
317
318
// Select using this DB.
319
// Any placeholder parameters are replaced with supplied args.
320
func (db *DB) Select(dest interface{}, query string, args ...interface{}) error {
321
	return Select(db, dest, query, args...)
322
}
323
324
// Get using this DB.
325
// Any placeholder parameters are replaced with supplied args.
326
// An error is returned if the result set is empty.
327
func (db *DB) Get(dest interface{}, query string, args ...interface{}) error {
328
	return Get(db, dest, query, args...)
329
}
330
331
// MustBegin starts a transaction, and panics on error.  Returns an *sqlx.Tx instead
332
// of an *sql.Tx.
333
func (db *DB) MustBegin() *Tx {
334
	tx, err := db.Beginx()
335
	if err != nil {
336
		panic(err)
337
	}
338
	return tx
339
}
340
341
// Beginx begins a transaction and returns an *sqlx.Tx instead of an *sql.Tx.
342
func (db *DB) Beginx() (*Tx, error) {
343
	tx, err := db.DB.Begin()
344
	if err != nil {
345
		return nil, err
346
	}
347
	return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
348
}
349
350
// Queryx queries the database and returns an *sqlx.Rows.
351
// Any placeholder parameters are replaced with supplied args.
352
func (db *DB) Queryx(query string, args ...interface{}) (*Rows, error) {
353
	r, err := db.DB.Query(query, args...)
354
	if err != nil {
355
		return nil, err
356
	}
357
	return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
358
}
359
360
// QueryRowx queries the database and returns an *sqlx.Row.
361
// Any placeholder parameters are replaced with supplied args.
362
func (db *DB) QueryRowx(query string, args ...interface{}) *Row {
363
	rows, err := db.DB.Query(query, args...)
364
	return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
365
}
366
367
// MustExec (panic) runs MustExec using this database.
368
// Any placeholder parameters are replaced with supplied args.
369
func (db *DB) MustExec(query string, args ...interface{}) sql.Result {
370
	return MustExec(db, query, args...)
371
}
372
373
// Preparex returns an sqlx.Stmt instead of a sql.Stmt
374
func (db *DB) Preparex(query string) (*Stmt, error) {
375
	return Preparex(db, query)
376
}
377
378
// PrepareNamed returns an sqlx.NamedStmt
379
func (db *DB) PrepareNamed(query string) (*NamedStmt, error) {
380
	return prepareNamed(db, query)
381
}
382
383
// Tx is an sqlx wrapper around sql.Tx with extra functionality
384
type Tx struct {
385
	*sql.Tx
386
	driverName string
387
	unsafe     bool
388
	Mapper     *reflectx.Mapper
389
}
390
391
// DriverName returns the driverName used by the DB which began this transaction.
392
func (tx *Tx) DriverName() string {
393
	return tx.driverName
394
}
395
396
// Rebind a query within a transaction's bindvar type.
397
func (tx *Tx) Rebind(query string) string {
398
	return Rebind(BindType(tx.driverName), query)
399
}
400
401
// Unsafe returns a version of Tx which will silently succeed to scan when
402
// columns in the SQL result have no fields in the destination struct.
403
func (tx *Tx) Unsafe() *Tx {
404
	return &Tx{Tx: tx.Tx, driverName: tx.driverName, unsafe: true, Mapper: tx.Mapper}
405
}
406
407
// BindNamed binds a query within a transaction's bindvar type.
408
func (tx *Tx) BindNamed(query string, arg interface{}) (string, []interface{}, error) {
409
	return bindNamedMapper(BindType(tx.driverName), query, arg, tx.Mapper)
410
}
411
412
// NamedQuery within a transaction.
413
// Any named placeholder parameters are replaced with fields from arg.
414
func (tx *Tx) NamedQuery(query string, arg interface{}) (*Rows, error) {
415
	return NamedQuery(tx, query, arg)
416
}
417
418
// NamedExec a named query within a transaction.
419
// Any named placeholder parameters are replaced with fields from arg.
420
func (tx *Tx) NamedExec(query string, arg interface{}) (sql.Result, error) {
421
	return NamedExec(tx, query, arg)
422
}
423
424
// Select within a transaction.
425
// Any placeholder parameters are replaced with supplied args.
426
func (tx *Tx) Select(dest interface{}, query string, args ...interface{}) error {
427
	return Select(tx, dest, query, args...)
428
}
429
430
// Queryx within a transaction.
431
// Any placeholder parameters are replaced with supplied args.
432
func (tx *Tx) Queryx(query string, args ...interface{}) (*Rows, error) {
433
	r, err := tx.Tx.Query(query, args...)
434
	if err != nil {
435
		return nil, err
436
	}
437
	return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
438
}
439
440
// QueryRowx within a transaction.
441
// Any placeholder parameters are replaced with supplied args.
442
func (tx *Tx) QueryRowx(query string, args ...interface{}) *Row {
443
	rows, err := tx.Tx.Query(query, args...)
444
	return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
445
}
446
447
// Get within a transaction.
448
// Any placeholder parameters are replaced with supplied args.
449
// An error is returned if the result set is empty.
450
func (tx *Tx) Get(dest interface{}, query string, args ...interface{}) error {
451
	return Get(tx, dest, query, args...)
452
}
453
454
// MustExec runs MustExec within a transaction.
455
// Any placeholder parameters are replaced with supplied args.
456
func (tx *Tx) MustExec(query string, args ...interface{}) sql.Result {
457
	return MustExec(tx, query, args...)
458
}
459
460
// Preparex  a statement within a transaction.
461
func (tx *Tx) Preparex(query string) (*Stmt, error) {
462
	return Preparex(tx, query)
463
}
464
465
// Stmtx returns a version of the prepared statement which runs within a transaction.  Provided
466
// stmt can be either *sql.Stmt or *sqlx.Stmt.
467
func (tx *Tx) Stmtx(stmt interface{}) *Stmt {
468
	var s *sql.Stmt
469
	switch v := stmt.(type) {
470
	case Stmt:
471
		s = v.Stmt
472
	case *Stmt:
473
		s = v.Stmt
474
	case *sql.Stmt:
475
		s = v
476
	default:
477
		panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
478
	}
479
	return &Stmt{Stmt: tx.Stmt(s), Mapper: tx.Mapper}
480
}
481
482
// NamedStmt returns a version of the prepared statement which runs within a transaction.
483
func (tx *Tx) NamedStmt(stmt *NamedStmt) *NamedStmt {
484
	return &NamedStmt{
485
		QueryString: stmt.QueryString,
486
		Params:      stmt.Params,
487
		Stmt:        tx.Stmtx(stmt.Stmt),
488
	}
489
}
490
491
// PrepareNamed returns an sqlx.NamedStmt
492
func (tx *Tx) PrepareNamed(query string) (*NamedStmt, error) {
493
	return prepareNamed(tx, query)
494
}
495
496
// Stmt is an sqlx wrapper around sql.Stmt with extra functionality
497
type Stmt struct {
498
	*sql.Stmt
499
	unsafe bool
500
	Mapper *reflectx.Mapper
501
}
502
503
// Unsafe returns a version of Stmt which will silently succeed to scan when
504
// columns in the SQL result have no fields in the destination struct.
505
func (s *Stmt) Unsafe() *Stmt {
506
	return &Stmt{Stmt: s.Stmt, unsafe: true, Mapper: s.Mapper}
507
}
508
509
// Select using the prepared statement.
510
// Any placeholder parameters are replaced with supplied args.
511
func (s *Stmt) Select(dest interface{}, args ...interface{}) error {
512
	return Select(&qStmt{s}, dest, "", args...)
513
}
514
515
// Get using the prepared statement.
516
// Any placeholder parameters are replaced with supplied args.
517
// An error is returned if the result set is empty.
518
func (s *Stmt) Get(dest interface{}, args ...interface{}) error {
519
	return Get(&qStmt{s}, dest, "", args...)
520
}
521
522
// MustExec (panic) using this statement.  Note that the query portion of the error
523
// output will be blank, as Stmt does not expose its query.
524
// Any placeholder parameters are replaced with supplied args.
525
func (s *Stmt) MustExec(args ...interface{}) sql.Result {
526
	return MustExec(&qStmt{s}, "", args...)
527
}
528
529
// QueryRowx using this statement.
530
// Any placeholder parameters are replaced with supplied args.
531
func (s *Stmt) QueryRowx(args ...interface{}) *Row {
532
	qs := &qStmt{s}
533
	return qs.QueryRowx("", args...)
534
}
535
536
// Queryx using this statement.
537
// Any placeholder parameters are replaced with supplied args.
538
func (s *Stmt) Queryx(args ...interface{}) (*Rows, error) {
539
	qs := &qStmt{s}
540
	return qs.Queryx("", args...)
541
}
542
543
// qStmt is an unexposed wrapper which lets you use a Stmt as a Queryer & Execer by
544
// implementing those interfaces and ignoring the `query` argument.
545
type qStmt struct{ *Stmt }
546
547
func (q *qStmt) Query(query string, args ...interface{}) (*sql.Rows, error) {
548
	return q.Stmt.Query(args...)
549
}
550
551
func (q *qStmt) Queryx(query string, args ...interface{}) (*Rows, error) {
552
	r, err := q.Stmt.Query(args...)
553
	if err != nil {
554
		return nil, err
555
	}
556
	return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
557
}
558
559
func (q *qStmt) QueryRowx(query string, args ...interface{}) *Row {
560
	rows, err := q.Stmt.Query(args...)
561
	return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
562
}
563
564
func (q *qStmt) Exec(query string, args ...interface{}) (sql.Result, error) {
565
	return q.Stmt.Exec(args...)
566
}
567
568
// Rows is a wrapper around sql.Rows which caches costly reflect operations
569
// during a looped StructScan
570
type Rows struct {
571
	*sql.Rows
572
	unsafe bool
573
	Mapper *reflectx.Mapper
574
	// these fields cache memory use for a rows during iteration w/ structScan
575
	started bool
576
	fields  [][]int
577
	values  []interface{}
578
}
579
580
// SliceScan using this Rows.
581
func (r *Rows) SliceScan() ([]interface{}, error) {
582
	return SliceScan(r)
583
}
584
585
// MapScan using this Rows.
586
func (r *Rows) MapScan(dest map[string]interface{}) error {
587
	return MapScan(r, dest)
588
}
589
590
// StructScan is like sql.Rows.Scan, but scans a single Row into a single Struct.
591
// Use this and iterate over Rows manually when the memory load of Select() might be
592
// prohibitive.  *Rows.StructScan caches the reflect work of matching up column
593
// positions to fields to avoid that overhead per scan, which means it is not safe
594
// to run StructScan on the same Rows instance with different struct types.
595
func (r *Rows) StructScan(dest interface{}) error {
596
	v := reflect.ValueOf(dest)
597
598
	if v.Kind() != reflect.Ptr {
599
		return errors.New("must pass a pointer, not a value, to StructScan destination")
600
	}
601
602
	v = v.Elem()
603
604
	if !r.started {
605
		columns, err := r.Columns()
606
		if err != nil {
607
			return err
608
		}
609
		m := r.Mapper
610
611
		r.fields = m.TraversalsByName(v.Type(), columns)
612
		// if we are not unsafe and are missing fields, return an error
613
		if f, err := missingFields(r.fields); err != nil && !r.unsafe {
614
			return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
615
		}
616
		r.values = make([]interface{}, len(columns))
617
		r.started = true
618
	}
619
620
	err := fieldsByTraversal(v, r.fields, r.values, true)
621
	if err != nil {
622
		return err
623
	}
624
	// scan into the struct field pointers and append to our results
625
	err = r.Scan(r.values...)
626
	if err != nil {
627
		return err
628
	}
629
	return r.Err()
630
}
631
632
// Connect to a database and verify with a ping.
633
func Connect(driverName, dataSourceName string) (*DB, error) {
634
	db, err := Open(driverName, dataSourceName)
635
	if err != nil {
636
		return nil, err
637
	}
638
	err = db.Ping()
639
	if err != nil {
640
		db.Close()
641
		return nil, err
642
	}
643
	return db, nil
644
}
645
646
// MustConnect connects to a database and panics on error.
647
func MustConnect(driverName, dataSourceName string) *DB {
648
	db, err := Connect(driverName, dataSourceName)
649
	if err != nil {
650
		panic(err)
651
	}
652
	return db
653
}
654
655
// Preparex prepares a statement.
656
func Preparex(p Preparer, query string) (*Stmt, error) {
657
	s, err := p.Prepare(query)
658
	if err != nil {
659
		return nil, err
660
	}
661
	return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
662
}
663
664
// Select executes a query using the provided Queryer, and StructScans each row
665
// into dest, which must be a slice.  If the slice elements are scannable, then
666
// the result set must have only one column.  Otherwise, StructScan is used.
667
// The *sql.Rows are closed automatically.
668
// Any placeholder parameters are replaced with supplied args.
669
func Select(q Queryer, dest interface{}, query string, args ...interface{}) error {
670
	rows, err := q.Queryx(query, args...)
671
	if err != nil {
672
		return err
673
	}
674
	// if something happens here, we want to make sure the rows are Closed
675
	defer rows.Close()
676
	return scanAll(rows, dest, false)
677
}
678
679
// Get does a QueryRow using the provided Queryer, and scans the resulting row
680
// to dest.  If dest is scannable, the result must only have one column.  Otherwise,
681
// StructScan is used.  Get will return sql.ErrNoRows like row.Scan would.
682
// Any placeholder parameters are replaced with supplied args.
683
// An error is returned if the result set is empty.
684
func Get(q Queryer, dest interface{}, query string, args ...interface{}) error {
685
	r := q.QueryRowx(query, args...)
686
	return r.scanAny(dest, false)
687
}
688
689
// LoadFile exec's every statement in a file (as a single call to Exec).
690
// LoadFile may return a nil *sql.Result if errors are encountered locating or
691
// reading the file at path.  LoadFile reads the entire file into memory, so it
692
// is not suitable for loading large data dumps, but can be useful for initializing
693
// schemas or loading indexes.
694
//
695
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
696
// or the go-mysql-driver/mysql drivers;  pq seems to be an exception here.  Detecting
697
// this by requiring something with DriverName() and then attempting to split the
698
// queries will be difficult to get right, and its current driver-specific behavior
699
// is deemed at least not complex in its incorrectness.
700
func LoadFile(e Execer, path string) (*sql.Result, error) {
701
	realpath, err := filepath.Abs(path)
702
	if err != nil {
703
		return nil, err
704
	}
705
	contents, err := ioutil.ReadFile(realpath)
706
	if err != nil {
707
		return nil, err
708
	}
709
	res, err := e.Exec(string(contents))
710
	return &res, err
711
}
712
713
// MustExec execs the query using e and panics if there was an error.
714
// Any placeholder parameters are replaced with supplied args.
715
func MustExec(e Execer, query string, args ...interface{}) sql.Result {
716
	res, err := e.Exec(query, args...)
717
	if err != nil {
718
		panic(err)
719
	}
720
	return res
721
}
722
723
// SliceScan using this Rows.
724
func (r *Row) SliceScan() ([]interface{}, error) {
725
	return SliceScan(r)
726
}
727
728
// MapScan using this Rows.
729
func (r *Row) MapScan(dest map[string]interface{}) error {
730
	return MapScan(r, dest)
731
}
732
733
func (r *Row) scanAny(dest interface{}, structOnly bool) error {
734
	if r.err != nil {
735
		return r.err
736
	}
737
	if r.rows == nil {
738
		r.err = sql.ErrNoRows
739
		return r.err
740
	}
741
	defer r.rows.Close()
742
743
	v := reflect.ValueOf(dest)
744
	if v.Kind() != reflect.Ptr {
745
		return errors.New("must pass a pointer, not a value, to StructScan destination")
746
	}
747
	if v.IsNil() {
748
		return errors.New("nil pointer passed to StructScan destination")
749
	}
750
751
	base := reflectx.Deref(v.Type())
752
	scannable := isScannable(base)
753
754
	if structOnly && scannable {
755
		return structOnlyError(base)
756
	}
757
758
	columns, err := r.Columns()
759
	if err != nil {
760
		return err
761
	}
762
763
	if scannable && len(columns) > 1 {
764
		return fmt.Errorf("scannable dest type %s with >1 columns (%d) in result", base.Kind(), len(columns))
765
	}
766
767
	if scannable {
768
		return r.Scan(dest)
769
	}
770
771
	m := r.Mapper
772
773
	fields := m.TraversalsByName(v.Type(), columns)
774
	// if we are not unsafe and are missing fields, return an error
775
	if f, err := missingFields(fields); err != nil && !r.unsafe {
776
		return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
777
	}
778
	values := make([]interface{}, len(columns))
779
780
	err = fieldsByTraversal(v, fields, values, true)
781
	if err != nil {
782
		return err
783
	}
784
	// scan into the struct field pointers and append to our results
785
	return r.Scan(values...)
786
}
787
788
// StructScan a single Row into dest.
789
func (r *Row) StructScan(dest interface{}) error {
790
	return r.scanAny(dest, true)
791
}
792
793
// SliceScan a row, returning a []interface{} with values similar to MapScan.
794
// This function is primarily intended for use where the number of columns
795
// is not known.  Because you can pass an []interface{} directly to Scan,
796
// it's recommended that you do that as it will not have to allocate new
797
// slices per row.
798
func SliceScan(r ColScanner) ([]interface{}, error) {
799
	// ignore r.started, since we needn't use reflect for anything.
800
	columns, err := r.Columns()
801
	if err != nil {
802
		return []interface{}{}, err
803
	}
804
805
	values := make([]interface{}, len(columns))
806
	for i := range values {
807
		values[i] = new(interface{})
808
	}
809
810
	err = r.Scan(values...)
811
812
	if err != nil {
813
		return values, err
814
	}
815
816
	for i := range columns {
817
		values[i] = *(values[i].(*interface{}))
818
	}
819
820
	return values, r.Err()
821
}
822
823
// MapScan scans a single Row into the dest map[string]interface{}.
824
// Use this to get results for SQL that might not be under your control
825
// (for instance, if you're building an interface for an SQL server that
826
// executes SQL from input).  Please do not use this as a primary interface!
827
// This will modify the map sent to it in place, so reuse the same map with
828
// care.  Columns which occur more than once in the result will overwrite
829
// each other!
830
func MapScan(r ColScanner, dest map[string]interface{}) error {
831
	// ignore r.started, since we needn't use reflect for anything.
832
	columns, err := r.Columns()
833
	if err != nil {
834
		return err
835
	}
836
837
	values := make([]interface{}, len(columns))
838
	for i := range values {
839
		values[i] = new(interface{})
840
	}
841
842
	err = r.Scan(values...)
843
	if err != nil {
844
		return err
845
	}
846
847
	for i, column := range columns {
848
		dest[column] = *(values[i].(*interface{}))
849
	}
850
851
	return r.Err()
852
}
853
854
type rowsi interface {
855
	Close() error
856
	Columns() ([]string, error)
857
	Err() error
858
	Next() bool
859
	Scan(...interface{}) error
860
}
861
862
// structOnlyError returns an error appropriate for type when a non-scannable
863
// struct is expected but something else is given
864
func structOnlyError(t reflect.Type) error {
865
	isStruct := t.Kind() == reflect.Struct
866
	isScanner := reflect.PtrTo(t).Implements(_scannerInterface)
867
	if !isStruct {
868
		return fmt.Errorf("expected %s but got %s", reflect.Struct, t.Kind())
869
	}
870
	if isScanner {
871
		return fmt.Errorf("structscan expects a struct dest but the provided struct type %s implements scanner", t.Name())
872
	}
873
	return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
874
}
875
876
// scanAll scans all rows into a destination, which must be a slice of any
877
// type.  If the destination slice type is a Struct, then StructScan will be
878
// used on each row.  If the destination is some other kind of base type, then
879
// each row must only have one column which can scan into that type.  This
880
// allows you to do something like:
881
//
882
//    rows, _ := db.Query("select id from people;")
883
//    var ids []int
884
//    scanAll(rows, &ids, false)
885
//
886
// and ids will be a list of the id results.  I realize that this is a desirable
887
// interface to expose to users, but for now it will only be exposed via changes
888
// to `Get` and `Select`.  The reason that this has been implemented like this is
889
// this is the only way to not duplicate reflect work in the new API while
890
// maintaining backwards compatibility.
891
func scanAll(rows rowsi, dest interface{}, structOnly bool) error {
892
	var v, vp reflect.Value
893
894
	value := reflect.ValueOf(dest)
895
896
	// json.Unmarshal returns errors for these
897
	if value.Kind() != reflect.Ptr {
898
		return errors.New("must pass a pointer, not a value, to StructScan destination")
899
	}
900
	if value.IsNil() {
901
		return errors.New("nil pointer passed to StructScan destination")
902
	}
903
	direct := reflect.Indirect(value)
904
905
	slice, err := baseType(value.Type(), reflect.Slice)
906
	if err != nil {
907
		return err
908
	}
909
910
	isPtr := slice.Elem().Kind() == reflect.Ptr
911
	base := reflectx.Deref(slice.Elem())
912
	scannable := isScannable(base)
913
914
	if structOnly && scannable {
915
		return structOnlyError(base)
916
	}
917
918
	columns, err := rows.Columns()
919
	if err != nil {
920
		return err
921
	}
922
923
	// if it's a base type make sure it only has 1 column;  if not return an error
924
	if scannable && len(columns) > 1 {
925
		return fmt.Errorf("non-struct dest type %s with >1 columns (%d)", base.Kind(), len(columns))
926
	}
927
928
	if !scannable {
929
		var values []interface{}
930
		var m *reflectx.Mapper
931
932
		switch rows.(type) {
933
		case *Rows:
934
			m = rows.(*Rows).Mapper
935
		default:
936
			m = mapper()
937
		}
938
939
		fields := m.TraversalsByName(base, columns)
940
		// if we are not unsafe and are missing fields, return an error
941
		if f, err := missingFields(fields); err != nil && !isUnsafe(rows) {
942
			return fmt.Errorf("missing destination name %s in %T", columns[f], dest)
943
		}
944
		values = make([]interface{}, len(columns))
945
946
		for rows.Next() {
947
			// create a new struct type (which returns PtrTo) and indirect it
948
			vp = reflect.New(base)
949
			v = reflect.Indirect(vp)
950
951
			err = fieldsByTraversal(v, fields, values, true)
952
			if err != nil {
953
				return err
954
			}
955
956
			// scan into the struct field pointers and append to our results
957
			err = rows.Scan(values...)
958
			if err != nil {
959
				return err
960
			}
961
962
			if isPtr {
963
				direct.Set(reflect.Append(direct, vp))
964
			} else {
965
				direct.Set(reflect.Append(direct, v))
966
			}
967
		}
968
	} else {
969
		for rows.Next() {
970
			vp = reflect.New(base)
971
			err = rows.Scan(vp.Interface())
972
			if err != nil {
973
				return err
974
			}
975
			// append
976
			if isPtr {
977
				direct.Set(reflect.Append(direct, vp))
978
			} else {
979
				direct.Set(reflect.Append(direct, reflect.Indirect(vp)))
980
			}
981
		}
982
	}
983
984
	return rows.Err()
985
}
986
987
// FIXME: StructScan was the very first bit of API in sqlx, and now unfortunately
988
// it doesn't really feel like it's named properly.  There is an incongruency
989
// between this and the way that StructScan (which might better be ScanStruct
990
// anyway) works on a rows object.
991
992
// StructScan all rows from an sql.Rows or an sqlx.Rows into the dest slice.
993
// StructScan will scan in the entire rows result, so if you do not want to
994
// allocate structs for the entire result, use Queryx and see sqlx.Rows.StructScan.
995
// If rows is sqlx.Rows, it will use its mapper, otherwise it will use the default.
996
func StructScan(rows rowsi, dest interface{}) error {
997
	return scanAll(rows, dest, true)
998
999
}
1000
1001
// reflect helpers
1002
1003
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
1004
	t = reflectx.Deref(t)
1005
	if t.Kind() != expected {
1006
		return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
1007
	}
1008
	return t, nil
1009
}
1010
1011
// fieldsByName fills a values interface with fields from the passed value based
1012
// on the traversals in int.  If ptrs is true, return addresses instead of values.
1013
// We write this instead of using FieldsByName to save allocations and map lookups
1014
// when iterating over many rows.  Empty traversals will get an interface pointer.
1015
// Because of the necessity of requesting ptrs or values, it's considered a bit too
1016
// specialized for inclusion in reflectx itself.
1017
func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error {
1018
	v = reflect.Indirect(v)
1019
	if v.Kind() != reflect.Struct {
1020
		return errors.New("argument not a struct")
1021
	}
1022
1023
	for i, traversal := range traversals {
1024
		if len(traversal) == 0 {
1025
			values[i] = new(interface{})
1026
			continue
1027
		}
1028
		f := reflectx.FieldByIndexes(v, traversal)
1029
		if ptrs {
1030
			values[i] = f.Addr().Interface()
1031
		} else {
1032
			values[i] = f.Interface()
1033
		}
1034
	}
1035
	return nil
1036
}
1037
1038
func missingFields(transversals [][]int) (field int, err error) {
1039
	for i, t := range transversals {
1040
		if len(t) == 0 {
1041
			return i, errors.New("missing field")
1042
		}
1043
	}
1044
	return 0, nil
1045
}
1046