Test Failed
Pull Request — master (#638)
by
unknown
15:46
created

TabPyApp._initialize_ssl_context()   A

Complexity

Conditions 2

Size

Total Lines 16
Code Lines 12

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 5.2516

Importance

Changes 0
Metric Value
eloc 12
dl 0
loc 16
ccs 1
cts 15
cp 0.0667
rs 9.8
c 0
b 0
f 0
cc 2
nop 1
crap 5.2516
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import ssl
9 1
import sys
10
import _thread
11 1
12 1
import tornado
13
from tornado.http1connection import HTTP1Connection
14 1
15 1
import tabpy
16 1
import tabpy.tabpy_server.app.arrow_server as pa
17 1
from tabpy.tabpy import __version__
18 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
19 1
from tabpy.tabpy_server.app.util import parse_pwd_file
20 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
21 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
22 1
from tabpy.tabpy_server.management.state import TabPyState
23 1
from tabpy.tabpy_server.management.util import _get_state_from_file
24 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
25 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
26
from tabpy.tabpy_server.handlers import (
27
    EndpointHandler,
28
    EndpointsHandler,
29
    EvaluationPlaneHandler,
30
    EvaluationPlaneDisabledHandler,
31
    QueryPlaneHandler,
32
    ServiceInfoHandler,
33
    StatusHandler,
34
    UploadDestinationHandler,
35
)
36 1
37
logger = logging.getLogger(__name__)
38 1
39
def _init_asyncio_patch():
40
    """
41
    Select compatible event loop for Tornado 5+.
42
    As of Python 3.8, the default event loop on Windows is `proactor`,
43
    however Tornado requires the old default "selector" event loop.
44
    As Tornado has decided to leave this to users to set, MkDocs needs
45
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
46 1
    """
47
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
48
        import asyncio
49
        try:
50
            from asyncio import WindowsSelectorEventLoopPolicy
51
        except ImportError:
52
            pass  # Can't assign a policy which doesn't exist.
53
        else:
54
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
55
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
56
57 1
58
class TabPyApp:
59
    """
60
    TabPy application class for keeping context like settings, state, etc.
61
    """
62 1
63 1
    settings = {}
64 1
    subdirectory = ""
65 1
    tabpy_state = None
66 1
    python_service = None
67 1
    credentials = {}
68 1
    arrow_server = None
69
    max_request_size = None
70 1
71 1
    def __init__(self, config_file, disable_auth_warning=True):
72 1
        self.disable_auth_warning = disable_auth_warning
73 1
        if config_file is None:
74
            config_file = os.path.join(
75
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
76
            )
77 1
78 1
        if os.path.isfile(config_file):
79 1
            try:
80 1
                from logging import config
81 1
                config.fileConfig(config_file, disable_existing_loggers=False)
82 1
            except KeyError:
83
                logging.basicConfig(level=logging.DEBUG)
84 1
85
        self._parse_config(config_file)
86 1
87
    def _initialize_ssl_context(self):
88
        ssl_context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
89
90
        ssl_context.load_cert_chain(
91
            certfile=self.settings[SettingsParameters.CertificateFile],
92
            keyfile=self.settings[SettingsParameters.KeyFile]
93
        )
94
        
95
        min_tls = self.settings.get(SettingsParameters.MinimumTLSVersion)
96
        try: 
97 1
            ssl_context.minimum_version = ssl.TLSVersion[min_tls]
98
            logger.info(f"Setting minimum TLS version to: {min_tls}")
99
        except KeyError:
100
            logger.warning(f"Unrecognized value for TABPY_MINIMUM_TLS_VERSION: {min_tls}")
101
102
        return ssl_context
103
104
    def _get_tls_certificates(self, config):
105
        tls_certificates = []
106
        cert = config[SettingsParameters.CertificateFile]
107
        key = config[SettingsParameters.KeyFile]
108
        with open(cert, "rb") as cert_file:
109
            tls_cert_chain = cert_file.read()
110
        with open(key, "rb") as key_file:
111
            tls_private_key = key_file.read()
112
        tls_certificates.append((tls_cert_chain, tls_private_key))
113
        return tls_certificates
114
    
115
    def _get_arrow_server(self, config):
116
        verify_client = None
117
        tls_certificates = None
118
        scheme = "grpc+tcp"
119
        if config[SettingsParameters.TransferProtocol] == "https":
120
            scheme = "grpc+tls"
121
            tls_certificates = self._get_tls_certificates(config)
122 1
123
        host = "0.0.0.0"
124
        port = config.get(SettingsParameters.ArrowFlightPort)
125
        location = "{}://{}:{}".format(scheme, host, port)
126
127
        auth_middleware = None
128
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
129
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
130
            auth_middleware = {
131
                "basic": BasicAuthServerMiddlewareFactory(creds)
132
            }
133
134
        server = pa.FlightServer(host, location,
135
                            tls_certificates=tls_certificates,
136
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
137
                            middleware=auth_middleware)
138
        return server
139
140
    def run(self):
141
        application = self._create_tornado_web_app()
142
        
143
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
144
145
        protocol = self.settings[SettingsParameters.TransferProtocol]
146
        ssl_options = None
147
        if protocol == "https":
148
            ssl_options = self._initialize_ssl_context()
149
        elif protocol != "http":
150
            msg = f"Unsupported transfer protocol {protocol}."
151
            logger.critical(msg)
152
            raise RuntimeError(msg)
153
154
        settings = {}
155
        if self.settings[SettingsParameters.GzipEnabled] is True:
156
            settings["decompress_request"] = True
157
158
        application.listen(
159
            self.settings[SettingsParameters.Port],
160
            ssl_options=ssl_options,
161
            max_buffer_size=self.max_request_size,
162
            max_body_size=self.max_request_size,
163
            **settings,
164
        ) 
165
166
        logger.info(
167
            "Web service listening on port "
168 1
            f"{str(self.settings[SettingsParameters.Port])}"
169 1
        )
170 1
171
        if self.settings[SettingsParameters.ArrowEnabled]:
172 1
            def start_pyarrow():
173
                self.arrow_server = self._get_arrow_server(self.settings)
174
                pa.start(self.arrow_server)
175
176 1
            try:
177
                _thread.start_new_thread(start_pyarrow, ())
178
            except Exception as e:
179
                logger.critical(f"Failed to start PyArrow server: {e}")
180
181 1
        tornado.ioloop.IOLoop.instance().start()
182 1
183
    def _create_tornado_web_app(self):
184
        class TabPyTornadoApp(tornado.web.Application):
185 1
            is_closing = False
186
187 1
            def signal_handler(self, signal, _):
188
                logger.critical(f"Exiting on signal {signal}...")
189
                self.is_closing = True
190
191
            def try_exit(self):
192 1
                if self.is_closing:
193 1
                    tornado.ioloop.IOLoop.instance().stop()
194
                    logger.info("Shutting down TabPy...")
195
196
        logger.info("Initializing TabPy...")
197
        tornado.ioloop.IOLoop.instance().run_sync(
198
            lambda: init_ps_server(self.settings, self.tabpy_state)
199
        )
200
        logger.info("Done initializing TabPy.")
201
202
        executor = concurrent.futures.ThreadPoolExecutor(
203
            max_workers=multiprocessing.cpu_count()
204
        )
205
206
        # initialize Tornado application
207
        _init_asyncio_patch()
208
        application = TabPyTornadoApp(
209
            [
210
                (
211
                    self.subdirectory + r"/query/([^/]+)",
212
                    QueryPlaneHandler,
213
                    dict(app=self),
214
                ),
215
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
216
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
217
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
218
                (
219
                    self.subdirectory + r"/endpoints/([^/]+)?",
220
                    EndpointHandler,
221
                    dict(app=self),
222
                ),
223
                (
224
                    self.subdirectory + r"/evaluate",
225
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
226
                    else EvaluationPlaneDisabledHandler,
227
                    dict(executor=executor, app=self),
228
                ),
229
                (
230
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
231
                    UploadDestinationHandler,
232 1
                    dict(app=self),
233 1
                ),
234
                (
235 1
                    self.subdirectory + r"/(.*)",
236 1
                    tornado.web.StaticFileHandler,
237
                    dict(
238 1
                        path=self.settings[SettingsParameters.StaticPath],
239
                        default_filename="index.html",
240 1
                    ),
241 1
                ),
242
            ],
243 1
            debug=False,
244
            **self.settings,
245
        )
246
247
        signal.signal(signal.SIGINT, application.signal_handler)
248 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
249 1
250 1
        signal.signal(signal.SIGINT, application.signal_handler)
251 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
252 1
253
        return application
254
255
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
256
        key_is_set = False
257
258 1
        if (
259 1
            config_key is not None
260 1
            and parser.has_section("TabPy")
261 1
            and parser.has_option("TabPy", config_key)
262
        ):
263
            if parse_function is None:
264
                parse_function = parser.get
265
            self.settings[settings_key] = parse_function("TabPy", config_key)
266
            key_is_set = True
267 1
            logger.debug(
268 1
                f"Parameter {settings_key} set to "
269
                f'"{self.settings[settings_key]}" '
270 1
                "from config file or environment variable"
271
            )
272
273
        if not key_is_set and default_val is not None:
274
            self.settings[settings_key] = default_val
275
            key_is_set = True
276
            logger.debug(
277
                f"Parameter {settings_key} set to "
278
                f'"{self.settings[settings_key]}" '
279
                "from default value"
280
            )
281
282
        if not key_is_set:
283
            logger.debug(f"Parameter {settings_key} is not set")
284
285
    def _parse_config(self, config_file):
286
        """Provide consistent mechanism for pulling in configuration.
287
288
        Attempt to retain backward compatibility for
289
        existing implementations by grabbing port
290
        setting from CLI first.
291 1
292 1
        Take settings in the following order:
293 1
294 1
        1. CLI arguments if present
295 1
        2. config file
296
        3. OS environment variables (for ease of
297 1
           setting defaults if not present)
298
        4. current defaults if a setting is not present in any location
299 1
300 1
        Additionally provide similar configuration capabilities in between
301
        config file and environment variables.
302 1
        For consistency use the same variable name in the config file as
303 1
        in the os environment.
304 1
        For naming standards use all capitals and start with 'TABPY_'
305 1
        """
306 1
        self.settings = {}
307 1
        self.subdirectory = ""
308 1
        self.tabpy_state = None
309 1
        self.python_service = None
310
        self.credentials = {}
311 1
312 1
        pkg_path = os.path.dirname(tabpy.__file__)
313
314
        parser = configparser.ConfigParser(os.environ)
315
        logger.info(f"Parsing config file {config_file}")
316
317 1
        file_exists = False
318
        if os.path.isfile(config_file):
319
            try:
320
                with open(config_file, 'r') as f:
321
                    parser.read_string(f.read())
322
                    file_exists = True
323
            except Exception:
324
                pass
325
326
        if not file_exists:
327
            logger.warning(
328
                f"Unable to open config file {config_file}, "
329
                "using default settings."
330
            )
331
332
        settings_parameters = [
333
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
334
            (SettingsParameters.ServerVersion, None, __version__, None),
335
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
336
             True, parser.getboolean),
337
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
338
             30, parser.getfloat),
