Passed
Push — master ( c030a9...68a85e )
by P.R.
01:11
created

MsSqlConstantWorker._fill_constants()   B

Complexity

Conditions 6

Size

Total Lines 12
Code Lines 8

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 8
dl 0
loc 12
rs 8.6666
c 0
b 0
f 0
cc 6
nop 1
1
"""
2
PyStratum
3
"""
4
import os
5
import re
6
from configparser import ConfigParser
7
from typing import Any, Dict
8
9
from pystratum_backend.StratumStyle import StratumStyle
10
from pystratum_common.backend.CommonConstantWorker import CommonConstantWorker
11
from pystratum_common.Util import Util
12
13
from pystratum_mssql.backend.MsSqlWorker import MsSqlWorker
14
15
16
class MsSqlConstantWorker(MsSqlWorker, CommonConstantWorker):
17
    """
18
    Class for creating constants based on column widths, and auto increment columns and labels for SQL Server
19
    databases.
20
    """
21
22
    # ------------------------------------------------------------------------------------------------------------------
23
    def __init__(self, io: StratumStyle, config: ConfigParser):
24
        """
25
        Object constructor.
26
27
        :param io: The output decorator.
28
        """
29
        MsSqlWorker.__init__(self, io, config)
30
        CommonConstantWorker.__init__(self, io, config)
31
32
        self._columns: Dict[str, Any] = {}
33
        """
34
        All columns in the database.
35
        """
36
37
    # ------------------------------------------------------------------------------------------------------------------
38
    def _get_old_columns(self) -> None:
39
        """
40
        Reads from file constants_filename the previous table and column names, the width of the column,
41
        and the constant name (if assigned) and stores this data in old_columns.
42
        """
43
        if os.path.exists(self._constants_filename):
44
            with open(self._constants_filename, 'r') as f:
45
                line_number = 0
46
                for line in f:
47
                    line_number += 1
48
                    if line != "\n":
49
                        p = re.compile(r'\s*(?:([a-zA-Z0-9_]+)\.)?([a-zA-Z0-9_]+)\.'
50
                                       r'([a-zA-Z0-9_]+)\s+(\d+)\s*(\*|[a-zA-Z0-9_]+)?\s*')
51
                        matches = p.findall(line)
52
53
                        if matches:
54
                            matches = matches[0]
55
                            schema_name = str(matches[0])
56
                            table_name = str(matches[1])
57
                            column_name = str(matches[2])
58
                            length = str(matches[3])
59
                            constant_name = str(matches[4])
60
61
                            if constant_name:
62
                                column_info = {'schema_name':   schema_name,
63
                                               'table_name':    table_name,
64
                                               'column_name':   column_name,
65
                                               'length':        length,
66
                                               'constant_name': constant_name}
67
                            else:
68
                                column_info = {'schema_name': schema_name,
69
                                               'table_name':  table_name,
70
                                               'column_name': column_name,
71
                                               'length':      length}
72
73
                            if schema_name in self._old_columns:
74
                                if table_name in self._old_columns[schema_name]:
75
                                    if column_name in self._old_columns[schema_name][table_name]:
76
                                        pass
77
                                    else:
78
                                        self._old_columns[schema_name][table_name][column_name] = column_info
79
                                else:
80
                                    self._old_columns[schema_name][table_name] = {column_name: column_info}
81
                            else:
82
                                self._old_columns[schema_name] = {table_name: {column_name: column_info}}
83
84
    # ------------------------------------------------------------------------------------------------------------------
85
    def _get_columns(self) -> None:
86
        """
87
        Retrieves metadata all columns in the database.
88
        """
89
        rows = self._dl.get_all_table_columns()
90
        for row in rows:
91
            row['length'] = MsSqlConstantWorker.derive_field_length(row)
92
93
            if row['schema_name'] in self._columns:
94
                if row['table_name'] in self._columns[row['schema_name']]:
95
                    if row['column_name'] in self._columns[row['schema_name']][row['table_name']]:
96
                        pass
97
                    else:
98
                        self._columns[row['schema_name']][row['table_name']][row['column_name']] = row
99
                else:
100
                    self._columns[row['schema_name']][row['table_name']] = {row['column_name']: row}
101
            else:
102
                self._columns[row['schema_name']] = {row['table_name']: {row['column_name']: row}}
103
104
    # ------------------------------------------------------------------------------------------------------------------
105
    def _enhance_columns(self) -> None:
106
        """
107
        Enhances old_columns as follows:
108
        If the constant name is *, is is replaced with the column name prefixed by prefix in uppercase.
109
        Otherwise the constant name is set to uppercase.
110
        """
111
        if self._old_columns:
112
            for schema_name, schema in sorted(self._old_columns.items()):
113
                for table_name, table in sorted(schema.items()):
114
                    for column_name, column in sorted(table.items()):
115
                        if 'constant_name' in column:
116
                            if column['constant_name'].strip() == '*':
117
                                constant_name = str(self._prefix + column['column_name']).upper()
118
                                self._old_columns[schema_name][table_name][column_name]['constant_name'] = constant_name
119
                            else:
120
                                constant_name = str(
121
                                        self._old_columns[schema_name][table_name][column_name][
122
                                            'constant_name']).upper()
123
                                self._old_columns[schema_name][table_name][column_name]['constant_name'] = constant_name
124
125
    # ------------------------------------------------------------------------------------------------------------------
126
    def _merge_columns(self) -> None:
