Passed
Pull Request — master (#608)
by
unknown
22:45 queued 09:47
created

TabPyApp._get_arrow_server()   A

Complexity

Conditions 3

Size

Total Lines 24
Code Lines 20

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 10.4157

Importance

Changes 0
Metric Value
eloc 20
dl 0
loc 24
ccs 1
cts 16
cp 0.0625
rs 9.4
c 0
b 0
f 0
cc 3
nop 2
crap 10.4157
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import tabpy
10 1
from tabpy.tabpy import __version__
11 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
12 1
from tabpy.tabpy_server.app.util import parse_pwd_file
13 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
14 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
15 1
from tabpy.tabpy_server.management.state import TabPyState
16 1
from tabpy.tabpy_server.management.util import _get_state_from_file
17 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
18 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
19 1
from tabpy.tabpy_server.handlers import (
20
    EndpointHandler,
21
    EndpointsHandler,
22
    EvaluationPlaneHandler,
23
    EvaluationPlaneDisabledHandler,
24
    QueryPlaneHandler,
25
    ServiceInfoHandler,
26
    StatusHandler,
27
    UploadDestinationHandler,
28
)
29 1
import tornado
30 1
from tornado.http1connection import HTTP1Connection
31 1
import tabpy.tabpy_server.app.arrow_server as pa
32 1
import _thread
33
34 1
logger = logging.getLogger(__name__)
35
36 1
def _init_asyncio_patch():
37
    """
38
    Select compatible event loop for Tornado 5+.
39
    As of Python 3.8, the default event loop on Windows is `proactor`,
40
    however Tornado requires the old default "selector" event loop.
41
    As Tornado has decided to leave this to users to set, MkDocs needs
42
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
43
    """
44 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
45
        import asyncio
46
        try:
47
            from asyncio import WindowsSelectorEventLoopPolicy
48
        except ImportError:
49
            pass  # Can't assign a policy which doesn't exist.
50
        else:
51
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
52
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
53
54
55 1
class TabPyApp:
56
    """
57
    TabPy application class for keeping context like settings, state, etc.
58
    """
59
60 1
    settings = {}
61 1
    subdirectory = ""
62 1
    tabpy_state = None
63 1
    python_service = None
64 1
    credentials = {}
65 1
    arrow_server = None
66 1
    max_request_size = None
67
68 1
    def __init__(self, config_file):
69 1
        if config_file is None:
70 1
            config_file = os.path.join(
71
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
72
            )
73
74 1
        if os.path.isfile(config_file):
75 1
            try:
76 1
                from logging import config
77 1
                config.fileConfig(config_file, disable_existing_loggers=False)
78 1
            except KeyError:
79 1
                logging.basicConfig(level=logging.DEBUG)
80
81 1
        self._parse_config(config_file)
82
83 1
    def _get_tls_certificates(self, config):
84
        tls_certificates = []
85
        cert = config[SettingsParameters.CertificateFile]
86
        key = config[SettingsParameters.KeyFile]
87
        with open(cert, "rb") as cert_file:
88
            tls_cert_chain = cert_file.read()
89
        with open(key, "rb") as key_file:
90
            tls_private_key = key_file.read()
91
        tls_certificates.append((tls_cert_chain, tls_private_key))
92
        return tls_certificates
93
    
94 1
    def _get_arrow_server(self, config):
95
        verify_client = None
96
        tls_certificates = None
97
        scheme = "grpc+tcp"
98
        if config[SettingsParameters.TransferProtocol] == "https":
99
            scheme = "grpc+tls"
100
            tls_certificates = self._get_tls_certificates(config)
101
102
        host = "0.0.0.0"
103
        port = config.get(SettingsParameters.ArrowFlightPort)
104
        location = "{}://{}:{}".format(scheme, host, port)
105
106
        auth_middleware = None
107
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
108
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
109
            auth_middleware = {
110
                "basic": BasicAuthServerMiddlewareFactory(creds)
111
            }
112
113
        server = pa.FlightServer(host, location,
114
                            tls_certificates=tls_certificates,
115
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
116
                            middleware=auth_middleware)
117
        return server
118
119 1
    def run(self):
120
        application = self._create_tornado_web_app()
121
        self.max_request_size = (
122
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
123
        )
124
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
125
126
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
127
128
        protocol = self.settings[SettingsParameters.TransferProtocol]
129
        ssl_options = None
130
        if protocol == "https":
131
            ssl_options = {
132
                "certfile": self.settings[SettingsParameters.CertificateFile],
133
                "keyfile": self.settings[SettingsParameters.KeyFile],
134
            }
135
        elif protocol != "http":
136
            msg = f"Unsupported transfer protocol {protocol}."
137
            logger.critical(msg)
138
            raise RuntimeError(msg)
139
140
        settings = {}
141
        if self.settings[SettingsParameters.GzipEnabled] is True:
142
            settings["decompress_request"] = True
143
144
        application.listen(
145
            self.settings[SettingsParameters.Port],
146
            ssl_options=ssl_options,
147
            max_buffer_size=self.max_request_size,
148
            max_body_size=self.max_request_size,
149
            **settings,
150
        ) 
151
152
        logger.info(
153
            "Web service listening on port "
154
            f"{str(self.settings[SettingsParameters.Port])}"
155
        )
156
157
        if self.settings[SettingsParameters.ArrowEnabled]:
158
            def start_pyarrow():
159
                self.arrow_server = self._get_arrow_server(self.settings)
160
                pa.start(self.arrow_server)
161
162
            try:
163
                _thread.start_new_thread(start_pyarrow, ())
164
            except Exception as e:
165
                logger.critical(f"Failed to start PyArrow server: {e}")
166
167
        tornado.ioloop.IOLoop.instance().start()
168
169 1
    def _create_tornado_web_app(self):
170 1
        class TabPyTornadoApp(tornado.web.Application):
171 1
            is_closing = False
172
173 1
            def signal_handler(self, signal, _):
174
                logger.critical(f"Exiting on signal {signal}...")
175
                self.is_closing = True
176
177 1
            def try_exit(self):
178
                if self.is_closing:
179
                    tornado.ioloop.IOLoop.instance().stop()
180
                    logger.info("Shutting down TabPy...")
181
182 1
        logger.info("Initializing TabPy...")
183 1
        tornado.ioloop.IOLoop.instance().run_sync(
184
            lambda: init_ps_server(self.settings, self.tabpy_state)
185
        )
186 1
        logger.info("Done initializing TabPy.")
187
188 1
        executor = concurrent.futures.ThreadPoolExecutor(
189
            max_workers=multiprocessing.cpu_count()
190
        )
191
192
        # initialize Tornado application
193 1
        _init_asyncio_patch()
194 1
        application = TabPyTornadoApp(
195
            [
196
                (
197
                    self.subdirectory + r"/query/([^/]+)",
198
                    QueryPlaneHandler,
199
                    dict(app=self),
200
                ),
201
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
202
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
203
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
204
                (
205
                    self.subdirectory + r"/endpoints/([^/]+)?",
206
                    EndpointHandler,
207
                    dict(app=self),
208
                ),
209
                (
210
                    self.subdirectory + r"/evaluate",
211
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
212
                    else EvaluationPlaneDisabledHandler,
213
                    dict(executor=executor, app=self),
214
                ),
215
                (
216
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
217
                    UploadDestinationHandler,
218
                    dict(app=self),
219
                ),
220
                (
221
                    self.subdirectory + r"/(.*)",
222
                    tornado.web.StaticFileHandler,
223
                    dict(
224
                        path=self.settings[SettingsParameters.StaticPath],
225
                        default_filename="index.html",
226
                    ),
227
                ),
228
            ],
229
            debug=False,
230
            **self.settings,
231
        )
232
233 1
        signal.signal(signal.SIGINT, application.signal_handler)
234 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
235
236 1
        signal.signal(signal.SIGINT, application.signal_handler)
237 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
238
239 1
        return application
240
241 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
242 1
        key_is_set = False
243
244 1
        if (
245
            config_key is not None
246
            and parser.has_section("TabPy")
247
            and parser.has_option("TabPy", config_key)
248
        ):
249 1
            if parse_function is None:
250 1
                parse_function = parser.get
251 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
252 1
            key_is_set = True
253 1
            logger.debug(
254
                f"Parameter {settings_key} set to "
255
                f'"{self.settings[settings_key]}" '
256
                "from config file or environment variable"
257
            )
258
259 1
        if not key_is_set and default_val is not None:
260 1
            self.settings[settings_key] = default_val
261 1
            key_is_set = True
262 1
            logger.debug(
263
                f"Parameter {settings_key} set to "
264
                f'"{self.settings[settings_key]}" '
265
                "from default value"
266
            )
267
268 1
        if not key_is_set:
269 1
            logger.debug(f"Parameter {settings_key} is not set")
270
271 1
    def _parse_config(self, config_file):
272
        """Provide consistent mechanism for pulling in configuration.
273
274
        Attempt to retain backward compatibility for
275
        existing implementations by grabbing port
276
        setting from CLI first.
277
278
        Take settings in the following order:
279
280
        1. CLI arguments if present
281
        2. config file
282
        3. OS environment variables (for ease of
283
           setting defaults if not present)
284
        4. current defaults if a setting is not present in any location
285
286
        Additionally provide similar configuration capabilities in between
287
        config file and environment variables.
288
        For consistency use the same variable name in the config file as
289
        in the os environment.
290
        For naming standards use all capitals and start with 'TABPY_'
291
        """
292 1
        self.settings = {}
293 1
        self.subdirectory = ""
294 1
        self.tabpy_state = None
295 1
        self.python_service = None
296 1
        self.credentials = {}
297
298 1
        pkg_path = os.path.dirname(tabpy.__file__)
299
300 1
        parser = configparser.ConfigParser(os.environ)
301 1
        logger.info(f"Parsing config file {config_file}")
302
303 1
        file_exists = False
304 1
        if os.path.isfile(config_file):
305 1
            try:
306 1
                with open(config_file, 'r') as f:
307 1
                    parser.read_string(f.read())
308 1
                    file_exists = True
309
            except Exception:
310
                pass
311
312 1
        if not file_exists:
313 1
            logger.warning(
314
                f"Unable to open config file {config_file}, "
315
                "using default settings."
316
            )
317
318 1
        settings_parameters = [
319
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
320
            (SettingsParameters.ServerVersion, None, __version__, None),
321
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
322
             True, parser.getboolean),
323
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
324
             30, parser.getfloat),
