Test Failed
Pull Request — master (#611)
by
unknown
15:21
created

TabPyApp._validate_transfer_protocol_settings()   A

Complexity

Conditions 4

Size

Total Lines 29
Code Lines 22

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 12
CRAP Score 4.128

Importance

Changes 0
Metric Value
eloc 22
dl 0
loc 29
ccs 12
cts 15
cp 0.8
rs 9.352
c 0
b 0
f 0
cc 4
nop 1
crap 4.128
1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import _thread
10
11 1
import tornado
12 1
from tornado.http1connection import HTTP1Connection
13
14 1
import tabpy
15 1
import tabpy.tabpy_server.app.arrow_server as pa
16 1
from tabpy.tabpy import __version__
17 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
18 1
from tabpy.tabpy_server.app.util import parse_pwd_file
19 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
20 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
21 1
from tabpy.tabpy_server.management.state import TabPyState
22 1
from tabpy.tabpy_server.management.util import _get_state_from_file
23 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
24 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
25 1
from tabpy.tabpy_server.handlers import (
26
    EndpointHandler,
27
    EndpointsHandler,
28
    EvaluationPlaneHandler,
29
    EvaluationPlaneDisabledHandler,
30
    QueryPlaneHandler,
31
    ServiceInfoHandler,
32
    StatusHandler,
33
    UploadDestinationHandler,
34
)
35
36 1
logger = logging.getLogger(__name__)
37
38 1
def _init_asyncio_patch():
39
    """
40
    Select compatible event loop for Tornado 5+.
41
    As of Python 3.8, the default event loop on Windows is `proactor`,
42
    however Tornado requires the old default "selector" event loop.
43
    As Tornado has decided to leave this to users to set, MkDocs needs
44
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
45
    """
46 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
47
        import asyncio
48
        try:
49
            from asyncio import WindowsSelectorEventLoopPolicy
50
        except ImportError:
51
            pass  # Can't assign a policy which doesn't exist.
52
        else:
53
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
54
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
55
56
57 1
class TabPyApp:
58
    """
59
    TabPy application class for keeping context like settings, state, etc.
60
    """
61
62 1
    settings = {}
63 1
    subdirectory = ""
64 1
    tabpy_state = None
65 1
    python_service = None
66 1
    credentials = {}
67 1
    arrow_server = None
68 1
    max_request_size = None
69
70 1
    def __init__(self, config_file, show_no_auth_warning=False):
71 1
        self.show_no_auth_warning = show_no_auth_warning
72 1
        if config_file is None:
73
            config_file = os.path.join(
74
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
75
            )
76 1
77 1
        if os.path.isfile(config_file):
78 1
            try:
79 1
                from logging import config
80 1
                config.fileConfig(config_file, disable_existing_loggers=False)
81 1
            except KeyError:
82
                logging.basicConfig(level=logging.DEBUG)
83 1
84
        self._parse_config(config_file)
85 1
86
    def _get_tls_certificates(self, config):
87
        tls_certificates = []
88
        cert = config[SettingsParameters.CertificateFile]
89
        key = config[SettingsParameters.KeyFile]
90
        with open(cert, "rb") as cert_file:
91
            tls_cert_chain = cert_file.read()
92
        with open(key, "rb") as key_file:
93
            tls_private_key = key_file.read()
94
        tls_certificates.append((tls_cert_chain, tls_private_key))
95
        return tls_certificates
96 1
    
97
    def _get_arrow_server(self, config):
98
        verify_client = None
99
        tls_certificates = None
100
        scheme = "grpc+tcp"
101
        if config[SettingsParameters.TransferProtocol] == "https":
102
            scheme = "grpc+tls"
103
            tls_certificates = self._get_tls_certificates(config)
104
105
        host = "0.0.0.0"
106
        port = config.get(SettingsParameters.ArrowFlightPort)
107
        location = "{}://{}:{}".format(scheme, host, port)
108
109
        auth_middleware = None
110
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
111
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
112
            auth_middleware = {
113
                "basic": BasicAuthServerMiddlewareFactory(creds)
114
            }
115
116
        server = pa.FlightServer(host, location,
117
                            tls_certificates=tls_certificates,
118
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
119
                            middleware=auth_middleware)
120
        return server
121 1
122
    def run(self):
123
        application = self._create_tornado_web_app()
124
        
125
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
126
127
        protocol = self.settings[SettingsParameters.TransferProtocol]
128
        ssl_options = None
129
        if protocol == "https":
130
            ssl_options = {
131
                "certfile": self.settings[SettingsParameters.CertificateFile],
132
                "keyfile": self.settings[SettingsParameters.KeyFile],
133
            }
134
        elif protocol != "http":
135
            msg = f"Unsupported transfer protocol {protocol}."
136
            logger.critical(msg)
137
            raise RuntimeError(msg)
138
139
        settings = {}
140
        if self.settings[SettingsParameters.GzipEnabled] is True:
141
            settings["decompress_request"] = True
142
143
        application.listen(
144
            self.settings[SettingsParameters.Port],
145
            ssl_options=ssl_options,
146
            max_buffer_size=self.max_request_size,
147
            max_body_size=self.max_request_size,
148
            **settings,
149
        ) 
150
151
        logger.info(
152
            "Web service listening on port "
153
            f"{str(self.settings[SettingsParameters.Port])}"
154
        )
155
156
        if self.settings[SettingsParameters.ArrowEnabled]:
157
            def start_pyarrow():
158
                self.arrow_server = self._get_arrow_server(self.settings)
159
                pa.start(self.arrow_server)
160
161
            try:
162
                _thread.start_new_thread(start_pyarrow, ())
163
            except Exception as e:
164
                logger.critical(f"Failed to start PyArrow server: {e}")
165
166
        tornado.ioloop.IOLoop.instance().start()
167 1
168 1
    def _create_tornado_web_app(self):
169 1
        class TabPyTornadoApp(tornado.web.Application):
170
            is_closing = False
171 1
172
            def signal_handler(self, signal, _):
173
                logger.critical(f"Exiting on signal {signal}...")
174
                self.is_closing = True
175 1
176
            def try_exit(self):
177
                if self.is_closing:
178
                    tornado.ioloop.IOLoop.instance().stop()
179
                    logger.info("Shutting down TabPy...")
180 1
181 1
        logger.info("Initializing TabPy...")
182
        tornado.ioloop.IOLoop.instance().run_sync(
183
            lambda: init_ps_server(self.settings, self.tabpy_state)
184 1
        )
185
        logger.info("Done initializing TabPy.")
186 1
187
        executor = concurrent.futures.ThreadPoolExecutor(
188
            max_workers=multiprocessing.cpu_count()
189
        )
190
191 1
        # initialize Tornado application
192 1
        _init_asyncio_patch()
193
        application = TabPyTornadoApp(
194
            [
195
                (
196
                    self.subdirectory + r"/query/([^/]+)",
197
                    QueryPlaneHandler,
198
                    dict(app=self),
199
                ),
200
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
201
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
202
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
203
                (
204
                    self.subdirectory + r"/endpoints/([^/]+)?",
205
                    EndpointHandler,
206
                    dict(app=self),
207
                ),
208
                (
209
                    self.subdirectory + r"/evaluate",
210
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
211
                    else EvaluationPlaneDisabledHandler,
212
                    dict(executor=executor, app=self),
213
                ),
214
                (
215
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
216
                    UploadDestinationHandler,
217
                    dict(app=self),
218
                ),
219
                (
220
                    self.subdirectory + r"/(.*)",
221
                    tornado.web.StaticFileHandler,
222
                    dict(
223
                        path=self.settings[SettingsParameters.StaticPath],
224
                        default_filename="index.html",
225
                    ),
226
                ),
227
            ],
228
            debug=False,
229
            **self.settings,
230
        )
231 1
232 1
        signal.signal(signal.SIGINT, application.signal_handler)
233
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
234 1
235 1
        signal.signal(signal.SIGINT, application.signal_handler)
236
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
237 1
238
        return application
239 1
240 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
241
        key_is_set = False
242 1
243
        if (
244
            config_key is not None
245
            and parser.has_section("TabPy")
246
            and parser.has_option("TabPy", config_key)
247 1
        ):
248 1
            if parse_function is None:
249 1
                parse_function = parser.get
250 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
251 1
            key_is_set = True
252
            logger.debug(
253
                f"Parameter {settings_key} set to "
254
                f'"{self.settings[settings_key]}" '
255
                "from config file or environment variable"
256
            )
257 1
258 1
        if not key_is_set and default_val is not None:
259 1
            self.settings[settings_key] = default_val
260 1
            key_is_set = True
261
            logger.debug(
262
                f"Parameter {settings_key} set to "
263
                f'"{self.settings[settings_key]}" '
264
                "from default value"
265
            )
266 1
267 1
        if not key_is_set:
268
            logger.debug(f"Parameter {settings_key} is not set")
269 1
270
    def _parse_config(self, config_file):
271
        """Provide consistent mechanism for pulling in configuration.
272
273
        Attempt to retain backward compatibility for
274
        existing implementations by grabbing port
275
        setting from CLI first.
276
277
        Take settings in the following order:
278
279
        1. CLI arguments if present
280
        2. config file
281
        3. OS environment variables (for ease of
282
           setting defaults if not present)
283
        4. current defaults if a setting is not present in any location
284
285
        Additionally provide similar configuration capabilities in between
286
        config file and environment variables.
287
        For consistency use the same variable name in the config file as
288
        in the os environment.
289
        For naming standards use all capitals and start with 'TABPY_'
290 1
        """
291 1
        self.settings = {}
292 1
        self.subdirectory = ""
293 1
        self.tabpy_state = None
294 1
        self.python_service = None
295
        self.credentials = {}
296 1
297
        pkg_path = os.path.dirname(tabpy.__file__)
298 1
299 1
        parser = configparser.ConfigParser(os.environ)
300
        logger.info(f"Parsing config file {config_file}")
301 1
302 1
        file_exists = False
303 1
        if os.path.isfile(config_file):
304 1
            try:
305 1
                with open(config_file, 'r') as f:
306 1
                    parser.read_string(f.read())
307
                    file_exists = True
308
            except Exception:
309
                pass
310 1
311 1
        if not file_exists:
312
            logger.warning(
313
                f"Unable to open config file {config_file}, "
314
                "using default settings."
315
            )
316 1
317
        settings_parameters = [
318
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
319
            (SettingsParameters.ServerVersion, None, __version__, None),
320
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
321
             True, parser.getboolean),
322
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
323
             30, parser.getfloat),
