Passed
Push — master ( fe0a6e...5716a6 )
by
unknown
13:10 queued 15s
created

tabpy.tabpy_server.app.app.TabPyApp.run()   B

Complexity

Conditions 6

Size

Total Lines 45
Code Lines 33

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 37.8504

Importance

Changes 0
Metric Value
eloc 33
dl 0
loc 45
ccs 1
cts 25
cp 0.04
rs 8.1546
c 0
b 0
f 0
cc 6
nop 1
crap 37.8504
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import _thread
10
11 1
import tornado
12 1
from tornado.http1connection import HTTP1Connection
13
14 1
import tabpy
15 1
import tabpy.tabpy_server.app.arrow_server as pa
16 1
from tabpy.tabpy import __version__
17 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
18 1
from tabpy.tabpy_server.app.util import parse_pwd_file
19 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
20 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
21 1
from tabpy.tabpy_server.management.state import TabPyState
22 1
from tabpy.tabpy_server.management.util import _get_state_from_file
23 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
24 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
25 1
from tabpy.tabpy_server.handlers import (
26
    EndpointHandler,
27
    EndpointsHandler,
28
    EvaluationPlaneHandler,
29
    EvaluationPlaneDisabledHandler,
30
    QueryPlaneHandler,
31
    ServiceInfoHandler,
32
    StatusHandler,
33
    UploadDestinationHandler,
34
)
35
36 1
logger = logging.getLogger(__name__)
37
38 1
def _init_asyncio_patch():
39
    """
40
    Select compatible event loop for Tornado 5+.
41
    As of Python 3.8, the default event loop on Windows is `proactor`,
42
    however Tornado requires the old default "selector" event loop.
43
    As Tornado has decided to leave this to users to set, MkDocs needs
44
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
45
    """
46 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
47
        import asyncio
48
        try:
49
            from asyncio import WindowsSelectorEventLoopPolicy
50
        except ImportError:
51
            pass  # Can't assign a policy which doesn't exist.
52
        else:
53
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
54
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
55
56
57 1
class TabPyApp:
58
    """
59
    TabPy application class for keeping context like settings, state, etc.
60
    """
61
62 1
    settings = {}
63 1
    subdirectory = ""
64 1
    tabpy_state = None
65 1
    python_service = None
66 1
    credentials = {}
67 1
    arrow_server = None
68 1
    max_request_size = None
69
70 1
    def __init__(self, config_file):
71 1
        if config_file is None:
72 1
            config_file = os.path.join(
73
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
74
            )
75
76 1
        if os.path.isfile(config_file):
77 1
            try:
78 1
                from logging import config
79 1
                config.fileConfig(config_file, disable_existing_loggers=False)
80 1
            except KeyError:
81 1
                logging.basicConfig(level=logging.DEBUG)
82
83 1
        self._parse_config(config_file)
84
85 1
    def _get_tls_certificates(self, config):
86
        tls_certificates = []
87
        cert = config[SettingsParameters.CertificateFile]
88
        key = config[SettingsParameters.KeyFile]
89
        with open(cert, "rb") as cert_file:
90
            tls_cert_chain = cert_file.read()
91
        with open(key, "rb") as key_file:
92
            tls_private_key = key_file.read()
93
        tls_certificates.append((tls_cert_chain, tls_private_key))
94
        return tls_certificates
95
    
96 1
    def _get_arrow_server(self, config):
97
        verify_client = None
98
        tls_certificates = None
99
        scheme = "grpc+tcp"
100
        if config[SettingsParameters.TransferProtocol] == "https":
101
            scheme = "grpc+tls"
102
            tls_certificates = self._get_tls_certificates(config)
103
104
        host = "0.0.0.0"
105
        port = config.get(SettingsParameters.ArrowFlightPort)
106
        location = "{}://{}:{}".format(scheme, host, port)
107
108
        auth_middleware = None
109
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
110
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
111
            auth_middleware = {
112
                "basic": BasicAuthServerMiddlewareFactory(creds)
113
            }
114
115
        server = pa.FlightServer(host, location,
116
                            tls_certificates=tls_certificates,
117
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
118
                            middleware=auth_middleware)
