Passed
Pull Request — master (#23)
by Frank
06:05 queued 03:00
created

pq.isDriverSetting   C

Complexity

Conditions 9

Size

Total Lines 19
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 9
eloc 18
nop 1
dl 0
loc 19
rs 6.6666
c 0
b 0
f 0
1
package pq
2
3
import (
4
	"bufio"
5
	"context"
6
	"crypto/md5"
7
	"crypto/sha256"
8
	"database/sql"
9
	"database/sql/driver"
10
	"encoding/binary"
11
	"errors"
12
	"fmt"
13
	"io"
14
	"net"
15
	"os"
16
	"os/user"
17
	"path"
18
	"path/filepath"
19
	"strconv"
20
	"strings"
21
	"time"
22
	"unicode"
23
24
	"github.com/lib/pq/oid"
25
	"github.com/lib/pq/scram"
26
)
27
28
// Common error types
29
var (
30
	ErrNotSupported              = errors.New("pq: Unsupported command")
31
	ErrInFailedTransaction       = errors.New("pq: Could not complete operation in a failed transaction")
32
	ErrSSLNotSupported           = errors.New("pq: SSL is not enabled on the server")
33
	ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
34
	ErrCouldNotDetectUsername    = errors.New("pq: Could not detect default username. Please provide one explicitly")
35
36
	errUnexpectedReady = errors.New("unexpected ReadyForQuery")
37
	errNoRowsAffected  = errors.New("no RowsAffected available after the empty statement")
38
	errNoLastInsertID  = errors.New("no LastInsertId available after the empty statement")
39
)
40
41
// Driver is the Postgres database driver.
42
type Driver struct{}
43
44
// Open opens a new connection to the database. name is a connection string.
45
// Most users should only use it through database/sql package from the standard
46
// library.
47
func (d *Driver) Open(name string) (driver.Conn, error) {
48
	return Open(name)
49
}
50
51
func init() {
52
	sql.Register("postgres", &Driver{})
53
}
54
55
type parameterStatus struct {
56
	// server version in the same format as server_version_num, or 0 if
57
	// unavailable
58
	serverVersion int
59
60
	// the current location based on the TimeZone value of the session, if
61
	// available
62
	currentLocation *time.Location
63
}
64
65
type transactionStatus byte
66
67
const (
68
	txnStatusIdle                transactionStatus = 'I'
69
	txnStatusIdleInTransaction   transactionStatus = 'T'
70
	txnStatusInFailedTransaction transactionStatus = 'E'
71
)
72
73
func (s transactionStatus) String() string {
74
	switch s {
75
	case txnStatusIdle:
76
		return "idle"
77
	case txnStatusIdleInTransaction:
78
		return "idle in transaction"
79
	case txnStatusInFailedTransaction:
80
		return "in a failed transaction"
81
	default:
82
		errorf("unknown transactionStatus %d", s)
83
	}
84
85
	panic("not reached")
86
}
87
88
// Dialer is the dialer interface. It can be used to obtain more control over
89
// how pq creates network connections.
90
type Dialer interface {
91
	Dial(network, address string) (net.Conn, error)
92
	DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
93
}
94
95
// DialerContext is the context-aware dialer interface.
96
type DialerContext interface {
97
	DialContext(ctx context.Context, network, address string) (net.Conn, error)
98
}
99
100
type defaultDialer struct {
101
	d net.Dialer
102
}
103
104
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
105
	return d.d.Dial(network, address)
106
}
107
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
108
	ctx, cancel := context.WithTimeout(context.Background(), timeout)
109
	defer cancel()
110
	return d.DialContext(ctx, network, address)
111
}
112
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
113
	return d.d.DialContext(ctx, network, address)
114
}
115
116
type conn struct {
117
	c         net.Conn
118
	buf       *bufio.Reader
119
	namei     int
120
	scratch   [512]byte
121
	txnStatus transactionStatus
122
	txnFinish func()
123
124
	// Save connection arguments to use during CancelRequest.
125
	dialer Dialer
126
	opts   values
127
128
	// Cancellation key data for use with CancelRequest messages.
129
	processID int
130
	secretKey int
131
132
	parameterStatus parameterStatus
133
134
	saveMessageType   byte
135
	saveMessageBuffer []byte
136
137
	// If true, this connection is bad and all public-facing functions should
138
	// return ErrBadConn.
139
	bad bool
140
141
	// If set, this connection should never use the binary format when
142
	// receiving query results from prepared statements.  Only provided for
143
	// debugging.
144
	disablePreparedBinaryResult bool
145
146
	// Whether to always send []byte parameters over as binary.  Enables single
147
	// round-trip mode for non-prepared Query calls.
148
	binaryParameters bool
149
150
	// If true this connection is in the middle of a COPY
151
	inCopy bool
152
}
153
154
// Handle driver-side settings in parsed connection string.
155
func (cn *conn) handleDriverSettings(o values) (err error) {
156
	boolSetting := func(key string, val *bool) error {
157
		if value, ok := o[key]; ok {
158
			if value == "yes" {
159
				*val = true
160
			} else if value == "no" {
161
				*val = false
162
			} else {
163
				return fmt.Errorf("unrecognized value %q for %s", value, key)
164
			}
165
		}
166
		return nil
167
	}
168
169
	err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
170
	if err != nil {
171
		return err
172
	}
173
	return boolSetting("binary_parameters", &cn.binaryParameters)
174
}
175
176
func (cn *conn) handlePgpass(o values) {
177
	// if a password was supplied, do not process .pgpass
178
	if _, ok := o["password"]; ok {
179
		return
180
	}
181
	filename := os.Getenv("PGPASSFILE")
182
	if filename == "" {
183
		// XXX this code doesn't work on Windows where the default filename is
184
		// XXX %APPDATA%\postgresql\pgpass.conf
185
		// Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
186
		userHome := os.Getenv("HOME")
187
		if userHome == "" {
188
			user, err := user.Current()
189
			if err != nil {
190
				return
191
			}
192
			userHome = user.HomeDir
193
		}
194
		filename = filepath.Join(userHome, ".pgpass")
195
	}
196
	fileinfo, err := os.Stat(filename)
197
	if err != nil {
198
		return
199
	}
200
	mode := fileinfo.Mode()
201
	if mode&(0x77) != 0 {
202
		// XXX should warn about incorrect .pgpass permissions as psql does
203
		return
204
	}
205
	file, err := os.Open(filename)
206
	if err != nil {
207
		return
208
	}
209
	defer file.Close()
210
	scanner := bufio.NewScanner(io.Reader(file))
211
	hostname := o["host"]
212
	ntw, _ := network(o)
213
	port := o["port"]
214
	db := o["dbname"]
215
	username := o["user"]
216
	// From: https://github.com/tg/pgpass/blob/master/reader.go
217
	getFields := func(s string) []string {
218
		fs := make([]string, 0, 5)
219
		f := make([]rune, 0, len(s))
220
221
		var esc bool
222
		for _, c := range s {
223
			switch {
224
			case esc:
225
				f = append(f, c)
226
				esc = false
227
			case c == '\\':
228
				esc = true
229
			case c == ':':
230
				fs = append(fs, string(f))
231
				f = f[:0]
232
			default:
233
				f = append(f, c)
234
			}
235
		}
236
		return append(fs, string(f))
237
	}
238
	for scanner.Scan() {
239
		line := scanner.Text()
240
		if len(line) == 0 || line[0] == '#' {
241
			continue
242
		}
243
		split := getFields(line)
244
		if len(split) != 5 {
245
			continue
246
		}
247
		if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
248
			o["password"] = split[4]
249
			return
250
		}
251
	}
