Passed
Pull Request — master (#608)
by
unknown
13:03
created

TabPyApp._validate_transfer_protocol_settings()   A

Complexity

Conditions 4

Size

Total Lines 29
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 13
CRAP Score 4.1054

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 29
ccs 13
cts 16
cp 0.8125
rs 9.352
c 0
b 0
f 0
cc 4
nop 1
crap 4.1054
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import tabpy
10 1
from tabpy.tabpy import __version__
11 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
12 1
from tabpy.tabpy_server.app.util import parse_pwd_file
13 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
14 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
15 1
from tabpy.tabpy_server.management.state import TabPyState
16 1
from tabpy.tabpy_server.management.util import _get_state_from_file
17 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
18 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
19 1
from tabpy.tabpy_server.handlers import (
20
    EndpointHandler,
21
    EndpointsHandler,
22
    EvaluationPlaneHandler,
23
    EvaluationPlaneDisabledHandler,
24
    QueryPlaneHandler,
25
    ServiceInfoHandler,
26
    StatusHandler,
27
    UploadDestinationHandler,
28
)
29 1
import tornado
30 1
from tornado.http1connection import HTTP1Connection
31 1
import tabpy.tabpy_server.app.arrow_server as pa
32 1
import _thread
33
34 1
logger = logging.getLogger(__name__)
35
36 1
def _init_asyncio_patch():
37
    """
38
    Select compatible event loop for Tornado 5+.
39
    As of Python 3.8, the default event loop on Windows is `proactor`,
40
    however Tornado requires the old default "selector" event loop.
41
    As Tornado has decided to leave this to users to set, MkDocs needs
42
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
43
    """
44 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
45
        import asyncio
46
        try:
47
            from asyncio import WindowsSelectorEventLoopPolicy
48
        except ImportError:
49
            pass  # Can't assign a policy which doesn't exist.
50
        else:
51
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
52
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
53
54
55 1
class TabPyApp:
56
    """
57
    TabPy application class for keeping context like settings, state, etc.
58
    """
59
60 1
    settings = {}
61 1
    subdirectory = ""
62 1
    tabpy_state = None
63 1
    python_service = None
64 1
    credentials = {}
65 1
    arrow_server = None
66 1
    max_request_size = None
67
68 1
    def __init__(self, config_file):
69 1
        if config_file is None:
70 1
            config_file = os.path.join(
71
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
72
            )
73
74 1
        if os.path.isfile(config_file):
75 1
            try:
76 1
                from logging import config
77 1
                config.fileConfig(config_file, disable_existing_loggers=False)
78 1
            except KeyError:
79 1
                logging.basicConfig(level=logging.DEBUG)
80
81 1
        self._parse_config(config_file)
82
83 1
    def _get_tls_certificates(self, config):
84
        tls_certificates = []
85
        cert = config[SettingsParameters.CertificateFile]
86
        key = config[SettingsParameters.KeyFile]
87
        with open(cert, "rb") as cert_file:
88
            tls_cert_chain = cert_file.read()
89
        with open(key, "rb") as key_file:
90
            tls_private_key = key_file.read()
91
        tls_certificates.append((tls_cert_chain, tls_private_key))
92
        return tls_certificates
93
    
94 1
    def _get_arrow_server(self, config):
95
        verify_client = None
96
        tls_certificates = None
97
        scheme = "grpc+tcp"
98
        if config[SettingsParameters.TransferProtocol] == "https":
99
            scheme = "grpc+tls"
100
            tls_certificates = self._get_tls_certificates(config)
101
102
        host = "0.0.0.0"
103
        port = config.get(SettingsParameters.ArrowFlightPort)
104
        location = "{}://{}:{}".format(scheme, host, port)
105
106
        auth_middleware = None
107
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
108
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
109
            auth_middleware = {
110
                "basic": BasicAuthServerMiddlewareFactory(creds)
111
            }
112
113
        server = pa.FlightServer(host, location,
114
                            tls_certificates=tls_certificates,
115
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
116
                            middleware=auth_middleware)
117
        return server
118
119 1
    def run(self):
120
        application = self._create_tornado_web_app()
121
        self.max_request_size = (
122
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
123
        )
124
        print(f"Setting max request size to {self.max_request_size} bytes")
125
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
126
127
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
128
129
        protocol = self.settings[SettingsParameters.TransferProtocol]
130
        ssl_options = None
131
        if protocol == "https":
132
            ssl_options = {
133
                "certfile": self.settings[SettingsParameters.CertificateFile],
134
                "keyfile": self.settings[SettingsParameters.KeyFile],
135
            }
136
        elif protocol != "http":
137
            msg = f"Unsupported transfer protocol {protocol}."
138
            logger.critical(msg)
139
            raise RuntimeError(msg)
140
141
        settings = {}
142
        if self.settings[SettingsParameters.GzipEnabled] is True:
143
            settings["decompress_request"] = True
144
145
        application.listen(
146
            self.settings[SettingsParameters.Port],
147
            ssl_options=ssl_options,
148
            max_buffer_size=self.max_request_size,
149
            max_body_size=self.max_request_size,
150
            **settings,
151
        ) 
152
153
        logger.info(
154
            "Web service listening on port "
155
            f"{str(self.settings[SettingsParameters.Port])}"
156
        )
157
158
        if self.settings[SettingsParameters.ArrowEnabled]:
159
            def start_pyarrow():
160
                self.arrow_server = self._get_arrow_server(self.settings)
161
                pa.start(self.arrow_server)
162
163
            try:
164
                _thread.start_new_thread(start_pyarrow, ())
165
            except Exception as e:
166
                logger.critical(f"Failed to start PyArrow server: {e}")
167
168
        tornado.ioloop.IOLoop.instance().start()
169
170 1
    def _create_tornado_web_app(self):
171 1
        class TabPyTornadoApp(tornado.web.Application):
172 1
            is_closing = False
173
174 1
            def signal_handler(self, signal, _):
175
                logger.critical(f"Exiting on signal {signal}...")
176
                self.is_closing = True
177
178 1
            def try_exit(self):
179
                if self.is_closing:
180
                    tornado.ioloop.IOLoop.instance().stop()
181
                    logger.info("Shutting down TabPy...")
182
183 1
        logger.info("Initializing TabPy...")
184 1
        tornado.ioloop.IOLoop.instance().run_sync(
185
            lambda: init_ps_server(self.settings, self.tabpy_state)
186
        )
187 1
        logger.info("Done initializing TabPy.")
188
189 1
        executor = concurrent.futures.ThreadPoolExecutor(
190
            max_workers=multiprocessing.cpu_count()
191
        )
192
193
        # initialize Tornado application
194 1
        _init_asyncio_patch()
195 1
        application = TabPyTornadoApp(
196
            [
197
                (
198
                    self.subdirectory + r"/query/([^/]+)",
199
                    QueryPlaneHandler,
200
                    dict(app=self),
201
                ),
202
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
203
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
204
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
205
                (
206
                    self.subdirectory + r"/endpoints/([^/]+)?",
207
                    EndpointHandler,
208
                    dict(app=self),
209
                ),
210
                (
211
                    self.subdirectory + r"/evaluate",
212
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
213
                    else EvaluationPlaneDisabledHandler,
214
                    dict(executor=executor, app=self),
215
                ),
216
                (
217
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
218
                    UploadDestinationHandler,
219
                    dict(app=self),
220
                ),
221
                (
222
                    self.subdirectory + r"/(.*)",
223
                    tornado.web.StaticFileHandler,
224
                    dict(
225
                        path=self.settings[SettingsParameters.StaticPath],
226
                        default_filename="index.html",
227
                    ),
228
                ),
229
            ],
230
            debug=False,
231
            **self.settings,
232
        )
233
234 1
        signal.signal(signal.SIGINT, application.signal_handler)
235 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
236
237 1
        signal.signal(signal.SIGINT, application.signal_handler)
238 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
239
240 1
        return application
241
242 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
243 1
        key_is_set = False
244
245 1
        if (
246
            config_key is not None
247
            and parser.has_section("TabPy")
248
            and parser.has_option("TabPy", config_key)
249
        ):
250 1
            if parse_function is None:
251 1
                parse_function = parser.get
252 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
253 1
            key_is_set = True
254 1
            logger.debug(
255
                f"Parameter {settings_key} set to "
256
                f'"{self.settings[settings_key]}" '
257
                "from config file or environment variable"
258
            )
259
260 1
        if not key_is_set and default_val is not None:
261 1
            self.settings[settings_key] = default_val
262 1
            key_is_set = True
263 1
            logger.debug(
264
                f"Parameter {settings_key} set to "
265
                f'"{self.settings[settings_key]}" '
266
                "from default value"
267
            )
268
269 1
        if not key_is_set:
270 1
            logger.debug(f"Parameter {settings_key} is not set")
271
272 1
    def _parse_config(self, config_file):
273
        """Provide consistent mechanism for pulling in configuration.
274
275
        Attempt to retain backward compatibility for
276
        existing implementations by grabbing port
277
        setting from CLI first.
278
279
        Take settings in the following order:
280
281
        1. CLI arguments if present
282
        2. config file
283
        3. OS environment variables (for ease of
284
           setting defaults if not present)
285
        4. current defaults if a setting is not present in any location
286
287
        Additionally provide similar configuration capabilities in between
288
        config file and environment variables.
289
        For consistency use the same variable name in the config file as
290
        in the os environment.
291
        For naming standards use all capitals and start with 'TABPY_'
292
        """
293 1
        self.settings = {}
294 1
        self.subdirectory = ""
295 1
        self.tabpy_state = None
296 1
        self.python_service = None
297 1
        self.credentials = {}
298
299 1
        pkg_path = os.path.dirname(tabpy.__file__)
300
301 1
        parser = configparser.ConfigParser(os.environ)
302 1
        logger.info(f"Parsing config file {config_file}")
303
304 1
        file_exists = False
305 1
        if os.path.isfile(config_file):
306 1
            try:
307 1
                with open(config_file, 'r') as f:
308 1
                    parser.read_string(f.read())
309 1
                    file_exists = True
310
            except Exception:
311
                pass
312
313 1
        if not file_exists:
314 1
            logger.warning(
315
                f"Unable to open config file {config_file}, "
316
                "using default settings."
317
            )
318
319 1
        settings_parameters = [
320
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
321
            (SettingsParameters.ServerVersion, None, __version__, None),
322
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
323
             True, parser.getboolean),
324
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
325
             30, parser.getfloat),
