Passed
Push — master ( 47c3cf...c030a9 )
by P.R.
01:09
created

MsSqlRoutineLoaderWorker.__init__()   A

Complexity

Conditions 1

Size

Total Lines 8
Code Lines 3

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 3
dl 0
loc 8
rs 10
c 0
b 0
f 0
cc 1
nop 3
1
from configparser import ConfigParser
2
from typing import Any, Dict, Optional
3
4
from pystratum_backend.StratumStyle import StratumStyle
5
from pystratum_common.backend.CommonRoutineLoaderWorker import CommonRoutineLoaderWorker
6
7
from pystratum_mssql.backend.MsSqlWorker import MsSqlWorker
8
from pystratum_mssql.helper.MsSqlRoutineLoaderHelper import MsSqlRoutineLoaderHelper
9
10
11
class MsSqlRoutineLoaderWorker(MsSqlWorker, CommonRoutineLoaderWorker):
12
    """
13
    Class for loading stored routines into a SQL Server instance from pseudo SQL files.
14
    """
15
16
    # ------------------------------------------------------------------------------------------------------------------
17
    def __init__(self, io: StratumStyle, config: ConfigParser):
18
        """
19
        Object constructor.
20
21
        :param PyStratumStyle io: The output decorator.
22
        """
23
        MsSqlWorker.__init__(self, io, config)
24
        CommonRoutineLoaderWorker.__init__(self, io, config)
25
26
    # ------------------------------------------------------------------------------------------------------------------
27
    def _get_column_type(self) -> None:
28
        """
29
        Selects schema, table, column names and the column types from the SQL Server instance and saves them as replace
30
        pairs.
31
        """
32
        rows = self._dl.get_all_table_columns()
33
        for row in rows:
34
            key = '@{0}.{1}.{2}%type@'.format(row['schema_name'], row['table_name'], row['column_name'])
35
            key = key.lower()
36
37
            value = self._derive_data_type(row)
38
39
            self._replace_pairs[key] = value
40
41
        self._io.text('Selected {0} column types for substitution'.format(len(rows)))
42
43
    # ------------------------------------------------------------------------------------------------------------------
44
    def create_routine_loader_helper(self,
45
                                     routine_name: str,
46
                                     pystratum_old_metadata: Optional[Dict],
47
                                     rdbms_old_metadata: Optional[Dict]) -> MsSqlRoutineLoaderHelper:
48
        """
49
        Creates a Routine Loader Helper object.
50
51
        :param str routine_name: The name of the routine.
52
        :param dict pystratum_old_metadata: The old metadata of the stored routine from PyStratum.
53
        :param dict rdbms_old_metadata:  The old metadata of the stored routine from MS SQL Server.
54
55
        :rtype: MsSqlRoutineLoaderHelper
56
        """
57
        return MsSqlRoutineLoaderHelper(self._io,
58
                                        self._dl,
59
                                        self._source_file_names[routine_name],
60
                                        self._source_file_encoding,
61
                                        pystratum_old_metadata,
62
                                        self._replace_pairs,
63
                                        rdbms_old_metadata)
64
65
    # ------------------------------------------------------------------------------------------------------------------
66
    def _get_old_stored_routine_info(self) -> None:
67
        """
68
        Retrieves information about all stored routines.
69
        """
70
        rows = self._dl.get_routines()
71
        self._rdbms_old_metadata = {}
72
        for row in rows:
73
            self._rdbms_old_metadata[row['schema_name'] + '.' + row['routine_name']] = row
74
75
    # ------------------------------------------------------------------------------------------------------------------
76
    def _drop_obsolete_routines(self) -> None:
77
        """
78
        Drops obsolete stored routines (i.e. stored routines that exits but for which we don't have a source file).
79
        """
80
        for routine_name, values in self._rdbms_old_metadata.items():
81
            if routine_name not in self._source_file_names:
82
                if values['routine_type'].strip() == 'P':
83
                    routine_type = 'procedure'
84
                elif values['routine_type'].strip() in ('FN', 'TF'):
85
                    routine_type = 'function'
86
                else:
87
                    raise Exception("Unknown routine type '{0}'".format(values['routine_type']))
88
89
                self._io.writeln("Dropping {0} <dbo>{1}.{2}</dbo>".format(routine_type,
90
                                                                          values['schema_name'],
91
                                                                          values['routine_name']))
92
                self._dl.drop_stored_routine(routine_type, values['schema_name'], values['routine_name'])
93
94
    # ------------------------------------------------------------------------------------------------------------------
95
    @staticmethod
96
    def _derive_data_type(column: Dict[str, Any]) -> str:
97
        """
98
        Returns the proper SQL declaration of a data type of a column.
99
100
        :param dict column: The column of which the field is based.
101
102
        :rtype: str
103
        """
104
        data_type = column['data_type']
105
106
        if data_type == 'bigint':
107
            return data_type
108
109
        if data_type == 'int':
110
            return data_type
111
112
        if data_type == 'smallint':
113
            return data_type
114
115
        if data_type == 'tinyint':
116
            return data_type
117
118
        if data_type == 'bit':
119
            return data_type
120
121
        if data_type == 'money':
122
            return data_type
123
124
        if data_type == 'smallmoney':
125
            return data_type
126
127
        if data_type == 'decimal':
128
            return 'decimal({0:d},{1:d})'.format(column['precision'], column['scale'])
129
130
        if data_type == 'numeric':
131
            return 'decimal({0:d},{1:d})'.format(column['precision'], column['scale'])
132
133
        if data_type == 'float':
134
            return data_type
135
136
        if data_type == 'real':
137
            return data_type
138
139
        if data_type == 'date':
140
            return data_type
141
142
        if data_type == 'datetime':
143
            return data_type
144
145
        if data_type == 'datetime2':
146
            return data_type
147
148
        if data_type == 'datetimeoffset':
149
            return data_type
150
151
        if data_type == 'smalldatetime':
152
            return data_type
153
154
        if data_type == 'time':
155
            return data_type
156
157
        if data_type == 'char':
158
            return 'char({0:d})'.format(column['max_length'])
159
160
        if data_type == 'varchar':
161
            if column['max_length'] == -1:
162
                return 'varchar(max)'
163
164
            return 'varchar({0:d})'.format(column['max_length'])
165
166
        if data_type == 'text':
167
            return data_type
168
169
        if data_type == 'nchar':
170
            return 'nchar({0:d})'.format(int(column['max_length'] / 2))
171
172
        if data_type == 'nvarchar':
173
            if column['max_length'] == -1:
174
                return 'nvarchar(max)'
175
176
            return 'nvarchar({0:d})'.format(int(column['max_length'] / 2))
177
178
        if data_type == 'ntext':
179
            return data_type
180
181
        if data_type == 'binary':
182
            return data_type
183
184
        if data_type == 'varbinary':
185
            return 'varbinary({0:d})'.format(column['max_length'])
186
187
        if data_type == 'image':
188
            return data_type
189
190
        if data_type == 'xml':
191
            return data_type
192
193
        if data_type == 'geography':
194
            return data_type
195
196
        if data_type == 'geometry':
197
            return data_type
198
199
        if data_type == 'sysname':
200
            return data_type
201
202
        raise Exception("Unexpected data type '{0}'".format(data_type))
203
204
# ----------------------------------------------------------------------------------------------------------------------
205