252
}
253
254
func (cn *conn) writeBuf(b byte) *writeBuf {
255
	cn.scratch[0] = b
256
	return &writeBuf{
257
		buf: cn.scratch[:5],
258
		pos: 1,
259
	}
260
}
261
262
// Open opens a new connection to the database. dsn is a connection string.
263
// Most users should only use it through database/sql package from the standard
264
// library.
265
func Open(dsn string) (_ driver.Conn, err error) {
266
	return DialOpen(defaultDialer{}, dsn)
267
}
268
269
// DialOpen opens a new connection to the database using a dialer.
270
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
271
	c, err := NewConnector(dsn)
272
	if err != nil {
273
		return nil, err
274
	}
275
	c.dialer = d
276
	return c.open(context.Background())
277
}
278
279
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
280
	// Handle any panics during connection initialization.  Note that we
281
	// specifically do *not* want to use errRecover(), as that would turn any
282
	// connection errors into ErrBadConns, hiding the real error message from
283
	// the user.
284
	defer errRecoverNoErrBadConn(&err)
285
286
	o := c.opts
287
288
	cn = &conn{
289
		opts:   o,
290
		dialer: c.dialer,
291
	}
292
	err = cn.handleDriverSettings(o)
293
	if err != nil {
294
		return nil, err
295
	}
296
	cn.handlePgpass(o)
297
298
	cn.c, err = dial(ctx, c.dialer, o)
299
	if err != nil {
300
		return nil, err
301
	}
302
303
	err = cn.ssl(o)
304
	if err != nil {
305
		if cn.c != nil {
306
			cn.c.Close()
307
		}
308
		return nil, err
309
	}
310
311
	// cn.startup panics on error. Make sure we don't leak cn.c.
312
	panicking := true
313
	defer func() {
314
		if panicking {
315
			cn.c.Close()
316
		}
317
	}()
318
319
	cn.buf = bufio.NewReader(cn.c)
320
	cn.startup(o)
321
322
	// reset the deadline, in case one was set (see dial)
323
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
324
		err = cn.c.SetDeadline(time.Time{})
325
	}
326
	panicking = false
327
	return cn, err
328
}
329
330
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
331
	network, address := network(o)
332
	// SSL is not necessary or supported over UNIX domain sockets
333
	if network == "unix" {
334
		o["sslmode"] = "disable"
335
	}
336
337
	// Zero or not specified means wait indefinitely.
338
	if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
339
		seconds, err := strconv.ParseInt(timeout, 10, 0)
340
		if err != nil {
341
			return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
342
		}
343
		duration := time.Duration(seconds) * time.Second
344
345
		// connect_timeout should apply to the entire connection establishment
346
		// procedure, so we both use a timeout for the TCP connection
347
		// establishment and set a deadline for doing the initial handshake.
348
		// The deadline is then reset after startup() is done.
349
		deadline := time.Now().Add(duration)
350
		var conn net.Conn
351
		if dctx, ok := d.(DialerContext); ok {
352
			ctx, cancel := context.WithTimeout(ctx, duration)
353
			defer cancel()
354
			conn, err = dctx.DialContext(ctx, network, address)
355
		} else {
356
			conn, err = d.DialTimeout(network, address, duration)
357
		}
358
		if err != nil {
359
			return nil, err
360
		}
361
		err = conn.SetDeadline(deadline)
362
		return conn, err
363
	}
364
	if dctx, ok := d.(DialerContext); ok {
365
		return dctx.DialContext(ctx, network, address)
366
	}
367
	return d.Dial(network, address)
368
}
369
370
func network(o values) (string, string) {
371
	host := o["host"]
372
373
	if strings.HasPrefix(host, "/") {
374
		sockPath := path.Join(host, ".s.PGSQL."+o["port"])
375
		return "unix", sockPath
376
	}
377
378
	return "tcp", net.JoinHostPort(host, o["port"])
379
}
380
381
type values map[string]string
382
383
// scanner implements a tokenizer for libpq-style option strings.
384
type scanner struct {
385
	s []rune
386
	i int
387
}
388
389
// newScanner returns a new scanner initialized with the option string s.
390
func newScanner(s string) *scanner {
391
	return &scanner{[]rune(s), 0}
392
}
393
394
// Next returns the next rune.
395
// It returns 0, false if the end of the text has been reached.
396
func (s *scanner) Next() (rune, bool) {
397
	if s.i >= len(s.s) {
398
		return 0, false
399
	}
400
	r := s.s[s.i]
401
	s.i++
402
	return r, true
403
}
404
405
// SkipSpaces returns the next non-whitespace rune.
406
// It returns 0, false if the end of the text has been reached.
407
func (s *scanner) SkipSpaces() (rune, bool) {
408
	r, ok := s.Next()
409
	for unicode.IsSpace(r) && ok {
410
		r, ok = s.Next()
411
	}
412
	return r, ok
413
}
414
415
// parseOpts parses the options from name and adds them to the values.
416
//
417
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
418
func parseOpts(name string, o values) error {
419
	s := newScanner(name)
420
421
	for {
422
		var (
423
			keyRunes, valRunes []rune
424
			r                  rune
425
			ok                 bool
426
		)
427
428
		if r, ok = s.SkipSpaces(); !ok {
429
			break
430
		}
431
432
		// Scan the key
433
		for !unicode.IsSpace(r) && r != '=' {
434
			keyRunes = append(keyRunes, r)
435
			if r, ok = s.Next(); !ok {
436
				break
437
			}
438
		}
439
440
		// Skip any whitespace if we're not at the = yet
441
		if r != '=' {
442
			r, ok = s.SkipSpaces()
443
		}
444
445
		// The current character should be =
446
		if r != '=' || !ok {
447
			return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
448
		}
449
450
		// Skip any whitespace after the =
451
		if r, ok = s.SkipSpaces(); !ok {
452
			// If we reach the end here, the last value is just an empty string as per libpq.
453
			o[string(keyRunes)] = ""
454
			break
455
		}
456
457
		if r != '\'' {
458
			for !unicode.IsSpace(r) {
459
				if r == '\\' {
460
					if r, ok = s.Next(); !ok {
461
						return fmt.Errorf(`missing character after backslash`)
462
					}
463
				}
464
				valRunes = append(valRunes, r)
465
466
				if r, ok = s.Next(); !ok {
467
					break
468
				}
469
			}
470
		} else {
471
		quote:
472
			for {
473
				if r, ok = s.Next(); !ok {
474
					return fmt.Errorf(`unterminated quoted string literal in connection string`)
475
				}
476
				switch r {
477
				case '\'':
478
					break quote
479
				case '\\':
480
					r, _ = s.Next()
481
					fallthrough
482
				default:
483
					valRunes = append(valRunes, r)
484
				}
485
			}
486
		}
487
488
		o[string(keyRunes)] = string(valRunes)
489
	}
490
491
	return nil
492
}
493
494
func (cn *conn) isInTransaction() bool {
495
	return cn.txnStatus == txnStatusIdleInTransaction ||
496
		cn.txnStatus == txnStatusInFailedTransaction
497
}
498
499
func (cn *conn) checkIsInTransaction(intxn bool) {
500
	if cn.isInTransaction() != intxn {
501
		cn.bad = true
502
		errorf("unexpected transaction status %v", cn.txnStatus)
503
	}
504
}
505
506
func (cn *conn) Begin() (_ driver.Tx, err error) {
507
	return cn.begin("")
508
}
509
510
func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
511
	if cn.bad {
512
		return nil, driver.ErrBadConn
513
	}
