Passed
Pull Request — master (#595)
by
unknown
13:06
created

TabPyApp._get_features()   A

Complexity

Conditions 2

Size

Total Lines 13
Code Lines 9

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 7
CRAP Score 2

Importance

Changes 0
Metric Value
eloc 9
dl 0
loc 13
ccs 7
cts 7
cp 1
rs 9.95
c 0
b 0
f 0
cc 2
nop 1
crap 2
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import tabpy
10 1
from tabpy.tabpy import __version__
11 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
12 1
from tabpy.tabpy_server.app.util import parse_pwd_file
13 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
14 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
15 1
from tabpy.tabpy_server.management.state import TabPyState
16 1
from tabpy.tabpy_server.management.util import _get_state_from_file
17 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
18 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
19 1
from tabpy.tabpy_server.handlers import (
20
    EndpointHandler,
21
    EndpointsHandler,
22
    EvaluationPlaneHandler,
23
    EvaluationPlaneDisabledHandler,
24
    QueryPlaneHandler,
25
    ServiceInfoHandler,
26
    StatusHandler,
27
    UploadDestinationHandler,
28
)
29 1
import tornado
30 1
import tabpy.tabpy_server.app.arrow_server as pa
31 1
import _thread
32
33 1
logger = logging.getLogger(__name__)
34
35 1
def _init_asyncio_patch():
36
    """
37
    Select compatible event loop for Tornado 5+.
38
    As of Python 3.8, the default event loop on Windows is `proactor`,
39
    however Tornado requires the old default "selector" event loop.
40
    As Tornado has decided to leave this to users to set, MkDocs needs
41
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
42
    """
43 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
44
        import asyncio
45
        try:
46
            from asyncio import WindowsSelectorEventLoopPolicy
47
        except ImportError:
48
            pass  # Can't assign a policy which doesn't exist.
49
        else:
50
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
51
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
52
53
54 1
class TabPyApp:
55
    """
56
    TabPy application class for keeping context like settings, state, etc.
57
    """
58
59 1
    settings = {}
60 1
    subdirectory = ""
61 1
    tabpy_state = None
62 1
    python_service = None
63 1
    credentials = {}
64 1
    arrow_server = None
65
66 1
    def __init__(self, config_file):
67 1
        if config_file is None:
68 1
            config_file = os.path.join(
69
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
70
            )
71
72 1
        if os.path.isfile(config_file):
73 1
            try:
74 1
                from logging import config
75 1
                config.fileConfig(config_file, disable_existing_loggers=False)
76 1
            except KeyError:
77 1
                logging.basicConfig(level=logging.DEBUG)
78
79 1
        self._parse_config(config_file)
80
81 1
    def _get_tls_certificates(self, config):
82
        tls_certificates = []
83
        cert = config[SettingsParameters.CertificateFile]
84
        key = config[SettingsParameters.KeyFile]
85
        with open(cert, "rb") as cert_file:
86
            tls_cert_chain = cert_file.read()
87
        with open(key, "rb") as key_file:
88
            tls_private_key = key_file.read()
89
        tls_certificates.append((tls_cert_chain, tls_private_key))
90
        return tls_certificates
91
    
92 1
    def _get_arrow_server(self, config):
93
        verify_client = None
94
        tls_certificates = None
95
        scheme = "grpc+tcp"
96
        if config[SettingsParameters.TransferProtocol] == "https":
97
            scheme = "grpc+tls"
98
            tls_certificates = self._get_tls_certificates(config)
99
100
        host = "localhost"
101
        port = config.get(SettingsParameters.ArrowFlightPort)
102
        location = "{}://{}:{}".format(scheme, host, port)
103
104
        auth_middleware = None
105
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
106
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
107
            auth_middleware = {
108
                "basic": BasicAuthServerMiddlewareFactory(creds)
109
            }
110
111
        server = pa.FlightServer(host, location,
112
                            tls_certificates=tls_certificates,
113
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
114
                            middleware=auth_middleware)
115
        return server
116
117 1
    def run(self):
118
        application = self._create_tornado_web_app()
119
        max_request_size = (
120
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
121
        )
122
        logger.info(f"Setting max request size to {max_request_size} bytes")
123
124
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
125
126
        protocol = self.settings[SettingsParameters.TransferProtocol]
127
        ssl_options = None
128
        if protocol == "https":
129
            ssl_options = {
130
                "certfile": self.settings[SettingsParameters.CertificateFile],
131
                "keyfile": self.settings[SettingsParameters.KeyFile],
132
            }
133
        elif protocol != "http":
134
            msg = f"Unsupported transfer protocol {protocol}."
135
            logger.critical(msg)
136
            raise RuntimeError(msg)
137
138
        settings = {}
139
        if self.settings[SettingsParameters.GzipEnabled] is True:
140
            settings["decompress_request"] = True
141
142
        application.listen(
143
            self.settings[SettingsParameters.Port],
144
            ssl_options=ssl_options,
145
            max_buffer_size=max_request_size,
146
            max_body_size=max_request_size,
147
            **settings,
148
        ) 
149
150
        logger.info(
151
            "Web service listening on port "
152
            f"{str(self.settings[SettingsParameters.Port])}"
153
        )
154
155
        if self.settings[SettingsParameters.ArrowEnabled]:
156
            def start_pyarrow():
157
                self.arrow_server = self._get_arrow_server(self.settings)
158
                pa.start(self.arrow_server)
159
160
            try:
161
                _thread.start_new_thread(start_pyarrow, ())
162
            except Exception as e:
163
                print(e)
164
                print("Error: unable to start pyarrow server")
165
166
        tornado.ioloop.IOLoop.instance().start()
167
168 1
    def _create_tornado_web_app(self):
169 1
        class TabPyTornadoApp(tornado.web.Application):
170 1
            is_closing = False
171
172 1
            def signal_handler(self, signal, _):
173
                logger.critical(f"Exiting on signal {signal}...")
174
                self.is_closing = True
175
176 1
            def try_exit(self):
177
                if self.is_closing:
178
                    tornado.ioloop.IOLoop.instance().stop()
179
                    logger.info("Shutting down TabPy...")
180
181 1
        logger.info("Initializing TabPy...")
182 1
        tornado.ioloop.IOLoop.instance().run_sync(
183
            lambda: init_ps_server(self.settings, self.tabpy_state)
184
        )
185 1
        logger.info("Done initializing TabPy.")
186
187 1
        executor = concurrent.futures.ThreadPoolExecutor(
188
            max_workers=multiprocessing.cpu_count()
189
        )
190
191
        # initialize Tornado application
192 1
        _init_asyncio_patch()
193 1
        application = TabPyTornadoApp(
194
            [
195
                (
196
                    self.subdirectory + r"/query/([^/]+)",
197
                    QueryPlaneHandler,
198
                    dict(app=self),
199
                ),
200
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
201
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
202
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
203
                (
204
                    self.subdirectory + r"/endpoints/([^/]+)?",
205
                    EndpointHandler,
206
                    dict(app=self),
207
                ),
208
                (
209
                    self.subdirectory + r"/evaluate",
210
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
211
                    else EvaluationPlaneDisabledHandler,
212
                    dict(executor=executor, app=self),
213
                ),
214
                (
215
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
216
                    UploadDestinationHandler,
217
                    dict(app=self),
218
                ),
219
                (
220
                    self.subdirectory + r"/(.*)",
221
                    tornado.web.StaticFileHandler,
222
                    dict(
223
                        path=self.settings[SettingsParameters.StaticPath],
224
                        default_filename="index.html",
225
                    ),
226
                ),
227
            ],
228
            debug=False,
229
            **self.settings,
230
        )
231
232 1
        signal.signal(signal.SIGINT, application.signal_handler)
233 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
234
235 1
        signal.signal(signal.SIGINT, application.signal_handler)
236 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
237
238 1
        return application
239
240 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
241 1
        key_is_set = False
242
243 1
        if (
244
            config_key is not None
245
            and parser.has_section("TabPy")
246
            and parser.has_option("TabPy", config_key)
247
        ):
248 1
            if parse_function is None:
249 1
                parse_function = parser.get
250 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
251 1
            key_is_set = True
252 1
            logger.debug(
253
                f"Parameter {settings_key} set to "
254
                f'"{self.settings[settings_key]}" '
255
                "from config file or environment variable"
256
            )
257
258 1
        if not key_is_set and default_val is not None:
259 1
            self.settings[settings_key] = default_val
260 1
            key_is_set = True
261 1
            logger.debug(
262
                f"Parameter {settings_key} set to "
263
                f'"{self.settings[settings_key]}" '
264
                "from default value"
265
            )
266
267 1
        if not key_is_set:
268 1
            logger.debug(f"Parameter {settings_key} is not set")
269
270 1
    def _parse_config(self, config_file):
271
        """Provide consistent mechanism for pulling in configuration.
272
273
        Attempt to retain backward compatibility for
274
        existing implementations by grabbing port
275
        setting from CLI first.
276
277
        Take settings in the following order:
278
279
        1. CLI arguments if present
280
        2. config file
281
        3. OS environment variables (for ease of
282
           setting defaults if not present)
283
        4. current defaults if a setting is not present in any location
284
285
        Additionally provide similar configuration capabilities in between
286
        config file and environment variables.
287
        For consistency use the same variable name in the config file as
288
        in the os environment.
289
        For naming standards use all capitals and start with 'TABPY_'
290
        """
291 1
        self.settings = {}
292 1
        self.subdirectory = ""
293 1
        self.tabpy_state = None
294 1
        self.python_service = None
295 1
        self.credentials = {}
296
297 1
        pkg_path = os.path.dirname(tabpy.__file__)
298
299 1
        parser = configparser.ConfigParser(os.environ)
300 1
        logger.info(f"Parsing config file {config_file}")
301
302 1
        file_exists = False
303 1
        if os.path.isfile(config_file):
304 1
            try:
305 1
                with open(config_file, 'r') as f:
306 1
                    parser.read_string(f.read())
307 1
                    file_exists = True
308
            except Exception:
309
                pass
310
311 1
        if not file_exists:
312 1
            logger.warning(
313
                f"Unable to open config file {config_file}, "
314
                "using default settings."
315
            )
316
317 1
        settings_parameters = [
318
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
319
            (SettingsParameters.ServerVersion, None, __version__, None),
320
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
321
             True, parser.getboolean),
322
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
323
             30, parser.getfloat),
