Test Failed
Pull Request — master (#36)
by Frank
03:43 queued 02:03
created

pq.parseEnviron   F

Complexity

Conditions 26

Size

Total Lines 68
Code Lines 54

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 26
eloc 54
nop 1
dl 0
loc 68
rs 0
c 0
b 0
f 0

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like pq.parseEnviron often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
package pq
2
3
import (
4
	"bufio"
5
	"context"
6
	"crypto/md5"
7
	"crypto/sha256"
8
	"database/sql"
9
	"database/sql/driver"
10
	"encoding/binary"
11
	"errors"
12
	"fmt"
13
	"io"
14
	"net"
15
	"os"
16
	"os/user"
17
	"path"
18
	"path/filepath"
19
	"strconv"
20
	"strings"
21
	"sync"
22
	"time"
23
	"unicode"
24
25
	"github.com/lib/pq/oid"
26
	"github.com/lib/pq/scram"
27
)
28
29
// Common error types
30
var (
31
	ErrNotSupported              = errors.New("pq: Unsupported command")
32
	ErrInFailedTransaction       = errors.New("pq: Could not complete operation in a failed transaction")
33
	ErrSSLNotSupported           = errors.New("pq: SSL is not enabled on the server")
34
	ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
35
	ErrCouldNotDetectUsername    = errors.New("pq: Could not detect default username. Please provide one explicitly")
36
37
	errUnexpectedReady = errors.New("unexpected ReadyForQuery")
38
	errNoRowsAffected  = errors.New("no RowsAffected available after the empty statement")
39
	errNoLastInsertID  = errors.New("no LastInsertId available after the empty statement")
40
)
41
42
// Compile time validation that our types implement the expected interfaces
43
var (
44
	_ driver.Driver = Driver{}
45
)
46
47
// Driver is the Postgres database driver.
48
type Driver struct{}
49
50
// Open opens a new connection to the database. name is a connection string.
51
// Most users should only use it through database/sql package from the standard
52
// library.
53
func (d Driver) Open(name string) (driver.Conn, error) {
54
	return Open(name)
55
}
56
57
func init() {
58
	sql.Register("postgres", &Driver{})
59
}
60
61
type parameterStatus struct {
62
	// server version in the same format as server_version_num, or 0 if
63
	// unavailable
64
	serverVersion int
65
66
	// the current location based on the TimeZone value of the session, if
67
	// available
68
	currentLocation *time.Location
69
}
70
71
type transactionStatus byte
72
73
const (
74
	txnStatusIdle                transactionStatus = 'I'
75
	txnStatusIdleInTransaction   transactionStatus = 'T'
76
	txnStatusInFailedTransaction transactionStatus = 'E'
77
)
78
79
func (s transactionStatus) String() string {
80
	switch s {
81
	case txnStatusIdle:
82
		return "idle"
83
	case txnStatusIdleInTransaction:
84
		return "idle in transaction"
85
	case txnStatusInFailedTransaction:
86
		return "in a failed transaction"
87
	default:
88
		errorf("unknown transactionStatus %d", s)
89
	}
90
91
	panic("not reached")
92
}
93
94
// Dialer is the dialer interface. It can be used to obtain more control over
95
// how pq creates network connections.
96
type Dialer interface {
97
	Dial(network, address string) (net.Conn, error)
98
	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
99
}
100
101
// DialerContext is the context-aware dialer interface.
102
type DialerContext interface {
103
	DialContext(ctx context.Context, network, address string) (net.Conn, error)
104
}
105
106
type defaultDialer struct {
107
	d net.Dialer
108
}
109
110
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
111
	return d.d.Dial(network, address)
112
}
113
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
114
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
115
	defer cancel()
116
	return d.DialContext(ctx, network, address)
117
}
118
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
119
	return d.d.DialContext(ctx, network, address)
120
}
121
122
type conn struct {
123
	c         net.Conn
124
	buf       *bufio.Reader
125
	namei     int
126
	scratch   [512]byte
127
	txnStatus transactionStatus
128
	txnFinish func()
129
130
	// Save connection arguments to use during CancelRequest.
131
	dialer Dialer
132
	opts   values
133
134
	// Cancellation key data for use with CancelRequest messages.
135
	processID int
136
	secretKey int
137
138
	parameterStatus parameterStatus
139
140
	saveMessageType   byte
141
	saveMessageBuffer []byte
142
143
	// If an error is set, this connection is bad and all public-facing
144
	// functions should return the appropriate error by calling get()
145
	// (ErrBadConn) or getForNext().
146
	err syncErr
147
148
	// If set, this connection should never use the binary format when
149
	// receiving query results from prepared statements.  Only provided for
150
	// debugging.
151
	disablePreparedBinaryResult bool
152
153
	// Whether to always send []byte parameters over as binary.  Enables single
154
	// round-trip mode for non-prepared Query calls.
155
	binaryParameters bool
156
157
	// If true this connection is in the middle of a COPY
158
	inCopy bool
159
160
	// If not nil, notices will be synchronously sent here
161
	noticeHandler func(*Error)
162
163
	// If not nil, notifications will be synchronously sent here
164
	notificationHandler func(*Notification)
165
166
	// GSSAPI context
167
	gss GSS
168
}
169
170
type syncErr struct {
171
	err error
172
	sync.Mutex
173
}
174
175
// Return ErrBadConn if connection is bad.
176
func (e *syncErr) get() error {
177
	e.Lock()
178
	defer e.Unlock()
179
	if e.err != nil {
180
		return driver.ErrBadConn
181
	}
182
	return nil
183
}
184
185
// Return the error set on the connection. Currently only used by rows.Next.
186
func (e *syncErr) getForNext() error {
187
	e.Lock()
188
	defer e.Unlock()
189
	return e.err
190
}
191
192
// Set error, only if it isn't set yet.
193
func (e *syncErr) set(err error) {
194
	if err == nil {
195
		panic("attempt to set nil err")
196
	}
197
	e.Lock()
198
	defer e.Unlock()
199
	if e.err == nil {
200
		e.err = err
201
	}
202
}
203
204
// Handle driver-side settings in parsed connection string.
205
func (cn *conn) handleDriverSettings(o values) (err error) {
206
	boolSetting := func(key string, val *bool) error {
207
		if value, ok := o[key]; ok {
208
			if value == "yes" {
209
				*val = true
210
			} else if value == "no" {
211
				*val = false
212
			} else {
213
				return fmt.Errorf("unrecognized value %q for %s", value, key)
214
			}
215
		}
216
		return nil
217
	}
218
219
	err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
220
	if err != nil {
221
		return err
222
	}
223
	return boolSetting("binary_parameters", &cn.binaryParameters)
224
}
225
226
func (cn *conn) handlePgpass(o values) {
227
	// if a password was supplied, do not process .pgpass
228
	if _, ok := o["password"]; ok {
229
		return
230
	}
231
	filename := os.Getenv("PGPASSFILE")
232
	if filename == "" {
233
		// XXX this code doesn't work on Windows where the default filename is
234
		// XXX %APPDATA%\postgresql\pgpass.conf
235
		// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
236
		userHome := os.Getenv("HOME")
237
		if userHome == "" {
238
			user, err := user.Current()
239
			if err != nil {
240
				return
241
			}
242
			userHome = user.HomeDir
243
		}
244
		filename = filepath.Join(userHome, ".pgpass")
245
	}
246
	fileinfo, err := os.Stat(filename)
247
	if err != nil {
248
		return
249
	}
250
	mode := fileinfo.Mode()
251
	if mode&(0x77) != 0 {
252
		// XXX should warn about incorrect .pgpass permissions as psql does
253
		return
254
	}
255
	file, err := os.Open(filename)
256
	if err != nil {
257
		return
258
	}
259
	defer file.Close()
260
	scanner := bufio.NewScanner(io.Reader(file))
261
	hostname := o["host"]
262
	ntw, _ := network(o)
263
	port := o["port"]
264
	db := o["dbname"]
265
	username := o["user"]
266
	// From: https://github.com/tg/pgpass/blob/master/reader.go
267
	getFields := func(s string) []string {
268
		fs := make([]string, 0, 5)
269
		f := make([]rune, 0, len(s))
270
271
		var esc bool
272
		for _, c := range s {
273
			switch {
274
			case esc:
275
				f = append(f, c)
276
				esc = false
277
			case c == '\\':
278
				esc = true
279
			case c == ':':
280
				fs = append(fs, string(f))
281
				f = f[:0]
282
			default:
283
				f = append(f, c)
284
			}
285
		}
286
		return append(fs, string(f))
287
	}
288
	for scanner.Scan() {
289
		line := scanner.Text()
290
		if len(line) == 0 || line[0] == '#' {
291
			continue
292
		}
293
		split := getFields(line)
294
		if len(split) != 5 {
295
			continue
296
		}
297
		if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
298
			o["password"] = split[4]
299
			return
300
		}
301
	}
