Passed
Pull Request — master (#132)
by
unknown
03:30
created

DataTables._map_columns_with_params()   B

Complexity

Conditions 6

Size

Total Lines 26
Code Lines 15

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 15
dl 0
loc 26
rs 8.6666
c 0
b 0
f 0
cc 6
nop 3
1
from __future__ import absolute_import
2
3
import math
4
import re
5
6
from sqlalchemy import Text, func, or_
7
from sqlalchemy.dialects import mysql, postgresql, sqlite
8
9
from datatables.clean_regex import clean_regex
10
from datatables.search_methods import SEARCH_METHODS
11
12
13
class DataTables:
14
    """Class defining a DataTables object.
15
16
    :param request: request containing the GET values, specified by the
17
        datatable for filtering, sorting and paging
18
    :type request: pyramid.request
19
    :param query: the query wanted to be seen in the the table
20
    :type query: sqlalchemy.orm.query.Query
21
    :param columns: columns specification for the datatables
22
    :type columns: list
23
24
    :returns: a DataTables object
25
    """
26
27
    def __init__(self, request, query, columns, allow_regex_searches=False):
28
        """Initialize object and run the query."""
29
        self.params = dict(request)
30
        if "sEcho" in self.params:
31
            raise ValueError("Legacy datatables not supported, upgrade to >=1.10")
32
        self.query = query
33
        self.columns = columns
34
        self.results = None
35
        self.allow_regex_searches = allow_regex_searches
36
37
        # total in the table after filtering
38
        self.cardinality_filtered = 0
39
40
        # total in the table unfiltered
41
        self.cardinality = 0
42
43
        self.yadcf_params = []
44
        self.filter_expressions = []
45
        self.error = None
46
        try:
47
            self.run()
48
        except Exception as exc:
49
            self.error = str(exc)
50
51
    def output_result(self):
52
        """Output results in the format needed by DataTables."""
53
        output = {}
54
        output["draw"] = str(int(self.params.get("draw", 1)))
55
        output["recordsTotal"] = str(self.cardinality)
56
        output["recordsFiltered"] = str(self.cardinality_filtered)
57
        if self.error:
58
            output["error"] = self.error
59
            return output
60
61
        output["data"] = self.results
62
        for k, v in self.yadcf_params:
63
            output[k] = v
64
        return output
65
66
    def _map_columns_with_params(self, columns, params):
67
        """Compare columns data with the parameters data and map the correct
68
        column number to the columns data. As a lot of times in  frontend columns
69
        are not in the correct order as they are in the backend. Also there
70
        are cases when extra dummy columns are added in the frontend and they
71
        disturb the sequencing, thus the results coming from the backend."""
72
        pattern = re.compile("columns\[(.*?)\]\[data\]")
73
        # Extract only the keys of type columns[i][data] from the params
74
        params_column_data = {k: v for k, v in params.items() if pattern.match(k)}
75
        column_params_map = []
76
        i = 0
77
        for key, value in params_column_data.items():
78
            column_number = int(pattern.search(key).group(1))
79
            if str(value):
80
                for column in columns:
81
                    # If the mData is specified as well as the data is specified
82
                    # in the frontend then we would try to map the correct column number
83
                    # You can set the data in the datatables here https://datatables.net/reference/option/columns.data
84
                    if str(value) == column.mData:
85
                        column_params_map.append((column_number, column))
86
                        break
87
                else:
88
                    # If we are unable to find the matching data
89
                    column_params_map.append((column_number, columns[i]))
90
                i += 1
91
        return column_params_map
92
93
    def _query_with_all_filters_except_one(self, query, exclude):
94
        return query.filter(
95
            *[
96
                e
97
                for i, e in enumerate(self.filter_expressions)
98
                if e is not None and i is not exclude
99
            ]
100
        )
101
102
    def _set_yadcf_data(self, query):
103
        # determine values for yadcf filters
104
        column_params_map = self._map_columns_with_params(self.columns, self.params)
105
        for i, col in column_params_map:
106
            if col.yadcf_data:
107
                if col.search_method in "yadcf_range_number_slider":
108
                    v = query.with_entities(
109
                        func.min(col.sqla_expr), func.max(col.sqla_expr)
110
                    ).one()
111
                    self.yadcf_params.append(
112
                        (
113
                            "yadcf_data_{:d}".format(i),
114
                            (math.floor(v[0]), math.ceil(v[1])),
115
                        )
116
                    )
117
118
                if col.search_method in [
119
                    "yadcf_select",
120
                    "yadcf_multi_select",
121
                    "yadcf_autocomplete",
122
                ]:
123
                    filtered = self._query_with_all_filters_except_one(
124
                        query=query, exclude=i
125
                    )
126
                    v = filtered.with_entities(col.sqla_expr).distinct().all()
127
                    # Added the below `if` statement so that data with only
128
                    # null value is not returned.
129
                    if not (len(v) == 1 and v[0][0] == None):
130
                        self.yadcf_params.append(
131
                            ("yadcf_data_{:d}".format(i), [r[0] for r in v])
132
                        )
133
134
    def run(self):
135
        """Launch filtering, sorting and paging to output results."""
136
        query = self.query
137
138
        # count before filtering
139
        self.cardinality = query.add_columns(self.columns[0].sqla_expr).count()
140
141
        self._set_column_filter_expressions()
142
        self._set_global_filter_expression()
143
        self._set_sort_expressions()
144
        self._set_yadcf_data(query)
145
146
        # apply filters
147
        query = query.filter(*[e for e in self.filter_expressions if e is not None])
148
149
        self.cardinality_filtered = query.with_entities(
150
            self.columns[0].sqla_expr
151
        ).count()
152
153
        # apply sorts
154
        query = query.order_by(*[e for e in self.sort_expressions if e is not None])
155
156
        # add paging options
157
        length = int(self.params.get("length"))
158
        if length >= 0:
159
            query = query.limit(length)
160
        elif length == -1:
161
            pass
162
        else:
163
            raise (ValueError("Length should be a positive integer or -1 to disable"))
164
        query = query.offset(int(self.params.get("start")))
165
166
        # add columns to query
167
        query = query.with_entities(*[c.sqla_expr for c in self.columns])
168
169
        # fetch the result of the queries
170
        column_names = [
171
            col.mData if col.mData else str(i) for i, col in enumerate(self.columns)
172
        ]
173
        self.results = [
174
            {k: v for k, v in zip(column_names, row)} for row in query.all()
175
        ]
176
177
    def _set_column_filter_expressions(self):
178
        """Construct the query: filtering.
179
180
        Add filtering when per column searching is used.
181
        """
182
        # per columns filters:
183
        column_params_map = self._map_columns_with_params(self.columns, self.params)
184
        for i, col in column_params_map:
185
            filter_expr = None
186
            value = self.params.get(
187
                "columns[{:d}][search][value]".format(i), ""
188
            ).replace("\\", "")
189
190
            if value:
191
                search_func = SEARCH_METHODS[col.search_method]
192
                filter_expr = search_func(col.sqla_expr, value)
193
            self.filter_expressions.append(filter_expr)
194
195
    def _set_global_filter_expression(self):
196
        # global search filter
197
        global_search = self.params.get("search[value]", "")
198
        if global_search == "":
199
            return
200
201
        if self.allow_regex_searches and self.params.get("search[regex]") == "true":
202
            op = self._get_regex_operator()
203
            val = clean_regex(global_search)
204
205
            def filter_for(col):
206
                return col.sqla_expr.op(op)(val)
0 ignored issues
show
introduced by
The variable op does not seem to be defined for all execution paths.
Loading history...
207
208
        else:
209
            val = "%" + global_search + "%"
210
211
            def filter_for(col):
212
                return col.sqla_expr.cast(Text).ilike(val)
213
214
        global_filter = [filter_for(col) for col in self.columns if col.global_search]
215
216
        self.filter_expressions.append(or_(*global_filter))
217
218
    def _set_sort_expressions(self):
219
        """Construct the query: sorting.
220
221
        Add sorting(ORDER BY) on the columns needed to be applied on.
222
        """
223
        column_params_map = dict(
224
            self._map_columns_with_params(self.columns, self.params)
225
        )
226
        sort_expressions = []
227
        i = 0
228
        while self.params.get("order[{:d}][column]".format(i), False):
229
            column_nr = int(self.params.get("order[{:d}][column]".format(i)))
230
            column = column_params_map[column_nr]
231
            direction = self.params.get("order[{:d}][dir]".format(i))
232
            sort_expr = column.sqla_expr
233
            if direction == "asc":
234
                sort_expr = sort_expr.asc()
235
            elif direction == "desc":
236
                sort_expr = sort_expr.desc()
237
            else:
238
                raise ValueError("Invalid order direction: {}".format(direction))
239
            if column.nulls_order:
240
                if column.nulls_order == "nullsfirst":
241
                    sort_expr = sort_expr.nullsfirst()
242
                elif column.nulls_order == "nullslast":
243
                    sort_expr = sort_expr.nullslast()
244
                else:
245
                    raise ValueError("Invalid order direction: {}".format(direction))
246
247
            sort_expressions.append(sort_expr)
248
            i += 1
249
        self.sort_expressions = sort_expressions
250
251
    def _get_regex_operator(self):
252
        if isinstance(self.query.session.bind.dialect, postgresql.dialect):
253
            return "~"
254
        elif isinstance(self.query.session.bind.dialect, mysql.dialect):
255
            return "REGEXP"
256
        elif isinstance(self.query.session.bind.dialect, sqlite.dialect):
257
            return "REGEXP"
258
        else:
259
            raise NotImplementedError(
260
                "Regex searches are not implemented for this dialect"
261
            )
262