514
	defer cn.errRecover(&err)
515
516
	cn.checkIsInTransaction(false)
517
	_, commandTag, err := cn.simpleExec("BEGIN" + mode)
518
	if err != nil {
519
		return nil, err
520
	}
521
	if commandTag != "BEGIN" {
522
		cn.bad = true
523
		return nil, fmt.Errorf("unexpected command tag %s", commandTag)
524
	}
525
	if cn.txnStatus != txnStatusIdleInTransaction {
526
		cn.bad = true
527
		return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
528
	}
529
	return cn, nil
530
}
531
532
func (cn *conn) closeTxn() {
533
	if finish := cn.txnFinish; finish != nil {
534
		finish()
535
	}
536
}
537
538
func (cn *conn) Commit() (err error) {
539
	defer cn.closeTxn()
540
	if cn.bad {
541
		return driver.ErrBadConn
542
	}
543
	defer cn.errRecover(&err)
544
545
	cn.checkIsInTransaction(true)
546
	// We don't want the client to think that everything is okay if it tries
547
	// to commit a failed transaction.  However, no matter what we return,
548
	// database/sql will release this connection back into the free connection
549
	// pool so we have to abort the current transaction here.  Note that you
550
	// would get the same behaviour if you issued a COMMIT in a failed
551
	// transaction, so it's also the least surprising thing to do here.
552
	if cn.txnStatus == txnStatusInFailedTransaction {
553
		if err := cn.rollback(); err != nil {
554
			return err
555
		}
556
		return ErrInFailedTransaction
557
	}
558
559
	_, commandTag, err := cn.simpleExec("COMMIT")
560
	if err != nil {
561
		if cn.isInTransaction() {
562
			cn.bad = true
563
		}
564
		return err
565
	}
566
	if commandTag != "COMMIT" {
567
		cn.bad = true
568
		return fmt.Errorf("unexpected command tag %s", commandTag)
569
	}
570
	cn.checkIsInTransaction(false)
571
	return nil
572
}
573
574
func (cn *conn) Rollback() (err error) {
575
	defer cn.closeTxn()
576
	if cn.bad {
577
		return driver.ErrBadConn
578
	}
579
	defer cn.errRecover(&err)
580
	return cn.rollback()
581
}
582
583
func (cn *conn) rollback() (err error) {
584
	cn.checkIsInTransaction(true)
585
	_, commandTag, err := cn.simpleExec("ROLLBACK")
586
	if err != nil {
587
		if cn.isInTransaction() {
588
			cn.bad = true
589
		}
590
		return err
591
	}
592
	if commandTag != "ROLLBACK" {
593
		return fmt.Errorf("unexpected command tag %s", commandTag)
594
	}
595
	cn.checkIsInTransaction(false)
596
	return nil
597
}
598
599
func (cn *conn) gname() string {
600
	cn.namei++
601
	return strconv.FormatInt(int64(cn.namei), 10)
602
}
603
604
func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
605
	b := cn.writeBuf('Q')
606
	b.string(q)
607
	cn.send(b)
608
609
	for {
610
		t, r := cn.recv1()
611
		switch t {
612
		case 'C':
613
			res, commandTag = cn.parseComplete(r.string())
614
		case 'Z':
615
			cn.processReadyForQuery(r)
616
			if res == nil && err == nil {
617
				err = errUnexpectedReady
618
			}
619
			// done
620
			return
621
		case 'E':
622
			err = parseError(r)
623
		case 'I':
624
			res = emptyRows
625
		case 'T', 'D':
626
			// ignore any results
627
		default:
628
			cn.bad = true
629
			errorf("unknown response for simple query: %q", t)
630
		}
631
	}
632
}
633
634
func (cn *conn) simpleQuery(q string) (res *rows, err error) {
635
	defer cn.errRecover(&err)
636
637
	b := cn.writeBuf('Q')
638
	b.string(q)
639
	cn.send(b)
640
641
	for {
642
		t, r := cn.recv1()
643
		switch t {
644
		case 'C', 'I':
645
			// We allow queries which don't return any results through Query as
646
			// well as Exec.  We still have to give database/sql a rows object
647
			// the user can close, though, to avoid connections from being
648
			// leaked.  A "rows" with done=true works fine for that purpose.
649
			if err != nil {
650
				cn.bad = true
651
				errorf("unexpected message %q in simple query execution", t)
652
			}
653
			if res == nil {
654
				res = &rows{
655
					cn: cn,
656
				}
657
			}
658
			// Set the result and tag to the last command complete if there wasn't a
659
			// query already run. Although queries usually return from here and cede
660
			// control to Next, a query with zero results does not.
661
			if t == 'C' && res.colNames == nil {
662
				res.result, res.tag = cn.parseComplete(r.string())
663
			}
664
			res.done = true
665
		case 'Z':
666
			cn.processReadyForQuery(r)
667
			// done
668
			return
669
		case 'E':
670
			res = nil
671
			err = parseError(r)
672
		case 'D':
673
			if res == nil {
674
				cn.bad = true
675
				errorf("unexpected DataRow in simple query execution")
676
			}
677
			// the query didn't fail; kick off to Next
678
			cn.saveMessage(t, r)
679
			return
680
		case 'T':
681
			// res might be non-nil here if we received a previous
682
			// CommandComplete, but that's fine; just overwrite it
683
			res = &rows{cn: cn}
684
			res.rowsHeader = parsePortalRowDescribe(r)
685
686
			// To work around a bug in QueryRow in Go 1.2 and earlier, wait
687
			// until the first DataRow has been received.
688
		default:
689
			cn.bad = true
690
			errorf("unknown response for simple query: %q", t)
691
		}
692
	}
693
}
694
695
type noRows struct{}
696
697
var emptyRows noRows
698
699
var _ driver.Result = noRows{}
700
701
func (noRows) LastInsertId() (int64, error) {
702
	return 0, errNoLastInsertID
703
}
704
705
func (noRows) RowsAffected() (int64, error) {
706
	return 0, errNoRowsAffected
707
}
708
709
// Decides which column formats to use for a prepared statement.  The input is
710
// an array of type oids, one element per result column.
711
func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
712
	if len(colTyps) == 0 {
713
		return nil, colFmtDataAllText
714
	}
715
716
	colFmts = make([]format, len(colTyps))
717
	if forceText {
718
		return colFmts, colFmtDataAllText
719
	}
720
721
	allBinary := true
722
	allText := true
723
	for i, t := range colTyps {
724
		switch t.OID {
725
		// This is the list of types to use binary mode for when receiving them
726
		// through a prepared statement.  If a type appears in this list, it
727
		// must also be implemented in binaryDecode in encode.go.
728
		case oid.T_bytea:
729
			fallthrough
730
		case oid.T_int8:
731
			fallthrough
732
		case oid.T_int4:
733
			fallthrough
734
		case oid.T_int2:
735
			fallthrough
736
		case oid.T_uuid:
737
			colFmts[i] = formatBinary
738
			allText = false
739
740
		default:
741
			allBinary = false
742
		}
743
	}