302
}
303
304
func (cn *conn) writeBuf(b byte) *writeBuf {
305
	cn.scratch[0] = b
306
	return &writeBuf{
307
		buf: cn.scratch[:5],
308
		pos: 1,
309
	}
310
}
311
312
// Open opens a new connection to the database. dsn is a connection string.
313
// Most users should only use it through database/sql package from the standard
314
// library.
315
func Open(dsn string) (_ driver.Conn, err error) {
316
	return DialOpen(defaultDialer{}, dsn)
317
}
318
319
// DialOpen opens a new connection to the database using a dialer.
320
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
321
	c, err := NewConnector(dsn)
322
	if err != nil {
323
		return nil, err
324
	}
325
	c.dialer = d
326
	return c.open(context.Background())
327
}
328
329
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
330
	// Handle any panics during connection initialization.  Note that we
331
	// specifically do *not* want to use errRecover(), as that would turn any
332
	// connection errors into ErrBadConns, hiding the real error message from
333
	// the user.
334
	defer errRecoverNoErrBadConn(&err)
335
336
	// Create a new values map (copy). This makes it so maps in different
337
	// connections do not reference the same underlying data structure, so it
338
	// is safe for multiple connections to concurrently write to their opts.
339
	o := make(values)
340
	for k, v := range c.opts {
341
		o[k] = v
342
	}
343
344
	cn = &conn{
345
		opts:   o,
346
		dialer: c.dialer,
347
	}
348
	err = cn.handleDriverSettings(o)
349
	if err != nil {
350
		return nil, err
351
	}
352
	cn.handlePgpass(o)
353
354
	cn.c, err = dial(ctx, c.dialer, o)
355
	if err != nil {
356
		return nil, err
357
	}
358
359
	err = cn.ssl(o)
360
	if err != nil {
361
		if cn.c != nil {
362
			cn.c.Close()
363
		}
364
		return nil, err
365
	}
366
367
	// cn.startup panics on error. Make sure we don't leak cn.c.
368
	panicking := true
369
	defer func() {
370
		if panicking {
371
			cn.c.Close()
372
		}
373
	}()
374
375
	cn.buf = bufio.NewReader(cn.c)
376
	cn.startup(o)
377
378
	// reset the deadline, in case one was set (see dial)
379
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
380
		err = cn.c.SetDeadline(time.Time{})
381
	}
382
	panicking = false
383
	return cn, err
384
}
385
386
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
387
	network, address := network(o)
388
389
	// Zero or not specified means wait indefinitely.
390
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
391
		seconds, err := strconv.ParseInt(timeout, 10, 0)
392
		if err != nil {
393
			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
394
		}
395
		duration := time.Duration(seconds) * time.Second
396
397
		// connect_timeout should apply to the entire connection establishment
398
		// procedure, so we both use a timeout for the TCP connection
399
		// establishment and set a deadline for doing the initial handshake.
400
		// The deadline is then reset after startup() is done.
401
		deadline := time.Now().Add(duration)
402
		var conn net.Conn
403
		if dctx, ok := d.(DialerContext); ok {
404
			ctx, cancel := context.WithTimeout(ctx, duration)
405
			defer cancel()
406
			conn, err = dctx.DialContext(ctx, network, address)
407
		} else {
408
			conn, err = d.DialTimeout(network, address, duration)
409
		}
410
		if err != nil {
411
			return nil, err
412
		}
413
		err = conn.SetDeadline(deadline)
414
		return conn, err
415
	}
416
	if dctx, ok := d.(DialerContext); ok {
417
		return dctx.DialContext(ctx, network, address)
418
	}
419
	return d.Dial(network, address)
420
}
421
422
func network(o values) (string, string) {
423
	host := o["host"]
424
425
	if strings.HasPrefix(host, "/") {
426
		sockPath := path.Join(host, ".s.PGSQL."+o["port"])
427
		return "unix", sockPath
428
	}
429
430
	return "tcp", net.JoinHostPort(host, o["port"])
431
}
432
433
type values map[string]string
434
435
// scanner implements a tokenizer for libpq-style option strings.
436
type scanner struct {
437
	s []rune
438
	i int
439
}
440
441
// newScanner returns a new scanner initialized with the option string s.
442
func newScanner(s string) *scanner {
443
	return &scanner{[]rune(s), 0}
444
}
445
446
// Next returns the next rune.
447
// It returns 0, false if the end of the text has been reached.
448
func (s *scanner) Next() (rune, bool) {
449
	if s.i >= len(s.s) {
450
		return 0, false
451
	}
452
	r := s.s[s.i]
453
	s.i++
454
	return r, true
455
}
456
457
// SkipSpaces returns the next non-whitespace rune.
458
// It returns 0, false if the end of the text has been reached.
459
func (s *scanner) SkipSpaces() (rune, bool) {
460
	r, ok := s.Next()
461
	for unicode.IsSpace(r) && ok {
462
		r, ok = s.Next()
463
	}
464
	return r, ok
465
}
466
467
// parseOpts parses the options from name and adds them to the values.
468
//
469
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
470
func parseOpts(name string, o values) error {
471
	s := newScanner(name)
472
473
	for {
474
		var (
475
			keyRunes, valRunes []rune
476
			r                  rune
477
			ok                 bool
478
		)
479
480
		if r, ok = s.SkipSpaces(); !ok {
481
			break
482
		}
483
484
		// Scan the key
485
		for !unicode.IsSpace(r) && r != '=' {
486
			keyRunes = append(keyRunes, r)
487
			if r, ok = s.Next(); !ok {
488
				break
489
			}
490
		}
491
492
		// Skip any whitespace if we're not at the = yet
493
		if r != '=' {
494
			r, ok = s.SkipSpaces()
495
		}
496
497
		// The current character should be =
498
		if r != '=' || !ok {
499
			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
500
		}
501
502
		// Skip any whitespace after the =
503
		if r, ok = s.SkipSpaces(); !ok {
504
			// If we reach the end here, the last value is just an empty string as per libpq.
505
			o[string(keyRunes)] = ""
506
			break
507
		}
508
509
		if r != '\'' {
510
			for !unicode.IsSpace(r) {
511
				if r == '\\' {
512
					if r, ok = s.Next(); !ok {
513
						return fmt.Errorf(`missing character after backslash`)
514
					}
515
				}
516
				valRunes = append(valRunes, r)
517
518
				if r, ok = s.Next(); !ok {
519
					break
520
				}
521
			}
522
		} else {
523
		quote:
524
			for {
525
				if r, ok = s.Next(); !ok {
526
					return fmt.Errorf(`unterminated quoted string literal in connection string`)
527
				}
528
				switch r {
529
				case '\'':
530
					break quote
531
				case '\\':
532
					r, _ = s.Next()
533
					fallthrough
534
				default:
535
					valRunes = append(valRunes, r)
536
				}
537
			}
538
		}
539
540
		o[string(keyRunes)] = string(valRunes)
541
	}
542
543
	return nil
544
}
545
546
func (cn *conn) isInTransaction() bool {
547
	return cn.txnStatus == txnStatusIdleInTransaction ||
548
		cn.txnStatus == txnStatusInFailedTransaction
549
}
550
551
func (cn *conn) checkIsInTransaction(intxn bool) {
552
	if cn.isInTransaction() != intxn {
553
		cn.err.set(driver.ErrBadConn)
554
		errorf("unexpected transaction status %v", cn.txnStatus)
555
	}
556
}
557
558
func (cn *conn) Begin() (_ driver.Tx, err error) {
559
	return cn.begin("")
560
}
561
562
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
563
	if err := cn.err.get(); err != nil {
564
		return nil, err
565
	}