325
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
326
             os.path.join(pkg_path, "tmp", "query_objects"), None),
327
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
328
             "http", None),
329
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
330
             None, None),
331
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
332
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
333
             os.path.join(pkg_path, "tabpy_server"), None),
334
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
335
             os.path.join(pkg_path, "tabpy_server", "static"), None),
336
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
337
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
338
             "false", None),
339
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
340
             100, None),
341
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
342
             True, parser.getboolean),
343
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
344
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
345
        ]
346
347 1
        for setting, parameter, default_val, parse_function in settings_parameters:
348 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
349
350 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
351 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
352
353
        # set and validate transfer protocol
354 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
355
            SettingsParameters.TransferProtocol
356
        ].lower()
357
358 1
        self._validate_transfer_protocol_settings()
359
360
        # if state.ini does not exist try and create it - remove
361
        # last dependence on batch/shell script
362 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
363
            os.path.normpath(
364
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
365
            )
366
        )
367 1
        state_config, self.tabpy_state = self._build_tabpy_state()
368
369 1
        self.python_service = PythonServiceHandler(PythonService())
370 1
        self.settings["compress_response"] = True
371 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
372
            self.settings[SettingsParameters.StaticPath]
373
        )
374 1
        logger.debug(
375
            f"Static pages folder set to "
376
            f'"{self.settings[SettingsParameters.StaticPath]}"'
377
        )
