MySqlConstantWorker._get_old_columns()   C
last analyzed

Complexity

Conditions 10

Size

Total Lines 47
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 99.3056

Importance

Changes 0
Metric Value
eloc 33
dl 0
loc 47
ccs 1
cts 27
cp 0.037
rs 5.9999
c 0
b 0
f 0
cc 10
nop 1
crap 99.3056

How to fix   Complexity   

Complexity

Complex classes like pystratum_mysql.backend.MySqlConstantWorker.MySqlConstantWorker._get_old_columns() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1 1
import os
2 1
import re
3 1
from configparser import ConfigParser
4 1
from typing import Any, Dict, Optional
5
6 1
from pystratum_backend.StratumIO import StratumIO
7 1
from pystratum_common.backend.CommonConstantWorker import CommonConstantWorker
8 1
from pystratum_common.Util import Util
9
10 1
from pystratum_mysql.backend.MySqlWorker import MySqlWorker
11
12
13 1
class MySqlConstantWorker(MySqlWorker, CommonConstantWorker):
14
    """
15
    Class for creating constants based on column widths, and auto increment columns and labels for MySQL databases.
16
    """
17
18
    # ------------------------------------------------------------------------------------------------------------------
19 1
    def __init__(self, io: StratumIO, config: ConfigParser):
20
        """
21
        Object constructor.
22
23
        :param io: The output decorator.
24
        """
25 1
        MySqlWorker.__init__(self, io, config)
26 1
        CommonConstantWorker.__init__(self, io, config)
27
28 1
        self._columns: Dict[str, Any] = {}
29 1
        """
30
        All columns in the MySQL schema.
31
        """
32
33
    # ------------------------------------------------------------------------------------------------------------------
34 1
    def _get_old_columns(self) -> None:
35
        """
36
        Reads from file constants_filename the previous table and column names, the width of the column,
37
        and the constant name (if assigned) and stores this data in old_columns.
38
        """
39
        if os.path.exists(self._constants_filename):
40
            with open(self._constants_filename, 'r') as file:
41
                line_number = 0
42
                for line in file:
43
                    line_number += 1
44
                    if line != "\n":
45
                        prog = re.compile(r'\s*(?:([a-zA-Z0-9_]+)\.)?([a-zA-Z0-9_]+)\.'
46
                                          r'([a-zA-Z0-9_]+)\s+(\d+)\s*(\*|[a-zA-Z0-9_]+)?\s*')
47
                        matches = prog.findall(line)
48
49
                        if matches:
50
                            matches = matches[0]
51
                            schema_name = str(matches[0])
52
                            table_name = str(matches[1])
53
                            column_name = str(matches[2])
54
                            length = str(matches[3])
55
                            constant_name = str(matches[4])
56
57
                            if schema_name:
58
                                table_name = schema_name + '.' + table_name
59
60
                            if constant_name:
61
                                column_info = {'table_name':    table_name,
62
                                               'column_name':   column_name,
63
                                               'length':        length,
64
                                               'constant_name': constant_name}
65
                            else:
66
                                column_info = {'table_name':  table_name,
67
                                               'column_name': column_name,
68
                                               'length':      length}
69
70
                            if table_name in self._old_columns:
71
                                if column_name in self._old_columns[table_name]:
72
                                    pass
73
                                else:
74
                                    self._old_columns[table_name][column_name] = column_info
75
                            else:
76
                                self._old_columns[table_name] = {column_name: column_info}
77
78
                        else:
79
                            raise RuntimeError("Illegal format at line {0} in file {1}".
80
                                               format(line_number, self._constants_filename))
81
82
    # ------------------------------------------------------------------------------------------------------------------
83 1
    def _get_columns(self) -> None:
84
        """
85
        Retrieves metadata about all table columns in the MySQL schema.
86
        """
87
        rows = self._dl.get_all_table_columns()
88
        for row in rows:
89
            # Enhance row with the actual length of the column.
90
            row['length'] = self.derive_field_length(row)
91
92
            if row['table_name'] in self._columns:
93
                if row['column_name'] in self._columns[row['table_name']]:
94
                    pass
95
                else:
96
                    self._columns[row['table_name']][row['column_name']] = row
97
            else:
98
                self._columns[row['table_name']] = {row['column_name']: row}
99
100
    # ------------------------------------------------------------------------------------------------------------------
101 1
    def _enhance_columns(self) -> None:
102
        """
103
        Enhances old_columns as follows:
104
        If the constant name is *, it is replaced with the column name prefixed by prefix in uppercase.
105
        Otherwise, the constant name is set to uppercase.
106
        """
107
        if self._old_columns:
108
            for table_name, table in sorted(self._old_columns.items()):
109
                for column_name, column in sorted(table.items()):
110
                    table_name = column['table_name']
111
                    column_name = column['column_name']
112
113
                    if 'constant_name' in column:
114
                        if column['constant_name'].strip() == '*':
115
                            constant_name = str(self._prefix + column['column_name']).upper()
116
                            self._old_columns[table_name][column_name]['constant_name'] = constant_name
117
                        else:
118
                            constant_name = str(self._old_columns[table_name][column_name]['constant_name']).upper()
