Test Failed
Push — main ( 973aa1...436074 )
by Christian
02:37
created

pq.*conn.sendBinaryModeQuery   A

Complexity

Conditions 2

Size

Total Lines 25
Code Lines 19

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 2
eloc 19
nop 2
dl 0
loc 25
rs 9.45
c 0
b 0
f 0
1
package pq
2
3
import (
4
	"bufio"
5
	"context"
6
	"crypto/md5"
7
	"crypto/sha256"
8
	"database/sql"
9
	"database/sql/driver"
10
	"encoding/binary"
11
	"errors"
12
	"fmt"
13
	"io"
14
	"net"
15
	"os"
16
	"os/user"
17
	"path"
18
	"path/filepath"
19
	"strconv"
20
	"strings"
21
	"sync"
22
	"time"
23
	"unicode"
24
25
	"github.com/lib/pq/oid"
26
	"github.com/lib/pq/scram"
27
)
28
29
// Common error types
30
var (
31
	ErrNotSupported              = errors.New("pq: Unsupported command")
32
	ErrInFailedTransaction       = errors.New("pq: Could not complete operation in a failed transaction")
33
	ErrSSLNotSupported           = errors.New("pq: SSL is not enabled on the server")
34
	ErrSSLKeyUnknownOwnership    = errors.New("pq: Could not get owner information for private key, may not be properly protected")
35
	ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key has world access. Permissions should be u=rw,g=r (0640) if owned by root, or u=rw (0600), or less")
36
37
	ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
38
39
	errUnexpectedReady = errors.New("unexpected ReadyForQuery")
40
	errNoRowsAffected  = errors.New("no RowsAffected available after the empty statement")
41
	errNoLastInsertID  = errors.New("no LastInsertId available after the empty statement")
42
)
43
44
// Compile time validation that our types implement the expected interfaces
45
var (
46
	_ driver.Driver = Driver{}
47
)
48
49
// Driver is the Postgres database driver.
50
type Driver struct{}
51
52
// Open opens a new connection to the database. name is a connection string.
53
// Most users should only use it through database/sql package from the standard
54
// library.
55
func (d Driver) Open(name string) (driver.Conn, error) {
56
	return Open(name)
57
}
58
59
func init() {
60
	sql.Register("postgres", &Driver{})
61
}
62
63
type parameterStatus struct {
64
	// server version in the same format as server_version_num, or 0 if
65
	// unavailable
66
	serverVersion int
67
68
	// the current location based on the TimeZone value of the session, if
69
	// available
70
	currentLocation *time.Location
71
}
72
73
type transactionStatus byte
74
75
const (
76
	txnStatusIdle                transactionStatus = 'I'
77
	txnStatusIdleInTransaction   transactionStatus = 'T'
78
	txnStatusInFailedTransaction transactionStatus = 'E'
79
)
80
81
func (s transactionStatus) String() string {
82
	switch s {
83
	case txnStatusIdle:
84
		return "idle"
85
	case txnStatusIdleInTransaction:
86
		return "idle in transaction"
87
	case txnStatusInFailedTransaction:
88
		return "in a failed transaction"
89
	default:
90
		errorf("unknown transactionStatus %d", s)
91
	}
92
93
	panic("not reached")
94
}
95
96
// Dialer is the dialer interface. It can be used to obtain more control over
97
// how pq creates network connections.
98
type Dialer interface {
99
	Dial(network, address string) (net.Conn, error)
100
	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
101
}
102
103
// DialerContext is the context-aware dialer interface.
104
type DialerContext interface {
105
	DialContext(ctx context.Context, network, address string) (net.Conn, error)
106
}
107
108
type defaultDialer struct {
109
	d net.Dialer
110
}
111
112
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
113
	return d.d.Dial(network, address)
114
}
115
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
116
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
117
	defer cancel()
118
	return d.DialContext(ctx, network, address)
119
}
120
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
121
	return d.d.DialContext(ctx, network, address)
122
}
123
124
type conn struct {
125
	c         net.Conn
126
	buf       *bufio.Reader
127
	namei     int
128
	scratch   [512]byte
129
	txnStatus transactionStatus
130
	txnFinish func()
131
132
	// Save connection arguments to use during CancelRequest.
133
	dialer Dialer
134
	opts   values
135
136
	// Cancellation key data for use with CancelRequest messages.
137
	processID int
138
	secretKey int
139
140
	parameterStatus parameterStatus
141
142
	saveMessageType   byte
143
	saveMessageBuffer []byte
144
145
	// If an error is set, this connection is bad and all public-facing
146
	// functions should return the appropriate error by calling get()
147
	// (ErrBadConn) or getForNext().
148
	err syncErr
149
150
	// If set, this connection should never use the binary format when
151
	// receiving query results from prepared statements.  Only provided for
152
	// debugging.
153
	disablePreparedBinaryResult bool
154
155
	// Whether to always send []byte parameters over as binary.  Enables single
156
	// round-trip mode for non-prepared Query calls.
157
	binaryParameters bool
158
159
	// If true this connection is in the middle of a COPY
160
	inCopy bool
161
162
	// If not nil, notices will be synchronously sent here
163
	noticeHandler func(*Error)
164
165
	// If not nil, notifications will be synchronously sent here
166
	notificationHandler func(*Notification)
167
168
	// GSSAPI context
169
	gss GSS
170
}
171
172
type syncErr struct {
173
	err error
174
	sync.Mutex
175
}
176
177
// Return ErrBadConn if connection is bad.
178
func (e *syncErr) get() error {
179
	e.Lock()
180
	defer e.Unlock()
181
	if e.err != nil {
182
		return driver.ErrBadConn
183
	}
184
	return nil
185
}
186
187
// Return the error set on the connection. Currently only used by rows.Next.
188
func (e *syncErr) getForNext() error {
189
	e.Lock()
190
	defer e.Unlock()
191
	return e.err
192
}
193
194
// Set error, only if it isn't set yet.
195
func (e *syncErr) set(err error) {
196
	if err == nil {
197
		panic("attempt to set nil err")
198
	}
199
	e.Lock()
200
	defer e.Unlock()
201
	if e.err == nil {
202
		e.err = err
203
	}
204
}
205
206
// Handle driver-side settings in parsed connection string.
207
func (cn *conn) handleDriverSettings(o values) (err error) {
208
	boolSetting := func(key string, val *bool) error {
209
		if value, ok := o[key]; ok {
210
			if value == "yes" {
211
				*val = true
212
			} else if value == "no" {
213
				*val = false
214
			} else {
215
				return fmt.Errorf("unrecognized value %q for %s", value, key)
216
			}
217
		}
218
		return nil
219
	}
220
221
	err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
222
	if err != nil {
223
		return err
224
	}
225
	return boolSetting("binary_parameters", &cn.binaryParameters)
226
}
227
228
func (cn *conn) handlePgpass(o values) {
229
	// if a password was supplied, do not process .pgpass
230
	if _, ok := o["password"]; ok {
231
		return
232
	}
233
	filename := os.Getenv("PGPASSFILE")
234
	if filename == "" {
235
		// XXX this code doesn't work on Windows where the default filename is
236
		// XXX %APPDATA%\postgresql\pgpass.conf
237
		// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
238
		userHome := os.Getenv("HOME")
239
		if userHome == "" {
240
			user, err := user.Current()
241
			if err != nil {
242
				return
243
			}
244
			userHome = user.HomeDir
245
		}
246
		filename = filepath.Join(userHome, ".pgpass")
247
	}
248
	fileinfo, err := os.Stat(filename)
249
	if err != nil {
250
		return
251
	}
252
	mode := fileinfo.Mode()
253
	if mode&(0x77) != 0 {
254
		// XXX should warn about incorrect .pgpass permissions as psql does
255
		return
256
	}
257
	file, err := os.Open(filename)
258
	if err != nil {
259
		return
260
	}
261
	defer file.Close()
262
	scanner := bufio.NewScanner(io.Reader(file))
263
	hostname := o["host"]
264
	ntw, _ := network(o)
265
	port := o["port"]
266
	db := o["dbname"]
267
	username := o["user"]
268
	// From: https://github.com/tg/pgpass/blob/master/reader.go
269
	getFields := func(s string) []string {
270
		fs := make([]string, 0, 5)
271
		f := make([]rune, 0, len(s))
272
273
		var esc bool
274
		for _, c := range s {
275
			switch {
276
			case esc:
277
				f = append(f, c)
278
				esc = false
279
			case c == '\\':
280
				esc = true
281
			case c == ':':
282
				fs = append(fs, string(f))
283
				f = f[:0]
284
			default:
285
				f = append(f, c)
286
			}
287
		}
288
		return append(fs, string(f))
289
	}
290
	for scanner.Scan() {
291
		line := scanner.Text()
292
		if len(line) == 0 || line[0] == '#' {
293
			continue
294
		}
295
		split := getFields(line)
296
		if len(split) != 5 {
297
			continue
298
		}
299
		if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
300
			o["password"] = split[4]
301
			return
302
		}
303
	}