127
        """
128
        Preserves relevant data in old_columns into columns.
129
        """
130
        if self._old_columns:
131
            for schema_name, schema in sorted(self._old_columns.items()):
132
                for table_name, table in sorted(schema.items()):
133
                    for column_name, column in sorted(table.items()):
134
                        if 'constant_name' in column:
135
                            try:
136
                                self._columns[schema_name][table_name][column_name]['constant_name'] = \
137
                                    column['constant_name']
138
                            except KeyError:
139
                                # Either the column or table is not present anymore.
140
                                self._io.warning('Dropping constant {0} because column is not present anymore'.
141
                                                 format(column['constant_name']))
142
143
    # ------------------------------------------------------------------------------------------------------------------
144
    def _write_columns(self) -> None:
145
        """
146
        Writes table and column names, the width of the column, and the constant name (if assigned) to
147
        constants_filename.
148
        """
149
        content = ''
150
151
        for schema_name, schema in sorted(self._columns.items()):
152
            for table_name, table in sorted(schema.items()):
153
                width1 = 0
154
                width2 = 0
155
156
                key_map = {}
157
                for column_name, column in table.items():
158
                    key_map[column['column_id']] = column_name
159
                    width1 = max(len(str(column['column_name'])), width1)
160
                    width2 = max(len(str(column['length'])), width2)
161
162
                for col_id, column_name in sorted(key_map.items()):
163
                    if table[column_name]['length'] is not None:
164
                        if 'constant_name' in table[column_name]:
165
                            line_format = "%s.%s.%-{0:d}s %{1:d}d %s\n".format(int(width1), int(width2))
166
                            content += line_format % (schema_name,
167
                                                      table[column_name]['table_name'],
168
                                                      table[column_name]['column_name'],
169
                                                      table[column_name]['length'],
170
                                                      table[column_name]['constant_name'])
171
                        else:
172
                            line_format = "%s.%s.%-{0:d}s %{1:d}d\n".format(int(width1), int(width2))
173
                            content += line_format % (schema_name,
174
                                                      table[column_name]['table_name'],
175
                                                      table[column_name]['column_name'],
176
                                                      table[column_name]['length'])
177
178
                content += "\n"""
179
180
        # Save the columns, width, and constants to the filesystem.
181
        Util.write_two_phases(self._constants_filename, content, self._io)
182
183
    # ------------------------------------------------------------------------------------------------------------------
184
    def _get_labels(self) -> None:
185
        """
186
        Gets all primary key labels from the database.
187
        """
188
        tables = self._dl.get_label_tables(self._label_regex)
189
190
        for table in tables:
191
            rows = self._dl.get_labels_from_table(table['database'],
192
                                                  table['schema_name'],
193
                                                  table['table_name'],
194
                                                  table['id'],
195
                                                  table['label'])
196
            for row in rows:
197
                if row['label'] not in self._labels:
198
                    self._labels[row['label']] = row['id']
199
                else:
200
                    # todo improve exception.
201
                    Exception("Duplicate label '%s'")
202
203
    # ------------------------------------------------------------------------------------------------------------------
204
    def _fill_constants(self) -> None:
205
        """
206
        Merges columns and labels (i.e. all known constants) into constants.
207
        """
208
        for schema_name, schema in sorted(self._columns.items()):
209
            for table_name, table in sorted(schema.items()):
210
                for column_name, column in sorted(table.items()):
211
                    if 'constant_name' in column:
212
                        self._constants[column['constant_name']] = column['length']
213
214
        for label, label_id in sorted(self._labels.items()):
215
            self._constants[label] = label_id
216
217
    # ------------------------------------------------------------------------------------------------------------------
218
    @staticmethod
219
    def derive_field_length(column: Dict[str, Any]) -> int:
220
        """
221
        Returns the width of a field based based on the data type of column.
222
223
        :param dict column: Info about the column.
224
225
        :rtype: int
226
        """
227
        data_type = column['data_type']
228
229
        if data_type in ['bigint',
230
                         'int',
231
                         'smallint',
232
                         'tinyint',
233
                         'money',
234
                         'smallmoney',
235
                         'decimal',
236
                         'numeric',
237
                         'float',
238
                         'real',
239
                         'date',
240
                         'datetime',
241
                         'datetime2',
242
                         'datetimeoffset',
243
                         'smalldatetime',
244
                         'time']:
245
            return column['precision']
246
247
        if data_type in ['bit',
248
                         'char',
249
                         'binary',
250
                         'varbinary',
251
                         'sysname']:
252
            return column['max_length']
253
254
        if data_type == 'varchar':
255
            if column['max_length'] == -1:
256
                # This is a varchar(max) data type.
257
                return 2147483647
258
259
            return column['max_length']
260
261
        if data_type == 'nvarchar':
262
            if column['max_length'] == -1:
263
                # This is a nvarchar(max) data type.
264
                return 1073741823
265
266
            return column['max_length'] / 2
267
268
        if data_type in ['text', 'image', 'xml']:
269
            return 2147483647
270
271
        if data_type == 'nchar':
272
            return column['max_length'] / 2
273
274
        if data_type == 'ntext':
275
            return 1073741823
276
277
        if data_type in ['geography', 'geometry']:
278
            if column['max_length'] == -1:
279
                return 2147483647
280
281
        raise Exception("Unexpected data type '{0}'".format(data_type))
282
283
# ----------------------------------------------------------------------------------------------------------------------
284