326
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
327
             os.path.join(pkg_path, "tmp", "query_objects"), None),
328
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
329
             "http", None),
330
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
331
             None, None),
332
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
333
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
334
             os.path.join(pkg_path, "tabpy_server"), None),
335
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
336
             os.path.join(pkg_path, "tabpy_server", "static"), None),
337
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
338
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
339
             "false", None),
340
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
341
             100, None),
342
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
343
             True, parser.getboolean),
344
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
345
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
346
        ]
347
348 1
        for setting, parameter, default_val, parse_function in settings_parameters:
349 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
350
351 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
352 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
353
354
        # set and validate transfer protocol
355 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
356
            SettingsParameters.TransferProtocol
357
        ].lower()
358
359 1
        self._validate_transfer_protocol_settings()
360
361
        # if state.ini does not exist try and create it - remove
362
        # last dependence on batch/shell script
363 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
364
            os.path.normpath(
365
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
366
            )
367
        )
368 1
        state_config, self.tabpy_state = self._build_tabpy_state()
369
370 1
        self.python_service = PythonServiceHandler(PythonService())
371 1
        self.settings["compress_response"] = True
372 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
373
            self.settings[SettingsParameters.StaticPath]
374
        )
375 1
        logger.debug(
376
            f"Static pages folder set to "
377
            f'"{self.settings[SettingsParameters.StaticPath]}"'
378
        )