324
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
325
             os.path.join(pkg_path, "tmp", "query_objects"), None),
326
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
327
             "http", None),
328
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
329
             None, None),
330
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
331
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
332
             os.path.join(pkg_path, "tabpy_server"), None),
333
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
334
             os.path.join(pkg_path, "tabpy_server", "static"), None),
335
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
336
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
337
             "false", None),
338
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
339
             100, None),
340
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
341
             True, parser.getboolean),
342
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
343
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
344
        ]
345
346 1
        for setting, parameter, default_val, parse_function in settings_parameters:
347 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
348
349 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
350 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
351
352
        # set and validate transfer protocol
353 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
354
            SettingsParameters.TransferProtocol
355
        ].lower()
356
357 1
        self._validate_transfer_protocol_settings()
358
359
        # if state.ini does not exist try and create it - remove
360
        # last dependence on batch/shell script
361 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
362
            os.path.normpath(
363
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
364
            )
365
        )
366 1
        state_config, self.tabpy_state = self._build_tabpy_state()
367
368 1
        self.python_service = PythonServiceHandler(PythonService())
369 1
        self.settings["compress_response"] = True
370 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
371
            self.settings[SettingsParameters.StaticPath]
372
        )
373 1
        logger.debug(
374
            f"Static pages folder set to "
375
            f'"{self.settings[SettingsParameters.StaticPath]}"'
376
        )