744
745
	if allBinary {
746
		return colFmts, colFmtDataAllBinary
747
	} else if allText {
748
		return colFmts, colFmtDataAllText
749
	} else {
750
		colFmtData = make([]byte, 2+len(colFmts)*2)
751
		binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
752
		for i, v := range colFmts {
753
			binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
754
		}
755
		return colFmts, colFmtData
756
	}
757
}
758
759
func (cn *conn) prepareTo(q, stmtName string) *stmt {
760
	st := &stmt{cn: cn, name: stmtName}
761
762
	b := cn.writeBuf('P')
763
	b.string(st.name)
764
	b.string(q)
765
	b.int16(0)
766
767
	b.next('D')
768
	b.byte('S')
769
	b.string(st.name)
770
771
	b.next('S')
772
	cn.send(b)
773
774
	cn.readParseResponse()
775
	st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
776
	st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
777
	cn.readReadyForQuery()
778
	return st
779
}
780
781
func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
782
	if cn.bad {
783
		return nil, driver.ErrBadConn
784
	}
785
	defer cn.errRecover(&err)
786
787
	if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
788
		s, err := cn.prepareCopyIn(q)
789
		if err == nil {
790
			cn.inCopy = true
791
		}
792
		return s, err
793
	}
794
	return cn.prepareTo(q, cn.gname()), nil
795
}
796
797
func (cn *conn) Close() (err error) {
798
	// Skip cn.bad return here because we always want to close a connection.
799
	defer cn.errRecover(&err)
800
801
	// Ensure that cn.c.Close is always run. Since error handling is done with
802
	// panics and cn.errRecover, the Close must be in a defer.
803
	defer func() {
804
		cerr := cn.c.Close()
805
		if err == nil {
806
			err = cerr
807
		}
808
	}()
809
810
	// Don't go through send(); ListenerConn relies on us not scribbling on the
811
	// scratch buffer of this connection.
812
	return cn.sendSimpleMessage('X')
813
}
814
815
// Implement the "Queryer" interface
816
func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
817
	return cn.query(query, args)
818
}
819
820
func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
821
	if cn.bad {
822
		return nil, driver.ErrBadConn
823
	}
824
	if cn.inCopy {
825
		return nil, errCopyInProgress
826
	}
827
	defer cn.errRecover(&err)
828
829
	// Check to see if we can use the "simpleQuery" interface, which is
830
	// *much* faster than going through prepare/exec
831
	if len(args) == 0 {
832
		return cn.simpleQuery(query)
833
	}
834
835
	if cn.binaryParameters {
836
		cn.sendBinaryModeQuery(query, args)
837
838
		cn.readParseResponse()
839
		cn.readBindResponse()
840
		rows := &rows{cn: cn}
841
		rows.rowsHeader = cn.readPortalDescribeResponse()
842
		cn.postExecuteWorkaround()
843
		return rows, nil
844
	}
845
	st := cn.prepareTo(query, "")
846
	st.exec(args)
847
	return &rows{
848
		cn:         cn,
849
		rowsHeader: st.rowsHeader,
850
	}, nil
851
}
852
853
// Implement the optional "Execer" interface for one-shot queries
854
func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
855
	if cn.bad {
856
		return nil, driver.ErrBadConn
857
	}
858
	defer cn.errRecover(&err)
859
860
	// Check to see if we can use the "simpleExec" interface, which is
861
	// *much* faster than going through prepare/exec
862
	if len(args) == 0 {
863
		// ignore commandTag, our caller doesn't care
864
		r, _, err := cn.simpleExec(query)
865
		return r, err
866
	}
867
868
	if cn.binaryParameters {
869
		cn.sendBinaryModeQuery(query, args)
870
871
		cn.readParseResponse()
872
		cn.readBindResponse()
873
		cn.readPortalDescribeResponse()
874
		cn.postExecuteWorkaround()
875
		res, _, err = cn.readExecuteResponse("Execute")
876
		return res, err
877
	}
878
	// Use the unnamed statement to defer planning until bind
879
	// time, or else value-based selectivity estimates cannot be
880
	// used.
881
	st := cn.prepareTo(query, "")
882
	r, err := st.Exec(args)
883
	if err != nil {
884
		panic(err)
885
	}
886
	return r, err