379
380
        # Set subdirectory from config if applicable
381 1
        if state_config.has_option("Service Info", "Subdirectory"):
382 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
383
384
        # If passwords file specified load credentials
385 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
386 1
            if not self._parse_pwd_file():
387 1
                msg = (
388
                    "Failed to read passwords file "
389
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
390
                )
391 1
                logger.critical(msg)
392 1
                raise RuntimeError(msg)
393
        else:
394 1
            logger.info(
395
                "Password file is not specified: " "Authentication is not enabled"
396
            )
397
398 1
        features = self._get_features()
399 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
400
401 1
        self.settings[SettingsParameters.LogRequestContext] = (
402
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
403
        )
404 1
        call_context_state = (
405
            "enabled"
406
            if self.settings[SettingsParameters.LogRequestContext]
407
            else "disabled"
408
        )
409 1
        logger.info(f"Call context logging is {call_context_state}")
410
411 1
    def _validate_transfer_protocol_settings(self):
412 1
        if SettingsParameters.TransferProtocol not in self.settings:
413
            msg = "Missing transfer protocol information."
414
            logger.critical(msg)
415
            raise RuntimeError(msg)
416
417 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
418
419 1
        if protocol == "http":
420 1
            return
421
422 1
        if protocol != "https":
423 1
            msg = f"Unsupported transfer protocol: {protocol}"