304
}
305
306
func (cn *conn) writeBuf(b byte) *writeBuf {
307
	cn.scratch[0] = b
308
	return &writeBuf{
309
		buf: cn.scratch[:5],
310
		pos: 1,
311
	}
312
}
313
314
// Open opens a new connection to the database. dsn is a connection string.
315
// Most users should only use it through database/sql package from the standard
316
// library.
317
func Open(dsn string) (_ driver.Conn, err error) {
318
	return DialOpen(defaultDialer{}, dsn)
319
}
320
321
// DialOpen opens a new connection to the database using a dialer.
322
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
323
	c, err := NewConnector(dsn)
324
	if err != nil {
325
		return nil, err
326
	}
327
	c.Dialer(d)
328
	return c.open(context.Background())
329
}
330
331
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
332
	// Handle any panics during connection initialization.  Note that we
333
	// specifically do *not* want to use errRecover(), as that would turn any
334
	// connection errors into ErrBadConns, hiding the real error message from
335
	// the user.
336
	defer errRecoverNoErrBadConn(&err)
337
338
	// Create a new values map (copy). This makes it so maps in different
339
	// connections do not reference the same underlying data structure, so it
340
	// is safe for multiple connections to concurrently write to their opts.
341
	o := make(values)
342
	for k, v := range c.opts {
343
		o[k] = v
344
	}
345
346
	cn = &conn{
347
		opts:   o,
348
		dialer: c.dialer,
349
	}
350
	err = cn.handleDriverSettings(o)
351
	if err != nil {
352
		return nil, err
353
	}
354
	cn.handlePgpass(o)
355
356
	cn.c, err = dial(ctx, c.dialer, o)
357
	if err != nil {
358
		return nil, err
359
	}
360
361
	err = cn.ssl(o)
362
	if err != nil {
363
		if cn.c != nil {
364
			cn.c.Close()
365
		}
366
		return nil, err
367
	}
368
369
	// cn.startup panics on error. Make sure we don't leak cn.c.
370
	panicking := true
371
	defer func() {
372
		if panicking {
373
			cn.c.Close()
374
		}
375
	}()
376
377
	cn.buf = bufio.NewReader(cn.c)
378
	cn.startup(o)
379
380
	// reset the deadline, in case one was set (see dial)
381
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
382
		err = cn.c.SetDeadline(time.Time{})
383
	}
384
	panicking = false
385
	return cn, err
386
}
387
388
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
389
	network, address := network(o)
390
391
	// Zero or not specified means wait indefinitely.
392
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
393
		seconds, err := strconv.ParseInt(timeout, 10, 0)
394
		if err != nil {
395
			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
396
		}
397
		duration := time.Duration(seconds) * time.Second
398
399
		// connect_timeout should apply to the entire connection establishment
400
		// procedure, so we both use a timeout for the TCP connection
401
		// establishment and set a deadline for doing the initial handshake.
402
		// The deadline is then reset after startup() is done.
403
		deadline := time.Now().Add(duration)
404
		var conn net.Conn
405
		if dctx, ok := d.(DialerContext); ok {
406
			ctx, cancel := context.WithTimeout(ctx, duration)
407
			defer cancel()
408
			conn, err = dctx.DialContext(ctx, network, address)
409
		} else {
410
			conn, err = d.DialTimeout(network, address, duration)
411
		}
412
		if err != nil {
413
			return nil, err
414
		}
415
		err = conn.SetDeadline(deadline)
416
		return conn, err
417
	}
418
	if dctx, ok := d.(DialerContext); ok {
419
		return dctx.DialContext(ctx, network, address)
420
	}
421
	return d.Dial(network, address)
422
}
423
424
func network(o values) (string, string) {
425
	host := o["host"]
426
427
	if strings.HasPrefix(host, "/") {
428
		sockPath := path.Join(host, ".s.PGSQL."+o["port"])
429
		return "unix", sockPath
430
	}
431
432
	return "tcp", net.JoinHostPort(host, o["port"])
433
}
434
435
type values map[string]string
436
437
// scanner implements a tokenizer for libpq-style option strings.
438
type scanner struct {
439
	s []rune
440
	i int
441
}
442
443
// newScanner returns a new scanner initialized with the option string s.
444
func newScanner(s string) *scanner {
445
	return &scanner{[]rune(s), 0}
446
}
447
448
// Next returns the next rune.
449
// It returns 0, false if the end of the text has been reached.
450
func (s *scanner) Next() (rune, bool) {
451
	if s.i >= len(s.s) {
452
		return 0, false
453
	}
454
	r := s.s[s.i]
455
	s.i++
456
	return r, true
457
}
458
459
// SkipSpaces returns the next non-whitespace rune.
460
// It returns 0, false if the end of the text has been reached.
461
func (s *scanner) SkipSpaces() (rune, bool) {
462
	r, ok := s.Next()
463
	for unicode.IsSpace(r) && ok {
464
		r, ok = s.Next()
465
	}
466
	return r, ok
467
}
468
469
// parseOpts parses the options from name and adds them to the values.
470
//
471
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
472
func parseOpts(name string, o values) error {
473
	s := newScanner(name)
474
475
	for {
476
		var (
477
			keyRunes, valRunes []rune
478
			r                  rune
479
			ok                 bool
480
		)
481
482
		if r, ok = s.SkipSpaces(); !ok {
483
			break
484
		}
485
486
		// Scan the key
487
		for !unicode.IsSpace(r) && r != '=' {
488
			keyRunes = append(keyRunes, r)
489
			if r, ok = s.Next(); !ok {
490
				break
491
			}
492
		}
493
494
		// Skip any whitespace if we're not at the = yet
495
		if r != '=' {
496
			r, ok = s.SkipSpaces()
497
		}
498
499
		// The current character should be =
500
		if r != '=' || !ok {
501
			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
502
		}
503
504
		// Skip any whitespace after the =
505
		if r, ok = s.SkipSpaces(); !ok {
506
			// If we reach the end here, the last value is just an empty string as per libpq.
507
			o[string(keyRunes)] = ""
508
			break
509
		}
510
511
		if r != '\'' {
512
			for !unicode.IsSpace(r) {
513
				if r == '\\' {
514
					if r, ok = s.Next(); !ok {
515
						return fmt.Errorf(`missing character after backslash`)
516
					}
517
				}
518
				valRunes = append(valRunes, r)
519
520
				if r, ok = s.Next(); !ok {
521
					break
522
				}
523
			}
524
		} else {
525
		quote:
526
			for {
527
				if r, ok = s.Next(); !ok {
528
					return fmt.Errorf(`unterminated quoted string literal in connection string`)
529
				}
530
				switch r {
531
				case '\'':
532
					break quote
533
				case '\\':
534
					r, _ = s.Next()
535
					fallthrough
536
				default:
537
					valRunes = append(valRunes, r)
538
				}
539
			}
540
		}
541
542
		o[string(keyRunes)] = string(valRunes)
543
	}
544
545
	return nil
546
}
547
548
func (cn *conn) isInTransaction() bool {
549
	return cn.txnStatus == txnStatusIdleInTransaction ||
550
		cn.txnStatus == txnStatusInFailedTransaction
551
}
552
553
func (cn *conn) checkIsInTransaction(intxn bool) {
554
	if cn.isInTransaction() != intxn {
555
		cn.err.set(driver.ErrBadConn)
556
		errorf("unexpected transaction status %v", cn.txnStatus)
557
	}
558
}
559
560
func (cn *conn) Begin() (_ driver.Tx, err error) {
561
	return cn.begin("")
562
}
563
564
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
565
	if err := cn.err.get(); err != nil {
566
		return nil, err
567
	}