377
378
        # Set subdirectory from config if applicable
379 1
        if state_config.has_option("Service Info", "Subdirectory"):
380 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
381
382
        # If passwords file specified load credentials
383 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
384 1
            if not self._parse_pwd_file():
385 1
                msg = (
386
                    "Failed to read passwords file "
387
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
388
                )
389 1
                logger.critical(msg)
390 1
                raise RuntimeError(msg)
391
        else:
392 1
            logger.info(
393
                "Password file is not specified: " "Authentication is not enabled"
394
            )
395
396 1
        features = self._get_features()
397 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
398
399 1
        self.settings[SettingsParameters.LogRequestContext] = (
400
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
401
        )
402 1
        call_context_state = (
403
            "enabled"
404
            if self.settings[SettingsParameters.LogRequestContext]
405
            else "disabled"
406
        )
407 1
        logger.info(f"Call context logging is {call_context_state}")
408
409 1
    def _validate_transfer_protocol_settings(self):
410 1
        if SettingsParameters.TransferProtocol not in self.settings:
411
            msg = "Missing transfer protocol information."
412
            logger.critical(msg)
413
            raise RuntimeError(msg)
414
415 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
416
417 1
        if protocol == "http":
418 1
            return
419
420 1
        if protocol != "https":
421 1
            msg = f"Unsupported transfer protocol: {protocol}"
