Passed
Pull Request — master (#136)
by
unknown
03:23
created

datatables.datatables.DataTables.default_escape()   A

Complexity

Conditions 1

Size

Total Lines 4
Code Lines 2

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 2
dl 0
loc 4
rs 10
c 0
b 0
f 0
cc 1
nop 2
1
from __future__ import absolute_import
2
3
import math
4
5
from sqlalchemy import Text, func, or_
6
from sqlalchemy.dialects import mysql, postgresql, sqlite
7
8
from datatables.clean_regex import clean_regex
9
from datatables.search_methods import SEARCH_METHODS
10
11
12
class DataTables:
13
    """Class defining a DataTables object.
14
15
    :param request: request containing the GET values, specified by the
16
        datatable for filtering, sorting and paging
17
    :type request: pyramid.request
18
    :param query: the query wanted to be seen in the the table
19
    :type query: sqlalchemy.orm.query.Query
20
    :param columns: columns specification for the datatables
21
    :type columns: list
22
23
    :returns: a DataTables object
24
    """
25
26
    def __init__(
27
        self, request, query, columns, allow_regex_searches=False, escape=None
28
    ):
29
        """Initialize object and run the query."""
30
        self.params = dict(request)
31
        if "sEcho" in self.params:
32
            raise ValueError("Legacy datatables not supported, upgrade to >=1.10")
33
        self.query = query
34
        self.columns = columns
35
        self.results = None
36
        self.allow_regex_searches = allow_regex_searches
37
38
        # callable to escape data
39
        self.escape = escape or self.default_escape
40
41
        # total in the table after filtering
42
        self.cardinality_filtered = 0
43
44
        # total in the table unfiltered
45
        self.cardinality = 0
46
47
        self.yadcf_params = []
48
        self.filter_expressions = []
49
        self.error = None
50
        try:
51
            self.run()
52
        except Exception as exc:
53
            self.error = str(exc)
54
55
    def default_escape(self, data):
56
        """Escape data before output."""
57
        # do nothing by default
58
        return data
59
60
    def output_result(self):
61
        """Output results in the format needed by DataTables."""
62
        output = {}
63
        output["draw"] = str(int(self.params.get("draw", 1)))
64
        output["recordsTotal"] = str(self.cardinality)
65
        output["recordsFiltered"] = str(self.cardinality_filtered)
66
        if self.error:
67
            output["error"] = self.error
68
            return output
69
70
        output["data"] = self.results
71
        for k, v in self.yadcf_params:
72
            output[k] = v
73
        return output
74
75
    def _query_with_all_filters_except_one(self, query, exclude):
76
        return query.filter(
77
            *[
78
                e
79
                for i, e in enumerate(self.filter_expressions)
80
                if e is not None and i is not exclude
81
            ]
82
        )
83
84
    def _set_yadcf_data(self, query):
85
        # determine values for yadcf filters
86
        for i, col in enumerate(self.columns):
87
            if col.search_method in "yadcf_range_number_slider":
88
                v = query.add_columns(
89
                    func.min(col.sqla_expr), func.max(col.sqla_expr)
90
                ).one()
91
                self.yadcf_params.append(
92
                    ("yadcf_data_{:d}".format(i), (math.floor(v[0]), math.ceil(v[1])))
93
                )
94
            if col.search_method in [
95
                "yadcf_select",
96
                "yadcf_multi_select",
97
                "yadcf_autocomplete",
98
            ]:
99
                filtered = self._query_with_all_filters_except_one(
100
                    query=query, exclude=i
101
                )
102
                v = filtered.add_columns(col.sqla_expr).distinct().all()
103
                self.yadcf_params.append(
104
                    ("yadcf_data_{:d}".format(i), [r[0] for r in v])
105
                )
106
107
    def run(self):
108
        """Launch filtering, sorting and paging to output results."""
109
        query = self.query
110
111
        # count before filtering
112
        self.cardinality = query.add_columns(self.columns[0].sqla_expr).count()
113
114
        self._set_column_filter_expressions()
115
        self._set_global_filter_expression()
116
        self._set_sort_expressions()
117
        self._set_yadcf_data(query)
118
119
        # apply filters
120
        query = query.filter(*[e for e in self.filter_expressions if e is not None])
121
122
        self.cardinality_filtered = query.add_columns(self.columns[0].sqla_expr).count()
123
124
        # apply sorts
125
        query = query.order_by(*[e for e in self.sort_expressions if e is not None])
126
127
        # add paging options
128
        length = int(self.params.get("length"))
129
        if length >= 0:
130
            query = query.limit(length)
131
        elif length == -1:
132
            pass
133
        else:
134
            raise (ValueError("Length should be a positive integer or -1 to disable"))
135
        query = query.offset(int(self.params.get("start")))
136
137
        # add columns to query
138
        query = query.add_columns(*[c.sqla_expr for c in self.columns])
139
140
        # fetch the result of the queries
141
        column_names = [
142
            col.mData if col.mData else str(i) for i, col in enumerate(self.columns)
143
        ]
144
        self.results = [
145
            {k: self.escape(v) for k, v in zip(column_names, row)} for row in query.all()
146
        ]
147
148
    def _set_column_filter_expressions(self):
149
        """Construct the query: filtering.
150
151
        Add filtering when per column searching is used.
152
        """
153
        # per columns filters:
154
        for i in range(len(self.columns)):
155
            filter_expr = None
156
            value = self.params.get("columns[{:d}][search][value]".format(i), "")
157
            if value:
158
                search_func = SEARCH_METHODS[self.columns[i].search_method]
159
                filter_expr = search_func(self.columns[i].sqla_expr, value)
160
            self.filter_expressions.append(filter_expr)
161
162
    def _set_global_filter_expression(self):
163
        # global search filter
164
        global_search = self.params.get("search[value]", "")
165
        if global_search == "":
166
            return
167
168
        if self.allow_regex_searches and self.params.get("search[regex]") == "true":
169
            op = self._get_regex_operator()
170
            val = clean_regex(global_search)
171
172
            def filter_for(col):
173
                return col.sqla_expr.op(op)(val)
0 ignored issues
show
introduced by
The variable op does not seem to be defined for all execution paths.
Loading history...
174
175
        else:
176
            val = "%" + global_search + "%"
177
178
            def filter_for(col):
179
                return col.sqla_expr.cast(Text).ilike(val)
180
181
        global_filter = [filter_for(col) for col in self.columns if col.global_search]
182
183
        self.filter_expressions.append(or_(*global_filter))
184
185
    def _set_sort_expressions(self):
186
        """Construct the query: sorting.
187
188
        Add sorting(ORDER BY) on the columns needed to be applied on.
189
        """
190
        sort_expressions = []
191
        i = 0
192
        while self.params.get("order[{:d}][column]".format(i), False):
193
            column_nr = int(self.params.get("order[{:d}][column]".format(i)))
194
            column = self.columns[column_nr]
195
            direction = self.params.get("order[{:d}][dir]".format(i))
196
            sort_expr = column.sqla_expr
197
            if direction == "asc":
198
                sort_expr = sort_expr.asc()
199
            elif direction == "desc":
200
                sort_expr = sort_expr.desc()
201
            else:
202
                raise ValueError("Invalid order direction: {}".format(direction))
203
            if column.nulls_order:
204
                if column.nulls_order == "nullsfirst":
205
                    sort_expr = sort_expr.nullsfirst()
206
                elif column.nulls_order == "nullslast":
207
                    sort_expr = sort_expr.nullslast()
208
                else:
209
                    raise ValueError("Invalid order direction: {}".format(direction))
210
211
            sort_expressions.append(sort_expr)
212
            i += 1
213
        self.sort_expressions = sort_expressions
214
215
    def _get_regex_operator(self):
216
        if isinstance(self.query.session.bind.dialect, postgresql.dialect):
217
            return "~"
218
        elif isinstance(self.query.session.bind.dialect, mysql.dialect):
219
            return "REGEXP"
220
        elif isinstance(self.query.session.bind.dialect, sqlite.dialect):
221
            return "REGEXP"
222
        else:
223
            raise NotImplementedError(
224
                "Regex searches are not implemented for this dialect"
225
            )
226