568
	defer cn.errRecover(&err)
569
570
	cn.checkIsInTransaction(false)
571
	_, commandTag, err := cn.simpleExec("BEGIN" + mode)
572
	if err != nil {
573
		return nil, err
574
	}
575
	if commandTag != "BEGIN" {
576
		cn.err.set(driver.ErrBadConn)
577
		return nil, fmt.Errorf("unexpected command tag %s", commandTag)
578
	}
579
	if cn.txnStatus != txnStatusIdleInTransaction {
580
		cn.err.set(driver.ErrBadConn)
581
		return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
582
	}
583
	return cn, nil
584
}
585
586
func (cn *conn) closeTxn() {
587
	if finish := cn.txnFinish; finish != nil {
588
		finish()
589
	}
590
}
591
592
func (cn *conn) Commit() (err error) {
593
	defer cn.closeTxn()
594
	if err := cn.err.get(); err != nil {
595
		return err
596
	}
597
	defer cn.errRecover(&err)
598
599
	cn.checkIsInTransaction(true)
600
	// We don't want the client to think that everything is okay if it tries
601
	// to commit a failed transaction.  However, no matter what we return,
602
	// database/sql will release this connection back into the free connection
603
	// pool so we have to abort the current transaction here.  Note that you
604
	// would get the same behaviour if you issued a COMMIT in a failed
605
	// transaction, so it's also the least surprising thing to do here.
606
	if cn.txnStatus == txnStatusInFailedTransaction {
607
		if err := cn.rollback(); err != nil {
608
			return err
609
		}
610
		return ErrInFailedTransaction
611
	}
612
613
	_, commandTag, err := cn.simpleExec("COMMIT")
614
	if err != nil {
615
		if cn.isInTransaction() {
616
			cn.err.set(driver.ErrBadConn)
617
		}
618
		return err
619
	}
620
	if commandTag != "COMMIT" {
621
		cn.err.set(driver.ErrBadConn)
622
		return fmt.Errorf("unexpected command tag %s", commandTag)
623
	}
624
	cn.checkIsInTransaction(false)
625
	return nil
626
}
627
628
func (cn *conn) Rollback() (err error) {
629
	defer cn.closeTxn()
630
	if err := cn.err.get(); err != nil {
631
		return err
632
	}
633
	defer cn.errRecover(&err)
634
	return cn.rollback()
635
}
636
637
func (cn *conn) rollback() (err error) {
638
	cn.checkIsInTransaction(true)
639
	_, commandTag, err := cn.simpleExec("ROLLBACK")
640
	if err != nil {
641
		if cn.isInTransaction() {
642
			cn.err.set(driver.ErrBadConn)
643
		}
644
		return err
645
	}
646
	if commandTag != "ROLLBACK" {
647
		return fmt.Errorf("unexpected command tag %s", commandTag)
648
	}
649
	cn.checkIsInTransaction(false)
650
	return nil
651
}
652
653
func (cn *conn) gname() string {
654
	cn.namei++
655
	return strconv.FormatInt(int64(cn.namei), 10)
656
}
657
658
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
659
	b := cn.writeBuf('Q')
660
	b.string(q)
661
	cn.send(b)
662
663
	for {
664
		t, r := cn.recv1()
665
		switch t {
666
		case 'C':
667
			res, commandTag = cn.parseComplete(r.string())
668
		case 'Z':
669
			cn.processReadyForQuery(r)
670
			if res == nil && err == nil {
671
				err = errUnexpectedReady
672
			}
673
			// done
674
			return
675
		case 'E':
676
			err = parseError(r)
677
		case 'I':
678
			res = emptyRows
679
		case 'T', 'D':
680
			// ignore any results
681
		default:
682
			cn.err.set(driver.ErrBadConn)
683
			errorf("unknown response for simple query: %q", t)
684
		}
685
	}
686
}
687
688
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
689
	defer cn.errRecover(&err)
690
691
	b := cn.writeBuf('Q')
692
	b.string(q)
693
	cn.send(b)
694
695
	for {
696
		t, r := cn.recv1()
697
		switch t {
698
		case 'C', 'I':
699
			// We allow queries which don't return any results through Query as
700
			// well as Exec.  We still have to give database/sql a rows object
701
			// the user can close, though, to avoid connections from being
702
			// leaked.  A "rows" with done=true works fine for that purpose.
703
			if err != nil {
704
				cn.err.set(driver.ErrBadConn)
705
				errorf("unexpected message %q in simple query execution", t)
706
			}
707
			if res == nil {
708
				res = &rows{
709
					cn: cn,
710
				}
711
			}
712
			// Set the result and tag to the last command complete if there wasn't a
713
			// query already run. Although queries usually return from here and cede
714
			// control to Next, a query with zero results does not.
715
			if t == 'C' {
716
				res.result, res.tag = cn.parseComplete(r.string())
717
				if res.colNames != nil {
718
					return
719
				}
720
			}
721
			res.done = true
722
		case 'Z':
723
			cn.processReadyForQuery(r)
724
			// done
725
			return
726
		case 'E':
727
			res = nil
728
			err = parseError(r)
729
		case 'D':
730
			if res == nil {
731
				cn.err.set(driver.ErrBadConn)
732
				errorf("unexpected DataRow in simple query execution")
733
			}
734
			// the query didn't fail; kick off to Next
735
			cn.saveMessage(t, r)
736
			return
737
		case 'T':
738
			// res might be non-nil here if we received a previous
739
			// CommandComplete, but that's fine; just overwrite it
740
			res = &rows{cn: cn}
741
			res.rowsHeader = parsePortalRowDescribe(r)
742
743
			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
744
			// until the first DataRow has been received.
745
		default:
746
			cn.err.set(driver.ErrBadConn)
747
			errorf("unknown response for simple query: %q", t)
748
		}
749
	}
750
}
751
752
type noRows struct{}
753
754
var emptyRows noRows
755
756
var _ driver.Result = noRows{}
757
758
func (noRows) LastInsertId() (int64, error) {
759
	return 0, errNoLastInsertID
760
}
761
762
func (noRows) RowsAffected() (int64, error) {
763
	return 0, errNoRowsAffected
764
}
765
766
// Decides which column formats to use for a prepared statement.  The input is
767
// an array of type oids, one element per result column.
768
func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
769
	if len(colTyps) == 0 {
770
		return nil, colFmtDataAllText
771
	}
772
773
	colFmts = make([]format, len(colTyps))
774
	if forceText {
775
		return colFmts, colFmtDataAllText
776
	}
777
778
	allBinary := true
779
	allText := true
780
	for i, t := range colTyps {
781
		switch t.OID {
782
		// This is the list of types to use binary mode for when receiving them
783
		// through a prepared statement.  If a type appears in this list, it
784
		// must also be implemented in binaryDecode in encode.go.
785
		case oid.T_bytea:
786
			fallthrough
787
		case oid.T_int8:
788
			fallthrough
789
		case oid.T_int4:
790
			fallthrough
791
		case oid.T_int2:
792
			fallthrough
793
		case oid.T_uuid:
794
			colFmts[i] = formatBinary
795
			allText = false
796
797
		default:
798
			allBinary = false
799
		}
800
	}