324
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
325
             os.path.join(pkg_path, "tmp", "query_objects"), None),
326
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
327
             "http", None),
328
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
329
             None, None),
330
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
331
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
332
             os.path.join(pkg_path, "tabpy_server"), None),
333
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
334
             os.path.join(pkg_path, "tabpy_server", "static"), None),
335
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
336
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
337
             "false", None),
338
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
339
             100, None),
340
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
341
             True, parser.getboolean),
342
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
343
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
344
        ]
345 1
346 1
        for setting, parameter, default_val, parse_function in settings_parameters:
347
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
348 1
349 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
350
            os.makedirs(self.settings[SettingsParameters.UploadDir])
351
352 1
        # set and validate transfer protocol
353
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
354
            SettingsParameters.TransferProtocol
355
        ].lower()
356 1
357
        self._validate_transfer_protocol_settings()
358
        
359 1
        # Set max request size in bytes
360
        self.max_request_size = (
361
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
362 1
        )
363
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
364
365
        # if state.ini does not exist try and create it - remove
366 1
        # last dependence on batch/shell script
367
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
368
            os.path.normpath(
369
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
370
            )
371 1
        )
372
        state_config, self.tabpy_state = self._build_tabpy_state()
373 1
374 1
        self.python_service = PythonServiceHandler(PythonService())