339
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
340
             os.path.join(pkg_path, "tmp", "query_objects"), None),
341
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
342
             "http", None),
343
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
344
             None, None),
345
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
346 1
            (SettingsParameters.MinimumTLSVersion, ConfigParameters.TABPY_MINIMUM_TLS_VERSION, 
347 1
             "TLSv1_2", None),
348
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
349 1
             os.path.join(pkg_path, "tabpy_server"), None),
350 1
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
351
             os.path.join(pkg_path, "tabpy_server", "static"), None),
352
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
353 1
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
354
             "false", None),
355
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
356
             100, None),
357 1
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
358
             True, parser.getboolean),
359
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
360 1
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
361
        ]
362
363 1
        for setting, parameter, default_val, parse_function in settings_parameters:
364
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
365
366
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
367 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
368
369
        # set and validate transfer protocol
370
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
371
            SettingsParameters.TransferProtocol
372 1
        ].lower()
373
374 1
        self._validate_transfer_protocol_settings()
375 1
        
376 1
        # Set max request size in bytes
377
        self.max_request_size = (
378
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
379 1
        )
380
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
381
382
        # if state.ini does not exist try and create it - remove
383
        # last dependence on batch/shell script
384
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
385 1
            os.path.normpath(
386 1
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
387
            )
388
        )
