Passed
Push — main ( 9e2022...6ef1c9 )
by Christian
02:04 queued 14s
created

connector.*pgSqlConn.WithTLS   B

Complexity

Conditions 6

Size

Total Lines 27
Code Lines 18

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 6
eloc 18
nop 6
dl 0
loc 27
rs 8.5666
c 0
b 0
f 0
1
package connector
2
3
import (
4
	"crypto/tls"
5
	"crypto/x509"
6
	"database/sql"
7
	"fmt"
8
	"strings"
9
10
	"github.com/cdleo/go-commons/logger"
11
	"github.com/cdleo/go-commons/sqlcommons"
12
	pgx "github.com/jackc/pgx/v4"
13
	stdlib "github.com/jackc/pgx/v4/stdlib"
14
)
15
16
type pgSqlConn struct {
17
	host      string
18
	port      int
19
	user      string
20
	password  string
21
	database  string
22
	sslMode   string
23
	TLSConfig *tls.Config
24
}
25
26
const postgresProxyName = "pgx-proxy"
27
28
func NewPostgreSqlConnector(host string, port int, user string, password string, database string) sqlcommons.SQLConnector {
29
30
	return &pgSqlConn{
31
		host:     host,
32
		port:     port,
33
		user:     user,
34
		password: password,
35
		database: database,
36
		sslMode:  "disable",
37
	}
38
}
39
40
func (s *pgSqlConn) WithTLS(sslMode string, allowInsecure bool, serverName string, serverCertificate string, clientCertificate string, clientKey string) error {
41
42
	config := &tls.Config{
43
		InsecureSkipVerify: allowInsecure,
44
		ServerName:         serverName,
45
	}
46
47
	if serverCertificate != "" {
48
		caCertPool := x509.NewCertPool()
49
		ok := caCertPool.AppendCertsFromPEM([]byte(serverCertificate))
50
		if !ok {
51
			return fmt.Errorf("unable to append Certs from PEM")
52
		}
53
		config.RootCAs = caCertPool
54
	}
55
56
	if clientCertificate != "" && clientKey != "" {
57
		keypair, err := tls.X509KeyPair([]byte(clientCertificate), []byte(clientKey))
58
		if err != nil {
59
			return fmt.Errorf("unable to create keypair of client [%v]", err)
60
		}
61
		config.Certificates = []tls.Certificate{keypair}
62
	}
63
64
	s.TLSConfig = config
65
	s.sslMode = sslMode
66
	return nil
67
}
68
69
func (s *pgSqlConn) Open(logger logger.Logger, translator sqlcommons.SQLAdapter) (*sql.DB, error) {
70
71
	registerProxy(postgresProxyName, logger, translator, stdlib.GetDefaultDriver())
72
73
	psqlConn := fmt.Sprintf("host=%v port=%v user=%v password=%v dbname=%v sslmode=%v", s.host, s.port, s.user, s.password, s.database, s.sslMode)
74
75
	config, err := pgx.ParseConfig(psqlConn)
76
	if err != nil {
77
		return nil, err
78
	}
79
	config.TLSConfig = s.TLSConfig
80
81
	dbURI := stdlib.RegisterConnConfig(config)
82
	dbPool, err := sql.Open(postgresProxyName, dbURI)
83
	if err != nil {
84
		return nil, fmt.Errorf("sql.Open: %w", err)
85
	}
86
	return dbPool, nil
87
}
88
89
func (s *pgSqlConn) GetNextSequenceQuery(sequenceName string) string {
90
	return fmt.Sprintf("SELECT nextval('%s')", strings.ToLower(sequenceName))
91
}
92