424 1
            logger.critical(msg)
425 1
            raise RuntimeError(msg)
426
427 1
        self._validate_cert_key_state(
428
            "The parameter(s) {} must be set.",
429
            SettingsParameters.CertificateFile in self.settings,
430
            SettingsParameters.KeyFile in self.settings,
431
        )
432 1
        cert = self.settings[SettingsParameters.CertificateFile]
433
434 1
        self._validate_cert_key_state(
435
            "The parameter(s) {} must point to " "an existing file.",
436
            os.path.isfile(cert),
437
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
438
        )
439 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
440
441 1
    @staticmethod
442
    def _validate_cert_key_state(msg, cert_valid, key_valid):
443 1
        cert_and_key_param = (
444
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
445
            f"{ConfigParameters.TABPY_KEY_FILE}"
446
        )
447 1
        https_error = "Error using HTTPS: "
448 1
        err = None
449 1
        if not cert_valid and not key_valid:
450 1
            err = https_error + msg.format(cert_and_key_param)
451 1
        elif not cert_valid:
452 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
453 1
        elif not key_valid:
454 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
455
456 1
        if err is not None:
457 1
            logger.critical(err)
458 1
            raise RuntimeError(err)
459
460 1
    def _parse_pwd_file(self):
461 1
        succeeded, self.credentials = parse_pwd_file(
462
            self.settings[ConfigParameters.TABPY_PWD_FILE]
463
        )
464
465 1
        if succeeded and len(self.credentials) == 0:
466 1
            logger.error("No credentials found")
467 1
            succeeded = False
468
469 1
        return succeeded
470
471 1
    def _get_features(self):
472 1
        features = {}
473
474
        # Check for auth
475 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
476 1
            features["authentication"] = {
477
                "required": True,
478
                "methods": {"basic-auth": {}},
479
            }
480
481 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
482 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
483 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
484 1
        return features
485
486 1
    def _build_tabpy_state(self):
487 1
        pkg_path = os.path.dirname(tabpy.__file__)
488 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
489 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
490 1
        if not os.path.isfile(state_file_path):
491
            state_file_template_path = os.path.join(
492
                pkg_path, "tabpy_server", "state.ini.template"
493
            )
494
            logger.debug(
495
                f"File {state_file_path} not found, creating from "
496
                f"template {state_file_template_path}..."
497
            )
498
            shutil.copy(state_file_template_path, state_file_path)
499
500 1
        logger.info(f"Loading state from state file {state_file_path}")
501 1
        tabpy_state = _get_state_from_file(state_file_dir)
502 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
503
504
505
# Override _read_body to allow content with size exceeding max_body_size
506
# This enables proper handling of 413 errors in base_handler
507 1
def _read_body_allow_max_size(self, code, headers, delegate):
508 1
    if "Content-Length" in headers:
509 1
        content_length = int(headers["Content-Length"])
510 1
        if content_length > self._max_body_size:
511
            return
512 1
    return self.original_read_body(code, headers, delegate)
513
514 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
515
HTTP1Connection._read_body = _read_body_allow_max_size
516