389 1
        state_config, self.tabpy_state = self._build_tabpy_state()
390 1
391 1
        self.python_service = PythonServiceHandler(PythonService())
392
        self.settings["compress_response"] = True
393
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
394
            self.settings[SettingsParameters.StaticPath]
395 1
        )
396 1
        logger.debug(
397
            f"Static pages folder set to "
398 1
            f'"{self.settings[SettingsParameters.StaticPath]}"'
399
        )
400 1
401 1
        # Set subdirectory from config if applicable
402
        if state_config.has_option("Service Info", "Subdirectory"):
403 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
404
405
        # If passwords file specified load credentials
406 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
407
            if not self._parse_pwd_file():
408
                msg = (
409
                    "Failed to read passwords file "
410
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
411 1
                )
412
                logger.critical(msg)
413 1
                raise RuntimeError(msg)
414 1
        else:
415
            self._handle_configuration_without_authentication()
416
417
        features = self._get_features()
418
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
419 1
420
        self.settings[SettingsParameters.LogRequestContext] = (
421 1
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
422 1
        )
423
        call_context_state = (
424 1
            "enabled"
425 1
            if self.settings[SettingsParameters.LogRequestContext]
426 1
            else "disabled"
427 1
        )
428
        logger.info(f"Call context logging is {call_context_state}")
429 1
430
    def _validate_transfer_protocol_settings(self):
431
        if SettingsParameters.TransferProtocol not in self.settings:
432
            msg = "Missing transfer protocol information."
433
            logger.critical(msg)
