TabPyApp._initialize_ssl_context()   A
last analyzed

Complexity

Conditions 2

Size

Total Lines 17
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 4.916

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 17
ccs 1
cts 10
cp 0.1
rs 9.8
c 0
b 0
f 0
cc 2
nop 1
crap 4.916
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import ssl
9 1
import sys
10 1
import _thread
11
12 1
import tornado
13 1
from tornado.http1connection import HTTP1Connection
14
15 1
import tabpy
16 1
import tabpy.tabpy_server.app.arrow_server as pa
17 1
from tabpy.tabpy import __version__
18 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
19 1
from tabpy.tabpy_server.app.util import parse_pwd_file
20 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
21 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
22 1
from tabpy.tabpy_server.management.state import TabPyState
23 1
from tabpy.tabpy_server.management.util import _get_state_from_file
24 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
25 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
26 1
from tabpy.tabpy_server.handlers import (
27
    EndpointHandler,
28
    EndpointsHandler,
29
    EvaluationPlaneHandler,
30
    EvaluationPlaneDisabledHandler,
31
    QueryPlaneHandler,
32
    ServiceInfoHandler,
33
    StatusHandler,
34
    UploadDestinationHandler,
35
)
36
37 1
logger = logging.getLogger(__name__)
38
39 1
def _init_asyncio_patch():
40
    """
41
    Select compatible event loop for Tornado 5+.
42
    As of Python 3.8, the default event loop on Windows is `proactor`,
43
    however Tornado requires the old default "selector" event loop.
44
    As Tornado has decided to leave this to users to set, MkDocs needs
45
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
46
    """
47 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
48
        import asyncio
49
        try:
50
            from asyncio import WindowsSelectorEventLoopPolicy
51
        except ImportError:
52
            pass  # Can't assign a policy which doesn't exist.
53
        else:
54
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
55
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
56
57
58 1
class TabPyApp:
59
    """
60
    TabPy application class for keeping context like settings, state, etc.
61
    """
62
63 1
    settings = {}
64 1
    subdirectory = ""
65 1
    tabpy_state = None
66 1
    python_service = None
67 1
    credentials = {}
68 1
    arrow_server = None
69 1
    max_request_size = None
70
71 1
    def __init__(self, config_file, disable_auth_warning=True):
72 1
        self.disable_auth_warning = disable_auth_warning
73 1
        if config_file is None:
74 1
            config_file = os.path.join(
75
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
76
            )
77
78 1
        if os.path.isfile(config_file):
79 1
            try:
80 1
                from logging import config
81 1
                config.fileConfig(config_file, disable_existing_loggers=False)
82 1
            except KeyError:
83 1
                logging.basicConfig(level=logging.DEBUG)
84
85 1
        self._parse_config(config_file)
86
87 1
    def _initialize_ssl_context(self):
88
        ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
89
90
        ssl_context.load_cert_chain(
91
            certfile=self.settings[SettingsParameters.CertificateFile],
92
            keyfile=self.settings[SettingsParameters.KeyFile]
93
        )
94
95
        min_tls = self.settings[SettingsParameters.MinimumTLSVersion]
96
        if not hasattr(ssl.TLSVersion, min_tls):
97
            logger.warning(f"Unrecognized value for TABPY_MINIMUM_TLS_VERSION: {min_tls}")
98
            min_tls = "TLSv1_2"
99
            
100
        logger.info(f"Setting minimum TLS version to {min_tls}") 
101
        ssl_context.minimum_version = ssl.TLSVersion[min_tls]
102
103
        return ssl_context
104
105 1
    def _get_tls_certificates(self, config):
106
        tls_certificates = []
107
        cert = config[SettingsParameters.CertificateFile]
108
        key = config[SettingsParameters.KeyFile]
109
        with open(cert, "rb") as cert_file:
110
            tls_cert_chain = cert_file.read()
111
        with open(key, "rb") as key_file:
112
            tls_private_key = key_file.read()
113
        tls_certificates.append((tls_cert_chain, tls_private_key))
114
        return tls_certificates
115
    
116 1
    def _get_arrow_server(self, config):
117
        verify_client = None
118
        tls_certificates = None
119
        scheme = "grpc+tcp"
120
        if config[SettingsParameters.TransferProtocol] == "https":
121
            scheme = "grpc+tls"
122
            tls_certificates = self._get_tls_certificates(config)
123
124
        host = "0.0.0.0"
125
        port = config.get(SettingsParameters.ArrowFlightPort)
126
        location = "{}://{}:{}".format(scheme, host, port)
127
128
        auth_middleware = None
129
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
130
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
131
            auth_middleware = {
132
                "basic": BasicAuthServerMiddlewareFactory(creds)
133
            }