566
	defer cn.errRecover(&err)
567
568
	cn.checkIsInTransaction(false)
569
	_, commandTag, err := cn.simpleExec("BEGIN" + mode)
570
	if err != nil {
571
		return nil, err
572
	}
573
	if commandTag != "BEGIN" {
574
		cn.err.set(driver.ErrBadConn)
575
		return nil, fmt.Errorf("unexpected command tag %s", commandTag)
576
	}
577
	if cn.txnStatus != txnStatusIdleInTransaction {
578
		cn.err.set(driver.ErrBadConn)
579
		return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
580
	}
581
	return cn, nil
582
}
583
584
func (cn *conn) closeTxn() {
585
	if finish := cn.txnFinish; finish != nil {
586
		finish()
587
	}
588
}
589
590
func (cn *conn) Commit() (err error) {
591
	defer cn.closeTxn()
592
	if err := cn.err.get(); err != nil {
593
		return err
594
	}
595
	defer cn.errRecover(&err)
596
597
	cn.checkIsInTransaction(true)
598
	// We don't want the client to think that everything is okay if it tries
599
	// to commit a failed transaction.  However, no matter what we return,
600
	// database/sql will release this connection back into the free connection
601
	// pool so we have to abort the current transaction here.  Note that you
602
	// would get the same behaviour if you issued a COMMIT in a failed
603
	// transaction, so it's also the least surprising thing to do here.
604
	if cn.txnStatus == txnStatusInFailedTransaction {
605
		if err := cn.rollback(); err != nil {
606
			return err
607
		}
608
		return ErrInFailedTransaction
609
	}
610
611
	_, commandTag, err := cn.simpleExec("COMMIT")
612
	if err != nil {
613
		if cn.isInTransaction() {
614
			cn.err.set(driver.ErrBadConn)
615
		}
616
		return err
617
	}
618
	if commandTag != "COMMIT" {
619
		cn.err.set(driver.ErrBadConn)
620
		return fmt.Errorf("unexpected command tag %s", commandTag)
621
	}
622
	cn.checkIsInTransaction(false)
623
	return nil
624
}
625
626
func (cn *conn) Rollback() (err error) {
627
	defer cn.closeTxn()
628
	if err := cn.err.get(); err != nil {
629
		return err
630
	}
631
	defer cn.errRecover(&err)
632
	return cn.rollback()
633
}
634
635
func (cn *conn) rollback() (err error) {
636
	cn.checkIsInTransaction(true)
637
	_, commandTag, err := cn.simpleExec("ROLLBACK")
638
	if err != nil {
639
		if cn.isInTransaction() {
640
			cn.err.set(driver.ErrBadConn)
641
		}
642
		return err
643
	}
644
	if commandTag != "ROLLBACK" {
645
		return fmt.Errorf("unexpected command tag %s", commandTag)
646
	}
647
	cn.checkIsInTransaction(false)
648
	return nil
649
}
650
651
func (cn *conn) gname() string {
652
	cn.namei++
653
	return strconv.FormatInt(int64(cn.namei), 10)
654
}
655
656
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
657
	b := cn.writeBuf('Q')
658
	b.string(q)
659
	cn.send(b)
660
661
	for {
662
		t, r := cn.recv1()
663
		switch t {
664
		case 'C':
665
			res, commandTag = cn.parseComplete(r.string())
666
		case 'Z':
667
			cn.processReadyForQuery(r)
668
			if res == nil && err == nil {
669
				err = errUnexpectedReady
670
			}
671
			// done
672
			return
673
		case 'E':
674
			err = parseError(r)
675
		case 'I':
676
			res = emptyRows
677
		case 'T', 'D':
678
			// ignore any results
679
		default:
680
			cn.err.set(driver.ErrBadConn)
681
			errorf("unknown response for simple query: %q", t)
682
		}
683
	}
684
}
685
686
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
687
	defer cn.errRecover(&err)
688
689
	b := cn.writeBuf('Q')
690
	b.string(q)
691
	cn.send(b)
692
693
	for {
694
		t, r := cn.recv1()
695
		switch t {
696
		case 'C', 'I':
697
			// We allow queries which don't return any results through Query as
698
			// well as Exec.  We still have to give database/sql a rows object
699
			// the user can close, though, to avoid connections from being
700
			// leaked.  A "rows" with done=true works fine for that purpose.
701
			if err != nil {
702
				cn.err.set(driver.ErrBadConn)
703
				errorf("unexpected message %q in simple query execution", t)
704
			}
705
			if res == nil {
706
				res = &rows{
707
					cn: cn,
708
				}
709
			}
710
			// Set the result and tag to the last command complete if there wasn't a
711
			// query already run. Although queries usually return from here and cede
712
			// control to Next, a query with zero results does not.
713
			if t == 'C' {
714
				res.result, res.tag = cn.parseComplete(r.string())
715
				if res.colNames != nil {
716
					return
717
				}
718
			}
719
			res.done = true
720
		case 'Z':
721
			cn.processReadyForQuery(r)
722
			// done
723
			return
724
		case 'E':
725
			res = nil
726
			err = parseError(r)
727
		case 'D':
728
			if res == nil {
729
				cn.err.set(driver.ErrBadConn)
730
				errorf("unexpected DataRow in simple query execution")
731
			}
732
			// the query didn't fail; kick off to Next
733
			cn.saveMessage(t, r)
734
			return
735
		case 'T':
736
			// res might be non-nil here if we received a previous
737
			// CommandComplete, but that's fine; just overwrite it
738
			res = &rows{cn: cn}
739
			res.rowsHeader = parsePortalRowDescribe(r)
740
741
			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
742
			// until the first DataRow has been received.
743
		default:
744
			cn.err.set(driver.ErrBadConn)
745
			errorf("unknown response for simple query: %q", t)
746
		}
747
	}
748
}
749
750
type noRows struct{}
751
752
var emptyRows noRows
753
754
var _ driver.Result = noRows{}
755
756
func (noRows) LastInsertId() (int64, error) {
757
	return 0, errNoLastInsertID
758
}
759
760
func (noRows) RowsAffected() (int64, error) {
761
	return 0, errNoRowsAffected
762
}
763
764
// Decides which column formats to use for a prepared statement.  The input is
765
// an array of type oids, one element per result column.
766
func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
767
	if len(colTyps) == 0 {
768
		return nil, colFmtDataAllText
769
	}
770
771
	colFmts = make([]format, len(colTyps))
772
	if forceText {
773
		return colFmts, colFmtDataAllText
774
	}
775
776
	allBinary := true
777
	allText := true
778
	for i, t := range colTyps {
779
		switch t.OID {
780
		// This is the list of types to use binary mode for when receiving them
781
		// through a prepared statement.  If a type appears in this list, it
782
		// must also be implemented in binaryDecode in encode.go.
783
		case oid.T_bytea:
784
			fallthrough
785
		case oid.T_int8:
786
			fallthrough
787
		case oid.T_int4:
788
			fallthrough
789
		case oid.T_int2:
790
			fallthrough
791
		case oid.T_uuid:
792
			colFmts[i] = formatBinary
793
			allText = false
794
795
		default:
796
			allBinary = false
797
		}
798
	}