119
                            self._old_columns[table_name][column_name]['constant_name'] = constant_name
120
121
    # ------------------------------------------------------------------------------------------------------------------
122 1
    def _merge_columns(self) -> None:
123
        """
124
        Preserves relevant data in old_columns into columns.
125
        """
126
        if self._old_columns:
127
            for table_name, table in sorted(self._old_columns.items()):
128
                for column_name, column in sorted(table.items()):
129
                    if 'constant_name' in column:
130
                        try:
131
                            self._columns[table_name][column_name]['constant_name'] = column['constant_name']
132
                        except KeyError:
133
                            # Either the column or table is not present anymore.
134
                            self._io.warning('Dropping constant {0} because column is not present anymore'.
135
                                             format(column['constant_name']))
136
137
    # ------------------------------------------------------------------------------------------------------------------
138 1
    def _write_columns(self) -> None:
139
        """
140
        Writes table and column names, the width of the column, and the constant name (if assigned) to
141
        constants_filename.
142
        """
143
        content = ''
144
        for _, table in sorted(self._columns.items()):
145
            width1 = 0
146
            width2 = 0
147
148
            key_map = {}
149
            for column_name, column in table.items():
150
                key_map[column['ordinal_position']] = column_name
151
                width1 = max(len(str(column['column_name'])), width1)
152
                width2 = max(len(str(column['length'])), width2)
153
154
            for _, column_name in sorted(key_map.items()):
155
                if table[column_name]['length'] is not None:
156
                    if 'constant_name' in table[column_name]:
157
                        line_format = "%s.%-{0:d}s %{1:d}d %s\n".format(int(width1), int(width2))
158
                        content += line_format % (table[column_name]['table_name'],
159
                                                  table[column_name]['column_name'],
160
                                                  table[column_name]['length'],
161
                                                  table[column_name]['constant_name'])
162
                    else:
163
                        line_format = "%s.%-{0:d}s %{1:d}d\n".format(int(width1), int(width2))
164
                        content += line_format % (table[column_name]['table_name'],
165
                                                  table[column_name]['column_name'],
166
                                                  table[column_name]['length'])
167
168
            content += "\n"
169
170
        # Save the columns, width and constants to the filesystem.
171
        Util.write_two_phases(self._constants_filename, content, self._io)
172
173
    # ------------------------------------------------------------------------------------------------------------------
174 1
    def _get_labels(self) -> None:
175
        """
176
        Gets all primary key labels from the MySQL database.
177
        """
178
        tables = self._dl.get_label_tables(self._label_regex)
179
        for table in tables:
180
            rows = self._dl.get_labels_from_table(table['table_name'], table['id'], table['label'])
181
            for row in rows:
182
                self._labels[row['label']] = row['id']
183
184
    # ------------------------------------------------------------------------------------------------------------------
185 1
    def _fill_constants(self) -> None:
186
        """
187
        Merges columns and labels (i.e. all known constants) into constants.
188
        """
189
        for table_name, table in sorted(self._columns.items()):
190
            for _, column in sorted(table.items()):
191
                if 'constant_name' in column:
192
                    self._constants[column['constant_name']] = column['length']
193
194
        for label, label_id in sorted(self._labels.items()):
195
            self._constants[label] = label_id
196
197
    # ------------------------------------------------------------------------------------------------------------------
198 1
    @staticmethod
199 1
    def derive_field_length(column: Dict[str, Any]) -> Optional[int]:
200
        """
201
        Returns the width of a field based on column.
202
203
        :param column: The column of which the field is based.
204
        """
205
        types_length = {'tinyint':    column['numeric_precision'],
206
                        'smallint':   column['numeric_precision'],
207
                        'mediumint':  column['numeric_precision'],
208
                        'int':        column['numeric_precision'],
209
                        'bigint':     column['numeric_precision'],
210
                        'decimal':    column['numeric_precision'],
211
                        'float':      column['numeric_precision'],
212
                        'double':     column['numeric_precision'],
213
                        'char':       column['character_maximum_length'],
214
                        'varchar':    column['character_maximum_length'],
215
                        'binary':     column['character_maximum_length'],
216
                        'varbinary':  column['character_maximum_length'],
217
                        'tinytext':   column['character_maximum_length'],
218
                        'text':       column['character_maximum_length'],
219
                        'mediumtext': column['character_maximum_length'],
220
                        'longtext':   column['character_maximum_length'],
221
                        'tinyblob':   column['character_maximum_length'],
222
                        'blob':       column['character_maximum_length'],
223
                        'mediumblob': column['character_maximum_length'],
224
                        'longblob':   column['character_maximum_length'],
225
                        'bit':        column['character_maximum_length'],
226
                        'timestamp':  16,
227
                        'year':       4,
228
                        'time':       8,
229
                        'date':       10,
230
                        'datetime':   16,
231
                        'inet6':      39,
232
                        'enum':       None,
233
                        'set':        None}
234
235
        if column['data_type'] in types_length:
236
            return types_length[column['data_type']]
237
238
        raise Exception("Unexpected type '{0!s}'.".format(column['data_type']))
239
240
# ----------------------------------------------------------------------------------------------------------------------
241