134
135
        server = pa.FlightServer(host, location,
136
                            tls_certificates=tls_certificates,
137
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
138
                            middleware=auth_middleware)
139
        return server
140
141 1
    def run(self):
142
        application = self._create_tornado_web_app()
143
        
144
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
145
146
        protocol = self.settings[SettingsParameters.TransferProtocol]
147
        ssl_options = None
148
        if protocol == "https":
149
            ssl_options = self._initialize_ssl_context()
150
        elif protocol != "http":
151
            msg = f"Unsupported transfer protocol {protocol}."
152
            logger.critical(msg)
153
            raise RuntimeError(msg)
154
155
        settings = {}
156
        if self.settings[SettingsParameters.GzipEnabled] is True:
157
            settings["decompress_request"] = True
158
159
        application.listen(
160
            self.settings[SettingsParameters.Port],
161
            ssl_options=ssl_options,
162
            max_buffer_size=self.max_request_size,
163
            max_body_size=self.max_request_size,
164
            **settings,
165
        ) 
166
167
        logger.info(
168
            "Web service listening on port "
169
            f"{str(self.settings[SettingsParameters.Port])}"
170
        )
171
172
        if self.settings[SettingsParameters.ArrowEnabled]:
173
            def start_pyarrow():
174
                self.arrow_server = self._get_arrow_server(self.settings)
175
                pa.start(self.arrow_server)
176
177
            try:
178
                _thread.start_new_thread(start_pyarrow, ())
179
            except Exception as e:
180
                logger.critical(f"Failed to start PyArrow server: {e}")
181
182
        tornado.ioloop.IOLoop.instance().start()
183
184 1
    def _create_tornado_web_app(self):
185 1
        class TabPyTornadoApp(tornado.web.Application):
186 1
            is_closing = False
187
188 1
            def signal_handler(self, signal, _):
189
                logger.critical(f"Exiting on signal {signal}...")
190
                self.is_closing = True
191
192 1
            def try_exit(self):
193
                if self.is_closing:
194
                    tornado.ioloop.IOLoop.instance().stop()
195
                    logger.info("Shutting down TabPy...")
196
197 1
        logger.info("Initializing TabPy...")
198 1
        tornado.ioloop.IOLoop.instance().run_sync(
199
            lambda: init_ps_server(self.settings, self.tabpy_state)
200
        )
201 1
        logger.info("Done initializing TabPy.")
202
203 1
        executor = concurrent.futures.ThreadPoolExecutor(
204
            max_workers=multiprocessing.cpu_count()
205
        )
206
207
        # initialize Tornado application
208 1
        _init_asyncio_patch()
209 1
        application = TabPyTornadoApp(
210
            [
211
                (
212
                    self.subdirectory + r"/query/([^/]+)",
213
                    QueryPlaneHandler,
214
                    dict(app=self),
215
                ),
216
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
217
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
218
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
219
                (
220
                    self.subdirectory + r"/endpoints/([^/]+)?",
221
                    EndpointHandler,
222
                    dict(app=self),
223
                ),
224
                (
225
                    self.subdirectory + r"/evaluate",
226
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
227
                    else EvaluationPlaneDisabledHandler,
228
                    dict(executor=executor, app=self),
229
                ),
230
                (
231
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
232
                    UploadDestinationHandler,
233
                    dict(app=self),
234
                ),
235
                (
236
                    self.subdirectory + r"/(.*)",
237
                    tornado.web.StaticFileHandler,
238
                    dict(
239
                        path=self.settings[SettingsParameters.StaticPath],
240
                        default_filename="index.html",
241
                    ),
242
                ),
243
            ],
244
            debug=False,
245
            **self.settings,
246
        )
247
248 1
        signal.signal(signal.SIGINT, application.signal_handler)
249 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
250
251 1
        signal.signal(signal.SIGINT, application.signal_handler)
252 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
253
254 1
        return application
255
256 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
257 1
        key_is_set = False
258
259 1
        if (
260
            config_key is not None
261
            and parser.has_section("TabPy")
262
            and parser.has_option("TabPy", config_key)
263
        ):
264 1
            if parse_function is None:
265 1
                parse_function = parser.get
266 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
267 1
            key_is_set = True
268 1
            logger.debug(
269
                f"Parameter {settings_key} set to "
270
                f'"{self.settings[settings_key]}" '
271
                "from config file or environment variable"
272
            )
273
274 1
        if not key_is_set and default_val is not None:
275 1
            self.settings[settings_key] = default_val
276 1
            key_is_set = True
