Test Failed
Pull Request — master (#595)
by
unknown
13:42
created

TabPyApp._validate_transfer_protocol_settings()   A

Complexity

Conditions 4

Size

Total Lines 29
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 16
CRAP Score 4.0032

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 29
ccs 16
cts 17
cp 0.9412
rs 9.352
c 0
b 0
f 0
cc 4
nop 1
crap 4.0032
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import tabpy
10 1
from tabpy.tabpy import __version__
11 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
12 1
from tabpy.tabpy_server.app.util import parse_pwd_file
13 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
14 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
15 1
from tabpy.tabpy_server.management.state import TabPyState
16 1
from tabpy.tabpy_server.management.util import _get_state_from_file
17 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
18
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
19
from tabpy.tabpy_server.handlers import (
20
    EndpointHandler,
21
    EndpointsHandler,
22
    EvaluationPlaneHandler,
23
    EvaluationPlaneDisabledHandler,
24
    QueryPlaneHandler,
25
    ServiceInfoHandler,
26
    StatusHandler,
27 1
    UploadDestinationHandler,
28
)
29
import tornado
30 1
import tabpy.tabpy_server.app.arrow_server as pa
31
import _thread
32
33 1
logger = logging.getLogger(__name__)
34
35
def _init_asyncio_patch():
36
    """
37
    Select compatible event loop for Tornado 5+.
38
    As of Python 3.8, the default event loop on Windows is `proactor`,
39
    however Tornado requires the old default "selector" event loop.
40
    As Tornado has decided to leave this to users to set, MkDocs needs
41 1
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
42
    """
43
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
44
        import asyncio
45
        try:
46
            from asyncio import WindowsSelectorEventLoopPolicy
47
        except ImportError:
48
            pass  # Can't assign a policy which doesn't exist.
49
        else:
50
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
51
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
52 1
53
54
class TabPyApp:
55
    """
56
    TabPy application class for keeping context like settings, state, etc.
57 1
    """
58 1
59 1
    settings = {}
60 1
    subdirectory = ""
61 1
    tabpy_state = None
62
    python_service = None
63 1
    credentials = {}
64 1
    arrow_server = None
65 1
66
    def __init__(self, config_file):
67
        if config_file is None:
68
            config_file = os.path.join(
69 1
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
70 1
            )
71 1
72 1
        if os.path.isfile(config_file):
73 1
            try:
74 1
                from logging import config
75
                config.fileConfig(config_file, disable_existing_loggers=False)
76 1
            except KeyError:
77
                logging.basicConfig(level=logging.DEBUG)
78 1
79
        self._parse_config(config_file)
80
81
    def _get_tls_certificates(self, config):
82
        tls_certificates = []
83
        cert = config[SettingsParameters.CertificateFile]
84
        key = config[SettingsParameters.KeyFile]
85
        with open(cert, "rb") as cert_file:
86
            tls_cert_chain = cert_file.read()
87
        with open(key, "rb") as key_file:
88
            tls_private_key = key_file.read()
89
        tls_certificates.append((tls_cert_chain, tls_private_key))
90
        return tls_certificates
91
    
92
    def _get_arrow_server(self, config):
93
        verify_client = None
94
        tls_certificates = None
95
        scheme = "grpc+tcp"
96
        if config[SettingsParameters.TransferProtocol] == "https":
97
            scheme = "grpc+tls"
98
            tls_certificates = self._get_tls_certificates(config)
99
100
        host = "localhost"
101
        port = config.get(SettingsParameters.ArrowFlightPort)
102
        location = "{}://{}:{}".format(scheme, host, port)
103
104
        auth_middleware = None
105
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
106
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
107
            auth_middleware = {
108
                "basic": BasicAuthServerMiddlewareFactory(creds)
109
            }
110
111
        server = pa.FlightServer(host, location,
112
                            tls_certificates=tls_certificates,
113
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
114
                            middleware=auth_middleware)
115
        return server
116 1
117 1
    def run(self):
118 1
        application = self._create_tornado_web_app()
119
        max_request_size = (
120 1
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
121
        )
122
        logger.info(f"Setting max request size to {max_request_size} bytes")
123
124 1
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
125
126
        protocol = self.settings[SettingsParameters.TransferProtocol]
127
        ssl_options = None
128
        if protocol == "https":
129 1
            ssl_options = {
130 1
                "certfile": self.settings[SettingsParameters.CertificateFile],
131
                "keyfile": self.settings[SettingsParameters.KeyFile],
132
            }
133 1
        elif protocol != "http":
134
            msg = f"Unsupported transfer protocol {protocol}."
135 1
            logger.critical(msg)
136
            raise RuntimeError(msg)
137
138
        settings = {}
139
        if self.settings[SettingsParameters.GzipEnabled] is True:
140 1
            settings["decompress_request"] = True
141 1
142
        application.listen(
143
            self.settings[SettingsParameters.Port],
144
            ssl_options=ssl_options,
145
            max_buffer_size=max_request_size,
146
            max_body_size=max_request_size,
147
            **settings,
148
        ) 
149
150
        logger.info(
151
            "Web service listening on port "
152
            f"{str(self.settings[SettingsParameters.Port])}"
153
        )
154
155
        if self.settings[SettingsParameters.ArrowEnabled]:
156
            def start_pyarrow():
157
                self.arrow_server = self._get_arrow_server(self.settings)
158
                pa.start(self.arrow_server)
159
160
            try:
161
                _thread.start_new_thread(start_pyarrow, ())
162
            except Exception as e:
163
                logger.critical(f"Failed to start PyArrow server: {e}")
164
165
        tornado.ioloop.IOLoop.instance().start()
166
167
    def _create_tornado_web_app(self):
168
        class TabPyTornadoApp(tornado.web.Application):
169
            is_closing = False
170
171
            def signal_handler(self, signal, _):
172
                logger.critical(f"Exiting on signal {signal}...")
173
                self.is_closing = True
174
175
            def try_exit(self):
176
                if self.is_closing:
177
                    tornado.ioloop.IOLoop.instance().stop()
178
                    logger.info("Shutting down TabPy...")
179
180 1
        logger.info("Initializing TabPy...")
181 1
        tornado.ioloop.IOLoop.instance().run_sync(
182
            lambda: init_ps_server(self.settings, self.tabpy_state)
183 1
        )
184 1
        logger.info("Done initializing TabPy.")
185
186 1
        executor = concurrent.futures.ThreadPoolExecutor(
187
            max_workers=multiprocessing.cpu_count()
188 1
        )
189 1
190
        # initialize Tornado application
191 1
        _init_asyncio_patch()
192
        application = TabPyTornadoApp(
193
            [
194
                (
195
                    self.subdirectory + r"/query/([^/]+)",
196 1
                    QueryPlaneHandler,
197 1
                    dict(app=self),
198 1
                ),
199 1
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
200 1
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
201
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
202
                (
203
                    self.subdirectory + r"/endpoints/([^/]+)?",
204
                    EndpointHandler,
205
                    dict(app=self),
206 1
                ),
207 1
                (
208 1
                    self.subdirectory + r"/evaluate",
209 1
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
210
                    else EvaluationPlaneDisabledHandler,
211
                    dict(executor=executor, app=self),
212
                ),
213
                (
214
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
215 1
                    UploadDestinationHandler,
216 1
                    dict(app=self),
217
                ),
218 1
                (
219
                    self.subdirectory + r"/(.*)",
220
                    tornado.web.StaticFileHandler,
221
                    dict(
222
                        path=self.settings[SettingsParameters.StaticPath],
223
                        default_filename="index.html",
224
                    ),
225
                ),
226
            ],
227
            debug=False,
228
            **self.settings,
229
        )
230
231
        signal.signal(signal.SIGINT, application.signal_handler)
232
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
233
234
        signal.signal(signal.SIGINT, application.signal_handler)
235
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
236
237
        return application
238
239 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
240 1
        key_is_set = False
241 1
242 1
        if (
243 1
            config_key is not None
244
            and parser.has_section("TabPy")
245 1
            and parser.has_option("TabPy", config_key)
246
        ):
247 1
            if parse_function is None:
248 1
                parse_function = parser.get
249
            self.settings[settings_key] = parse_function("TabPy", config_key)
250 1
            key_is_set = True
251 1
            logger.debug(
252 1
                f"Parameter {settings_key} set to "
253 1
                f'"{self.settings[settings_key]}" '
254 1
                "from config file or environment variable"
255 1
            )
256
257
        if not key_is_set and default_val is not None:
258
            self.settings[settings_key] = default_val
259 1
            key_is_set = True
260 1
            logger.debug(
261
                f"Parameter {settings_key} set to "
262
                f'"{self.settings[settings_key]}" '
263
                "from default value"
264
            )
265 1
266
        if not key_is_set:
267
            logger.debug(f"Parameter {settings_key} is not set")
268
269
    def _parse_config(self, config_file):
270
        """Provide consistent mechanism for pulling in configuration.
271
272
        Attempt to retain backward compatibility for
273
        existing implementations by grabbing port
274
        setting from CLI first.
275
276
        Take settings in the following order:
277
278
        1. CLI arguments if present
279
        2. config file
280
        3. OS environment variables (for ease of
281
           setting defaults if not present)
282
        4. current defaults if a setting is not present in any location
283
284
        Additionally provide similar configuration capabilities in between
285
        config file and environment variables.
286
        For consistency use the same variable name in the config file as
287
        in the os environment.
288
        For naming standards use all capitals and start with 'TABPY_'
289
        """
290
        self.settings = {}
291
        self.subdirectory = ""
292 1
        self.tabpy_state = None
293 1
        self.python_service = None
294
        self.credentials = {}
295 1
296 1
        pkg_path = os.path.dirname(tabpy.__file__)
297
298
        parser = configparser.ConfigParser(os.environ)
299 1
        logger.info(f"Parsing config file {config_file}")
300
301
        file_exists = False
302
        if os.path.isfile(config_file):
303 1
            try:
304
                with open(config_file, 'r') as f:
305
                    parser.read_string(f.read())
306
                    file_exists = True
307 1
            except Exception:
308
                pass
309
310
        if not file_exists:
311
            logger.warning(
312 1
                f"Unable to open config file {config_file}, "
313
                "using default settings."
314 1
            )
315 1
316 1
        settings_parameters = [
317
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
318
            (SettingsParameters.ServerVersion, None, __version__, None),
319 1
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
320
             True, parser.getboolean),
321
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
322
             30, parser.getfloat),