799
800
	if allBinary {
801
		return colFmts, colFmtDataAllBinary
802
	} else if allText {
803
		return colFmts, colFmtDataAllText
804
	} else {
805
		colFmtData = make([]byte, 2+len(colFmts)*2)
806
		binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
807
		for i, v := range colFmts {
808
			binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
809
		}
810
		return colFmts, colFmtData
811
	}
812
}
813
814
func (cn *conn) prepareTo(q, stmtName string) *stmt {
815
	st := &stmt{cn: cn, name: stmtName}
816
817
	b := cn.writeBuf('P')
818
	b.string(st.name)
819
	b.string(q)
820
	b.int16(0)
821
822
	b.next('D')
823
	b.byte('S')
824
	b.string(st.name)
825
826
	b.next('S')
827
	cn.send(b)
828
829
	cn.readParseResponse()
830
	st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
831
	st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
832
	cn.readReadyForQuery()
833
	return st
834
}
835
836
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
837
	if err := cn.err.get(); err != nil {
838
		return nil, err
839
	}
840
	defer cn.errRecover(&err)
841
842
	if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
843
		s, err := cn.prepareCopyIn(q)
844
		if err == nil {
845
			cn.inCopy = true
846
		}
847
		return s, err
848
	}
849
	return cn.prepareTo(q, cn.gname()), nil
850
}
851
852
func (cn *conn) Close() (err error) {
853
	// Skip cn.bad return here because we always want to close a connection.
854
	defer cn.errRecover(&err)
855
856
	// Ensure that cn.c.Close is always run. Since error handling is done with
857
	// panics and cn.errRecover, the Close must be in a defer.
858
	defer func() {
859
		cerr := cn.c.Close()
860
		if err == nil {
861
			err = cerr
862
		}
863
	}()
864
865
	// Don't go through send(); ListenerConn relies on us not scribbling on the
866
	// scratch buffer of this connection.
867
	return cn.sendSimpleMessage('X')
868
}
869
870
// Implement the "Queryer" interface
871
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
872
	return cn.query(query, args)
873
}
874
875
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
876
	if err := cn.err.get(); err != nil {
877
		return nil, err
878
	}
879
	if cn.inCopy {
880
		return nil, errCopyInProgress
881
	}
882
	defer cn.errRecover(&err)
883
884
	// Check to see if we can use the "simpleQuery" interface, which is
885
	// *much* faster than going through prepare/exec
886
	if len(args) == 0 {
887
		return cn.simpleQuery(query)
888
	}
889
890
	if cn.binaryParameters {
891
		cn.sendBinaryModeQuery(query, args)
892
893
		cn.readParseResponse()
894
		cn.readBindResponse()
895
		rows := &rows{cn: cn}
896
		rows.rowsHeader = cn.readPortalDescribeResponse()
897
		cn.postExecuteWorkaround()
898
		return rows, nil
899
	}
900
	st := cn.prepareTo(query, "")
901
	st.exec(args)
902
	return &rows{
903
		cn:         cn,
904
		rowsHeader: st.rowsHeader,
905
	}, nil
906
}
907
908
// Implement the optional "Execer" interface for one-shot queries
909
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
910
	if err := cn.err.get(); err != nil {
911
		return nil, err
912
	}
913
	defer cn.errRecover(&err)
914
915
	// Check to see if we can use the "simpleExec" interface, which is
916
	// *much* faster than going through prepare/exec
917
	if len(args) == 0 {
918
		// ignore commandTag, our caller doesn't care
919
		r, _, err := cn.simpleExec(query)
920
		return r, err
921
	}
922
923
	if cn.binaryParameters {
924
		cn.sendBinaryModeQuery(query, args)
925
926
		cn.readParseResponse()
927
		cn.readBindResponse()
928
		cn.readPortalDescribeResponse()
929
		cn.postExecuteWorkaround()
930
		res, _, err = cn.readExecuteResponse("Execute")
931
		return res, err
932
	}
933
	// Use the unnamed statement to defer planning until bind
934
	// time, or else value-based selectivity estimates cannot be
935
	// used.
936
	st := cn.prepareTo(query, "")
937
	r, err := st.Exec(args)
938
	if err != nil {
939
		panic(err)
940
	}
941
	return r, err
942
}
943
944
type safeRetryError struct {
945
	Err error
946
}
947
948
func (se *safeRetryError) Error() string {
949
	return se.Err.Error()
950
}
951
952
func (cn *conn) send(m *writeBuf) {
953
	n, err := cn.c.Write(m.wrap())
954
	if err != nil {
955
		if n == 0 {
956
			err = &safeRetryError{Err: err}
957
		}
958
		panic(err)
959
	}
960
}
961
962
func (cn *conn) sendStartupPacket(m *writeBuf) error {
963
	_, err := cn.c.Write((m.wrap())[1:])
964
	return err
965
}
966
967
// Send a message of type typ to the server on the other end of cn.  The
968
// message should have no payload.  This method does not use the scratch
969
// buffer.
970
func (cn *conn) sendSimpleMessage(typ byte) (err error) {
971
	_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
972
	return err
973
}
974
975
// saveMessage memorizes a message and its buffer in the conn struct.
976
// recvMessage will then return these values on the next call to it.  This
977
// method is useful in cases where you have to see what the next message is
978
// going to be (e.g. to see whether it's an error or not) but you can't handle
979
// the message yourself.
980
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
981
	if cn.saveMessageType != 0 {
982
		cn.err.set(driver.ErrBadConn)
983
		errorf("unexpected saveMessageType %d", cn.saveMessageType)
984
	}
985
	cn.saveMessageType = typ
986
	cn.saveMessageBuffer = *buf