119
        return server
120
121 1
    def run(self):
122
        application = self._create_tornado_web_app()
123
        
124
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
125
126
        protocol = self.settings[SettingsParameters.TransferProtocol]
127
        ssl_options = None
128
        if protocol == "https":
129
            ssl_options = {
130
                "certfile": self.settings[SettingsParameters.CertificateFile],
131
                "keyfile": self.settings[SettingsParameters.KeyFile],
132
            }
133
        elif protocol != "http":
134
            msg = f"Unsupported transfer protocol {protocol}."
135
            logger.critical(msg)
136
            raise RuntimeError(msg)
137
138
        settings = {}
139
        if self.settings[SettingsParameters.GzipEnabled] is True:
140
            settings["decompress_request"] = True
141
142
        application.listen(
143
            self.settings[SettingsParameters.Port],
144
            ssl_options=ssl_options,
145
            max_buffer_size=self.max_request_size,
146
            max_body_size=self.max_request_size,
147
            **settings,
148
        ) 
149
150
        logger.info(
151
            "Web service listening on port "
152
            f"{str(self.settings[SettingsParameters.Port])}"
153
        )
154
155
        if self.settings[SettingsParameters.ArrowEnabled]:
156
            def start_pyarrow():
157
                self.arrow_server = self._get_arrow_server(self.settings)
158
                pa.start(self.arrow_server)
159
160
            try:
161
                _thread.start_new_thread(start_pyarrow, ())
162
            except Exception as e:
163
                logger.critical(f"Failed to start PyArrow server: {e}")
164
165
        tornado.ioloop.IOLoop.instance().start()
166
167 1
    def _create_tornado_web_app(self):
168 1
        class TabPyTornadoApp(tornado.web.Application):
169 1
            is_closing = False
170
171 1
            def signal_handler(self, signal, _):
172
                logger.critical(f"Exiting on signal {signal}...")
173
                self.is_closing = True
174
175 1
            def try_exit(self):
176
                if self.is_closing:
177
                    tornado.ioloop.IOLoop.instance().stop()
178
                    logger.info("Shutting down TabPy...")
179
180 1
        logger.info("Initializing TabPy...")
181 1
        tornado.ioloop.IOLoop.instance().run_sync(
182
            lambda: init_ps_server(self.settings, self.tabpy_state)
183
        )
184 1
        logger.info("Done initializing TabPy.")
185
186 1
        executor = concurrent.futures.ThreadPoolExecutor(
187
            max_workers=multiprocessing.cpu_count()
188
        )
189
190
        # initialize Tornado application
191 1
        _init_asyncio_patch()
192 1
        application = TabPyTornadoApp(
193
            [
194
                (
195
                    self.subdirectory + r"/query/([^/]+)",
196
                    QueryPlaneHandler,
197
                    dict(app=self),
198
                ),
199
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
200
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
201
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
202
                (
203
                    self.subdirectory + r"/endpoints/([^/]+)?",
204
                    EndpointHandler,
205
                    dict(app=self),
206
                ),
207
                (
208
                    self.subdirectory + r"/evaluate",
209
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
210
                    else EvaluationPlaneDisabledHandler,
211
                    dict(executor=executor, app=self),
212
                ),
213
                (
214
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
215
                    UploadDestinationHandler,
216
                    dict(app=self),
217
                ),
218
                (
219
                    self.subdirectory + r"/(.*)",
220
                    tornado.web.StaticFileHandler,
221
                    dict(
222
                        path=self.settings[SettingsParameters.StaticPath],
223
                        default_filename="index.html",
224
                    ),
225
                ),
226
            ],
227
            debug=False,
228
            **self.settings,
229
        )
230
231 1
        signal.signal(signal.SIGINT, application.signal_handler)
232 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
233
234 1
        signal.signal(signal.SIGINT, application.signal_handler)
235 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
236
237 1
        return application
238
239 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
240 1
        key_is_set = False
241
242 1
        if (
243
            config_key is not None
244
            and parser.has_section("TabPy")
245
            and parser.has_option("TabPy", config_key)
246
        ):
247 1
            if parse_function is None:
248 1
                parse_function = parser.get