277 1
            logger.debug(
278
                f"Parameter {settings_key} set to "
279
                f'"{self.settings[settings_key]}" '
280
                "from default value"
281
            )
282
283 1
        if not key_is_set:
284 1
            logger.debug(f"Parameter {settings_key} is not set")
285
286 1
    def _parse_config(self, config_file):
287
        """Provide consistent mechanism for pulling in configuration.
288
289
        Attempt to retain backward compatibility for
290
        existing implementations by grabbing port
291
        setting from CLI first.
292
293
        Take settings in the following order:
294
295
        1. CLI arguments if present
296
        2. config file
297
        3. OS environment variables (for ease of
298
           setting defaults if not present)
299
        4. current defaults if a setting is not present in any location
300
301
        Additionally provide similar configuration capabilities in between
302
        config file and environment variables.
303
        For consistency use the same variable name in the config file as
304
        in the os environment.
305
        For naming standards use all capitals and start with 'TABPY_'
306
        """
307 1
        self.settings = {}
308 1
        self.subdirectory = ""
309 1
        self.tabpy_state = None
310 1
        self.python_service = None
311 1
        self.credentials = {}
312
313 1
        pkg_path = os.path.dirname(tabpy.__file__)
314
315 1
        parser = configparser.ConfigParser(os.environ)
316 1
        logger.info(f"Parsing config file {config_file}")
317
318 1
        file_exists = False
319 1
        if os.path.isfile(config_file):
320 1
            try:
321 1
                with open(config_file, 'r') as f:
322 1
                    parser.read_string(f.read())
323 1
                    file_exists = True
324 1
            except Exception:
325 1
                pass
326
327 1
        if not file_exists:
328 1
            logger.warning(
329
                f"Unable to open config file {config_file}, "
330
                "using default settings."
331
            )
332
333 1
        settings_parameters = [
334
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
335
            (SettingsParameters.ServerVersion, None, __version__, None),
336
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
337
             True, parser.getboolean),
338
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
339
             30, parser.getfloat),
340
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
341
             os.path.join(pkg_path, "tmp", "query_objects"), None),
342
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
343
             "http", None),
344
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
345
             None, None),
346
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
347
            (SettingsParameters.MinimumTLSVersion, ConfigParameters.TABPY_MINIMUM_TLS_VERSION,
348
             "TLSv1_2", None),
349
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
350
             os.path.join(pkg_path, "tabpy_server"), None),
351
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
352
             os.path.join(pkg_path, "tabpy_server", "static"), None),
353
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
354
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
355
             "false", None),
356
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
357
             100, None),
358
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
359
             True, parser.getboolean),
360
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
361
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
362
        ]
363
364 1
        for setting, parameter, default_val, parse_function in settings_parameters:
365 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
366
367 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
368 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
369
370
        # set and validate transfer protocol
371 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
372
            SettingsParameters.TransferProtocol
373
        ].lower()
374
375 1
        self._validate_transfer_protocol_settings()
376
        
377
        # Set max request size in bytes
378 1
        self.max_request_size = (
379
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
380
        )
381 1
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
382
383
        # if state.ini does not exist try and create it - remove
384
        # last dependence on batch/shell script
385 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
386
            os.path.normpath(
387
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
388
            )
389
        )
390 1
        state_config, self.tabpy_state = self._build_tabpy_state()
391
392 1
        self.python_service = PythonServiceHandler(PythonService())
393 1
        self.settings["compress_response"] = True
394 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
395
            self.settings[SettingsParameters.StaticPath]
396
        )
397 1
        logger.debug(
398
            f"Static pages folder set to "
399
            f'"{self.settings[SettingsParameters.StaticPath]}"'
400
        )
401
402
        # Set subdirectory from config if applicable
403 1
        if state_config.has_option("Service Info", "Subdirectory"):
404 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
405
406
        # If passwords file specified load credentials
407 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
408 1
            if not self._parse_pwd_file():
409 1
                msg = (
410
                    "Failed to read passwords file "
411
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
412
                )
413 1
                logger.critical(msg)
414 1
                raise RuntimeError(msg)
415
        else:
416 1
            self._handle_configuration_without_authentication()
417
418 1
        features = self._get_features()
419 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
420
421 1
        self.settings[SettingsParameters.LogRequestContext] = (
422
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
423
        )
424 1
        call_context_state = (
425
            "enabled"
426
            if self.settings[SettingsParameters.LogRequestContext]
427
            else "disabled"
428
        )
429 1
        logger.info(f"Call context logging is {call_context_state}")
430
431 1
    def _validate_transfer_protocol_settings(self):
432 1
        if SettingsParameters.TransferProtocol not in self.settings:
433
            msg = "Missing transfer protocol information."
