database.New   A
last analyzed

Complexity

Conditions 5

Size

Total Lines 16
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 0
CRAP Score 30

Importance

Changes 0
Metric Value
cc 5
eloc 12
nop 1
dl 0
loc 16
ccs 0
cts 7
cp 0
crap 30
rs 9.3333
c 0
b 0
f 0
1
package database
2
3
import (
4
	"database/sql"
5
	"fmt"
6
7
	"github.com/jmoiron/sqlx"
8
9
	"github.com/fraenky8/tables-to-go/pkg/settings"
10
)
11
12
var (
13
	// dbTypeToDriverMap maps the database type to the driver names
14
	dbTypeToDriverMap = map[settings.DbType]string{
15
		settings.DbTypePostgresql: "postgres",
16
		settings.DbTypeMySQL:      "mysql",
17
		settings.DbTypeSQLite:     "sqlite3",
18
	}
19
)
20
21
// Database interface for the concrete databases
22
type Database interface {
23
	DSN() string
24
	Connect() (err error)
25
	Close() (err error)
26
	GetDriverImportLibrary() string
27
28
	GetTables() (tables []*Table, err error)
29
	PrepareGetColumnsOfTableStmt() (err error)
30
	GetColumnsOfTable(table *Table) (err error)
31
32
	IsPrimaryKey(column Column) bool
33
	IsAutoIncrement(column Column) bool
34
	IsNullable(column Column) bool
35
36
	GetStringDatatypes() []string
37
	IsString(column Column) bool
38
39
	GetTextDatatypes() []string
40
	IsText(column Column) bool
41
42
	GetIntegerDatatypes() []string
43
	IsInteger(column Column) bool
44
45
	GetFloatDatatypes() []string
46
	IsFloat(column Column) bool
47
48
	GetTemporalDatatypes() []string
49
	IsTemporal(column Column) bool
50
	GetTemporalDriverDataType() string
51
52
	// TODO pg: bitstrings, enum, range, other special types
53
	// TODO mysql: bit, enums, set
54
}
55
56
// Table has a name and a set (slice) of columns
57
type Table struct {
58
	Name    string `db:"table_name"`
59
	Columns []Column
60
}
61
62
// Column stores information about a column
63
type Column struct {
64
	OrdinalPosition        int            `db:"ordinal_position"`
65
	Name                   string         `db:"column_name"`
66
	DataType               string         `db:"data_type"`
67
	DefaultValue           sql.NullString `db:"column_default"`
68
	IsNullable             string         `db:"is_nullable"`
69
	CharacterMaximumLength sql.NullInt64  `db:"character_maximum_length"`
70
	NumericPrecision       sql.NullInt64  `db:"numeric_precision"`
71
	ColumnKey              string         `db:"column_key"`      // mysql specific
72
	Extra                  string         `db:"extra"`           // mysql specific
73
	ConstraintName         sql.NullString `db:"constraint_name"` // pg specific
74
	ConstraintType         sql.NullString `db:"constraint_type"` // pg specific
75
}
76
77
// GeneralDatabase represents a base "class" database - for all other concrete databases
78
// it implements partly the Database interface
79
type GeneralDatabase struct {
80
	GetColumnsOfTableStmt *sqlx.Stmt
81
	*sqlx.DB
82
	*settings.Settings
83
	driver string
84
}
85
86
// New creates a new Database based on the given type in the settings.
87
func New(s *settings.Settings) Database {
88
89
	var db Database
90
91
	switch s.DbType {
92
	case settings.DbTypeSQLite:
93
		db = NewSQLite(s)
94
	case settings.DbTypeMySQL:
95
		db = NewMySQL(s)
96
	case settings.DbTypePostgresql:
97
		fallthrough
98
	default:
99
		db = NewPostgresql(s)
100
	}
101
102
	return db
103
}
104
105
// Connect establishes a connection to the database with the given DSN.
106
// It pings the database to ensure it is reachable.
107
func (gdb *GeneralDatabase) Connect(dsn string) (err error) {
108
	gdb.DB, err = sqlx.Connect(gdb.driver, dsn)
109
	if err != nil {
110
		usingPswd := "no"
111
		if gdb.Settings.Pswd != "" {
112
			usingPswd = "yes"
113
		}
114
		return fmt.Errorf(
115
			"could not connect to database (type=%q, user=%q, database=%q, host='%v:%v', using password: %v): %v",
116
			gdb.DbType, gdb.User, gdb.DbName, gdb.Host, gdb.Port, usingPswd, err,
117
		)
118
	}
119
120
	return gdb.Ping()
121
}
122
123
// Close closes the database connection
124
func (gdb *GeneralDatabase) Close() error {
125
	return gdb.DB.Close()
126
}
127
128
// IsNullable returns true if column is a nullable one
129
func (gdb *GeneralDatabase) IsNullable(column Column) bool {
130
	return column.IsNullable == "YES"
131
}
132
133
// IsStringInSlice checks if needle (string) is in haystack ([]string)
134
func (gdb *GeneralDatabase) IsStringInSlice(needle string, haystack []string) bool {
135
	for _, s := range haystack {
136
		if s == needle {
137
			return true
138
		}
139
	}
140
	return false
141
}
142