249 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
250 1
            key_is_set = True
251 1
            logger.debug(
252
                f"Parameter {settings_key} set to "
253
                f'"{self.settings[settings_key]}" '
254
                "from config file or environment variable"
255
            )
256
257 1
        if not key_is_set and default_val is not None:
258 1
            self.settings[settings_key] = default_val
259 1
            key_is_set = True
260 1
            logger.debug(
261
                f"Parameter {settings_key} set to "
262
                f'"{self.settings[settings_key]}" '
263
                "from default value"
264
            )
265
266 1
        if not key_is_set:
267 1
            logger.debug(f"Parameter {settings_key} is not set")
268
269 1
    def _parse_config(self, config_file):
270
        """Provide consistent mechanism for pulling in configuration.
271
272
        Attempt to retain backward compatibility for
273
        existing implementations by grabbing port
274
        setting from CLI first.
275
276
        Take settings in the following order:
277
278
        1. CLI arguments if present
279
        2. config file
280
        3. OS environment variables (for ease of
281
           setting defaults if not present)
282
        4. current defaults if a setting is not present in any location
283
284
        Additionally provide similar configuration capabilities in between
285
        config file and environment variables.
286
        For consistency use the same variable name in the config file as
287
        in the os environment.
288
        For naming standards use all capitals and start with 'TABPY_'
289
        """
290 1
        self.settings = {}
291 1
        self.subdirectory = ""
292 1
        self.tabpy_state = None
293 1
        self.python_service = None
294 1
        self.credentials = {}
295
296 1
        pkg_path = os.path.dirname(tabpy.__file__)
297
298 1
        parser = configparser.ConfigParser(os.environ)
299 1
        logger.info(f"Parsing config file {config_file}")
300
301 1
        file_exists = False
302 1
        if os.path.isfile(config_file):
303 1
            try:
304 1
                with open(config_file, 'r') as f:
305 1
                    parser.read_string(f.read())
306 1
                    file_exists = True
307
            except Exception:
308
                pass
309
310 1
        if not file_exists:
311 1
            logger.warning(
312
                f"Unable to open config file {config_file}, "
313
                "using default settings."
314
            )
315
316 1
        settings_parameters = [
317
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
318
            (SettingsParameters.ServerVersion, None, __version__, None),
319
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
320
             True, parser.getboolean),
321
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
322
             30, parser.getfloat),
323
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
324
             os.path.join(pkg_path, "tmp", "query_objects"), None),
325
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
326
             "http", None),
327
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
328
             None, None),
329
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
330
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
331
             os.path.join(pkg_path, "tabpy_server"), None),
332
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
333
             os.path.join(pkg_path, "tabpy_server", "static"), None),
334
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
335
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
336
             "false", None),
337
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
338
             100, None),
339
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
340
             True, parser.getboolean),
341
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
342
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
343
        ]
344
345 1
        for setting, parameter, default_val, parse_function in settings_parameters:
346 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
347
348 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
349 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
350
351
        # set and validate transfer protocol
352 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
353
            SettingsParameters.TransferProtocol
354
        ].lower()
355
356 1
        self._validate_transfer_protocol_settings()
357
        
358
        # Set max request size in bytes
359 1
        self.max_request_size = (
360
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
361
        )
362 1
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
363
364
        # if state.ini does not exist try and create it - remove
365
        # last dependence on batch/shell script
366 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
367
            os.path.normpath(
368
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
369
            )
370
        )
371 1
        state_config, self.tabpy_state = self._build_tabpy_state()
372
373 1
        self.python_service = PythonServiceHandler(PythonService())
374 1
        self.settings["compress_response"] = True
375 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
376
            self.settings[SettingsParameters.StaticPath]
377
        )
378 1
        logger.debug(
379
            f"Static pages folder set to "
380
            f'"{self.settings[SettingsParameters.StaticPath]}"'
381
        )
382
383
        # Set subdirectory from config if applicable
384 1
        if state_config.has_option("Service Info", "Subdirectory"):
385 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
386
387
        # If passwords file specified load credentials
388 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
389 1
            if not self._parse_pwd_file():