375 1
        self.settings["compress_response"] = True
376
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
377
            self.settings[SettingsParameters.StaticPath]
378 1
        )
379
        logger.debug(
380
            f"Static pages folder set to "
381
            f'"{self.settings[SettingsParameters.StaticPath]}"'
382
        )
383
384 1
        # Set subdirectory from config if applicable
385 1
        if state_config.has_option("Service Info", "Subdirectory"):
386
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
387
388 1
        # If passwords file specified load credentials
389 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
390 1
            if not self._parse_pwd_file():
391
                msg = (
392
                    "Failed to read passwords file "
393
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
394 1
                )
395 1
                logger.critical(msg)
396
                raise RuntimeError(msg)
397 1
        else:
398
            self._handle_configuration_without_authentication()
399
     
400
        features = self._get_features()
401 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
402 1
403
        self.settings[SettingsParameters.LogRequestContext] = (
404 1
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
405
        )
406
        call_context_state = (
407 1
            "enabled"
408
            if self.settings[SettingsParameters.LogRequestContext]
409
            else "disabled"
410
        )
411
        logger.info(f"Call context logging is {call_context_state}")
412 1
413
    def _validate_transfer_protocol_settings(self):
414 1
        if SettingsParameters.TransferProtocol not in self.settings:
415 1
            msg = "Missing transfer protocol information."
416
            logger.critical(msg)
417
            raise RuntimeError(msg)
418
419
        protocol = self.settings[SettingsParameters.TransferProtocol]
420 1
421
        if protocol == "http":
422 1
            return
423 1
424
        if protocol != "https":