378
379
        # Set subdirectory from config if applicable
380 1
        if state_config.has_option("Service Info", "Subdirectory"):
381 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
382
383
        # If passwords file specified load credentials
384 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
385 1
            if not self._parse_pwd_file():
386 1
                msg = (
387
                    "Failed to read passwords file "
388
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
389
                )
390 1
                logger.critical(msg)
391 1
                raise RuntimeError(msg)
392
        else:
393 1
            logger.info(
394
                "Password file is not specified: " "Authentication is not enabled"
395
            )
396
397 1
        features = self._get_features()
398 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
399
400 1
        self.settings[SettingsParameters.LogRequestContext] = (
401
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
402
        )
403 1
        call_context_state = (
404
            "enabled"
405
            if self.settings[SettingsParameters.LogRequestContext]
406
            else "disabled"
407
        )
408 1
        logger.info(f"Call context logging is {call_context_state}")
409
410 1
    def _validate_transfer_protocol_settings(self):
411 1
        if SettingsParameters.TransferProtocol not in self.settings:
412
            msg = "Missing transfer protocol information."
413
            logger.critical(msg)
414
            raise RuntimeError(msg)
415
416 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
417
418 1
        if protocol == "http":
419 1
            return
420
421 1
        if protocol != "https":
422 1
            msg = f"Unsupported transfer protocol: {protocol}"