887
}
888
889
func (cn *conn) send(m *writeBuf) {
890
	_, err := cn.c.Write(m.wrap())
891
	if err != nil {
892
		panic(err)
893
	}
894
}
895
896
func (cn *conn) sendStartupPacket(m *writeBuf) error {
897
	_, err := cn.c.Write((m.wrap())[1:])
898
	return err
899
}
900
901
// Send a message of type typ to the server on the other end of cn.  The
902
// message should have no payload.  This method does not use the scratch
903
// buffer.
904
func (cn *conn) sendSimpleMessage(typ byte) (err error) {
905
	_, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
906
	return err
907
}
908
909
// saveMessage memorizes a message and its buffer in the conn struct.
910
// recvMessage will then return these values on the next call to it.  This
911
// method is useful in cases where you have to see what the next message is
912
// going to be (e.g. to see whether it's an error or not) but you can't handle
913
// the message yourself.
914
func (cn *conn) saveMessage(typ byte, buf *readBuf) {
915
	if cn.saveMessageType != 0 {
916
		cn.bad = true
917
		errorf("unexpected saveMessageType %d", cn.saveMessageType)
918
	}
919
	cn.saveMessageType = typ
920
	cn.saveMessageBuffer = *buf
921
}
922
923
// recvMessage receives any message from the backend, or returns an error if
924
// a problem occurred while reading the message.
925
func (cn *conn) recvMessage(r *readBuf) (byte, error) {
926
	// workaround for a QueryRow bug, see exec
927
	if cn.saveMessageType != 0 {
928
		t := cn.saveMessageType
929
		*r = cn.saveMessageBuffer
930
		cn.saveMessageType = 0
931
		cn.saveMessageBuffer = nil
932
		return t, nil
933
	}
934
935
	x := cn.scratch[:5]
936
	_, err := io.ReadFull(cn.buf, x)
937
	if err != nil {
938
		return 0, err
939
	}
940
941
	// read the type and length of the message that follows
942
	t := x[0]
943
	n := int(binary.BigEndian.Uint32(x[1:])) - 4
944
	var y []byte
945
	if n <= len(cn.scratch) {
946
		y = cn.scratch[:n]
947
	} else {
948
		y = make([]byte, n)
949
	}
950
	_, err = io.ReadFull(cn.buf, y)
951
	if err != nil {
952
		return 0, err
953
	}
954
	*r = y
955
	return t, nil
956
}
957
958
// recv receives a message from the backend, but if an error happened while
959
// reading the message or the received message was an ErrorResponse, it panics.
960
// NoticeResponses are ignored.  This function should generally be used only
961
// during the startup sequence.
962
func (cn *conn) recv() (t byte, r *readBuf) {
963
	for {
964
		var err error
965
		r = &readBuf{}
966
		t, err = cn.recvMessage(r)
967
		if err != nil {
968
			panic(err)
969
		}
970
		switch t {
971
		case 'E':
972
			panic(parseError(r))
973
		case 'N':
974
			// ignore
975
		default:
976
			return
977
		}
978
	}
979
}
980
981
// recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
982
// the caller to avoid an allocation.
983
func (cn *conn) recv1Buf(r *readBuf) byte {
984
	for {
985
		t, err := cn.recvMessage(r)
986
		if err != nil {
987
			panic(err)
988
		}
989
990
		switch t {
991
		case 'A', 'N':
992
			// ignore
993
		case 'S':
994
			cn.processParameterStatus(r)
995
		default:
996
			return t
997
		}
998
	}
999
}
1000
1001
// recv1 receives a message from the backend, panicking if an error occurs
1002
// while attempting to read it.  All asynchronous messages are ignored, with
1003
// the exception of ErrorResponse.
1004
func (cn *conn) recv1() (t byte, r *readBuf) {
1005
	r = &readBuf{}
1006
	t = cn.recv1Buf(r)
1007
	return t, r
1008
}
1009
1010
func (cn *conn) ssl(o values) error {
1011
	upgrade, err := ssl(o)
1012
	if err != nil {
1013
		return err
1014
	}
1015
1016
	if upgrade == nil {
1017
		// Nothing to do
1018
		return nil
1019
	}
1020
1021
	w := cn.writeBuf(0)
1022
	w.int32(80877103)
1023
	if err = cn.sendStartupPacket(w); err != nil {
1024
		return err
1025
	}
1026
1027
	b := cn.scratch[:1]
1028
	_, err = io.ReadFull(cn.c, b)
1029
	if err != nil {
1030
		return err
1031
	}
1032
1033
	if b[0] != 'S' {
1034
		return ErrSSLNotSupported
1035
	}
1036
1037
	cn.c, err = upgrade(cn.c)
1038
	return err
1039
}
1040
1041
// isDriverSetting returns true iff a setting is purely for configuring the
1042
// driver's options and should not be sent to the server in the connection
1043
// startup packet.
1044
func isDriverSetting(key string) bool {
1045
	switch key {
1046
	case "host", "port":
1047
		return true
1048
	case "password":
1049
		return true
1050
	case "sslmode", "sslcert", "sslkey", "sslrootcert":
1051
		return true
1052
	case "fallback_application_name":
1053
		return true
1054
	case "connect_timeout":
1055
		return true
1056
	case "disable_prepared_binary_result":
1057
		return true
1058
	case "binary_parameters":
1059
		return true
1060
1061
	default:
1062
		return false
1063
	}
1064
}
1065
1066
func (cn *conn) startup(o values) {
1067
	w := cn.writeBuf(0)
1068
	w.int32(196608)
1069
	// Send the backend the name of the database we want to connect to, and the
1070
	// user we want to connect as.  Additionally, we send over any run-time
1071
	// parameters potentially included in the connection string.  If the server
1072
	// doesn't recognize any of them, it will reply with an error.
1073
	for k, v := range o {
1074
		if isDriverSetting(k) {
1075
			// skip options which can't be run-time parameters
1076
			continue
1077
		}
1078
		// The protocol requires us to supply the database name as "database"
1079
		// instead of "dbname".
1080
		if k == "dbname" {
1081
			k = "database"
1082
		}
1083
		w.string(k)
1084
		w.string(v)
1085
	}
1086
	w.string("")
1087
	if err := cn.sendStartupPacket(w); err != nil {
1088
		panic(err)
1089
	}
1090
1091
	for {
1092
		t, r := cn.recv()
1093
		switch t {
1094
		case 'K':
1095
			cn.processBackendKeyData(r)
1096
		case 'S':
1097
			cn.processParameterStatus(r)
1098
		case 'R':
1099
			cn.auth(r, o)
1100
		case 'Z':
1101
			cn.processReadyForQuery(r)
1102
			return
1103
		default:
1104
			errorf("unknown response for startup: %q", t)
1105
		}
1106
	}
1107
}
1108
1109
func (cn *conn) auth(r *readBuf, o values) {
1110
	switch code := r.int32(); code {
1111
	case 0:
1112
		// OK
1113
	case 3:
1114
		w := cn.writeBuf('p')
1115
		w.string(o["password"])
1116
		cn.send(w)
1117
1118
		t, r := cn.recv()
1119
		if t != 'R' {
1120
			errorf("unexpected password response: %q", t)
1121
		}
1122
1123
		if r.int32() != 0 {
1124
			errorf("unexpected authentication response: %q", t)
1125
		}
1126
	case 5:
1127
		s := string(r.next(4))
1128
		w := cn.writeBuf('p')
1129
		w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
1130
		cn.send(w)
1131
1132
		t, r := cn.recv()
1133
		if t != 'R' {
1134
			errorf("unexpected password response: %q", t)
1135
		}
1136
1137
		if r.int32() != 0 {
1138
			errorf("unexpected authentication response: %q", t)
1139
		}
1140
	case 10:
1141
		sc := scram.NewClient(sha256.New, o["user"], o["password"])
1142
		sc.Step(nil)
1143
		if sc.Err() != nil {
1144
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1145
		}
1146
		scOut := sc.Out()
1147
1148
		w := cn.writeBuf('p')
1149
		w.string("SCRAM-SHA-256")
1150
		w.int32(len(scOut))
1151
		w.bytes(scOut)
1152
		cn.send(w)
1153
1154
		t, r := cn.recv()
1155
		if t != 'R' {
1156
			errorf("unexpected password response: %q", t)
1157
		}
1158
1159
		if r.int32() != 11 {
1160
			errorf("unexpected authentication response: %q", t)
1161
		}
1162
1163
		nextStep := r.next(len(*r))
1164
		sc.Step(nextStep)
1165
		if sc.Err() != nil {
1166
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1167
		}
1168
1169
		scOut = sc.Out()
1170
		w = cn.writeBuf('p')
1171
		w.bytes(scOut)
1172
		cn.send(w)
1173
1174
		t, r = cn.recv()
1175
		if t != 'R' {
1176
			errorf("unexpected password response: %q", t)
1177
		}
1178
1179
		if r.int32() != 12 {
1180
			errorf("unexpected authentication response: %q", t)
1181
		}
1182
1183
		nextStep = r.next(len(*r))
1184
		sc.Step(nextStep)
1185
		if sc.Err() != nil {
1186
			errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
1187
		}
1188
1189
	default:
1190
		errorf("unknown authentication response: %d", code)
1191
	}
1192
}
1193
1194
type format int
1195
1196
const formatText format = 0
1197
const formatBinary format = 1
1198
1199
// One result-column format code with the value 1 (i.e. all binary).
1200
var colFmtDataAllBinary = []byte{0, 1, 0, 1}
1201
1202
// No result-column format codes (i.e. all text).
1203
var colFmtDataAllText = []byte{0, 0}
1204
1205
type stmt struct {
1206
	cn   *conn
1207
	name string
1208
	rowsHeader
1209
	colFmtData []byte
1210
	paramTyps  []oid.Oid
1211
	closed     bool
1212
}
1213
1214
func (st *stmt) Close() (err error) {
1215
	if st.closed {
1216
		return nil
1217
	}
1218
	if st.cn.bad {
1219
		return driver.ErrBadConn
1220
	}
1221
	defer st.cn.errRecover(&err)
1222
1223
	w := st.cn.writeBuf('C')
1224
	w.byte('S')
1225
	w.string(st.name)
1226
	st.cn.send(w)
1227
1228
	st.cn.send(st.cn.writeBuf('S'))
1229
1230
	t, _ := st.cn.recv1()
1231
	if t != '3' {
1232
		st.cn.bad = true
1233
		errorf("unexpected close response: %q", t)
1234
	}
1235
	st.closed = true
1236
1237
	t, r := st.cn.recv1()
1238
	if t != 'Z' {
1239
		st.cn.bad = true
1240
		errorf("expected ready for query, but got: %q", t)
1241
	}
1242
	st.cn.processReadyForQuery(r)
1243
1244
	return nil
1245
}
1246
1247
func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
1248
	if st.cn.bad {
1249
		return nil, driver.ErrBadConn
1250
	}