323
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
324
             os.path.join(pkg_path, "tmp", "query_objects"), None),
325 1
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
326 1
             "http", None),
327
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
328
             None, None),
329 1
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
330 1
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
331 1
             os.path.join(pkg_path, "tabpy_server"), None),
332
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
333
             os.path.join(pkg_path, "tabpy_server", "static"), None),
334
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
335 1
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
336 1
             "false", None),
337
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
338 1
             100, None),
339
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
340
             True, parser.getboolean),
341
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
342 1
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
343 1
        ]
344
345 1
        for setting, parameter, default_val, parse_function in settings_parameters:
346
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
347
348 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
349
            os.makedirs(self.settings[SettingsParameters.UploadDir])
350
351
        # set and validate transfer protocol
352
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
353 1
            SettingsParameters.TransferProtocol
354
        ].lower()
355 1
356 1
        self._validate_transfer_protocol_settings()
357
358
        # if state.ini does not exist try and create it - remove
359
        # last dependence on batch/shell script
360
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
361 1
            os.path.normpath(
362
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
363 1
            )
364 1
        )
365
        state_config, self.tabpy_state = self._build_tabpy_state()
366 1
367 1
        self.python_service = PythonServiceHandler(PythonService())
368 1
        self.settings["compress_response"] = True
369 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
370
            self.settings[SettingsParameters.StaticPath]