422 1
            logger.critical(msg)
423 1
            raise RuntimeError(msg)
424
425 1
        self._validate_cert_key_state(
426
            "The parameter(s) {} must be set.",
427
            SettingsParameters.CertificateFile in self.settings,
428
            SettingsParameters.KeyFile in self.settings,
429
        )
430 1
        cert = self.settings[SettingsParameters.CertificateFile]
431
432 1
        self._validate_cert_key_state(
433
            "The parameter(s) {} must point to " "an existing file.",
434
            os.path.isfile(cert),
435
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
436
        )
437 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
438
439 1
    @staticmethod
440
    def _validate_cert_key_state(msg, cert_valid, key_valid):
441 1
        cert_and_key_param = (
442
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
443
            f"{ConfigParameters.TABPY_KEY_FILE}"
444
        )
445 1
        https_error = "Error using HTTPS: "
446 1
        err = None
447 1
        if not cert_valid and not key_valid:
448 1
            err = https_error + msg.format(cert_and_key_param)
449 1
        elif not cert_valid:
450 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
451 1
        elif not key_valid:
452 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
453
454 1
        if err is not None:
455 1
            logger.critical(err)
456 1
            raise RuntimeError(err)
457
458 1
    def _parse_pwd_file(self):
459 1
        succeeded, self.credentials = parse_pwd_file(
460
            self.settings[ConfigParameters.TABPY_PWD_FILE]
461
        )
462
463 1
        if succeeded and len(self.credentials) == 0:
464 1
            logger.error("No credentials found")
465 1
            succeeded = False
466
467 1
        return succeeded
468
469 1
    def _get_features(self):
470 1
        features = {}
471
472
        # Check for auth
473 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
474 1
            features["authentication"] = {
475
                "required": True,
476
                "methods": {"basic-auth": {}},
477
            }
478
479 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
480 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
481 1
        return features
482
483 1
    def _build_tabpy_state(self):
484 1
        pkg_path = os.path.dirname(tabpy.__file__)
485 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
486 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
487 1
        if not os.path.isfile(state_file_path):
488
            state_file_template_path = os.path.join(
489
                pkg_path, "tabpy_server", "state.ini.template"
490
            )
491
            logger.debug(
492
                f"File {state_file_path} not found, creating from "
493
                f"template {state_file_template_path}..."
494
            )
495
            shutil.copy(state_file_template_path, state_file_path)
496
497 1
        logger.info(f"Loading state from state file {state_file_path}")
498 1
        tabpy_state = _get_state_from_file(state_file_dir)
499
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
500