390 1
                msg = (
391
                    "Failed to read passwords file "
392
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
393
                )
394 1
                logger.critical(msg)
395 1
                raise RuntimeError(msg)
396
        else:
397 1
            logger.info(
398
                "Password file is not specified: " "Authentication is not enabled"
399
            )
400
401 1
        features = self._get_features()
402 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
403
404 1
        self.settings[SettingsParameters.LogRequestContext] = (
405
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
406
        )
407 1
        call_context_state = (
408
            "enabled"
409
            if self.settings[SettingsParameters.LogRequestContext]
410
            else "disabled"
411
        )
412 1
        logger.info(f"Call context logging is {call_context_state}")
413
414 1
    def _validate_transfer_protocol_settings(self):
415 1
        if SettingsParameters.TransferProtocol not in self.settings:
416
            msg = "Missing transfer protocol information."
417
            logger.critical(msg)
418
            raise RuntimeError(msg)
419
420 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
421
422 1
        if protocol == "http":
423 1
            return
424
425 1
        if protocol != "https":
426 1
            msg = f"Unsupported transfer protocol: {protocol}"
427 1
            logger.critical(msg)
428 1
            raise RuntimeError(msg)
429
430 1
        self._validate_cert_key_state(
431
            "The parameter(s) {} must be set.",
432
            SettingsParameters.CertificateFile in self.settings,
433
            SettingsParameters.KeyFile in self.settings,
434
        )
435 1
        cert = self.settings[SettingsParameters.CertificateFile]
436
437 1
        self._validate_cert_key_state(
438
            "The parameter(s) {} must point to " "an existing file.",
439
            os.path.isfile(cert),
440
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
441
        )
442 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
443
444 1
    @staticmethod
445
    def _validate_cert_key_state(msg, cert_valid, key_valid):
446 1
        cert_and_key_param = (
447
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
448
            f"{ConfigParameters.TABPY_KEY_FILE}"
449
        )
450 1
        https_error = "Error using HTTPS: "
451 1
        err = None
452 1
        if not cert_valid and not key_valid:
453 1
            err = https_error + msg.format(cert_and_key_param)
454 1
        elif not cert_valid:
455 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
456 1
        elif not key_valid:
457 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
458
459 1
        if err is not None:
460 1
            logger.critical(err)
461 1
            raise RuntimeError(err)
462
463 1
    def _parse_pwd_file(self):
464 1
        succeeded, self.credentials = parse_pwd_file(
465
            self.settings[ConfigParameters.TABPY_PWD_FILE]
466
        )
467
468 1
        if succeeded and len(self.credentials) == 0:
469 1
            logger.error("No credentials found")
470 1
            succeeded = False
471
472 1
        return succeeded
473
474 1
    def _get_features(self):
475 1
        features = {}
476
477
        # Check for auth
478 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
479 1
            features["authentication"] = {
480
                "required": True,
481
                "methods": {"basic-auth": {}},
482
            }
483
484 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
485 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
486 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
487 1
        return features
488
489 1
    def _build_tabpy_state(self):
490 1
        pkg_path = os.path.dirname(tabpy.__file__)
491 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
492 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
493 1
        if not os.path.isfile(state_file_path):
494
            state_file_template_path = os.path.join(
495
                pkg_path, "tabpy_server", "state.ini.template"
496
            )
497
            logger.debug(
498
                f"File {state_file_path} not found, creating from "
499
                f"template {state_file_template_path}..."
500
            )
501
            shutil.copy(state_file_template_path, state_file_path)
502
503 1
        logger.info(f"Loading state from state file {state_file_path}")
504 1
        tabpy_state = _get_state_from_file(state_file_dir)
505 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
506
507
508
# Override _read_body to allow content with size exceeding max_body_size
509
# This enables proper handling of 413 errors in base_handler
510 1
def _read_body_allow_max_size(self, code, headers, delegate):
511 1
    if "Content-Length" in headers:
512 1
        content_length = int(headers["Content-Length"])
513 1
        if content_length > self._max_body_size:
514
            return
515 1
    return self.original_read_body(code, headers, delegate)
516
517 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
518
HTTP1Connection._read_body = _read_body_allow_max_size
519