371 1
        )
372
        logger.debug(
373
            f"Static pages folder set to "
374
            f'"{self.settings[SettingsParameters.StaticPath]}"'
375
        )
376 1
377
        # Set subdirectory from config if applicable
378 1
        if state_config.has_option("Service Info", "Subdirectory"):
379
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
380
381
        # If passwords file specified load credentials
382
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
383 1
            if not self._parse_pwd_file():
384
                msg = (
385 1
                    "Failed to read passwords file "
386
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
387 1
                )
388
                logger.critical(msg)
389
                raise RuntimeError(msg)
390
        else:
391 1
            logger.info(
392 1
                "Password file is not specified: " "Authentication is not enabled"
393 1
            )
394 1
395 1
        features = self._get_features()
396 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
397 1
398 1
        self.settings[SettingsParameters.LogRequestContext] = (
399
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
400 1
        )
401 1
        call_context_state = (
402 1
            "enabled"
403
            if self.settings[SettingsParameters.LogRequestContext]
404 1
            else "disabled"
405 1
        )
406
        logger.info(f"Call context logging is {call_context_state}")
407
408
    def _validate_transfer_protocol_settings(self):
409 1
        if SettingsParameters.TransferProtocol not in self.settings:
410 1
            msg = "Missing transfer protocol information."