434 1
            raise RuntimeError(msg)
435
436 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
437
438
        if protocol == "http":
439
            return
440
441 1
        if protocol != "https":
442
            msg = f"Unsupported transfer protocol: {protocol}"
443 1
            logger.critical(msg)
444
            raise RuntimeError(msg)
445 1
446
        self._validate_cert_key_state(
447
            "The parameter(s) {} must be set.",
448
            SettingsParameters.CertificateFile in self.settings,
449 1
            SettingsParameters.KeyFile in self.settings,
450 1
        )
451 1
        cert = self.settings[SettingsParameters.CertificateFile]
452 1
453 1
        self._validate_cert_key_state(
454 1
            "The parameter(s) {} must point to " "an existing file.",
455 1
            os.path.isfile(cert),
456 1
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
457
        )
458 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
459 1
460 1
    @staticmethod
461
    def _validate_cert_key_state(msg, cert_valid, key_valid):
462 1
        cert_and_key_param = (
463 1
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
464
            f"{ConfigParameters.TABPY_KEY_FILE}"
465
        )
466
        https_error = "Error using HTTPS: "
467 1
        err = None
468 1
        if not cert_valid and not key_valid:
469 1
            err = https_error + msg.format(cert_and_key_param)
470
        elif not cert_valid:
471 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
472
        elif not key_valid:
473 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
474 1
475
        if err is not None:
476 1
            logger.critical(err)
477 1
            raise RuntimeError(err)
478 1
479
    def _parse_pwd_file(self):
480 1
        succeeded, self.credentials = parse_pwd_file(
481
            self.settings[ConfigParameters.TABPY_PWD_FILE]
482 1
        )
483 1
484
        if succeeded and len(self.credentials) == 0:
485
            logger.error("No credentials found")
486 1
            succeeded = False
487
488
        return succeeded
489 1
490
    def _handle_configuration_without_authentication(self):
491 1
        std_no_auth_msg = "Password file is not specified: Authentication is not enabled"
492 1
493
        if self.disable_auth_warning == True:
494
            logger.info(std_no_auth_msg)
495
            return  
496
497
        confirm_no_auth_msg = "\nWARNING: This TabPy server is not currently configured for username/password authentication. "
498 1
499 1
        if self.settings[SettingsParameters.EvaluateEnabled]:
500
            confirm_no_auth_msg += ("This means that, because the TABPY_EVALUATE_ENABLE feature is enabled, there is " 
501
                "the potential that unauthenticated individuals may be able to remotely execute code on this machine. ")
502 1
503 1
        confirm_no_auth_msg += ("We strongly advise against proceeding without authentication as it poses a significant security risk.\n\n"
504
            "Do you wish to proceed without authentication? (y/N): ")
505
506
        confirm_no_auth_input = input(confirm_no_auth_msg)
507
508 1
        if confirm_no_auth_input == 'y':
509 1
            logger.info(std_no_auth_msg)
510 1
        else:
511 1
            print("\nAborting start up. To enable authentication for your TabPy server, see "
512
                "https://github.com/tableau/TabPy/blob/master/docs/server-config.md#authentication.")
513 1
            exit()
514 1
515 1
    def _get_features(self):
516 1
        features = {}
517 1
518
        # Check for auth
519
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
520
            features["authentication"] = {
521
                "required": True,
522
                "methods": {"basic-auth": {}},
523
            }
524
525
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
526
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
527 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
528 1
        return features
529 1
530
    def _build_tabpy_state(self):
531
        pkg_path = os.path.dirname(tabpy.__file__)
532
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
533
        state_file_path = os.path.join(state_file_dir, "state.ini")
534 1
        if not os.path.isfile(state_file_path):
535 1
            state_file_template_path = os.path.join(
536 1
                pkg_path, "tabpy_server", "state.ini.template"
537 1
            )
538
            logger.debug(
539 1
                f"File {state_file_path} not found, creating from "
540
                f"template {state_file_template_path}..."
541 1
            )
542 1
            shutil.copy(state_file_template_path, state_file_path)
543
544
        logger.info(f"Loading state from state file {state_file_path}")
545
        tabpy_state = _get_state_from_file(state_file_dir)
546
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
547
548
549
# Override _read_body to allow content with size exceeding max_body_size
550
# This enables proper handling of 413 errors in base_handler
551
def _read_body_allow_max_size(self, code, headers, delegate):
552
    if "Content-Length" in headers:
553
        content_length = int(headers["Content-Length"])
554
        if content_length > self._max_body_size:
555
            return
556
    return self.original_read_body(code, headers, delegate)
557
558
HTTP1Connection.original_read_body = HTTP1Connection._read_body
559
HTTP1Connection._read_body = _read_body_allow_max_size
560