434
            logger.critical(msg)
435
            raise RuntimeError(msg)
436
437 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
438
439 1
        if protocol == "http":
440 1
            return
441
442 1
        if protocol != "https":
443 1
            msg = f"Unsupported transfer protocol: {protocol}"
444 1
            logger.critical(msg)
445 1
            raise RuntimeError(msg)
446
447 1
        self._validate_cert_key_state(
448
            "The parameter(s) {} must be set.",
449
            SettingsParameters.CertificateFile in self.settings,
450
            SettingsParameters.KeyFile in self.settings,
451
        )
452 1
        cert = self.settings[SettingsParameters.CertificateFile]
453
454 1
        self._validate_cert_key_state(
455
            "The parameter(s) {} must point to " "an existing file.",
456
            os.path.isfile(cert),
457
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
458
        )
459 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
460
461 1
    @staticmethod
462 1
    def _validate_cert_key_state(msg, cert_valid, key_valid):
463 1
        cert_and_key_param = (
464
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
465
            f"{ConfigParameters.TABPY_KEY_FILE}"
466
        )
467 1
        https_error = "Error using HTTPS: "
468 1
        err = None
469 1
        if not cert_valid and not key_valid:
470 1
            err = https_error + msg.format(cert_and_key_param)
471 1
        elif not cert_valid:
472 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
473 1
        elif not key_valid:
474 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
475
476 1
        if err is not None:
477 1
            logger.critical(err)
478 1
            raise RuntimeError(err)
479
480 1
    def _parse_pwd_file(self):
481 1
        succeeded, self.credentials = parse_pwd_file(
482
            self.settings[ConfigParameters.TABPY_PWD_FILE]
483
        )
484
485 1
        if succeeded and len(self.credentials) == 0:
486 1
            logger.error("No credentials found")
487 1
            succeeded = False
488
489 1
        return succeeded
490
491 1
    def _handle_configuration_without_authentication(self):
492 1
        std_no_auth_msg = "Password file is not specified: Authentication is not enabled"
493
494 1
        if self.disable_auth_warning == True:
495 1
            logger.info(std_no_auth_msg)
496 1
            return  
497
498 1
        confirm_no_auth_msg = "\nWARNING: This TabPy server is not currently configured for username/password authentication. "
499
500 1
        if self.settings[SettingsParameters.EvaluateEnabled]:
501 1
            confirm_no_auth_msg += ("This means that, because the TABPY_EVALUATE_ENABLE feature is enabled, there is " 
502
                "the potential that unauthenticated individuals may be able to remotely execute code on this machine. ")
503
504 1
        confirm_no_auth_msg += ("We strongly advise against proceeding without authentication as it poses a significant security risk.\n\n"
505
            "Do you wish to proceed without authentication? (y/N): ")
506
507 1
        confirm_no_auth_input = input(confirm_no_auth_msg)
508
509 1
        if confirm_no_auth_input == 'y':
510 1
            logger.info(std_no_auth_msg)
511
        else:
512
            print("\nAborting start up. To enable authentication for your TabPy server, see "
513
                "https://github.com/tableau/TabPy/blob/master/docs/server-config.md#authentication.")
514
            exit()
515
516 1
    def _get_features(self):
517 1
        features = {}
518
519
        # Check for auth
520 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
521 1
            features["authentication"] = {
522
                "required": True,
523
                "methods": {"basic-auth": {}},
524
            }
525
526 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
527 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
528 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
529 1
        return features
530
531 1
    def _build_tabpy_state(self):
532 1
        pkg_path = os.path.dirname(tabpy.__file__)
533 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
534 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
535 1
        if not os.path.isfile(state_file_path):
536
            state_file_template_path = os.path.join(
537
                pkg_path, "tabpy_server", "state.ini.template"
538
            )
539
            logger.debug(
540
                f"File {state_file_path} not found, creating from "
541
                f"template {state_file_template_path}..."
542
            )
543
            shutil.copy(state_file_template_path, state_file_path)
544
545 1
        logger.info(f"Loading state from state file {state_file_path}")
546 1
        tabpy_state = _get_state_from_file(state_file_dir)
547 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
548
549
550
# Override _read_body to allow content with size exceeding max_body_size
551
# This enables proper handling of 413 errors in base_handler
552 1
def _read_body_allow_max_size(self, code, headers, delegate):
553 1
    if "Content-Length" in headers:
554 1
        content_length = int(headers["Content-Length"])
555 1
        if content_length > self._max_body_size:
556
            return
557 1
    return self.original_read_body(code, headers, delegate)
558
559 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
560
HTTP1Connection._read_body = _read_body_allow_max_size
561