801
802
	if allBinary {
803
		return colFmts, colFmtDataAllBinary
804
	} else if allText {
805
		return colFmts, colFmtDataAllText
806
	} else {
807
		colFmtData = make([]byte, 2+len(colFmts)*2)
808
		binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
809
		for i, v := range colFmts {
810
			binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
811
		}
812
		return colFmts, colFmtData
813
	}
814
}
815
816
func (cn *conn) prepareTo(q, stmtName string) *stmt {
817
	st := &stmt{cn: cn, name: stmtName}
818
819
	b := cn.writeBuf('P')
820
	b.string(st.name)
821
	b.string(q)
822
	b.int16(0)
823
824
	b.next('D')
825
	b.byte('S')
826
	b.string(st.name)
827
828
	b.next('S')
829
	cn.send(b)
830
831
	cn.readParseResponse()
832
	st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
833
	st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
834
	cn.readReadyForQuery()
835
	return st
836
}
837
838
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
839
	if err := cn.err.get(); err != nil {
840
		return nil, err
841
	}
842
	defer cn.errRecover(&err)
843
844
	if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
845
		s, err := cn.prepareCopyIn(q)
846
		if err == nil {
847
			cn.inCopy = true
848
		}
849
		return s, err
850
	}
851
	return cn.prepareTo(q, cn.gname()), nil
852
}
853
854
func (cn *conn) Close() (err error) {
855
	// Skip cn.bad return here because we always want to close a connection.
856
	defer cn.errRecover(&err)
857
858
	// Ensure that cn.c.Close is always run. Since error handling is done with
859
	// panics and cn.errRecover, the Close must be in a defer.
860
	defer func() {
861
		cerr := cn.c.Close()
862
		if err == nil {
863
			err = cerr
864
		}
865
	}()
866
867
	// Don't go through send(); ListenerConn relies on us not scribbling on the
868
	// scratch buffer of this connection.
869
	return cn.sendSimpleMessage('X')
870
}
871
872
// Implement the "Queryer" interface
873
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
874
	return cn.query(query, args)
875
}
876
877
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
878
	if err := cn.err.get(); err != nil {
879
		return nil, err
880
	}
881
	if cn.inCopy {
882
		return nil, errCopyInProgress
883
	}
884
	defer cn.errRecover(&err)
885
886
	// Check to see if we can use the "simpleQuery" interface, which is
887
	// *much* faster than going through prepare/exec
888
	if len(args) == 0 {
889
		return cn.simpleQuery(query)
890
	}
891
892
	if cn.binaryParameters {
893
		cn.sendBinaryModeQuery(query, args)
894
895
		cn.readParseResponse()
896
		cn.readBindResponse()
897
		rows := &rows{cn: cn}
898
		rows.rowsHeader = cn.readPortalDescribeResponse()
899
		cn.postExecuteWorkaround()
900
		return rows, nil
901
	}
902
	st := cn.prepareTo(query, "")
903
	st.exec(args)
904
	return &rows{
905
		cn:         cn,
906
		rowsHeader: st.rowsHeader,
907
	}, nil
908
}
909
910
// Implement the optional "Execer" interface for one-shot queries
911
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
912
	if err := cn.err.get(); err != nil {
913
		return nil, err
914
	}
915
	defer cn.errRecover(&err)
916
917
	// Check to see if we can use the "simpleExec" interface, which is
918
	// *much* faster than going through prepare/exec
919
	if len(args) == 0 {
920
		// ignore commandTag, our caller doesn't care
921
		r, _, err := cn.simpleExec(query)
922
		return r, err
923
	}
924
925
	if cn.binaryParameters {
926
		cn.sendBinaryModeQuery(query, args)
927
928
		cn.readParseResponse()
929
		cn.readBindResponse()
930
		cn.readPortalDescribeResponse()
931
		cn.postExecuteWorkaround()
932
		res, _, err = cn.readExecuteResponse("Execute")
933
		return res, err
934
	}
935
	// Use the unnamed statement to defer planning until bind
936
	// time, or else value-based selectivity estimates cannot be
937
	// used.
938
	st := cn.prepareTo(query, "")
939
	r, err := st.Exec(args)
940
	if err != nil {
941
		panic(err)
942
	}
943
	return r, err
944
}
945
946
type safeRetryError struct {
947
	Err error
948
}
949
950
func (se *safeRetryError) Error() string {
951
	return se.Err.Error()
952
}
953
954
func (cn *conn) send(m *writeBuf) {
955
	n, err := cn.c.Write(m.wrap())
956
	if err != nil {
957
		if n == 0 {
958
			err = &safeRetryError{Err: err}
959
		}
960
		panic(err)
961
	}
962
}
963
964
func (cn *conn) sendStartupPacket(m *writeBuf) error {
965
	_, err := cn.c.Write((m.wrap())[1:])
966
	return err
967
}
968
969
// Send a message of type typ to the server on the other end of cn.  The
970
// message should have no payload.  This method does not use the scratch
971
// buffer.
972
func (cn *conn) sendSimpleMessage(typ byte) (err error) {
973
	_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
974
	return err
975
}
976
977
// saveMessage memorizes a message and its buffer in the conn struct.
978
// recvMessage will then return these values on the next call to it.  This
979
// method is useful in cases where you have to see what the next message is
980
// going to be (e.g. to see whether it's an error or not) but you can't handle
981
// the message yourself.
982
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
983
	if cn.saveMessageType != 0 {
984
		cn.err.set(driver.ErrBadConn)
985
		errorf("unexpected saveMessageType %d", cn.saveMessageType)
986
	}
987
	cn.saveMessageType = typ
988
	cn.saveMessageBuffer = *buf