411 1
            logger.critical(msg)
412
            raise RuntimeError(msg)
413 1
414
        protocol = self.settings[SettingsParameters.TransferProtocol]
415 1
416 1
        if protocol == "http":
417
            return
418
419 1
        if protocol != "https":
420 1
            msg = f"Unsupported transfer protocol: {protocol}"
421
            logger.critical(msg)
422
            raise RuntimeError(msg)
423
424
        self._validate_cert_key_state(
425 1
            "The parameter(s) {} must be set.",
426 1
            SettingsParameters.CertificateFile in self.settings,
427 1
            SettingsParameters.KeyFile in self.settings,
428
        )
429 1
        cert = self.settings[SettingsParameters.CertificateFile]
430 1
431 1
        self._validate_cert_key_state(
432 1
            "The parameter(s) {} must point to " "an existing file.",
433 1
            os.path.isfile(cert),
434
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
435
        )
436
        tabpy.tabpy_server.app.util.validate_cert(cert)
437
438
    @staticmethod
439
    def _validate_cert_key_state(msg, cert_valid, key_valid):
440
        cert_and_key_param = (
441
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
442
            f"{ConfigParameters.TABPY_KEY_FILE}"
443 1
        )
444 1
        https_error = "Error using HTTPS: "
445 1
        err = None
446
        if not cert_valid and not key_valid:
447
            err = https_error + msg.format(cert_and_key_param)
448
        elif not cert_valid:
449
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
450
        elif not key_valid:
451
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
452
453
        if err is not None:
454
            logger.critical(err)
455
            raise RuntimeError(err)
456
457
    def _parse_pwd_file(self):
458
        succeeded, self.credentials = parse_pwd_file(
459
            self.settings[ConfigParameters.TABPY_PWD_FILE]
460
        )
461
462
        if succeeded and len(self.credentials) == 0:
463
            logger.error("No credentials found")
464
            succeeded = False
465
466
        return succeeded
467
468
    def _get_features(self):
469
        features = {}
470
471
        # Check for auth
472
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
473
            features["authentication"] = {
474
                "required": True,
475
                "methods": {"basic-auth": {}},
476
            }
477
478
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
479
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
480
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
481
        return features
482
483
    def _build_tabpy_state(self):
484
        pkg_path = os.path.dirname(tabpy.__file__)
485
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
486
        state_file_path = os.path.join(state_file_dir, "state.ini")
487
        if not os.path.isfile(state_file_path):
488
            state_file_template_path = os.path.join(
489
                pkg_path, "tabpy_server", "state.ini.template"
490
            )
491
            logger.debug(
492
                f"File {state_file_path} not found, creating from "
493
                f"template {state_file_template_path}..."
494
            )
495
            shutil.copy(state_file_template_path, state_file_path)
496
497
        logger.info(f"Loading state from state file {state_file_path}")
498
        tabpy_state = _get_state_from_file(state_file_dir)
499
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
500