987
}
988
989
// recvMessage receives any message from the backend, or returns an error if
990
// a problem occurred while reading the message.
991
func (cn *conn) recvMessage(r *readBuf) (byte, error) {
992
	// workaround for a QueryRow bug, see exec
993
	if cn.saveMessageType != 0 {
994
		t := cn.saveMessageType
995
		*r = cn.saveMessageBuffer
996
		cn.saveMessageType = 0
997
		cn.saveMessageBuffer = nil
998
		return t, nil
999
	}
1000
1001
	x := cn.scratch[:5]
1002
	_, err := io.ReadFull(cn.buf, x)
1003
	if err != nil {
1004
		return 0, err
1005
	}
1006
1007
	// read the type and length of the message that follows
1008
	t := x[0]
1009
	n := int(binary.BigEndian.Uint32(x[1:])) - 4
1010
	var y []byte
1011
	if n <= len(cn.scratch) {
1012
		y = cn.scratch[:n]
1013
	} else {
1014
		y = make([]byte, n)
1015
	}
1016
	_, err = io.ReadFull(cn.buf, y)
1017
	if err != nil {
1018
		return 0, err
1019
	}
1020
	*r = y
1021
	return t, nil
1022
}
1023
1024
// recv receives a message from the backend, but if an error happened while
1025
// reading the message or the received message was an ErrorResponse, it panics.
1026
// NoticeResponses are ignored.  This function should generally be used only
1027
// during the startup sequence.
1028
func (cn *conn) recv() (t byte, r *readBuf) {
1029
	for {
1030
		var err error
1031
		r = &readBuf{}
1032
		t, err = cn.recvMessage(r)
1033
		if err != nil {
1034
			panic(err)
1035
		}
1036
		switch t {
1037
		case 'E':
1038
			panic(parseError(r))
1039
		case 'N':
1040
			if n := cn.noticeHandler; n != nil {
1041
				n(parseError(r))
1042
			}
1043
		case 'A':
1044
			if n := cn.notificationHandler; n != nil {
1045
				n(recvNotification(r))
1046
			}
1047
		default:
1048
			return
1049
		}
1050
	}
1051
}
1052
1053
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
1054
// the caller to avoid an allocation.
1055
func (cn *conn) recv1Buf(r *readBuf) byte {
1056
	for {
1057
		t, err := cn.recvMessage(r)
1058
		if err != nil {
1059
			panic(err)
1060
		}
1061
1062
		switch t {
1063
		case 'A':
1064
			if n := cn.notificationHandler; n != nil {
1065
				n(recvNotification(r))
1066
			}
1067
		case 'N':
1068
			if n := cn.noticeHandler; n != nil {
1069
				n(parseError(r))
1070
			}
1071
		case 'S':
1072
			cn.processParameterStatus(r)
1073
		default:
1074
			return t
1075
		}
1076
	}
1077
}
1078
1079
// recv1 receives a message from the backend, panicking if an error occurs
1080
// while attempting to read it.  All asynchronous messages are ignored, with
1081
// the exception of ErrorResponse.
1082
func (cn *conn) recv1() (t byte, r *readBuf) {
1083
	r = &readBuf{}
1084
	t = cn.recv1Buf(r)
1085
	return t, r
1086
}
1087
1088
func (cn *conn) ssl(o values) error {
1089
	upgrade, err := ssl(o)
1090
	if err != nil {
1091
		return err
1092
	}
1093
1094
	if upgrade == nil {
1095
		// Nothing to do
1096
		return nil
1097
	}
1098
1099
	w := cn.writeBuf(0)
1100
	w.int32(80877103)
1101
	if err = cn.sendStartupPacket(w); err != nil {
1102
		return err
1103
	}
1104
1105
	b := cn.scratch[:1]
1106
	_, err = io.ReadFull(cn.c, b)
1107
	if err != nil {
1108
		return err
1109
	}
1110
1111
	if b[0] != 'S' {
1112
		return ErrSSLNotSupported
1113
	}
1114
1115
	cn.c, err = upgrade(cn.c)
1116
	return err
1117
}
1118
1119
// isDriverSetting returns true iff a setting is purely for configuring the
1120
// driver's options and should not be sent to the server in the connection
1121
// startup packet.
1122
func isDriverSetting(key string) bool {
1123
	switch key {
1124
	case "host", "port":
1125
		return true
1126
	case "password":
1127
		return true
1128
	case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
1129
		return true
1130
	case "fallback_application_name":
1131
		return true
1132
	case "connect_timeout":
1133
		return true
1134
	case "disable_prepared_binary_result":
1135
		return true
1136
	case "binary_parameters":
1137
		return true
1138
	case "krbsrvname":
1139
		return true
1140
	case "krbspn":
1141
		return true
1142
	default:
1143
		return false
1144
	}
1145
}
1146
1147
func (cn *conn) startup(o values) {
1148
	w := cn.writeBuf(0)
1149
	w.int32(196608)
1150
	// Send the backend the name of the database we want to connect to, and the
1151
	// user we want to connect as.  Additionally, we send over any run-time
1152
	// parameters potentially included in the connection string.  If the server
1153
	// doesn't recognize any of them, it will reply with an error.
1154
	for k, v := range o {
1155
		if isDriverSetting(k) {
1156
			// skip options which can't be run-time parameters
1157
			continue
1158
		}
1159
		// The protocol requires us to supply the database name as "database"
1160
		// instead of "dbname".
1161
		if k == "dbname" {
1162
			k = "database"
1163
		}
1164
		w.string(k)
1165
		w.string(v)
1166
	}
1167
	w.string("")
1168
	if err := cn.sendStartupPacket(w); err != nil {
1169
		panic(err)
1170
	}
1171
1172
	for {
1173
		t, r := cn.recv()
1174
		switch t {
1175
		case 'K':
1176
			cn.processBackendKeyData(r)
1177
		case 'S':
1178
			cn.processParameterStatus(r)
1179
		case 'R':
1180
			cn.auth(r, o)
1181
		case 'Z':
1182
			cn.processReadyForQuery(r)
1183
			return
1184
		default:
1185
			errorf("unknown response for startup: %q", t)
1186
		}
1187
	}
1188
}
1189
1190
func (cn *conn) auth(r *readBuf, o values) {
1191
	switch code := r.int32(); code {
1192
	case 0:
1193
		// OK
1194
	case 3:
1195
		w := cn.writeBuf('p')
1196
		w.string(o["password"])
1197
		cn.send(w)
1198
1199
		t, r := cn.recv()
1200
		if t != 'R' {
1201
			errorf("unexpected password response: %q", t)
1202
		}
1203
1204
		if r.int32() != 0 {
1205
			errorf("unexpected authentication response: %q", t)
1206
		}
1207
	case 5:
1208
		s := string(r.next(4))
1209
		w := cn.writeBuf('p')
1210
		w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1211
		cn.send(w)
1212
1213
		t, r := cn.recv()
1214
		if t != 'R' {
1215
			errorf("unexpected password response: %q", t)
1216
		}
1217
1218
		if r.int32() != 0 {
1219
			errorf("unexpected authentication response: %q", t)
1220
		}
1221
	case 7: // GSSAPI, startup
1222
		if newGss == nil {
1223
			errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
1224
		}
1225
		cli, err := newGss()
1226
		if err != nil {
1227
			errorf("kerberos error: %s", err.Error())
1228
		}
1229
1230
		var token []byte
1231
1232
		if spn, ok := o["krbspn"]; ok {
1233
			// Use the supplied SPN if provided..
1234
			token, err = cli.GetInitTokenFromSpn(spn)
1235
		} else {
1236
			// Allow the kerberos service name to be overridden
1237
			service := "postgres"
1238
			if val, ok := o["krbsrvname"]; ok {
1239
				service = val
1240
			}
1241
1242
			token, err = cli.GetInitToken(o["host"], service)
1243
		}
1244
1245
		if err != nil {
1246
			errorf("failed to get Kerberos ticket: %q", err)
1247
		}
1248
1249
		w := cn.writeBuf('p')
1250
		w.bytes(token)
1251
		cn.send(w)
1252
1253
		// Store for GSSAPI continue message
1254
		cn.gss = cli
1255
1256
	case 8: // GSSAPI continue
1257
1258
		if cn.gss == nil {
1259
			errorf("GSSAPI protocol error")
1260
		}
1261
1262
		b := []byte(*r)
1263
1264
		done, tokOut, err := cn.gss.Continue(b)
1265
		if err == nil && !done {
1266
			w := cn.writeBuf('p')
1267
			w.bytes(tokOut)
1268
			cn.send(w)
1269
		}
1270
1271
		// Errors fall through and read the more detailed message
1272
		// from the server..
1273
1274
	case 10:
1275
		sc := scram.NewClient(sha256.New, o["user"], o["password"])
1276
		sc.Step(nil)
1277
		if sc.Err() != nil {
1278
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1279
		}
1280
		scOut := sc.Out()
1281
1282
		w := cn.writeBuf('p')
1283
		w.string("SCRAM-SHA-256")
1284
		w.int32(len(scOut))
1285
		w.bytes(scOut)
1286
		cn.send(w)
1287
1288
		t, r := cn.recv()
1289
		if t != 'R' {
1290
			errorf("unexpected password response: %q", t)
1291
		}
1292
1293
		if r.int32() != 11 {
1294
			errorf("unexpected authentication response: %q", t)
1295
		}
1296
1297
		nextStep := r.next(len(*r))
1298
		sc.Step(nextStep)
1299
		if sc.Err() != nil {
1300
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1301
		}
1302
1303
		scOut = sc.Out()
1304
		w = cn.writeBuf('p')
1305
		w.bytes(scOut)
1306
		cn.send(w)
1307
1308
		t, r = cn.recv()
1309
		if t != 'R' {
1310
			errorf("unexpected password response: %q", t)
1311
		}
1312
1313
		if r.int32() != 12 {
1314
			errorf("unexpected authentication response: %q", t)
1315
		}
1316
1317
		nextStep = r.next(len(*r))
1318
		sc.Step(nextStep)
1319
		if sc.Err() != nil {
1320
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1321
		}
1322
1323
	default:
1324
		errorf("unknown authentication response: %d", code)
1325
	}
1326
}
1327
1328
type format int
1329
1330
const formatText format = 0
1331
const formatBinary format = 1
1332
1333
// One result-column format code with the value 1 (i.e. all binary).
1334
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1335
1336
// No result-column format codes (i.e. all text).
1337
var colFmtDataAllText = []byte{0, 0}
1338
1339
type stmt struct {
1340
	cn   *conn
1341
	name string
1342
	rowsHeader
1343
	colFmtData []byte
1344
	paramTyps  []oid.Oid
1345
	closed     bool
1346
}
1347
1348
func (st *stmt) Close() (err error) {
1349
	if st.closed {
1350
		return nil
1351
	}
1352
	if err := st.cn.err.get(); err != nil {
1353
		return err
1354
	}
1355
	defer st.cn.errRecover(&err)
1356
1357
	w := st.cn.writeBuf('C')
1358
	w.byte('S')
1359
	w.string(st.name)
1360
	st.cn.send(w)
1361
1362
	st.cn.send(st.cn.writeBuf('S'))
1363
1364
	t, _ := st.cn.recv1()
1365
	if t != '3' {
1366
		st.cn.err.set(driver.ErrBadConn)
1367
		errorf("unexpected close response: %q", t)
1368
	}
1369
	st.closed = true
1370
1371
	t, r := st.cn.recv1()
1372
	if t != 'Z' {
1373
		st.cn.err.set(driver.ErrBadConn)
1374
		errorf("expected ready for query, but got: %q", t)
1375
	}
1376
	st.cn.processReadyForQuery(r)
1377
1378
	return nil
1379
}
1380
1381
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1382
	return st.query(v)