989
}
990
991
// recvMessage receives any message from the backend, or returns an error if
992
// a problem occurred while reading the message.
993
func (cn *conn) recvMessage(r *readBuf) (byte, error) {
994
	// workaround for a QueryRow bug, see exec
995
	if cn.saveMessageType != 0 {
996
		t := cn.saveMessageType
997
		*r = cn.saveMessageBuffer
998
		cn.saveMessageType = 0
999
		cn.saveMessageBuffer = nil
1000
		return t, nil
1001
	}
1002
1003
	x := cn.scratch[:5]
1004
	_, err := io.ReadFull(cn.buf, x)
1005
	if err != nil {
1006
		return 0, err
1007
	}
1008
1009
	// read the type and length of the message that follows
1010
	t := x[0]
1011
	n := int(binary.BigEndian.Uint32(x[1:])) - 4
1012
	var y []byte
1013
	if n <= len(cn.scratch) {
1014
		y = cn.scratch[:n]
1015
	} else {
1016
		y = make([]byte, n)
1017
	}
1018
	_, err = io.ReadFull(cn.buf, y)
1019
	if err != nil {
1020
		return 0, err
1021
	}
1022
	*r = y
1023
	return t, nil
1024
}
1025
1026
// recv receives a message from the backend, but if an error happened while
1027
// reading the message or the received message was an ErrorResponse, it panics.
1028
// NoticeResponses are ignored.  This function should generally be used only
1029
// during the startup sequence.
1030
func (cn *conn) recv() (t byte, r *readBuf) {
1031
	for {
1032
		var err error
1033
		r = &readBuf{}
1034
		t, err = cn.recvMessage(r)
1035
		if err != nil {
1036
			panic(err)
1037
		}
1038
		switch t {
1039
		case 'E':
1040
			panic(parseError(r))
1041
		case 'N':
1042
			if n := cn.noticeHandler; n != nil {
1043
				n(parseError(r))
1044
			}
1045
		case 'A':
1046
			if n := cn.notificationHandler; n != nil {
1047
				n(recvNotification(r))
1048
			}
1049
		default:
1050
			return
1051
		}
1052
	}
1053
}
1054
1055
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
1056
// the caller to avoid an allocation.
1057
func (cn *conn) recv1Buf(r *readBuf) byte {
1058
	for {
1059
		t, err := cn.recvMessage(r)
1060
		if err != nil {
1061
			panic(err)
1062
		}
1063
1064
		switch t {
1065
		case 'A':
1066
			if n := cn.notificationHandler; n != nil {
1067
				n(recvNotification(r))
1068
			}
1069
		case 'N':
1070
			if n := cn.noticeHandler; n != nil {
1071
				n(parseError(r))
1072
			}
1073
		case 'S':
1074
			cn.processParameterStatus(r)
1075
		default:
1076
			return t
1077
		}
1078
	}
1079
}
1080
1081
// recv1 receives a message from the backend, panicking if an error occurs
1082
// while attempting to read it.  All asynchronous messages are ignored, with
1083
// the exception of ErrorResponse.
1084
func (cn *conn) recv1() (t byte, r *readBuf) {
1085
	r = &readBuf{}
1086
	t = cn.recv1Buf(r)
1087
	return t, r
1088
}
1089
1090
func (cn *conn) ssl(o values) error {
1091
	upgrade, err := ssl(o)
1092
	if err != nil {
1093
		return err
1094
	}
1095
1096
	if upgrade == nil {
1097
		// Nothing to do
1098
		return nil
1099
	}
1100
1101
	w := cn.writeBuf(0)
1102
	w.int32(80877103)
1103
	if err = cn.sendStartupPacket(w); err != nil {
1104
		return err
1105
	}
1106
1107
	b := cn.scratch[:1]
1108
	_, err = io.ReadFull(cn.c, b)
1109
	if err != nil {
1110
		return err
1111
	}
1112
1113
	if b[0] != 'S' {
1114
		return ErrSSLNotSupported
1115
	}
1116
1117
	cn.c, err = upgrade(cn.c)
1118
	return err
1119
}
1120
1121
// isDriverSetting returns true iff a setting is purely for configuring the
1122
// driver's options and should not be sent to the server in the connection
1123
// startup packet.
1124
func isDriverSetting(key string) bool {
1125
	switch key {
1126
	case "host", "port":
1127
		return true
1128
	case "password":
1129
		return true
1130
	case "sslmode", "sslcert", "sslkey", "sslrootcert", "sslinline":
1131
		return true
1132
	case "fallback_application_name":
1133
		return true
1134
	case "connect_timeout":
1135
		return true
1136
	case "disable_prepared_binary_result":
1137
		return true
1138
	case "binary_parameters":
1139
		return true
1140
	case "krbsrvname":
1141
		return true
1142
	case "krbspn":
1143
		return true
1144
	default:
1145
		return false
1146
	}
1147
}
1148
1149
func (cn *conn) startup(o values) {
1150
	w := cn.writeBuf(0)
1151
	w.int32(196608)
1152
	// Send the backend the name of the database we want to connect to, and the
1153
	// user we want to connect as.  Additionally, we send over any run-time
1154
	// parameters potentially included in the connection string.  If the server
1155
	// doesn't recognize any of them, it will reply with an error.
1156
	for k, v := range o {
1157
		if isDriverSetting(k) {
1158
			// skip options which can't be run-time parameters
1159
			continue
1160
		}
1161
		// The protocol requires us to supply the database name as "database"
1162
		// instead of "dbname".
1163
		if k == "dbname" {
1164
			k = "database"
1165
		}
1166
		w.string(k)
1167
		w.string(v)
1168
	}
1169
	w.string("")
1170
	if err := cn.sendStartupPacket(w); err != nil {
1171
		panic(err)
1172
	}
1173
1174
	for {
1175
		t, r := cn.recv()
1176
		switch t {
1177
		case 'K':
1178
			cn.processBackendKeyData(r)
1179
		case 'S':
1180
			cn.processParameterStatus(r)
1181
		case 'R':
1182
			cn.auth(r, o)
1183
		case 'Z':
1184
			cn.processReadyForQuery(r)
1185
			return
1186
		default:
1187
			errorf("unknown response for startup: %q", t)
1188
		}
1189
	}
1190
}
1191
1192
func (cn *conn) auth(r *readBuf, o values) {
1193
	switch code := r.int32(); code {
1194
	case 0:
1195
		// OK
1196
	case 3:
1197
		w := cn.writeBuf('p')
1198
		w.string(o["password"])
1199
		cn.send(w)
1200
1201
		t, r := cn.recv()
1202
		if t != 'R' {
1203
			errorf("unexpected password response: %q", t)
1204
		}
1205
1206
		if r.int32() != 0 {
1207
			errorf("unexpected authentication response: %q", t)
1208
		}
1209
	case 5:
1210
		s := string(r.next(4))
1211
		w := cn.writeBuf('p')
1212
		w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1213
		cn.send(w)
1214
1215
		t, r := cn.recv()
1216
		if t != 'R' {
1217
			errorf("unexpected password response: %q", t)
1218
		}
1219
1220
		if r.int32() != 0 {
1221
			errorf("unexpected authentication response: %q", t)
1222
		}
1223
	case 7: // GSSAPI, startup
1224
		if newGss == nil {
1225
			errorf("kerberos error: no GSSAPI provider registered (import github.com/lib/pq/auth/kerberos if you need Kerberos support)")
1226
		}
1227
		cli, err := newGss()
1228
		if err != nil {
1229
			errorf("kerberos error: %s", err.Error())
1230
		}
1231
1232
		var token []byte
1233
1234
		if spn, ok := o["krbspn"]; ok {
1235
			// Use the supplied SPN if provided..
1236
			token, err = cli.GetInitTokenFromSpn(spn)
1237
		} else {
1238
			// Allow the kerberos service name to be overridden
1239
			service := "postgres"
1240
			if val, ok := o["krbsrvname"]; ok {
1241
				service = val
1242
			}
1243
1244
			token, err = cli.GetInitToken(o["host"], service)
1245
		}
1246
1247
		if err != nil {
1248
			errorf("failed to get Kerberos ticket: %q", err)
1249
		}
1250
1251
		w := cn.writeBuf('p')
1252
		w.bytes(token)
1253
		cn.send(w)
1254
1255
		// Store for GSSAPI continue message
1256
		cn.gss = cli
1257
1258
	case 8: // GSSAPI continue
1259
1260
		if cn.gss == nil {
1261
			errorf("GSSAPI protocol error")
1262
		}
1263
1264
		b := []byte(*r)
1265
1266
		done, tokOut, err := cn.gss.Continue(b)
1267
		if err == nil && !done {
1268
			w := cn.writeBuf('p')
1269
			w.bytes(tokOut)
1270
			cn.send(w)
1271
		}
1272
1273
		// Errors fall through and read the more detailed message
1274
		// from the server..
1275
1276
	case 10:
1277
		sc := scram.NewClient(sha256.New, o["user"], o["password"])
1278
		sc.Step(nil)
1279
		if sc.Err() != nil {
1280
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1281
		}
1282
		scOut := sc.Out()
1283
1284
		w := cn.writeBuf('p')
1285
		w.string("SCRAM-SHA-256")
1286
		w.int32(len(scOut))
1287
		w.bytes(scOut)
1288
		cn.send(w)
1289
1290
		t, r := cn.recv()
1291
		if t != 'R' {
1292
			errorf("unexpected password response: %q", t)
1293
		}
1294
1295
		if r.int32() != 11 {
1296
			errorf("unexpected authentication response: %q", t)
1297
		}
1298
1299
		nextStep := r.next(len(*r))
1300
		sc.Step(nextStep)
1301
		if sc.Err() != nil {
1302
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1303
		}
1304
1305
		scOut = sc.Out()
1306
		w = cn.writeBuf('p')
1307
		w.bytes(scOut)
1308
		cn.send(w)
1309
1310
		t, r = cn.recv()
1311
		if t != 'R' {
1312
			errorf("unexpected password response: %q", t)
1313
		}
1314
1315
		if r.int32() != 12 {
1316
			errorf("unexpected authentication response: %q", t)
1317
		}
1318
1319
		nextStep = r.next(len(*r))
1320
		sc.Step(nextStep)
1321
		if sc.Err() != nil {
1322
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1323
		}
1324
1325
	default:
1326
		errorf("unknown authentication response: %d", code)
1327
	}
1328
}
1329
1330
type format int
1331
1332
const formatText format = 0
1333
const formatBinary format = 1
1334
1335
// One result-column format code with the value 1 (i.e. all binary).
1336
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1337
1338
// No result-column format codes (i.e. all text).
1339
var colFmtDataAllText = []byte{0, 0}
1340
1341
type stmt struct {
1342
	cn   *conn
1343
	name string
1344
	rowsHeader
1345
	colFmtData []byte
1346
	paramTyps  []oid.Oid
1347
	closed     bool
1348
}
1349
1350
func (st *stmt) Close() (err error) {
1351
	if st.closed {
1352
		return nil
1353
	}
1354
	if err := st.cn.err.get(); err != nil {
1355
		return err
1356
	}
1357
	defer st.cn.errRecover(&err)
1358
1359
	w := st.cn.writeBuf('C')
1360
	w.byte('S')
1361
	w.string(st.name)
1362
	st.cn.send(w)
1363
1364
	st.cn.send(st.cn.writeBuf('S'))
1365
1366
	t, _ := st.cn.recv1()
1367
	if t != '3' {
1368
		st.cn.err.set(driver.ErrBadConn)
1369
		errorf("unexpected close response: %q", t)
1370
	}
1371
	st.closed = true
1372
1373
	t, r := st.cn.recv1()
1374
	if t != 'Z' {
1375
		st.cn.err.set(driver.ErrBadConn)
1376
		errorf("expected ready for query, but got: %q", t)
1377
	}
1378
	st.cn.processReadyForQuery(r)
1379
1380
	return nil
1381
}
1382
1383
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1384
	return st.query(v)
