Passed
Push — add-integration-tests ( a9cb16...a7d4dd )
by Frank
02:17
created

database.*Postgresql.IsAutoIncrement   A

Complexity

Conditions 1

Size

Total Lines 2
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 0
CRAP Score 2

Importance

Changes 0
Metric Value
cc 1
eloc 2
nop 1
dl 0
loc 2
ccs 0
cts 1
cp 0
crap 2
rs 10
c 0
b 0
f 0
1
package database
2
3
import (
4
	"fmt"
5
	"strings"
6
7
	"github.com/fraenky8/tables-to-go/pkg/settings"
8
9
	// postgres database driver
10
	_ "github.com/lib/pq"
11
)
12
13
const (
14
	PostgresqlColumnTypePrimaryKey    = "PRIMARY KEY"
15
	PostgresqlColumnTypeAutoIncrement = "nextval"
16
	PostgresqlColumnTypeSerial        = "serial"
17
)
18
19
// Postgresql implements the Database interface with help of generalDatabase
20
type Postgresql struct {
21
	*GeneralDatabase
22
23
	defaultUserName string
24
}
25
26
// NewPostgresql creates a new Postgresql database
27
func NewPostgresql(s *settings.Settings) *Postgresql {
28 1
	return &Postgresql{
29
		GeneralDatabase: &GeneralDatabase{
30
			Settings: s,
31
			driver:   dbTypeToDriverMap[s.DbType],
32
		},
33
		defaultUserName: "postgres",
34
	}
35
}
36
37
// Connect connects to the database by the given data source name (dsn) of the concrete database
38
func (pg *Postgresql) Connect() error {
39
	return pg.GeneralDatabase.Connect(pg.DSN())
40
}
41
42
// DSN creates the DSN String to connect to this database
43
func (pg *Postgresql) DSN() string {
44 1
	user := pg.defaultUserName
45 1
	if pg.Settings.User != "" {
46 1
		user = pg.Settings.User
47
	}
48 1
	return fmt.Sprintf("host=%v port=%v user=%v dbname=%v password=%v sslmode=disable",
49
		pg.Settings.Host, pg.Settings.Port, user, pg.Settings.DbName, pg.Settings.Pswd)
50
}
51
52
// GetDriverImportLibrary returns the golang sql driver specific fot the Postgres database
53
func (pg *Postgresql) GetDriverImportLibrary() string {
54
	return "pg \"github.com/lib/pq\""
55
}
56
57
// GetTables gets all tables for a given schema by name
58
func (pg *Postgresql) GetTables() (tables []*Table, err error) {
59
60
	err = pg.Select(&tables, `
61
		SELECT table_name
62
		FROM information_schema.tables
63
		WHERE table_type = 'BASE TABLE'
64
		AND table_schema = $1
65
		ORDER BY table_name
66
	`, pg.Schema)
67
68
	if pg.Verbose {
69
		if err != nil {
70
			fmt.Println("> Error at GetTables()")
71
			fmt.Printf("> schema: %q\r\n", pg.Schema)
72
		}
73
	}
74
75
	return tables, err
76
}
77
78
// PrepareGetColumnsOfTableStmt prepares the statement for retrieving the columns of a specific table for a given database
79
func (pg *Postgresql) PrepareGetColumnsOfTableStmt() (err error) {
80
81
	pg.GetColumnsOfTableStmt, err = pg.Preparex(`
82
		SELECT
83
			ic.ordinal_position,
84
			ic.column_name,
85
			ic.data_type,
86
			ic.column_default,
87
			ic.is_nullable,
88
			ic.character_maximum_length,
89
			ic.numeric_precision,
90
			itc.constraint_name,
91
			itc.constraint_type
92
		FROM information_schema.columns AS ic
93
			LEFT JOIN information_schema.key_column_usage AS ikcu ON ic.table_name = ikcu.table_name
94
			AND ic.table_schema = ikcu.table_schema
95
			AND ic.column_name = ikcu.column_name
96
			LEFT JOIN information_schema.table_constraints AS itc ON ic.table_name = itc.table_name
97
			AND ic.table_schema = itc.table_schema
98
			AND ikcu.constraint_name = itc.constraint_name
99
		WHERE ic.table_name = $1
100
		AND ic.table_schema = $2
101
		ORDER BY ic.ordinal_position
102
	`)
103
104
	return err
105
}
106
107
// GetColumnsOfTable executes the statement for retrieving the columns of a specific table in a given schema
108
func (pg *Postgresql) GetColumnsOfTable(table *Table) (err error) {
109
110
	err = pg.GetColumnsOfTableStmt.Select(&table.Columns, table.Name, pg.Schema)
111
112
	if pg.Verbose {
113
		if err != nil {
114
			fmt.Printf("> Error at GetColumnsOfTable(%v)\r\n", table.Name)
115
			fmt.Printf("> schema: %q\r\n", pg.Schema)
116
		}
117
	}
118
119
	return err
120
}
121
122
// IsPrimaryKey checks if column belongs to primary key
123
func (pg *Postgresql) IsPrimaryKey(column Column) bool {
124
	return strings.Contains(column.ConstraintType.String, PostgresqlColumnTypePrimaryKey)
125
}
126
127
// IsAutoIncrement checks if column is a serial column
128
func (pg *Postgresql) IsAutoIncrement(column Column) bool {
129
	return strings.Contains(column.DefaultValue.String, PostgresqlColumnTypeAutoIncrement)
130
}
131
132
// GetStringDatatypes returns the string datatypes for the Postgres database
133
func (pg *Postgresql) GetStringDatatypes() []string {
134
	return []string{
135
		"character varying",
136
		"varchar",
137
		"character",
138
		"char",
139
	}
140
}
141
142
// IsString returns true if column is of type string for the Postgres database
143
func (pg *Postgresql) IsString(column Column) bool {
144
	return pg.IsStringInSlice(column.DataType, pg.GetStringDatatypes())
145
}
146
147
// GetTextDatatypes returns the text datatypes for the Postgres database
148
func (pg *Postgresql) GetTextDatatypes() []string {
149
	return []string{
150
		"text",
151
	}
152
}
153
154
// IsText returns true if column is of type text for the Postgres database
155
func (pg *Postgresql) IsText(column Column) bool {
156
	return pg.IsStringInSlice(column.DataType, pg.GetTextDatatypes())
157
}
158
159
// GetIntegerDatatypes returns the integer datatypes for the Postgres database
160
// TODO remove these methods
161
func (pg *Postgresql) GetIntegerDatatypes() []string {
162
	return []string{
163
		"smallint", "int2",
164
		"integer", "int4",
165
		"bigint", "int8",
166
		"smallserial", "serial2",
167
		"serial", "serial4",
168
		"bigserial", "serial8",
169
	}
170
}
171
172
// IsInteger returns true if column is of type integer for the Postgres database
173
func (pg *Postgresql) IsInteger(column Column) bool {
174
	return pg.IsStringInSlice(column.DataType, pg.GetIntegerDatatypes())
175
}
176
177
// GetFloatDatatypes returns the float datatypes for the Postgres database
178
func (pg *Postgresql) GetFloatDatatypes() []string {
179
	return []string{
180
		"float",
181
		"float4",
182
		"float8",
183
		"numeric",
184
		"decimal",
185
		"real",
186
		"double precision",
187
	}
188
}
189
190
// IsFloat returns true if column is of type float for the Postgres database
191
func (pg *Postgresql) IsFloat(column Column) bool {
192
	return pg.IsStringInSlice(column.DataType, pg.GetFloatDatatypes())
193
}
194
195
// GetTemporalDatatypes returns the temporal datatypes for the Postgres database
196
func (pg *Postgresql) GetTemporalDatatypes() []string {
197
	return []string{
198
		"time",
199
		"timestamp",
200
		"time with time zone",
201
		"timestamp with time zone",
202
		"timestamptz",
203
		"time without time zone",
204
		"timestamp without time zone",
205
		"date",
206
	}
207
}
208
209
// IsTemporal returns true if column is of type temporal for the Postgres database
210
func (pg *Postgresql) IsTemporal(column Column) bool {
211
	return pg.IsStringInSlice(column.DataType, pg.GetTemporalDatatypes())
212
}
213
214
// GetTemporalDriverDataType returns the time data type specific for the Postgres database
215
func (pg *Postgresql) GetTemporalDriverDataType() string {
216
	return "pg.NullTime"
217
}
218