1251
	defer st.cn.errRecover(&err)
1252
1253
	st.exec(v)
1254
	return &rows{
1255
		cn:         st.cn,
1256
		rowsHeader: st.rowsHeader,
1257
	}, nil
1258
}
1259
1260
func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
1261
	if st.cn.bad {
1262
		return nil, driver.ErrBadConn
1263
	}
1264
	defer st.cn.errRecover(&err)
1265
1266
	st.exec(v)
1267
	res, _, err = st.cn.readExecuteResponse("simple query")
1268
	return res, err
1269
}
1270
1271
func (st *stmt) exec(v []driver.Value) {
1272
	if len(v) >= 65536 {
1273
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
1274
	}
1275
	if len(v) != len(st.paramTyps) {
1276
		errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
1277
	}
1278
1279
	cn := st.cn
1280
	w := cn.writeBuf('B')
1281
	w.byte(0) // unnamed portal
1282
	w.string(st.name)
1283
1284
	if cn.binaryParameters {
1285
		cn.sendBinaryParameters(w, v)
1286
	} else {
1287
		w.int16(0)
1288
		w.int16(len(v))
1289
		for i, x := range v {
1290
			if x == nil {
1291
				w.int32(-1)
1292
			} else {
1293
				b := encode(&cn.parameterStatus, x, st.paramTyps[i])
1294
				w.int32(len(b))
1295
				w.bytes(b)
1296
			}
1297
		}
1298
	}
1299
	w.bytes(st.colFmtData)
1300
1301
	w.next('E')
1302
	w.byte(0)
1303
	w.int32(0)
1304
1305
	w.next('S')
1306
	cn.send(w)
1307
1308
	cn.readBindResponse()
1309
	cn.postExecuteWorkaround()
1310
1311
}
1312
1313
func (st *stmt) NumInput() int {
1314
	return len(st.paramTyps)
1315
}
1316
1317
// parseComplete parses the "command tag" from a CommandComplete message, and
1318
// returns the number of rows affected (if applicable) and a string
1319
// identifying only the command that was executed, e.g. "ALTER TABLE".  If the
1320
// command tag could not be parsed, parseComplete panics.
1321
func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
1322
	commandsWithAffectedRows := []string{
1323
		"SELECT ",
1324
		// INSERT is handled below
1325
		"UPDATE ",
1326
		"DELETE ",
1327
		"FETCH ",
1328
		"MOVE ",
1329
		"COPY ",
1330
	}
1331
1332
	var affectedRows *string
1333
	for _, tag := range commandsWithAffectedRows {
1334
		if strings.HasPrefix(commandTag, tag) {
1335
			t := commandTag[len(tag):]
1336
			affectedRows = &t
1337
			commandTag = tag[:len(tag)-1]
1338
			break
1339
		}
1340
	}
1341
	// INSERT also includes the oid of the inserted row in its command tag.
1342
	// Oids in user tables are deprecated, and the oid is only returned when
1343
	// exactly one row is inserted, so it's unlikely to be of value to any
1344
	// real-world application and we can ignore it.
1345
	if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
1346
		parts := strings.Split(commandTag, " ")
1347
		if len(parts) != 3 {
1348
			cn.bad = true
1349
			errorf("unexpected INSERT command tag %s", commandTag)
1350
		}
1351
		affectedRows = &parts[len(parts)-1]
1352
		commandTag = "INSERT"
1353
	}
1354
	// There should be no affected rows attached to the tag, just return it
1355
	if affectedRows == nil {
1356
		return driver.RowsAffected(0), commandTag
1357
	}
1358
	n, err := strconv.ParseInt(*affectedRows, 10, 64)
1359
	if err != nil {
1360
		cn.bad = true
1361
		errorf("could not parse commandTag: %s", err)
1362
	}
1363
	return driver.RowsAffected(n), commandTag