1385
}
1386
1387
func (st *stmt) query(v []driver.Value) (r *rows, err error) {
1388
	if err := st.cn.err.get(); err != nil {
1389
		return nil, err
1390
	}
1391
	defer st.cn.errRecover(&err)
1392
1393
	st.exec(v)
1394
	return &rows{
1395
		cn:         st.cn,
1396
		rowsHeader: st.rowsHeader,
1397
	}, nil
1398
}
1399
1400
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1401
	if err := st.cn.err.get(); err != nil {
1402
		return nil, err
1403
	}
1404
	defer st.cn.errRecover(&err)
1405
1406
	st.exec(v)
1407
	res, _, err = st.cn.readExecuteResponse("simple query")
1408
	return res, err
1409
}
1410
1411
func (st *stmt) exec(v []driver.Value) {
1412
	if len(v) >= 65536 {
1413
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
1414
	}
1415
	if len(v) != len(st.paramTyps) {
1416
		errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
1417
	}
1418
1419
	cn := st.cn
1420
	w := cn.writeBuf('B')
1421
	w.byte(0) // unnamed portal
1422
	w.string(st.name)
1423
1424
	if cn.binaryParameters {
1425
		cn.sendBinaryParameters(w, v)
1426
	} else {
1427
		w.int16(0)
1428
		w.int16(len(v))
1429
		for i, x := range v {
1430
			if x == nil {
1431
				w.int32(-1)
1432
			} else {
1433
				b := encode(&cn.parameterStatus, x, st.paramTyps[i])
1434
				w.int32(len(b))
1435
				w.bytes(b)
1436
			}
1437
		}
1438
	}
1439
	w.bytes(st.colFmtData)
1440
1441
	w.next('E')
1442
	w.byte(0)
1443
	w.int32(0)
1444
1445
	w.next('S')
1446
	cn.send(w)
1447
1448
	cn.readBindResponse()
1449
	cn.postExecuteWorkaround()
1450
1451
}
1452
1453
func (st *stmt) NumInput() int {
1454
	return len(st.paramTyps)
1455
}
1456
1457
// parseComplete parses the "command tag" from a CommandComplete message, and
1458
// returns the number of rows affected (if applicable) and a string
1459
// identifying only the command that was executed, e.g. "ALTER TABLE".  If the
1460
// command tag could not be parsed, parseComplete panics.
1461
func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
1462
	commandsWithAffectedRows := []string{
1463
		"SELECT ",
1464
		// INSERT is handled below
1465
		"UPDATE ",
1466
		"DELETE ",
1467
		"FETCH ",
1468
		"MOVE ",
1469
		"COPY ",
1470
	}
1471
1472
	var affectedRows *string
1473
	for _, tag := range commandsWithAffectedRows {
1474
		if strings.HasPrefix(commandTag, tag) {
1475
			t := commandTag[len(tag):]
1476
			affectedRows = &t
1477
			commandTag = tag[:len(tag)-1]
1478
			break
1479
		}
1480
	}
1481
	// INSERT also includes the oid of the inserted row in its command tag.
1482
	// Oids in user tables are deprecated, and the oid is only returned when
1483
	// exactly one row is inserted, so it's unlikely to be of value to any
1484
	// real-world application and we can ignore it.
1485
	if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
1486
		parts := strings.Split(commandTag, " ")
1487
		if len(parts) != 3 {
1488
			cn.err.set(driver.ErrBadConn)
1489
			errorf("unexpected INSERT command tag %s", commandTag)
1490
		}
1491
		affectedRows = &parts[len(parts)-1]
1492
		commandTag = "INSERT"
1493
	}
1494
	// There should be no affected rows attached to the tag, just return it
1495
	if affectedRows == nil {
1496
		return driver.RowsAffected(0), commandTag
1497
	}
1498
	n, err := strconv.ParseInt(*affectedRows, 10, 64)
1499
	if err != nil {
1500
		cn.err.set(driver.ErrBadConn)
1501
		errorf("could not parse commandTag: %s", err)
1502
	}
1503
	return driver.RowsAffected(n), commandTag
