Test Failed
Push — add-integration-tests ( dde5b7...e98ac5 )
by Frank
02:26
created

database.*Postgresql.Version   A

Complexity

Conditions 2

Size

Total Lines 7
Code Lines 6

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 0
CRAP Score 6

Importance

Changes 0
Metric Value
cc 2
eloc 6
dl 0
loc 7
c 0
b 0
f 0
ccs 0
cts 1
cp 0
crap 6
rs 10
nop 0
1
package database
2
3
import (
4
	"fmt"
5
	"strings"
6
7
	"github.com/fraenky8/tables-to-go/pkg/settings"
8
9
	// postgres database driver
10
	_ "github.com/lib/pq"
11
)
12
13
const (
14
	PostgresqlColumnTypePrimaryKey    = "PRIMARY KEY"
15
	PostgresqlColumnTypeAutoIncrement = "nextval"
16
	PostgresqlColumnTypeSerial        = "serial"
17
)
18
19
// Postgresql implements the Database interface with help of generalDatabase
20
type Postgresql struct {
21
	*GeneralDatabase
22
23
	defaultUserName string
24
}
25
26
// NewPostgresql creates a new Postgresql database
27
func NewPostgresql(s *settings.Settings) *Postgresql {
28 1
	return &Postgresql{
29
		GeneralDatabase: &GeneralDatabase{
30
			Settings: s,
31
			driver:   dbTypeToDriverMap[s.DbType],
32
		},
33
		defaultUserName: "postgres",
34
	}
35
}
36
37
// Connect connects to the database by the given data source name (dsn) of the concrete database
38
func (pg *Postgresql) Connect() error {
39
	return pg.GeneralDatabase.Connect(pg.DSN())
40
}
41
42
// DSN creates the DSN String to connect to this database
43
func (pg *Postgresql) DSN() string {
44 1
	user := pg.defaultUserName
45 1
	if pg.Settings.User != "" {
46 1
		user = pg.Settings.User
47
	}
48 1
	return fmt.Sprintf("host=%v port=%v user=%v dbname=%v password=%v sslmode=disable",
49
		pg.Settings.Host, pg.Settings.Port, user, pg.Settings.DbName, pg.Settings.Pswd)
50
}
51
52
// Version reports the actual version of the Postgres database.
53
func (pg *Postgresql) Version() (string, error) {
54
	var version string
55
	err := pg.Get(&version, `SELECT version() as version`)
56
	if err != nil {
57
		return "", err
58
	}
59
	return version, nil
60
}
61
62
// GetDriverImportLibrary returns the golang sql driver specific fot the Postgres database
63
func (pg *Postgresql) GetDriverImportLibrary() string {
64
	return "pg \"github.com/lib/pq\""
65
}
66
67
// GetTables gets all tables for a given schema by name
68
func (pg *Postgresql) GetTables() (tables []*Table, err error) {
69
70
	err = pg.Select(&tables, `
71
		SELECT table_name
72
		FROM information_schema.tables
73
		WHERE table_type = 'BASE TABLE'
74
		AND table_schema = $1
75
		ORDER BY table_name
76
	`, pg.Schema)
77
78
	if pg.Verbose {
79
		if err != nil {
80
			fmt.Println("> Error at GetTables()")
81
			fmt.Printf("> schema: %q\r\n", pg.Schema)
82
		}
83
	}
84
85
	return tables, err
86
}
87
88
// PrepareGetColumnsOfTableStmt prepares the statement for retrieving the columns of a specific table for a given database
89
func (pg *Postgresql) PrepareGetColumnsOfTableStmt() (err error) {
90
91
	pg.GetColumnsOfTableStmt, err = pg.Preparex(`
92
		SELECT
93
			ic.ordinal_position,
94
			ic.column_name,
95
			ic.data_type,
96
			ic.column_default,
97
			ic.is_nullable,
98
			ic.character_maximum_length,
99
			ic.numeric_precision,
100
			itc.constraint_name,
101
			itc.constraint_type
102
		FROM information_schema.columns AS ic
103
			LEFT JOIN information_schema.key_column_usage AS ikcu ON ic.table_name = ikcu.table_name
104
			AND ic.table_schema = ikcu.table_schema
105
			AND ic.column_name = ikcu.column_name
106
			LEFT JOIN information_schema.table_constraints AS itc ON ic.table_name = itc.table_name
107
			AND ic.table_schema = itc.table_schema
108
			AND ikcu.constraint_name = itc.constraint_name
109
		WHERE ic.table_name = $1
110
		AND ic.table_schema = $2
111
		ORDER BY ic.ordinal_position
112
	`)
113
114
	return err
115
}
116
117
// GetColumnsOfTable executes the statement for retrieving the columns of a specific table in a given schema
118
func (pg *Postgresql) GetColumnsOfTable(table *Table) (err error) {
119
120
	err = pg.GetColumnsOfTableStmt.Select(&table.Columns, table.Name, pg.Schema)
121
122
	if pg.Verbose {
123
		if err != nil {
124
			fmt.Printf("> Error at GetColumnsOfTable(%v)\r\n", table.Name)
125
			fmt.Printf("> schema: %q\r\n", pg.Schema)
126
		}
127
	}
128
129
	return err
130
}
131
132
// IsPrimaryKey checks if column belongs to primary key
133
func (pg *Postgresql) IsPrimaryKey(column Column) bool {
134
	return strings.Contains(column.ConstraintType.String, PostgresqlColumnTypePrimaryKey)
135
}
136
137
// IsAutoIncrement checks if column is a serial column
138
func (pg *Postgresql) IsAutoIncrement(column Column) bool {
139
	return strings.Contains(column.DefaultValue.String, PostgresqlColumnTypeAutoIncrement)
140
}
141
142
// GetStringDatatypes returns the string data types for the Postgres database
143
func (pg *Postgresql) GetStringDatatypes() []string {
144
	return []string{
145
		"character varying",
146
		"varchar",
147
		"character",
148
		"char",
149
	}
150
}
151
152
// IsString returns true if column is of type string for the Postgres database
153
func (pg *Postgresql) IsString(column Column) bool {
154
	return pg.IsStringInSlice(column.DataType, pg.GetStringDatatypes())
155
}
156
157
// GetTextDatatypes returns the text data types for the Postgres database
158
func (pg *Postgresql) GetTextDatatypes() []string {
159
	return []string{
160
		"text",
161
	}
162
}
163
164
// IsText returns true if column is of type text for the Postgres database
165
func (pg *Postgresql) IsText(column Column) bool {
166
	return pg.IsStringInSlice(column.DataType, pg.GetTextDatatypes())
167
}
168
169
// GetIntegerDatatypes returns the integer data types for the Postgres database
170
// TODO remove these methods
171
func (pg *Postgresql) GetIntegerDatatypes() []string {
172
	return []string{
173
		"smallint", "int2",
174
		"integer", "int4",
175
		"bigint", "int8",
176
		"smallserial", "serial2",
177
		"serial", "serial4",
178
		"bigserial", "serial8",
179
	}
180
}
181
182
// IsInteger returns true if column is of type integer for the Postgres database
183
func (pg *Postgresql) IsInteger(column Column) bool {
184
	return pg.IsStringInSlice(column.DataType, pg.GetIntegerDatatypes())
185
}
186
187
// GetFloatDatatypes returns the float data types for the Postgres database
188
func (pg *Postgresql) GetFloatDatatypes() []string {
189
	return []string{
190
		"float",
191
		"float4",
192
		"float8",
193
		"numeric",
194
		"decimal",
195
		"real",
196
		"double precision",
197
	}
198
}
199
200
// IsFloat returns true if column is of type float for the Postgres database
201
func (pg *Postgresql) IsFloat(column Column) bool {
202
	return pg.IsStringInSlice(column.DataType, pg.GetFloatDatatypes())
203
}
204
205
// GetTemporalDatatypes returns the temporal data types for the Postgres database
206
func (pg *Postgresql) GetTemporalDatatypes() []string {
207
	return []string{
208
		"time",
209
		"timestamp",
210
		"time with time zone",
211
		"timestamp with time zone",
212
		"timestamptz",
213
		"time without time zone",
214
		"timestamp without time zone",
215
		"date",
216
	}
217
}
218
219
// IsTemporal returns true if column is of type temporal for the Postgres database
220
func (pg *Postgresql) IsTemporal(column Column) bool {
221
	return pg.IsStringInSlice(column.DataType, pg.GetTemporalDatatypes())
222
}
223
224
// GetTemporalDriverDataType returns the time data type specific for the Postgres database
225
func (pg *Postgresql) GetTemporalDriverDataType() string {
226
	return "pg.NullTime"
227
}
228