1383
}
1384
1385
func (st *stmt) query(v []driver.Value) (r *rows, err error) {
1386
	if err := st.cn.err.get(); err != nil {
1387
		return nil, err
1388
	}
1389
	defer st.cn.errRecover(&err)
1390
1391
	st.exec(v)
1392
	return &rows{
1393
		cn:         st.cn,
1394
		rowsHeader: st.rowsHeader,
1395
	}, nil
1396
}
1397
1398
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1399
	if err := st.cn.err.get(); err != nil {
1400
		return nil, err
1401
	}
1402
	defer st.cn.errRecover(&err)
1403
1404
	st.exec(v)
1405
	res, _, err = st.cn.readExecuteResponse("simple query")
1406
	return res, err
1407
}
1408
1409
func (st *stmt) exec(v []driver.Value) {
1410
	if len(v) >= 65536 {
1411
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
1412
	}
1413
	if len(v) != len(st.paramTyps) {
1414
		errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
1415
	}
1416
1417
	cn := st.cn
1418
	w := cn.writeBuf('B')
1419
	w.byte(0) // unnamed portal
1420
	w.string(st.name)
1421
1422
	if cn.binaryParameters {
1423
		cn.sendBinaryParameters(w, v)
1424
	} else {
1425
		w.int16(0)
1426
		w.int16(len(v))
1427
		for i, x := range v {
1428
			if x == nil {
1429
				w.int32(-1)
1430
			} else {
1431
				b := encode(&cn.parameterStatus, x, st.paramTyps[i])
1432
				w.int32(len(b))
1433
				w.bytes(b)
1434
			}
1435
		}
1436
	}
1437
	w.bytes(st.colFmtData)
1438
1439
	w.next('E')
1440
	w.byte(0)
1441
	w.int32(0)
1442
1443
	w.next('S')
1444
	cn.send(w)
1445
1446
	cn.readBindResponse()
1447
	cn.postExecuteWorkaround()
1448
1449
}
1450
1451
func (st *stmt) NumInput() int {
1452
	return len(st.paramTyps)
1453
}
1454
1455
// parseComplete parses the "command tag" from a CommandComplete message, and
1456
// returns the number of rows affected (if applicable) and a string
1457
// identifying only the command that was executed, e.g. "ALTER TABLE".  If the
1458
// command tag could not be parsed, parseComplete panics.
1459
func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
1460
	commandsWithAffectedRows := []string{
1461
		"SELECT ",
1462
		// INSERT is handled below
1463
		"UPDATE ",
1464
		"DELETE ",
1465
		"FETCH ",
1466
		"MOVE ",
1467
		"COPY ",
1468
	}
1469
1470
	var affectedRows *string
1471
	for _, tag := range commandsWithAffectedRows {
1472
		if strings.HasPrefix(commandTag, tag) {
1473
			t := commandTag[len(tag):]
1474
			affectedRows = &t
1475
			commandTag = tag[:len(tag)-1]
1476
			break
1477
		}
1478
	}
1479
	// INSERT also includes the oid of the inserted row in its command tag.
1480
	// Oids in user tables are deprecated, and the oid is only returned when
1481
	// exactly one row is inserted, so it's unlikely to be of value to any
1482
	// real-world application and we can ignore it.
1483
	if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
1484
		parts := strings.Split(commandTag, " ")
1485
		if len(parts) != 3 {
1486
			cn.err.set(driver.ErrBadConn)
1487
			errorf("unexpected INSERT command tag %s", commandTag)
1488
		}
1489
		affectedRows = &parts[len(parts)-1]
1490
		commandTag = "INSERT"
1491
	}
1492
	// There should be no affected rows attached to the tag, just return it
1493
	if affectedRows == nil {
1494
		return driver.RowsAffected(0), commandTag
1495
	}
1496
	n, err := strconv.ParseInt(*affectedRows, 10, 64)
1497
	if err != nil {
1498
		cn.err.set(driver.ErrBadConn)
1499
		errorf("could not parse commandTag: %s", err)
1500
	}
1501
	return driver.RowsAffected(n), commandTag
