PgSqlConstantWorker._get_old_columns()   C
last analyzed

Complexity

Conditions 10

Size

Total Lines 41
Code Lines 29

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 22
CRAP Score 10.0578

Importance

Changes 0
Metric Value
cc 10
eloc 29
nop 1
dl 0
loc 41
ccs 22
cts 24
cp 0.9167
crap 10.0578
rs 5.9999
c 0
b 0
f 0

How to fix   Complexity   

Complexity

Complex classes like pystratum_pgsql.backend.PgSqlConstantWorker.PgSqlConstantWorker._get_old_columns() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1 1
import os
2 1
import re
3 1
from configparser import ConfigParser
4 1
from typing import Any, Dict, Optional
5
6 1
from pystratum_backend.StratumStyle import StratumStyle
7 1
from pystratum_common.backend.CommonConstantWorker import CommonConstantWorker
8 1
from pystratum_common.Util import Util
9
10 1
from pystratum_pgsql.backend.PgSqlWorker import PgSqlWorker
11
12
13 1
class PgSqlConstantWorker(PgSqlWorker, CommonConstantWorker):
14
    """
15
    Class for creating constants based on column widths, and auto increment columns and labels for PostgreSQL databases.
16
    """
17
18
    # ------------------------------------------------------------------------------------------------------------------
19 1
    def __init__(self, io: StratumStyle, config: ConfigParser):
20
        """
21
        Object constructor.
22
23
        :param PyStratumStyle io: The output decorator.
24
        """
25 1
        PgSqlWorker.__init__(self, io, config)
26 1
        CommonConstantWorker.__init__(self, io, config)
27
28 1
        self._columns: Dict[str, Any] = {}
29 1
        """
30
        Metadata of all columns in the current schema and information_schema.
31
32
        :type: dict
33
        """
34
35
    # ------------------------------------------------------------------------------------------------------------------
36 1
    def _get_old_columns(self) -> None:
37
        """
38
        Reads from file constants_filename the previous table and column names, the width of the column,
39
        and the constant name (if assigned) and stores this data in old_columns.
40
        """
41 1
        if os.path.exists(self._constants_filename):
42 1
            with open(self._constants_filename, 'r') as stream:
43 1
                for line in stream:
44 1
                    if line != "\n":
45 1
                        prog = re.compile(r'\s*(?:([a-zA-Z0-9_]+)\.)?([a-zA-Z0-9_]+)\.'
46
                                          r'([a-zA-Z0-9_]+)\s+(\d+)\s*(\*|[a-zA-Z0-9_]+)?\s*')
47 1
                        matches = prog.findall(line)
48
49 1
                        if matches:
50 1
                            matches = matches[0]
51 1
                            schema_name = str(matches[0])
52 1
                            table_name = str(matches[1])
53 1
                            column_name = str(matches[2])
54 1
                            length = str(matches[3])
55 1
                            constant_name = str(matches[4])
56
57 1
                            if schema_name:
58
                                table_name = schema_name + '.' + table_name
59
60 1
                            if constant_name:
61 1
                                column_info = {'table_name':    table_name,
62
                                               'column_name':   column_name,
63
                                               'length':        length,
64
                                               'constant_name': constant_name}
65
                            else:
66 1
                                column_info = {'table_name':  table_name,
67
                                               'column_name': column_name,
68
                                               'length':      length}
69
70 1
                            if table_name in self._old_columns:
71 1
                                if column_name in self._old_columns[table_name]:
72
                                    pass
73
                                else:
74 1
                                    self._old_columns[table_name][column_name] = column_info
75
                            else:
76 1
                                self._old_columns[table_name] = {column_name: column_info}
77
78
    # ------------------------------------------------------------------------------------------------------------------
79 1
    def _get_columns(self) -> None:
80
        """
81
        Retrieves metadata all columns in the current schema and information_schema.
82
        """
83 1
        rows = self._dl.get_all_table_columns()
84 1
        for row in rows:
85
            # Enhance row with the actual length of the column.
86 1
            row['length'] = self.derive_field_length(row)
87
88 1
            if row['table_name'] in self._columns:
89 1
                if row['column_name'] in self._columns[row['table_name']]:
90
                    pass
91
                else:
92 1
                    self._columns[row['table_name']][row['column_name']] = row
93
            else:
94 1
                self._columns[row['table_name']] = {row['column_name']: row}
95
96
    # ------------------------------------------------------------------------------------------------------------------
97 1
    def _enhance_columns(self) -> None:
98
        """
99
        Enhances old_columns as follows:
100
        If the constant name is *, is is replaced with the column name prefixed by prefix in uppercase.
101
        Otherwise the constant name is set to uppercase.
102
        """
103 1
        if self._old_columns:
104 1
            for table_name, table in sorted(self._old_columns.items()):
105 1
                for column_name, column in sorted(table.items()):
106 1
                    table_name = column['table_name']
107 1
                    column_name = column['column_name']
108
109 1
                    if 'constant_name' in column:
110 1
                        if column['constant_name'].strip() == '*':
111
                            constant_name = str(self._prefix + column['column_name']).upper()
112
                            self._old_columns[table_name][column_name]['constant_name'] = constant_name