423 1
            logger.critical(msg)
424 1
            raise RuntimeError(msg)
425
426 1
        self._validate_cert_key_state(
427
            "The parameter(s) {} must be set.",
428
            SettingsParameters.CertificateFile in self.settings,
429
            SettingsParameters.KeyFile in self.settings,
430
        )
431 1
        cert = self.settings[SettingsParameters.CertificateFile]
432
433 1
        self._validate_cert_key_state(
434
            "The parameter(s) {} must point to " "an existing file.",
435
            os.path.isfile(cert),
436
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
437
        )
438 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
439
440 1
    @staticmethod
441
    def _validate_cert_key_state(msg, cert_valid, key_valid):
442 1
        cert_and_key_param = (
443
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
444
            f"{ConfigParameters.TABPY_KEY_FILE}"
445
        )
446 1
        https_error = "Error using HTTPS: "
447 1
        err = None
448 1
        if not cert_valid and not key_valid:
449 1
            err = https_error + msg.format(cert_and_key_param)
450 1
        elif not cert_valid:
451 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
452 1
        elif not key_valid:
453 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
454
455 1
        if err is not None:
456 1
            logger.critical(err)
457 1
            raise RuntimeError(err)
458
459 1
    def _parse_pwd_file(self):
460 1
        succeeded, self.credentials = parse_pwd_file(
461
            self.settings[ConfigParameters.TABPY_PWD_FILE]
462
        )
463
464 1
        if succeeded and len(self.credentials) == 0:
465 1
            logger.error("No credentials found")
466 1
            succeeded = False
467
468 1
        return succeeded
469
470 1
    def _get_features(self):
471 1
        features = {}
472
473
        # Check for auth
474 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
475 1
            features["authentication"] = {
476
                "required": True,
477
                "methods": {"basic-auth": {}},
478
            }
479
480 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
481 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
482 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
483 1
        return features
484
485 1
    def _build_tabpy_state(self):
486 1
        pkg_path = os.path.dirname(tabpy.__file__)
487 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
488 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
489 1
        if not os.path.isfile(state_file_path):
490
            state_file_template_path = os.path.join(
491
                pkg_path, "tabpy_server", "state.ini.template"
492
            )
493
            logger.debug(
494
                f"File {state_file_path} not found, creating from "
495
                f"template {state_file_template_path}..."
496
            )
497
            shutil.copy(state_file_template_path, state_file_path)
498
499 1
        logger.info(f"Loading state from state file {state_file_path}")
500 1
        tabpy_state = _get_state_from_file(state_file_dir)
501 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
502
503
504
# Override _read_body to allow content with size exceeding max_body_size
505
# This enables proper handling of 413 errors in base_handler
506 1
def _read_body_allow_max_size(self, code, headers, delegate):
507 1
    if "Content-Length" in headers:
508 1
        content_length = int(headers["Content-Length"])
509 1
        if content_length > self._max_body_size:
510
            return
511 1
    return self.original_read_body(code, headers, delegate)
512
513 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
514
HTTP1Connection._read_body = _read_body_allow_max_size
515