425 1
            msg = f"Unsupported transfer protocol: {protocol}"
426 1
            logger.critical(msg)
427 1
            raise RuntimeError(msg)
428 1
429
        self._validate_cert_key_state(
430 1
            "The parameter(s) {} must be set.",
431
            SettingsParameters.CertificateFile in self.settings,
432
            SettingsParameters.KeyFile in self.settings,
433
        )
434
        cert = self.settings[SettingsParameters.CertificateFile]
435 1
436
        self._validate_cert_key_state(
437 1
            "The parameter(s) {} must point to " "an existing file.",
438
            os.path.isfile(cert),
439
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
440
        )
441
        tabpy.tabpy_server.app.util.validate_cert(cert)
442 1
443
    @staticmethod
444 1
    def _validate_cert_key_state(msg, cert_valid, key_valid):
445
        cert_and_key_param = (
446 1
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
447
            f"{ConfigParameters.TABPY_KEY_FILE}"
448
        )
449
        https_error = "Error using HTTPS: "
450 1
        err = None
451 1
        if not cert_valid and not key_valid:
452 1
            err = https_error + msg.format(cert_and_key_param)
453 1
        elif not cert_valid:
454 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
455 1
        elif not key_valid:
456 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
457 1
458
        if err is not None:
459 1
            logger.critical(err)
460 1
            raise RuntimeError(err)
461 1
462
    def _parse_pwd_file(self):
463 1
        succeeded, self.credentials = parse_pwd_file(
464 1
            self.settings[ConfigParameters.TABPY_PWD_FILE]
465
        )
466
467
        if succeeded and len(self.credentials) == 0:
468 1
            logger.error("No credentials found")
469 1
            succeeded = False
470 1
471
        return succeeded
472 1
473
    def _handle_configuration_without_authentication(self):
474 1
        std_no_auth_msg = "Password file is not specified: Authentication is not enabled"
475 1
476
        if self.show_no_auth_warning != True:
477
            logger.info(std_no_auth_msg)
478 1
            return  
479 1
480
        confirm_no_auth_msg = "\nWARNING: No username/password authentication is configured for this TabPy server. "
481
482
        if self.settings[SettingsParameters.EvaluateEnabled]:
483
            confirm_no_auth_msg += ("Since TABPY_EVALUATE_ENABLE is on, unauthenticated users can execute "
484 1
                "remote code on this machine, posing significant security risks. ")
485 1
        
486 1
        confirm_no_auth_msg += ("Proceeding in this insecure state is strongly discouraged.\n\n"
487 1
            "Are you sure you want to continue without authentication? (y/N): ")
488
489 1
        confirm_no_auth_input = input(confirm_no_auth_msg)
490 1
491 1
        if confirm_no_auth_input == 'y':
492 1
            logger.info(std_no_auth_msg)
493 1
        else:
494
            print("\nAborting start up. To enable authentication for your TabPy server, see "
495
                "https://github.com/tableau/TabPy/blob/master/docs/server-config.md#authentication.")
496
            exit()
497
498
    def _get_features(self):
499
        features = {}
500
501
        # Check for auth
502
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
503 1
            features["authentication"] = {
504 1
                "required": True,
505 1
                "methods": {"basic-auth": {}},
506
            }
507
508
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
509
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
510 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
511 1
        return features
512 1
513 1
    def _build_tabpy_state(self):
514
        pkg_path = os.path.dirname(tabpy.__file__)
515 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
516
        state_file_path = os.path.join(state_file_dir, "state.ini")
517 1
        if not os.path.isfile(state_file_path):
518 1
            state_file_template_path = os.path.join(
519
                pkg_path, "tabpy_server", "state.ini.template"
520
            )
521
            logger.debug(
522
                f"File {state_file_path} not found, creating from "
523
                f"template {state_file_template_path}..."
524
            )
525
            shutil.copy(state_file_template_path, state_file_path)
526
527
        logger.info(f"Loading state from state file {state_file_path}")
528
        tabpy_state = _get_state_from_file(state_file_dir)
529
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
530
531
532
# Override _read_body to allow content with size exceeding max_body_size
533
# This enables proper handling of 413 errors in base_handler
534
def _read_body_allow_max_size(self, code, headers, delegate):
535
    if "Content-Length" in headers:
536
        content_length = int(headers["Content-Length"])
537
        if content_length > self._max_body_size:
538
            return
539
    return self.original_read_body(code, headers, delegate)
540
541
HTTP1Connection.original_read_body = HTTP1Connection._read_body
542
HTTP1Connection._read_body = _read_body_allow_max_size
543