113
                        else:
114 1
                            constant_name = str(self._old_columns[table_name][column_name]['constant_name']).upper()
115 1
                            self._old_columns[table_name][column_name]['constant_name'] = constant_name
116
117
    # ------------------------------------------------------------------------------------------------------------------
118 1
    def _merge_columns(self) -> None:
119
        """
120
        Preserves relevant data in old_columns into columns.
121
        """
122 1
        if self._old_columns:
123 1
            for table_name, table in sorted(self._old_columns.items()):
124 1
                for column_name, column in sorted(table.items()):
125 1
                    if 'constant_name' in column:
126 1
                        try:
127 1
                            self._columns[table_name][column_name]['constant_name'] = column['constant_name']
128
                        except KeyError:
129
                            # Either the column or table is not present anymore.
130
                            self._io.warning('Dropping constant {0} because column is not present anymore'.
131
                                             format(column['constant_name']))
132
133
    # ------------------------------------------------------------------------------------------------------------------
134 1
    def _write_columns(self) -> None:
135
        """
136
        Writes table and column names, the width of the column, and the constant name (if assigned) to
137
        constants_filename.
138
        """
139 1
        content = ''
140 1
        for _, table in sorted(self._columns.items()):
141 1
            width1 = 0
142 1
            width2 = 0
143
144 1
            key_map = {}
145 1
            for column_name, column in table.items():
146 1
                key_map[column['ordinal_position']] = column_name
147 1
                width1 = max(len(str(column['column_name'])), width1)
148 1
                width2 = max(len(str(column['length'])), width2)
149
150 1
            for _, column_name in sorted(key_map.items()):
151 1
                if table[column_name]['length'] is not None:
152 1
                    if 'constant_name' in table[column_name]:
153 1
                        line_format = "%s.%-{0:d}s %{1:d}d %s\n".format(int(width1), int(width2))
154 1
                        content += line_format % (table[column_name]['table_name'],
155
                                                  table[column_name]['column_name'],
156
                                                  table[column_name]['length'],
157
                                                  table[column_name]['constant_name'])
158
                    else:
159 1
                        line_format = "%s.%-{0:d}s %{1:d}d\n".format(int(width1), int(width2))
160 1
                        content += line_format % (table[column_name]['table_name'],
161
                                                  table[column_name]['column_name'],
162
                                                  table[column_name]['length'])
163
164 1
            content += "\n"
165
166
        # Save the columns, width and constants to the filesystem.
167 1
        Util.write_two_phases(self._constants_filename, content, self._io)
168
169
    # ------------------------------------------------------------------------------------------------------------------
170 1
    def _get_labels(self) -> None:
171
        """
172
        Gets all primary key labels from the current schema.
173
        """
174 1
        tables = self._dl.get_label_tables(self._label_regex)
175 1
        for table in tables:
176 1
            rows = self._dl.get_labels_from_table(table['table_name'], table['id'], table['label'])
177 1
            for row in rows:
178 1
                self._labels[row['label']] = row['id']
179
180
    # ------------------------------------------------------------------------------------------------------------------
181 1
    def _fill_constants(self) -> None:
182
        """
183
        Merges columns and labels (i.e. all known constants) into constants.
184
        """
185 1
        for _, table in sorted(self._columns.items()):
186 1
            for _, column in sorted(table.items()):
187 1
                if 'constant_name' in column:
188 1
                    self._constants[column['constant_name']] = column['length']
189
190 1
        for label, label_id in sorted(self._labels.items()):
191 1
            self._constants[label] = label_id
192
193
    # ------------------------------------------------------------------------------------------------------------------
194 1
    @staticmethod
195 1
    def derive_field_length(column: Dict[str, Any]) -> Optional[int]:
196
        """
197
        Returns the width of a field based on the data type of column.
198
199
        :param dict column: The column of which the field is based.
200
201
        :rtype: int|None
202
        """
203 1
        types_length = {'bigint':                      21,
204
                        'integer':                     11,
205
                        'smallint':                    6,
206
                        'bit':                         column['character_maximum_length'],
207
                        'money':                       None,
208
                        'boolean':                     None,
209
                        'double':                      column['numeric_precision'],
210
                        'numeric':                     column['numeric_precision'],
211
                        'real':                        None,
212
                        'character':                   column['character_maximum_length'],
213
                        'character varying':           column['character_maximum_length'],
214
                        'point':                       None,
215
                        'polygon':                     None,
216
                        'text':                        None,
217
                        'bytea':                       None,
218
                        'xml':                         None,
219
                        'USER-DEFINED':                None,
220
                        'timestamp without time zone': 16,
221
                        'time without time zone':      8,
222
                        'date':                        10}
223
224 1
        if column['data_type'] in types_length:
225 1
            return types_length[column['data_type']]
226
227
        raise Exception("Unexpected type '{0!s}'.".format(column['data_type']))
228
229
    # ------------------------------------------------------------------------------------------------------------------
230 1
    def _read_configuration_file(self) -> None:
231
        """
232
        Reads parameters from the configuration file.
233
        """
234 1
        CommonConstantWorker._read_configuration_file(self)
235
236
# ----------------------------------------------------------------------------------------------------------------------
237