1364
}
1365
1366
type rowsHeader struct {
1367
	colNames []string
1368
	colTyps  []fieldDesc
1369
	colFmts  []format
1370
}
1371
1372
type rows struct {
1373
	cn     *conn
1374
	finish func()
1375
	rowsHeader
1376
	done   bool
1377
	rb     readBuf
1378
	result driver.Result
1379
	tag    string
1380
1381
	next *rowsHeader
1382
}
1383
1384
func (rs *rows) Close() error {
1385
	if finish := rs.finish; finish != nil {
1386
		defer finish()
1387
	}
1388
	// no need to look at cn.bad as Next() will
1389
	for {
1390
		err := rs.Next(nil)
1391
		switch err {
1392
		case nil:
1393
		case io.EOF:
1394
			// rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
1395
			// description, used with HasNextResultSet). We need to fetch messages until
1396
			// we hit a 'Z', which is done by waiting for done to be set.
1397
			if rs.done {
1398
				return nil
1399
			}
1400
		default:
1401
			return err
1402
		}
1403
	}
1404
}
1405
1406
func (rs *rows) Columns() []string {
1407
	return rs.colNames
1408
}
1409
1410
func (rs *rows) Result() driver.Result {
1411
	if rs.result == nil {
1412
		return emptyRows
1413
	}
1414
	return rs.result
1415
}
1416
1417
func (rs *rows) Tag() string {
1418
	return rs.tag
1419
}
1420
1421
func (rs *rows) Next(dest []driver.Value) (err error) {
1422
	if rs.done {
1423
		return io.EOF
1424
	}
1425
1426
	conn := rs.cn
1427
	if conn.bad {
1428
		return driver.ErrBadConn
1429
	}
1430
	defer conn.errRecover(&err)
1431
1432
	for {
1433
		t := conn.recv1Buf(&rs.rb)
1434
		switch t {
1435
		case 'E':
1436
			err = parseError(&rs.rb)
1437
		case 'C', 'I':
1438
			if t == 'C' {
1439
				rs.result, rs.tag = conn.parseComplete(rs.rb.string())
1440
			}
1441
			continue
1442
		case 'Z':
1443
			conn.processReadyForQuery(&rs.rb)
1444
			rs.done = true
1445
			if err != nil {
1446
				return err
1447
			}
1448
			return io.EOF
1449
		case 'D':
1450
			n := rs.rb.int16()
1451
			if err != nil {
1452
				conn.bad = true
1453
				errorf("unexpected DataRow after error %s", err)
1454
			}
1455
			if n < len(dest) {
1456
				dest = dest[:n]
1457
			}
1458
			for i := range dest {
1459
				l := rs.rb.int32()
1460
				if l == -1 {
1461
					dest[i] = nil
1462
					continue
1463
				}
1464
				dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
1465
			}
1466
			return
1467
		case 'T':
1468
			next := parsePortalRowDescribe(&rs.rb)
1469
			rs.next = &next
1470
			return io.EOF
1471
		default:
1472
			errorf("unexpected message after execute: %q", t)
1473
		}
1474
	}
1475
}
1476
1477
func (rs *rows) HasNextResultSet() bool {
1478
	hasNext := rs.next != nil && !rs.done
1479
	return hasNext
1480
}
1481
1482
func (rs *rows) NextResultSet() error {
1483
	if rs.next == nil {
1484
		return io.EOF
1485
	}
1486
	rs.rowsHeader = *rs.next
1487
	rs.next = nil
1488
	return nil
1489
}
1490
1491
// QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
1492
// used as part of an SQL statement.  For example:
1493
//
1494
//    tblname := "my_table"
1495
//    data := "my_data"
1496
//    quoted := pq.QuoteIdentifier(tblname)
1497
//    err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
1498
//
1499
// Any double quotes in name will be escaped.  The quoted identifier will be
1500
// case sensitive when used in a query.  If the input string contains a zero
1501
// byte, the result will be truncated immediately before it.
1502
func QuoteIdentifier(name string) string {
1503
	end := strings.IndexRune(name, 0)
1504
	if end > -1 {
1505
		name = name[:end]
1506
	}
1507
	return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
1508
}
1509
1510
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
1511
// to DDL and other statements that do not accept parameters) to be used as part
1512
// of an SQL statement.  For example:
1513
//
1514
//    exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
1515
//    err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
1516
//
1517
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
1518
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
1519
// that PostgreSQL provides ('E') will be prepended to the string.
1520
func QuoteLiteral(literal string) string {
1521
	// This follows the PostgreSQL internal algorithm for handling quoted literals
1522
	// from libpq, which can be found in the "PQEscapeStringInternal" function,
1523
	// which is found in the libpq/fe-exec.c source file:
1524
	// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
1525
	//
1526
	// substitute any single-quotes (') with two single-quotes ('')
1527
	literal = strings.Replace(literal, `'`, `''`, -1)
1528
	// determine if the string has any backslashes (\) in it.
1529
	// if it does, replace any backslashes (\) with two backslashes (\\)
1530
	// then, we need to wrap the entire string with a PostgreSQL
1531
	// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
1532
	// also add a space before the "E"
1533
	if strings.Contains(literal, `\`) {
1534
		literal = strings.Replace(literal, `\`, `\\`, -1)
1535
		literal = ` E'` + literal + `'`
1536
	} else {
1537
		// otherwise, we can just wrap the literal with a pair of single quotes
1538
		literal = `'` + literal + `'`
1539
	}
1540
	return literal
1541
}
1542
1543
func md5s(s string) string {
1544
	h := md5.New()
1545
	h.Write([]byte(s))
1546
	return fmt.Sprintf("%x", h.Sum(nil))
1547
}
1548
1549
func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
1550
	// Do one pass over the parameters to see if we're going to send any of
1551
	// them over in binary.  If we are, create a paramFormats array at the
1552
	// same time.
1553
	var paramFormats []int
1554
	for i, x := range args {
1555
		_, ok := x.([]byte)
1556
		if ok {
1557
			if paramFormats == nil {
1558
				paramFormats = make([]int, len(args))
1559
			}
1560
			paramFormats[i] = 1
1561
		}
1562
	}
1563
	if paramFormats == nil {
1564
		b.int16(0)
1565
	} else {
1566
		b.int16(len(paramFormats))
1567
		for _, x := range paramFormats {
1568
			b.int16(x)
1569
		}
1570
	}
1571
1572
	b.int16(len(args))
1573
	for _, x := range args {
1574
		if x == nil {
1575
			b.int32(-1)
1576
		} else {
1577
			datum := binaryEncode(&cn.parameterStatus, x)
1578
			b.int32(len(datum))
1579
			b.bytes(datum)
1580
		}
1581
	}
1582
}
1583
1584
func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
1585
	if len(args) >= 65536 {
1586
		errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
1587
	}
1588
1589
	b := cn.writeBuf('P')
1590
	b.byte(0) // unnamed statement
1591
	b.string(query)
1592
	b.int16(0)
1593
1594
	b.next('B')
1595
	b.int16(0) // unnamed portal and statement
1596
	cn.sendBinaryParameters(b, args)
1597
	b.bytes(colFmtDataAllText)
1598
1599
	b.next('D')
1600
	b.byte('P')
1601
	b.byte(0) // unnamed portal
1602
1603
	b.next('E')
1604
	b.byte(0)
1605
	b.int32(0)
1606
1607
	b.next('S')
1608
	cn.send(b)
1609
}
1610
1611
func (cn *conn) processParameterStatus(r *readBuf) {
1612
	var err error
1613
1614
	param := r.string()
1615
	switch param {
1616
	case "server_version":
1617
		var major1 int
1618
		var major2 int
1619
		var minor int
1620
		_, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
1621
		if err == nil {
1622
			cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
1623
		}
1624
1625
	case "TimeZone":
1626
		cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
1627
		if err != nil {
1628
			cn.parameterStatus.currentLocation = nil
1629
		}
1630
1631
	default:
1632
		// ignore
1633
	}
1634
}
1635
1636
func (cn *conn) processReadyForQuery(r *readBuf) {
1637
	cn.txnStatus = transactionStatus(r.byte())
1638
}
1639
1640
func (cn *conn) readReadyForQuery() {
1641
	t, r := cn.recv1()
1642
	switch t {
1643
	case 'Z':
1644
		cn.processReadyForQuery(r)
1645
		return
1646
	default:
1647
		cn.bad = true
1648
		errorf("unexpected message %q; expected ReadyForQuery", t)
1649
	}
1650
}
1651
1652
func (cn *conn) processBackendKeyData(r *readBuf) {
1653
	cn.processID = r.int32()
1654
	cn.secretKey = r.int32()
1655
}
1656
1657
func (cn *conn) readParseResponse() {
1658
	t, r := cn.recv1()
1659
	switch t {
1660
	case '1':
1661
		return
1662
	case 'E':
1663
		err := parseError(r)
1664
		cn.readReadyForQuery()
1665
		panic(err)
1666
	default:
1667
		cn.bad = true
1668
		errorf("unexpected Parse response %q", t)
1669
	}
1670
}
1671
1672
func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
1673
	for {
1674
		t, r := cn.recv1()
1675
		switch t {
1676
		case 't':
1677
			nparams := r.int16()
1678
			paramTyps = make([]oid.Oid, nparams)
1679
			for i := range paramTyps {
1680
				paramTyps[i] = r.oid()
1681
			}
1682
		case 'n':
1683
			return paramTyps, nil, nil
1684
		case 'T':
1685
			colNames, colTyps = parseStatementRowDescribe(r)
1686
			return paramTyps, colNames, colTyps
1687
		case 'E':
1688
			err := parseError(r)
1689
			cn.readReadyForQuery()
1690
			panic(err)
1691
		default:
1692
			cn.bad = true
1693
			errorf("unexpected Describe statement response %q", t)
1694
		}
1695
	}
1696
}
1697
1698
func (cn *conn) readPortalDescribeResponse() rowsHeader {
1699
	t, r := cn.recv1()
1700
	switch t {
1701
	case 'T':
1702
		return parsePortalRowDescribe(r)
1703
	case 'n':
1704
		return rowsHeader{}
1705
	case 'E':
1706
		err := parseError(r)
1707
		cn.readReadyForQuery()
1708
		panic(err)
1709
	default:
1710
		cn.bad = true
1711
		errorf("unexpected Describe response %q", t)
1712
	}
1713
	panic("not reached")
1714
}
1715
1716
func (cn *conn) readBindResponse() {
1717
	t, r := cn.recv1()
1718
	switch t {
1719
	case '2':
1720
		return
1721
	case 'E':
1722
		err := parseError(r)
1723
		cn.readReadyForQuery()
1724
		panic(err)
1725
	default:
1726
		cn.bad = true
1727
		errorf("unexpected Bind response %q", t)
1728
	}
1729
}
1730
1731
func (cn *conn) postExecuteWorkaround() {
1732
	// Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
1733
	// any errors from rows.Next, which masks errors that happened during the
1734
	// execution of the query.  To avoid the problem in common cases, we wait
1735
	// here for one more message from the database.  If it's not an error the
1736
	// query will likely succeed (or perhaps has already, if it's a
1737
	// CommandComplete), so we push the message into the conn struct; recv1
1738
	// will return it as the next message for rows.Next or rows.Close.
1739
	// However, if it's an error, we wait until ReadyForQuery and then return
1740
	// the error to our caller.
1741
	for {
1742
		t, r := cn.recv1()
1743
		switch t {
1744
		case 'E':
1745
			err := parseError(r)
1746
			cn.readReadyForQuery()
1747
			panic(err)
1748
		case 'C', 'D', 'I':
1749
			// the query didn't fail, but we can't process this message
1750
			cn.saveMessage(t, r)
1751
			return
1752
		default:
1753
			cn.bad = true
1754
			errorf("unexpected message during extended query execution: %q", t)
1755
		}
1756
	}
1757
}
1758
1759
// Only for Exec(), since we ignore the returned data
1760
func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
1761
	for {
1762
		t, r := cn.recv1()
1763
		switch t {
1764
		case 'C':
1765
			if err != nil {
1766
				cn.bad = true
1767
				errorf("unexpected CommandComplete after error %s", err)
1768
			}
1769
			res, commandTag = cn.parseComplete(r.string())
1770
		case 'Z':
1771
			cn.processReadyForQuery(r)
1772
			if res == nil && err == nil {
1773
				err = errUnexpectedReady
1774
			}
1775
			return res, commandTag, err
1776
		case 'E':
1777
			err = parseError(r)
1778
		case 'T', 'D', 'I':
1779
			if err != nil {
1780
				cn.bad = true
1781
				errorf("unexpected %q after error %s", t, err)
1782
			}
1783
			if t == 'I' {
1784
				res = emptyRows
1785
			}
1786
			// ignore any results
1787
		default:
1788
			cn.bad = true
1789
			errorf("unknown %s response: %q", protocolState, t)
1790
		}
1791
	}
1792
}
1793
1794
func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
1795
	n := r.int16()
1796
	colNames = make([]string, n)
1797
	colTyps = make([]fieldDesc, n)
1798
	for i := range colNames {
1799
		colNames[i] = r.string()
1800
		r.next(6)
1801
		colTyps[i].OID = r.oid()
1802
		colTyps[i].Len = r.int16()
1803
		colTyps[i].Mod = r.int32()
1804
		// format code not known when describing a statement; always 0
1805
		r.next(2)
1806
	}
1807
	return
1808
}
1809
1810
func parsePortalRowDescribe(r *readBuf) rowsHeader {
1811
	n := r.int16()
1812
	colNames := make([]string, n)
1813
	colFmts := make([]format, n)
1814
	colTyps := make([]fieldDesc, n)
1815
	for i := range colNames {
1816
		colNames[i] = r.string()
1817
		r.next(6)
1818
		colTyps[i].OID = r.oid()
1819
		colTyps[i].Len = r.int16()
1820
		colTyps[i].Mod = r.int32()
1821
		colFmts[i] = format(r.int16())
1822
	}
1823
	return rowsHeader{
1824
		colNames: colNames,
1825
		colFmts:  colFmts,
1826
		colTyps:  colTyps,
1827
	}
1828
}
1829
1830
// parseEnviron tries to mimic some of libpq's environment handling
1831
//
1832
// To ease testing, it does not directly reference os.Environ, but is
1833
// designed to accept its output.
1834
//
1835
// Environment-set connection information is intended to have a higher
1836
// precedence than a library default but lower than any explicitly
1837
// passed information (such as in the URL or connection string).
1838
func parseEnviron(env []string) (out map[string]string) {
1839
	out = make(map[string]string)
1840
1841
	for _, v := range env {
1842
		parts := strings.SplitN(v, "=", 2)
1843
1844
		accrue := func(keyname string) {
1845
			out[keyname] = parts[1]
1846
		}
1847
		unsupported := func() {
1848
			panic(fmt.Sprintf("setting %v not supported", parts[0]))
1849
		}
1850
1851
		// The order of these is the same as is seen in the
1852
		// PostgreSQL 9.1 manual. Unsupported but well-defined
1853
		// keys cause a panic; these should be unset prior to
1854
		// execution. Options which pq expects to be set to a
1855
		// certain value are allowed, but must be set to that
1856
		// value if present (they can, of course, be absent).
1857
		switch parts[0] {
1858
		case "PGHOST":
1859
			accrue("host")
1860
		case "PGHOSTADDR":
1861
			unsupported()
1862
		case "PGPORT":
1863
			accrue("port")
1864
		case "PGDATABASE":
1865
			accrue("dbname")
1866
		case "PGUSER":
1867
			accrue("user")
1868
		case "PGPASSWORD":
1869
			accrue("password")
1870
		case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
1871
			unsupported()
1872
		case "PGOPTIONS":
1873
			accrue("options")
1874
		case "PGAPPNAME":
1875
			accrue("application_name")
1876
		case "PGSSLMODE":
1877
			accrue("sslmode")
1878
		case "PGSSLCERT":
1879
			accrue("sslcert")
1880
		case "PGSSLKEY":
1881
			accrue("sslkey")
1882
		case "PGSSLROOTCERT":
1883
			accrue("sslrootcert")
1884
		case "PGREQUIRESSL", "PGSSLCRL":
1885
			unsupported()
1886
		case "PGREQUIREPEER":
1887
			unsupported()
1888
		case "PGKRBSRVNAME", "PGGSSLIB":
1889
			unsupported()
1890
		case "PGCONNECT_TIMEOUT":
1891
			accrue("connect_timeout")
1892
		case "PGCLIENTENCODING":
1893
			accrue("client_encoding")
1894
		case "PGDATESTYLE":
1895
			accrue("datestyle")
1896
		case "PGTZ":
1897
			accrue("timezone")
1898
		case "PGGEQO":
1899
			accrue("geqo")
1900
		case "PGSYSCONFDIR", "PGLOCALEDIR":
1901
			unsupported()
1902
		}
1903
	}
1904
1905
	return out
1906
}
1907
1908
// isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
1909
func isUTF8(name string) bool {
1910
	// Recognize all sorts of silly things as "UTF-8", like Postgres does
1911
	s := strings.Map(alnumLowerASCII, name)
1912
	return s == "utf8" || s == "unicode"
1913
}
1914
1915
func alnumLowerASCII(ch rune) rune {
1916
	if 'A' <= ch && ch <= 'Z' {
1917
		return ch + ('a' - 'A')
1918
	}
1919
	if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
1920
		return ch
1921
	}
1922
	return -1 // discard
1923
}
1924