1502
}
1503
1504
type rowsHeader struct {
1505
	colNames []string
1506
	colTyps  []fieldDesc
1507
	colFmts  []format
1508
}
1509
1510
type rows struct {
1511
	cn     *conn
1512
	finish func()
1513
	rowsHeader
1514
	done   bool
1515
	rb     readBuf
1516
	result driver.Result
1517
	tag    string
1518
1519
	next *rowsHeader
1520
}
1521
1522
func (rs *rows) Close() error {
1523
	if finish := rs.finish; finish != nil {
1524
		defer finish()
1525
	}
1526
	// no need to look at cn.bad as Next() will
1527
	for {
1528
		err := rs.Next(nil)
1529
		switch err {
1530
		case nil:
1531
		case io.EOF:
1532
			// rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
1533
			// description, used with HasNextResultSet). We need to fetch messages until
1534
			// we hit a 'Z', which is done by waiting for done to be set.
1535
			if rs.done {
1536
				return nil
1537
			}
1538
		default:
1539
			return err
1540
		}
1541
	}
1542
}
1543
1544
func (rs *rows) Columns() []string {
1545
	return rs.colNames
1546
}
1547
1548
func (rs *rows) Result() driver.Result {
1549
	if rs.result == nil {
1550
		return emptyRows
1551
	}
1552
	return rs.result
1553
}
1554
1555
func (rs *rows) Tag() string {
1556
	return rs.tag
1557
}
1558
1559
func (rs *rows) Next(dest []driver.Value) (err error) {
1560
	if rs.done {
1561
		return io.EOF
1562
	}
1563
1564
	conn := rs.cn
1565
	if err := conn.err.getForNext(); err != nil {
1566
		return err
1567
	}
1568
	defer conn.errRecover(&err)
1569
1570
	for {
1571
		t := conn.recv1Buf(&rs.rb)
1572
		switch t {
1573
		case 'E':
1574
			err = parseError(&rs.rb)
1575
		case 'C', 'I':
1576
			if t == 'C' {
1577
				rs.result, rs.tag = conn.parseComplete(rs.rb.string())
1578
			}
1579
			continue
1580
		case 'Z':
1581
			conn.processReadyForQuery(&rs.rb)
1582
			rs.done = true
1583
			if err != nil {
1584
				return err
1585
			}
1586
			return io.EOF
1587
		case 'D':
1588
			n := rs.rb.int16()
1589
			if err != nil {
1590
				conn.err.set(driver.ErrBadConn)
1591
				errorf("unexpected DataRow after error %s", err)
1592
			}
1593
			if n < len(dest) {
1594
				dest = dest[:n]
1595
			}
1596
			for i := range dest {
1597
				l := rs.rb.int32()
1598
				if l == -1 {
1599
					dest[i] = nil
1600
					continue
1601
				}
1602
				dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
1603
			}
1604
			return
1605
		case 'T':
1606
			next := parsePortalRowDescribe(&rs.rb)
1607
			rs.next = &next
1608
			return io.EOF
1609
		default:
1610
			errorf("unexpected message after execute: %q", t)
1611
		}
1612
	}
1613
}
1614
1615
func (rs *rows) HasNextResultSet() bool {
1616
	hasNext := rs.next != nil && !rs.done
1617
	return hasNext
1618
}
1619
1620
func (rs *rows) NextResultSet() error {
1621
	if rs.next == nil {
1622
		return io.EOF
1623
	}
1624
	rs.rowsHeader = *rs.next
1625
	rs.next = nil
1626
	return nil
1627
}
1628
1629
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
1630
// used as part of an SQL statement.  For example:
1631
//
1632
//    tblname := "my_table"
1633
//    data := "my_data"
1634
//    quoted := pq.QuoteIdentifier(tblname)
1635
//    err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
1636
//
1637
// Any double quotes in name will be escaped.  The quoted identifier will be
1638
// case sensitive when used in a query.  If the input string contains a zero
1639
// byte, the result will be truncated immediately before it.
1640
func QuoteIdentifier(name string) string {
1641
	end := strings.IndexRune(name, 0)
1642
	if end > -1 {
1643
		name = name[:end]
1644
	}
1645
	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
1646
}
1647
1648
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
1649
// to DDL and other statements that do not accept parameters) to be used as part
1650
// of an SQL statement.  For example:
1651
//
1652
//    exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
1653
//    err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
1654
//
1655
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
1656
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
1657
// that PostgreSQL provides ('E') will be prepended to the string.
1658
func QuoteLiteral(literal string) string {
1659
	// This follows the PostgreSQL internal algorithm for handling quoted literals
1660
	// from libpq, which can be found in the "PQEscapeStringInternal" function,
1661
	// which is found in the libpq/fe-exec.c source file:
1662
	// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
1663
	//
1664
	// substitute any single-quotes (') with two single-quotes ('')
1665
	literal = strings.Replace(literal, `'`, `''`, -1)
1666
	// determine if the string has any backslashes (\) in it.
1667
	// if it does, replace any backslashes (\) with two backslashes (\\)
1668
	// then, we need to wrap the entire string with a PostgreSQL
1669
	// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
1670
	// also add a space before the "E"
1671
	if strings.Contains(literal, `\`) {
1672
		literal = strings.Replace(literal, `\`, `\\`, -1)
1673
		literal = ` E'` + literal + `'`
1674
	} else {
1675
		// otherwise, we can just wrap the literal with a pair of single quotes
1676
		literal = `'` + literal + `'`
1677
	}
1678
	return literal
1679
}
1680
1681
func md5s(s string) string {
1682
	h := md5.New()
1683
	h.Write([]byte(s))
1684
	return fmt.Sprintf("%x", h.Sum(nil))
1685
}
1686
1687
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
1688
	// Do one pass over the parameters to see if we're going to send any of
1689
	// them over in binary.  If we are, create a paramFormats array at the
1690
	// same time.
1691
	var paramFormats []int
1692
	for i, x := range args {
1693
		_, ok := x.([]byte)
1694
		if ok {
1695
			if paramFormats == nil {
1696
				paramFormats = make([]int, len(args))
1697
			}
1698
			paramFormats[i] = 1
1699
		}
1700
	}
1701
	if paramFormats == nil {
1702
		b.int16(0)
1703
	} else {
1704
		b.int16(len(paramFormats))
1705
		for _, x := range paramFormats {
1706
			b.int16(x)
1707
		}
1708
	}
1709
1710
	b.int16(len(args))
1711
	for _, x := range args {
1712
		if x == nil {
1713
			b.int32(-1)
1714
		} else {
1715
			datum := binaryEncode(&cn.parameterStatus, x)
1716
			b.int32(len(datum))
1717
			b.bytes(datum)
1718
		}
1719
	}
1720
}
1721
1722
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1723
	if len(args) >= 65536 {
1724
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
1725
	}
1726
1727
	b := cn.writeBuf('P')
1728
	b.byte(0) // unnamed statement
1729
	b.string(query)
1730
	b.int16(0)
1731
1732
	b.next('B')
1733
	b.int16(0) // unnamed portal and statement
1734
	cn.sendBinaryParameters(b, args)
1735
	b.bytes(colFmtDataAllText)
1736
1737
	b.next('D')
1738
	b.byte('P')
1739
	b.byte(0) // unnamed portal
1740
1741
	b.next('E')
1742
	b.byte(0)
1743
	b.int32(0)
1744
1745
	b.next('S')
1746
	cn.send(b)
1747
}
1748
1749
func (cn *conn) processParameterStatus(r *readBuf) {
1750
	var err error
1751
1752
	param := r.string()
1753
	switch param {
1754
	case "server_version":
1755
		var major1 int
1756
		var major2 int
1757
		_, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
1758
		if err == nil {
1759
			cn.parameterStatus.serverVersion = major1*10000 + major2*100
1760
		}
1761
1762
	case "TimeZone":
1763
		cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
1764
		if err != nil {
1765
			cn.parameterStatus.currentLocation = nil
1766
		}
1767
1768
	default:
1769
		// ignore
1770
	}
1771
}
1772
1773
func (cn *conn) processReadyForQuery(r *readBuf) {
1774
	cn.txnStatus = transactionStatus(r.byte())
1775
}
1776
1777
func (cn *conn) readReadyForQuery() {
1778
	t, r := cn.recv1()
1779
	switch t {
1780
	case 'Z':
1781
		cn.processReadyForQuery(r)
1782
		return
1783
	default:
1784
		cn.err.set(driver.ErrBadConn)
1785
		errorf("unexpected message %q; expected ReadyForQuery", t)
1786
	}
1787
}
1788
1789
func (cn *conn) processBackendKeyData(r *readBuf) {
1790
	cn.processID = r.int32()
1791
	cn.secretKey = r.int32()
1792
}
1793
1794
func (cn *conn) readParseResponse() {
1795
	t, r := cn.recv1()
1796
	switch t {
1797
	case '1':
1798
		return
1799
	case 'E':
1800
		err := parseError(r)
1801
		cn.readReadyForQuery()
1802
		panic(err)
1803
	default:
1804
		cn.err.set(driver.ErrBadConn)
1805
		errorf("unexpected Parse response %q", t)
1806
	}
1807
}
1808
1809
func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
1810
	for {
1811
		t, r := cn.recv1()
1812
		switch t {
1813
		case 't':
1814
			nparams := r.int16()
1815
			paramTyps = make([]oid.Oid, nparams)
1816
			for i := range paramTyps {
1817
				paramTyps[i] = r.oid()
1818
			}
1819
		case 'n':
1820
			return paramTyps, nil, nil
1821
		case 'T':
1822
			colNames, colTyps = parseStatementRowDescribe(r)
1823
			return paramTyps, colNames, colTyps
1824
		case 'E':
1825
			err := parseError(r)
1826
			cn.readReadyForQuery()
1827
			panic(err)
1828
		default:
1829
			cn.err.set(driver.ErrBadConn)
1830
			errorf("unexpected Describe statement response %q", t)
1831
		}
1832
	}
1833
}
1834
1835
func (cn *conn) readPortalDescribeResponse() rowsHeader {
1836
	t, r := cn.recv1()
1837
	switch t {
1838
	case 'T':
1839
		return parsePortalRowDescribe(r)
1840
	case 'n':
1841
		return rowsHeader{}
1842
	case 'E':
1843
		err := parseError(r)
1844
		cn.readReadyForQuery()
1845
		panic(err)
1846
	default:
1847
		cn.err.set(driver.ErrBadConn)
1848
		errorf("unexpected Describe response %q", t)
1849
	}
1850
	panic("not reached")
1851
}
1852
1853
func (cn *conn) readBindResponse() {
1854
	t, r := cn.recv1()
1855
	switch t {
1856
	case '2':
1857
		return
1858
	case 'E':
1859
		err := parseError(r)
1860
		cn.readReadyForQuery()
1861
		panic(err)
1862
	default:
1863
		cn.err.set(driver.ErrBadConn)
1864
		errorf("unexpected Bind response %q", t)
1865
	}
1866
}
1867
1868
func (cn *conn) postExecuteWorkaround() {
1869
	// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
1870
	// any errors from rows.Next, which masks errors that happened during the
1871
	// execution of the query.  To avoid the problem in common cases, we wait
1872
	// here for one more message from the database.  If it's not an error the
1873
	// query will likely succeed (or perhaps has already, if it's a
1874
	// CommandComplete), so we push the message into the conn struct; recv1
1875
	// will return it as the next message for rows.Next or rows.Close.
1876
	// However, if it's an error, we wait until ReadyForQuery and then return
1877
	// the error to our caller.
1878
	for {
1879
		t, r := cn.recv1()
1880
		switch t {
1881
		case 'E':
1882
			err := parseError(r)
1883
			cn.readReadyForQuery()
1884
			panic(err)
1885
		case 'C', 'D', 'I':
1886
			// the query didn't fail, but we can't process this message
1887
			cn.saveMessage(t, r)
1888
			return
1889
		default:
1890
			cn.err.set(driver.ErrBadConn)
1891
			errorf("unexpected message during extended query execution: %q", t)
1892
		}
1893
	}
1894
}
1895
1896
// Only for Exec(), since we ignore the returned data
1897
func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
1898
	for {
1899
		t, r := cn.recv1()
1900
		switch t {
1901
		case 'C':
1902
			if err != nil {
1903
				cn.err.set(driver.ErrBadConn)
1904
				errorf("unexpected CommandComplete after error %s", err)
1905
			}
1906
			res, commandTag = cn.parseComplete(r.string())
1907
		case 'Z':
1908
			cn.processReadyForQuery(r)
1909
			if res == nil && err == nil {
1910
				err = errUnexpectedReady
1911
			}
1912
			return res, commandTag, err
1913
		case 'E':
1914
			err = parseError(r)
1915
		case 'T', 'D', 'I':
1916
			if err != nil {
1917
				cn.err.set(driver.ErrBadConn)
1918
				errorf("unexpected %q after error %s", t, err)
1919
			}
1920
			if t == 'I' {
1921
				res = emptyRows
1922
			}
1923
			// ignore any results
1924
		default:
1925
			cn.err.set(driver.ErrBadConn)
1926
			errorf("unknown %s response: %q", protocolState, t)
1927
		}
1928
	}
1929
}
1930
1931
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
1932
	n := r.int16()
1933
	colNames = make([]string, n)
1934
	colTyps = make([]fieldDesc, n)
1935
	for i := range colNames {
1936
		colNames[i] = r.string()
1937
		r.next(6)
1938
		colTyps[i].OID = r.oid()
1939
		colTyps[i].Len = r.int16()
1940
		colTyps[i].Mod = r.int32()
1941
		// format code not known when describing a statement; always 0
1942
		r.next(2)
1943
	}
1944
	return
1945
}
1946
1947
func parsePortalRowDescribe(r *readBuf) rowsHeader {
1948
	n := r.int16()
1949
	colNames := make([]string, n)
1950
	colFmts := make([]format, n)
1951
	colTyps := make([]fieldDesc, n)
1952
	for i := range colNames {
1953
		colNames[i] = r.string()
1954
		r.next(6)
1955
		colTyps[i].OID = r.oid()
1956
		colTyps[i].Len = r.int16()
1957
		colTyps[i].Mod = r.int32()
1958
		colFmts[i] = format(r.int16())
1959
	}
1960
	return rowsHeader{
1961
		colNames: colNames,
1962
		colFmts:  colFmts,
1963
		colTyps:  colTyps,
1964
	}
1965
}
1966
1967
// parseEnviron tries to mimic some of libpq's environment handling
1968
//
1969
// To ease testing, it does not directly reference os.Environ, but is
1970
// designed to accept its output.
1971
//
1972
// Environment-set connection information is intended to have a higher
1973
// precedence than a library default but lower than any explicitly
1974
// passed information (such as in the URL or connection string).
1975
func parseEnviron(env []string) (out map[string]string) {
1976
	out = make(map[string]string)
1977
1978
	for _, v := range env {
1979
		parts := strings.SplitN(v, "=", 2)
1980
1981
		accrue := func(keyname string) {
1982
			out[keyname] = parts[1]
1983
		}
1984
		unsupported := func() {
1985
			panic(fmt.Sprintf("setting %v not supported", parts[0]))
1986
		}
1987
1988
		// The order of these is the same as is seen in the
1989
		// PostgreSQL 9.1 manual. Unsupported but well-defined
1990
		// keys cause a panic; these should be unset prior to
1991
		// execution. Options which pq expects to be set to a
1992
		// certain value are allowed, but must be set to that
1993
		// value if present (they can, of course, be absent).
1994
		switch parts[0] {
1995
		case "PGHOST":
1996
			accrue("host")
1997
		case "PGHOSTADDR":
1998
			unsupported()
1999
		case "PGPORT":
2000
			accrue("port")
2001
		case "PGDATABASE":
2002
			accrue("dbname")
2003
		case "PGUSER":
2004
			accrue("user")
2005
		case "PGPASSWORD":
2006
			accrue("password")
2007
		case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
2008
			unsupported()
2009
		case "PGOPTIONS":
2010
			accrue("options")
2011
		case "PGAPPNAME":
2012
			accrue("application_name")
2013
		case "PGSSLMODE":
2014
			accrue("sslmode")
2015
		case "PGSSLCERT":
2016
			accrue("sslcert")
2017
		case "PGSSLKEY":
2018
			accrue("sslkey")
2019
		case "PGSSLROOTCERT":
2020
			accrue("sslrootcert")
2021
		case "PGREQUIRESSL", "PGSSLCRL":
2022
			unsupported()
2023
		case "PGREQUIREPEER":
2024
			unsupported()
2025
		case "PGKRBSRVNAME", "PGGSSLIB":
2026
			unsupported()
2027
		case "PGCONNECT_TIMEOUT":
2028
			accrue("connect_timeout")
2029
		case "PGCLIENTENCODING":
2030
			accrue("client_encoding")
2031
		case "PGDATESTYLE":
2032
			accrue("datestyle")
2033
		case "PGTZ":
2034
			accrue("timezone")
2035
		case "PGGEQO":
2036
			accrue("geqo")
2037
		case "PGSYSCONFDIR", "PGLOCALEDIR":
2038
			unsupported()
2039
		}
2040
	}
2041
2042
	return out
2043
}
2044
2045
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
2046
func isUTF8(name string) bool {
2047
	// Recognize all sorts of silly things as "UTF-8", like Postgres does
2048
	s := strings.Map(alnumLowerASCII, name)
2049
	return s == "utf8" || s == "unicode"
2050
}
2051
2052
func alnumLowerASCII(ch rune) rune {
2053
	if 'A' <= ch && ch <= 'Z' {
2054
		return ch + ('a' - 'A')
2055
	}
2056
	if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
2057
		return ch
2058
	}
2059
	return -1 // discard
2060
}
2061