1504
}
1505
1506
type rowsHeader struct {
1507
	colNames []string
1508
	colTyps  []fieldDesc
1509
	colFmts  []format
1510
}
1511
1512
type rows struct {
1513
	cn     *conn
1514
	finish func()
1515
	rowsHeader
1516
	done   bool
1517
	rb     readBuf
1518
	result driver.Result
1519
	tag    string
1520
1521
	next *rowsHeader
1522
}
1523
1524
func (rs *rows) Close() error {
1525
	if finish := rs.finish; finish != nil {
1526
		defer finish()
1527
	}
1528
	// no need to look at cn.bad as Next() will
1529
	for {
1530
		err := rs.Next(nil)
1531
		switch err {
1532
		case nil:
1533
		case io.EOF:
1534
			// rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
1535
			// description, used with HasNextResultSet). We need to fetch messages until
1536
			// we hit a 'Z', which is done by waiting for done to be set.
1537
			if rs.done {
1538
				return nil
1539
			}
1540
		default:
1541
			return err
1542
		}
1543
	}
1544
}
1545
1546
func (rs *rows) Columns() []string {
1547
	return rs.colNames
1548
}
1549
1550
func (rs *rows) Result() driver.Result {
1551
	if rs.result == nil {
1552
		return emptyRows
1553
	}
1554
	return rs.result
1555
}
1556
1557
func (rs *rows) Tag() string {
1558
	return rs.tag
1559
}
1560
1561
func (rs *rows) Next(dest []driver.Value) (err error) {
1562
	if rs.done {
1563
		return io.EOF
1564
	}
1565
1566
	conn := rs.cn
1567
	if err := conn.err.getForNext(); err != nil {
1568
		return err
1569
	}
1570
	defer conn.errRecover(&err)
1571
1572
	for {
1573
		t := conn.recv1Buf(&rs.rb)
1574
		switch t {
1575
		case 'E':
1576
			err = parseError(&rs.rb)
1577
		case 'C', 'I':
1578
			if t == 'C' {
1579
				rs.result, rs.tag = conn.parseComplete(rs.rb.string())
1580
			}
1581
			continue
1582
		case 'Z':
1583
			conn.processReadyForQuery(&rs.rb)
1584
			rs.done = true
1585
			if err != nil {
1586
				return err
1587
			}
1588
			return io.EOF
1589
		case 'D':
1590
			n := rs.rb.int16()
1591
			if err != nil {
1592
				conn.err.set(driver.ErrBadConn)
1593
				errorf("unexpected DataRow after error %s", err)
1594
			}
1595
			if n < len(dest) {
1596
				dest = dest[:n]
1597
			}
1598
			for i := range dest {
1599
				l := rs.rb.int32()
1600
				if l == -1 {
1601
					dest[i] = nil
1602
					continue
1603
				}
1604
				dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
1605
			}
1606
			return
1607
		case 'T':
1608
			next := parsePortalRowDescribe(&rs.rb)
1609
			rs.next = &next
1610
			return io.EOF
1611
		default:
1612
			errorf("unexpected message after execute: %q", t)
1613
		}
1614
	}
1615
}
1616
1617
func (rs *rows) HasNextResultSet() bool {
1618
	hasNext := rs.next != nil && !rs.done
1619
	return hasNext
1620
}
1621
1622
func (rs *rows) NextResultSet() error {
1623
	if rs.next == nil {
1624
		return io.EOF
1625
	}
1626
	rs.rowsHeader = *rs.next
1627
	rs.next = nil
1628
	return nil
1629
}
1630
1631
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
1632
// used as part of an SQL statement.  For example:
1633
//
1634
//    tblname := "my_table"
1635
//    data := "my_data"
1636
//    quoted := pq.QuoteIdentifier(tblname)
1637
//    err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
1638
//
1639
// Any double quotes in name will be escaped.  The quoted identifier will be
1640
// case sensitive when used in a query.  If the input string contains a zero
1641
// byte, the result will be truncated immediately before it.
1642
func QuoteIdentifier(name string) string {
1643
	end := strings.IndexRune(name, 0)
1644
	if end > -1 {
1645
		name = name[:end]
1646
	}
1647
	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
1648
}
1649
1650
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
1651
// to DDL and other statements that do not accept parameters) to be used as part
1652
// of an SQL statement.  For example:
1653
//
1654
//    exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
1655
//    err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
1656
//
1657
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
1658
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
1659
// that PostgreSQL provides ('E') will be prepended to the string.
1660
func QuoteLiteral(literal string) string {
1661
	// This follows the PostgreSQL internal algorithm for handling quoted literals
1662
	// from libpq, which can be found in the "PQEscapeStringInternal" function,
1663
	// which is found in the libpq/fe-exec.c source file:
1664
	// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
1665
	//
1666
	// substitute any single-quotes (') with two single-quotes ('')
1667
	literal = strings.Replace(literal, `'`, `''`, -1)
1668
	// determine if the string has any backslashes (\) in it.
1669
	// if it does, replace any backslashes (\) with two backslashes (\\)
1670
	// then, we need to wrap the entire string with a PostgreSQL
1671
	// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
1672
	// also add a space before the "E"
1673
	if strings.Contains(literal, `\`) {
1674
		literal = strings.Replace(literal, `\`, `\\`, -1)
1675
		literal = ` E'` + literal + `'`
1676
	} else {
1677
		// otherwise, we can just wrap the literal with a pair of single quotes
1678
		literal = `'` + literal + `'`
1679
	}
1680
	return literal
1681
}
1682
1683
func md5s(s string) string {
1684
	h := md5.New()
1685
	h.Write([]byte(s))
1686
	return fmt.Sprintf("%x", h.Sum(nil))
1687
}
1688
1689
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
1690
	// Do one pass over the parameters to see if we're going to send any of
1691
	// them over in binary.  If we are, create a paramFormats array at the
1692
	// same time.
1693
	var paramFormats []int
1694
	for i, x := range args {
1695
		_, ok := x.([]byte)
1696
		if ok {
1697
			if paramFormats == nil {
1698
				paramFormats = make([]int, len(args))
1699
			}
1700
			paramFormats[i] = 1
1701
		}
1702
	}
1703
	if paramFormats == nil {
1704
		b.int16(0)
1705
	} else {
1706
		b.int16(len(paramFormats))
1707
		for _, x := range paramFormats {
1708
			b.int16(x)
1709
		}
1710
	}
1711
1712
	b.int16(len(args))
1713
	for _, x := range args {
1714
		if x == nil {
1715
			b.int32(-1)
1716
		} else {
1717
			datum := binaryEncode(&cn.parameterStatus, x)
1718
			b.int32(len(datum))
1719
			b.bytes(datum)
1720
		}
1721
	}
1722
}
1723
1724
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1725
	if len(args) >= 65536 {
1726
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
1727
	}
1728
1729
	b := cn.writeBuf('P')
1730
	b.byte(0) // unnamed statement
1731
	b.string(query)
1732
	b.int16(0)
1733
1734
	b.next('B')
1735
	b.int16(0) // unnamed portal and statement
1736
	cn.sendBinaryParameters(b, args)
1737
	b.bytes(colFmtDataAllText)
1738
1739
	b.next('D')
1740
	b.byte('P')
1741
	b.byte(0) // unnamed portal
1742
1743
	b.next('E')
1744
	b.byte(0)
1745
	b.int32(0)
1746
1747
	b.next('S')
1748
	cn.send(b)
1749
}
1750
1751
func (cn *conn) processParameterStatus(r *readBuf) {
1752
	var err error
1753
1754
	param := r.string()
1755
	switch param {
1756
	case "server_version":
1757
		var major1 int
1758
		var major2 int
1759
		_, err = fmt.Sscanf(r.string(), "%d.%d", &major1, &major2)
1760
		if err == nil {
1761
			cn.parameterStatus.serverVersion = major1*10000 + major2*100
1762
		}
1763
1764
	case "TimeZone":
1765
		cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
1766
		if err != nil {
1767
			cn.parameterStatus.currentLocation = nil
1768
		}
1769
1770
	default:
1771
		// ignore
1772
	}
1773
}
1774
1775
func (cn *conn) processReadyForQuery(r *readBuf) {
1776
	cn.txnStatus = transactionStatus(r.byte())
1777
}
1778
1779
func (cn *conn) readReadyForQuery() {
1780
	t, r := cn.recv1()
1781
	switch t {
1782
	case 'Z':
1783
		cn.processReadyForQuery(r)
1784
		return
1785
	default:
1786
		cn.err.set(driver.ErrBadConn)
1787
		errorf("unexpected message %q; expected ReadyForQuery", t)
1788
	}
1789
}
1790
1791
func (cn *conn) processBackendKeyData(r *readBuf) {
1792
	cn.processID = r.int32()
1793
	cn.secretKey = r.int32()
1794
}
1795
1796
func (cn *conn) readParseResponse() {
1797
	t, r := cn.recv1()
1798
	switch t {
1799
	case '1':
1800
		return
1801
	case 'E':
1802
		err := parseError(r)
1803
		cn.readReadyForQuery()
1804
		panic(err)
1805
	default:
1806
		cn.err.set(driver.ErrBadConn)
1807
		errorf("unexpected Parse response %q", t)
1808
	}
1809
}
1810
1811
func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
1812
	for {
1813
		t, r := cn.recv1()
1814
		switch t {
1815
		case 't':
1816
			nparams := r.int16()
1817
			paramTyps = make([]oid.Oid, nparams)
1818
			for i := range paramTyps {
1819
				paramTyps[i] = r.oid()
1820
			}
1821
		case 'n':
1822
			return paramTyps, nil, nil
1823
		case 'T':
1824
			colNames, colTyps = parseStatementRowDescribe(r)
1825
			return paramTyps, colNames, colTyps
1826
		case 'E':
1827
			err := parseError(r)
1828
			cn.readReadyForQuery()
1829
			panic(err)
1830
		default:
1831
			cn.err.set(driver.ErrBadConn)
1832
			errorf("unexpected Describe statement response %q", t)
1833
		}
1834
	}
1835
}
1836
1837
func (cn *conn) readPortalDescribeResponse() rowsHeader {
1838
	t, r := cn.recv1()
1839
	switch t {
1840
	case 'T':
1841
		return parsePortalRowDescribe(r)
1842
	case 'n':
1843
		return rowsHeader{}
1844
	case 'E':
1845
		err := parseError(r)
1846
		cn.readReadyForQuery()
1847
		panic(err)
1848
	default:
1849
		cn.err.set(driver.ErrBadConn)
1850
		errorf("unexpected Describe response %q", t)
1851
	}
1852
	panic("not reached")
1853
}
1854
1855
func (cn *conn) readBindResponse() {
1856
	t, r := cn.recv1()
1857
	switch t {
1858
	case '2':
1859
		return
1860
	case 'E':
1861
		err := parseError(r)
1862
		cn.readReadyForQuery()
1863
		panic(err)
1864
	default:
1865
		cn.err.set(driver.ErrBadConn)
1866
		errorf("unexpected Bind response %q", t)
1867
	}
1868
}
1869
1870
func (cn *conn) postExecuteWorkaround() {
1871
	// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
1872
	// any errors from rows.Next, which masks errors that happened during the
1873
	// execution of the query.  To avoid the problem in common cases, we wait
1874
	// here for one more message from the database.  If it's not an error the
1875
	// query will likely succeed (or perhaps has already, if it's a
1876
	// CommandComplete), so we push the message into the conn struct; recv1
1877
	// will return it as the next message for rows.Next or rows.Close.
1878
	// However, if it's an error, we wait until ReadyForQuery and then return
1879
	// the error to our caller.
1880
	for {
1881
		t, r := cn.recv1()
1882
		switch t {
1883
		case 'E':
1884
			err := parseError(r)
1885
			cn.readReadyForQuery()
1886
			panic(err)
1887
		case 'C', 'D', 'I':
1888
			// the query didn't fail, but we can't process this message
1889
			cn.saveMessage(t, r)
1890
			return
1891
		default:
1892
			cn.err.set(driver.ErrBadConn)
1893
			errorf("unexpected message during extended query execution: %q", t)
1894
		}
1895
	}
1896
}
1897
1898
// Only for Exec(), since we ignore the returned data
1899
func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
1900
	for {
1901
		t, r := cn.recv1()
1902
		switch t {
1903
		case 'C':
1904
			if err != nil {
1905
				cn.err.set(driver.ErrBadConn)
1906
				errorf("unexpected CommandComplete after error %s", err)
1907
			}
1908
			res, commandTag = cn.parseComplete(r.string())
1909
		case 'Z':
1910
			cn.processReadyForQuery(r)
1911
			if res == nil && err == nil {
1912
				err = errUnexpectedReady
1913
			}
1914
			return res, commandTag, err
1915
		case 'E':
1916
			err = parseError(r)
1917
		case 'T', 'D', 'I':
1918
			if err != nil {
1919
				cn.err.set(driver.ErrBadConn)
1920
				errorf("unexpected %q after error %s", t, err)
1921
			}
1922
			if t == 'I' {
1923
				res = emptyRows
1924
			}
1925
			// ignore any results
1926
		default:
1927
			cn.err.set(driver.ErrBadConn)
1928
			errorf("unknown %s response: %q", protocolState, t)
1929
		}
1930
	}
1931
}
1932
1933
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
1934
	n := r.int16()
1935
	colNames = make([]string, n)
1936
	colTyps = make([]fieldDesc, n)
1937
	for i := range colNames {
1938
		colNames[i] = r.string()
1939
		r.next(6)
1940
		colTyps[i].OID = r.oid()
1941
		colTyps[i].Len = r.int16()
1942
		colTyps[i].Mod = r.int32()
1943
		// format code not known when describing a statement; always 0
1944
		r.next(2)
1945
	}
1946
	return
1947
}
1948
1949
func parsePortalRowDescribe(r *readBuf) rowsHeader {
1950
	n := r.int16()
1951
	colNames := make([]string, n)
1952
	colFmts := make([]format, n)
1953
	colTyps := make([]fieldDesc, n)
1954
	for i := range colNames {
1955
		colNames[i] = r.string()
1956
		r.next(6)
1957
		colTyps[i].OID = r.oid()
1958
		colTyps[i].Len = r.int16()
1959
		colTyps[i].Mod = r.int32()
1960
		colFmts[i] = format(r.int16())
1961
	}
1962
	return rowsHeader{
1963
		colNames: colNames,
1964
		colFmts:  colFmts,
1965
		colTyps:  colTyps,
1966
	}
1967
}
1968
1969
// parseEnviron tries to mimic some of libpq's environment handling
1970
//
1971
// To ease testing, it does not directly reference os.Environ, but is
1972
// designed to accept its output.
1973
//
1974
// Environment-set connection information is intended to have a higher
1975
// precedence than a library default but lower than any explicitly
1976
// passed information (such as in the URL or connection string).
1977
func parseEnviron(env []string) (out map[string]string) {
1978
	out = make(map[string]string)
1979
1980
	for _, v := range env {
1981
		parts := strings.SplitN(v, "=", 2)
1982
1983
		accrue := func(keyname string) {
1984
			out[keyname] = parts[1]
1985
		}
1986
		unsupported := func() {
1987
			panic(fmt.Sprintf("setting %v not supported", parts[0]))
1988
		}
1989
1990
		// The order of these is the same as is seen in the
1991
		// PostgreSQL 9.1 manual. Unsupported but well-defined
1992
		// keys cause a panic; these should be unset prior to
1993
		// execution. Options which pq expects to be set to a
1994
		// certain value are allowed, but must be set to that
1995
		// value if present (they can, of course, be absent).
1996
		switch parts[0] {
1997
		case "PGHOST":
1998
			accrue("host")
1999
		case "PGHOSTADDR":
2000
			unsupported()
2001
		case "PGPORT":
2002
			accrue("port")
2003
		case "PGDATABASE":
2004
			accrue("dbname")
2005
		case "PGUSER":
2006
			accrue("user")
2007
		case "PGPASSWORD":
2008
			accrue("password")
2009
		case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
2010
			unsupported()
2011
		case "PGOPTIONS":
2012
			accrue("options")
2013
		case "PGAPPNAME":
2014
			accrue("application_name")
2015
		case "PGSSLMODE":
2016
			accrue("sslmode")
2017
		case "PGSSLCERT":
2018
			accrue("sslcert")
2019
		case "PGSSLKEY":
2020
			accrue("sslkey")
2021
		case "PGSSLROOTCERT":
2022
			accrue("sslrootcert")
2023
		case "PGREQUIRESSL", "PGSSLCRL":
2024
			unsupported()
2025
		case "PGREQUIREPEER":
2026
			unsupported()
2027
		case "PGKRBSRVNAME", "PGGSSLIB":
2028
			unsupported()
2029
		case "PGCONNECT_TIMEOUT":
2030
			accrue("connect_timeout")
2031
		case "PGCLIENTENCODING":
2032
			accrue("client_encoding")
2033
		case "PGDATESTYLE":
2034
			accrue("datestyle")
2035
		case "PGTZ":
2036
			accrue("timezone")
2037
		case "PGGEQO":
2038
			accrue("geqo")
2039
		case "PGSYSCONFDIR", "PGLOCALEDIR":
2040
			unsupported()
2041
		}
2042
	}
2043
2044
	return out
2045
}
2046
2047
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
2048
func isUTF8(name string) bool {
2049
	// Recognize all sorts of silly things as "UTF-8", like Postgres does
2050
	s := strings.Map(alnumLowerASCII, name)
2051
	return s == "utf8" || s == "unicode"
2052
}
2053
2054
func alnumLowerASCII(ch rune) rune {
2055
	if 'A' <= ch && ch <= 'Z' {
2056
		return ch + ('a' - 'A')
2057
	}
2058
	if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
2059
		return ch
2060
	}
2061
	return -1 // discard
2062
}
2063