Passed
Pull Request — master (#638)
by
unknown
13:06
created

TabPyApp._create_tornado_web_app()   B

Complexity

Conditions 4

Size

Total Lines 71
Code Lines 52

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 16
CRAP Score 4.2159

Importance

Changes 0
Metric Value
eloc 52
dl 0
loc 71
ccs 16
cts 21
cp 0.7619
rs 8.5709
c 0
b 0
f 0
cc 4
nop 1
crap 4.2159

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import ssl
9 1
import sys
10 1
import _thread
11
12 1
import tornado
13 1
from tornado.http1connection import HTTP1Connection
14
15 1
import tabpy
16 1
import tabpy.tabpy_server.app.arrow_server as pa
17 1
from tabpy.tabpy import __version__
18 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
19 1
from tabpy.tabpy_server.app.util import parse_pwd_file
20 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
21 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
22 1
from tabpy.tabpy_server.management.state import TabPyState
23 1
from tabpy.tabpy_server.management.util import _get_state_from_file
24 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
25 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
26 1
from tabpy.tabpy_server.handlers import (
27
    EndpointHandler,
28
    EndpointsHandler,
29
    EvaluationPlaneHandler,
30
    EvaluationPlaneDisabledHandler,
31
    QueryPlaneHandler,
32
    ServiceInfoHandler,
33
    StatusHandler,
34
    UploadDestinationHandler,
35
)
36
37 1
logger = logging.getLogger(__name__)
38
39 1
def _init_asyncio_patch():
40
    """
41
    Select compatible event loop for Tornado 5+.
42
    As of Python 3.8, the default event loop on Windows is `proactor`,
43
    however Tornado requires the old default "selector" event loop.
44
    As Tornado has decided to leave this to users to set, MkDocs needs
45
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
46
    """
47 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
48
        import asyncio
49
        try:
50
            from asyncio import WindowsSelectorEventLoopPolicy
51
        except ImportError:
52
            pass  # Can't assign a policy which doesn't exist.
53
        else:
54
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
55
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
56
57
58 1
class TabPyApp:
59
    """
60
    TabPy application class for keeping context like settings, state, etc.
61
    """
62
63 1
    settings = {}
64 1
    subdirectory = ""
65 1
    tabpy_state = None
66 1
    python_service = None
67 1
    credentials = {}
68 1
    arrow_server = None
69 1
    max_request_size = None
70
71 1
    def __init__(self, config_file, disable_auth_warning=True):
72 1
        self.disable_auth_warning = disable_auth_warning
73 1
        if config_file is None:
74 1
            config_file = os.path.join(
75
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
76
            )
77
78 1
        if os.path.isfile(config_file):
79 1
            try:
80 1
                from logging import config
81 1
                config.fileConfig(config_file, disable_existing_loggers=False)
82 1
            except KeyError:
83 1
                logging.basicConfig(level=logging.DEBUG)
84
85 1
        self._parse_config(config_file)
86
87 1
    def _initialize_ssl_context(self):
88
        ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
89
90
        ssl_context.load_cert_chain(
91
            certfile=self.settings[SettingsParameters.CertificateFile],
92
            keyfile=self.settings[SettingsParameters.KeyFile]
93
        )
94
95
        min_tls = self.settings.get(SettingsParameters.MinimumTLSVersion)
96
        try:
97
            ssl_context.minimum_version = ssl.TLSVersion[min_tls]
98
            logger.info(f"Setting minimum TLS version to: {min_tls}")
99
        except KeyError:
100
            logger.warning(f"Unrecognized value for TABPY_MINIMUM_TLS_VERSION: {min_tls}")
101
102
        return ssl_context
103
104 1
    def _get_tls_certificates(self, config):
105
        tls_certificates = []
106
        cert = config[SettingsParameters.CertificateFile]
107
        key = config[SettingsParameters.KeyFile]
108
        with open(cert, "rb") as cert_file:
109
            tls_cert_chain = cert_file.read()
110
        with open(key, "rb") as key_file:
111
            tls_private_key = key_file.read()
112
        tls_certificates.append((tls_cert_chain, tls_private_key))
113
        return tls_certificates
114
    
115 1
    def _get_arrow_server(self, config):
116
        verify_client = None
117
        tls_certificates = None
118
        scheme = "grpc+tcp"
119
        if config[SettingsParameters.TransferProtocol] == "https":
120
            scheme = "grpc+tls"
121
            tls_certificates = self._get_tls_certificates(config)
122
123
        host = "0.0.0.0"
124
        port = config.get(SettingsParameters.ArrowFlightPort)
125
        location = "{}://{}:{}".format(scheme, host, port)
126
127
        auth_middleware = None
128
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
129
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
130
            auth_middleware = {
131
                "basic": BasicAuthServerMiddlewareFactory(creds)
132
            }
133
134
        server = pa.FlightServer(host, location,
135
                            tls_certificates=tls_certificates,
136
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
137
                            middleware=auth_middleware)
138
        return server
139
140 1
    def run(self):
141
        application = self._create_tornado_web_app()
142
        
143
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
144
145
        protocol = self.settings[SettingsParameters.TransferProtocol]
146
        ssl_options = None
147
        if protocol == "https":
148
            ssl_options = self._initialize_ssl_context()
149
        elif protocol != "http":
150
            msg = f"Unsupported transfer protocol {protocol}."
151
            logger.critical(msg)
152
            raise RuntimeError(msg)
153
154
        settings = {}
155
        if self.settings[SettingsParameters.GzipEnabled] is True:
156
            settings["decompress_request"] = True
157
158
        application.listen(
159
            self.settings[SettingsParameters.Port],
160
            ssl_options=ssl_options,
161
            max_buffer_size=self.max_request_size,
162
            max_body_size=self.max_request_size,
163
            **settings,
164
        ) 
165
166
        logger.info(
167
            "Web service listening on port "
168
            f"{str(self.settings[SettingsParameters.Port])}"
169
        )
170
171
        if self.settings[SettingsParameters.ArrowEnabled]:
172
            def start_pyarrow():
173
                self.arrow_server = self._get_arrow_server(self.settings)
174
                pa.start(self.arrow_server)
175
176
            try:
177
                _thread.start_new_thread(start_pyarrow, ())
178
            except Exception as e:
179
                logger.critical(f"Failed to start PyArrow server: {e}")
180
181
        tornado.ioloop.IOLoop.instance().start()
182
183 1
    def _create_tornado_web_app(self):
184 1
        class TabPyTornadoApp(tornado.web.Application):
185 1
            is_closing = False
186
187 1
            def signal_handler(self, signal, _):
188
                logger.critical(f"Exiting on signal {signal}...")
189
                self.is_closing = True
190
191 1
            def try_exit(self):
192
                if self.is_closing:
193
                    tornado.ioloop.IOLoop.instance().stop()
194
                    logger.info("Shutting down TabPy...")
195
196 1
        logger.info("Initializing TabPy...")
197 1
        tornado.ioloop.IOLoop.instance().run_sync(
198
            lambda: init_ps_server(self.settings, self.tabpy_state)
199
        )
200 1
        logger.info("Done initializing TabPy.")
201
202 1
        executor = concurrent.futures.ThreadPoolExecutor(
203
            max_workers=multiprocessing.cpu_count()
204
        )
205
206
        # initialize Tornado application
207 1
        _init_asyncio_patch()
208 1
        application = TabPyTornadoApp(
209
            [
210
                (
211
                    self.subdirectory + r"/query/([^/]+)",
212
                    QueryPlaneHandler,
213
                    dict(app=self),
214
                ),
215
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
216
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
217
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
218
                (
219
                    self.subdirectory + r"/endpoints/([^/]+)?",
220
                    EndpointHandler,
221
                    dict(app=self),
222
                ),
223
                (
224
                    self.subdirectory + r"/evaluate",
225
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
226
                    else EvaluationPlaneDisabledHandler,
227
                    dict(executor=executor, app=self),
228
                ),
229
                (
230
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
231
                    UploadDestinationHandler,
232
                    dict(app=self),
233
                ),
234
                (
235
                    self.subdirectory + r"/(.*)",
236
                    tornado.web.StaticFileHandler,
237
                    dict(
238
                        path=self.settings[SettingsParameters.StaticPath],
239
                        default_filename="index.html",
240
                    ),
241
                ),
242
            ],
243
            debug=False,
244
            **self.settings,
245
        )
246
247 1
        signal.signal(signal.SIGINT, application.signal_handler)
248 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
249
250 1
        signal.signal(signal.SIGINT, application.signal_handler)
251 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
252
253 1
        return application
254
255 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
256 1
        key_is_set = False
257
258 1
        if (
259
            config_key is not None
260
            and parser.has_section("TabPy")
261
            and parser.has_option("TabPy", config_key)
262
        ):
263 1
            if parse_function is None:
264 1
                parse_function = parser.get
265 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
266 1
            key_is_set = True
267 1
            logger.debug(
268
                f"Parameter {settings_key} set to "
269
                f'"{self.settings[settings_key]}" '
270
                "from config file or environment variable"
271
            )
272
273 1
        if not key_is_set and default_val is not None:
274 1
            self.settings[settings_key] = default_val
275 1
            key_is_set = True
276 1
            logger.debug(
277
                f"Parameter {settings_key} set to "
278
                f'"{self.settings[settings_key]}" '
279
                "from default value"
280
            )
281
282 1
        if not key_is_set:
283 1
            logger.debug(f"Parameter {settings_key} is not set")
284
285 1
    def _parse_config(self, config_file):
286
        """Provide consistent mechanism for pulling in configuration.
287
288
        Attempt to retain backward compatibility for
289
        existing implementations by grabbing port
290
        setting from CLI first.
291
292
        Take settings in the following order:
293
294
        1. CLI arguments if present
295
        2. config file
296
        3. OS environment variables (for ease of
297
           setting defaults if not present)
298
        4. current defaults if a setting is not present in any location
299
300
        Additionally provide similar configuration capabilities in between
301
        config file and environment variables.
302
        For consistency use the same variable name in the config file as
303
        in the os environment.
304
        For naming standards use all capitals and start with 'TABPY_'
305
        """
306 1
        self.settings = {}
307 1
        self.subdirectory = ""
308 1
        self.tabpy_state = None
309 1
        self.python_service = None
310 1
        self.credentials = {}
311
312 1
        pkg_path = os.path.dirname(tabpy.__file__)
313
314 1
        parser = configparser.ConfigParser(os.environ)
315 1
        logger.info(f"Parsing config file {config_file}")
316
317 1
        file_exists = False
318 1
        if os.path.isfile(config_file):
319 1
            try:
320 1
                with open(config_file, 'r') as f:
321 1
                    parser.read_string(f.read())
322 1
                    file_exists = True
323 1
            except Exception:
324 1
                pass
325
326 1
        if not file_exists:
327 1
            logger.warning(
328
                f"Unable to open config file {config_file}, "
329
                "using default settings."
330
            )
331
332 1
        settings_parameters = [
333
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
334
            (SettingsParameters.ServerVersion, None, __version__, None),
335
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
336
             True, parser.getboolean),
337
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
338
             30, parser.getfloat),
339
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
340
             os.path.join(pkg_path, "tmp", "query_objects"), None),
341
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
342
             "http", None),
343
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
344
             None, None),
345
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
346
            (SettingsParameters.MinimumTLSVersion, ConfigParameters.TABPY_MINIMUM_TLS_VERSION,
347
             "TLSv1_2", None),
348
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
349
             os.path.join(pkg_path, "tabpy_server"), None),
350
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
351
             os.path.join(pkg_path, "tabpy_server", "static"), None),
352
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
353
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
354
             "false", None),
355
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
356
             100, None),
357
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
358
             True, parser.getboolean),
359
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
360
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
361
        ]
362
363 1
        for setting, parameter, default_val, parse_function in settings_parameters:
364 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
365
366 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
367 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
368
369
        # set and validate transfer protocol
370 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
371
            SettingsParameters.TransferProtocol
372
        ].lower()
373
374 1
        self._validate_transfer_protocol_settings()
375
        
376
        # Set max request size in bytes
377 1
        self.max_request_size = (
378
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
379
        )
380 1
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
381
382
        # if state.ini does not exist try and create it - remove
383
        # last dependence on batch/shell script
384 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
385
            os.path.normpath(
386
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
387
            )
388
        )
389 1
        state_config, self.tabpy_state = self._build_tabpy_state()
390
391 1
        self.python_service = PythonServiceHandler(PythonService())
392 1
        self.settings["compress_response"] = True
393 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
394
            self.settings[SettingsParameters.StaticPath]
395
        )
396 1
        logger.debug(
397
            f"Static pages folder set to "
398
            f'"{self.settings[SettingsParameters.StaticPath]}"'
399
        )
400
401
        # Set subdirectory from config if applicable
402 1
        if state_config.has_option("Service Info", "Subdirectory"):
403 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
404
405
        # If passwords file specified load credentials
406 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
407 1
            if not self._parse_pwd_file():
408 1
                msg = (
409
                    "Failed to read passwords file "
410
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
411
                )
412 1
                logger.critical(msg)
413 1
                raise RuntimeError(msg)
414
        else:
415 1
            self._handle_configuration_without_authentication()
416
417 1
        features = self._get_features()
418 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
419
420 1
        self.settings[SettingsParameters.LogRequestContext] = (
421
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
422
        )
423 1
        call_context_state = (
424
            "enabled"
425
            if self.settings[SettingsParameters.LogRequestContext]
426
            else "disabled"
427
        )
428 1
        logger.info(f"Call context logging is {call_context_state}")
429
430 1
    def _validate_transfer_protocol_settings(self):
431 1
        if SettingsParameters.TransferProtocol not in self.settings:
432
            msg = "Missing transfer protocol information."
433
            logger.critical(msg)
434
            raise RuntimeError(msg)
435
436 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
437
438 1
        if protocol == "http":
439 1
            return
440
441 1
        if protocol != "https":
442 1
            msg = f"Unsupported transfer protocol: {protocol}"
443 1
            logger.critical(msg)
444 1
            raise RuntimeError(msg)
445
446 1
        self._validate_cert_key_state(
447
            "The parameter(s) {} must be set.",
448
            SettingsParameters.CertificateFile in self.settings,
449
            SettingsParameters.KeyFile in self.settings,
450
        )
451 1
        cert = self.settings[SettingsParameters.CertificateFile]
452
453 1
        self._validate_cert_key_state(
454
            "The parameter(s) {} must point to " "an existing file.",
455
            os.path.isfile(cert),
456
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
457
        )
458 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
459
460 1
    @staticmethod
461
    def _validate_cert_key_state(msg, cert_valid, key_valid):
462 1
        cert_and_key_param = (
463
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
464
            f"{ConfigParameters.TABPY_KEY_FILE}"
465
        )
466 1
        https_error = "Error using HTTPS: "
467 1
        err = None
468 1
        if not cert_valid and not key_valid:
469 1
            err = https_error + msg.format(cert_and_key_param)
470 1
        elif not cert_valid:
471 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
472 1
        elif not key_valid:
473 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
474
475 1
        if err is not None:
476 1
            logger.critical(err)
477 1
            raise RuntimeError(err)
478
479 1
    def _parse_pwd_file(self):
480 1
        succeeded, self.credentials = parse_pwd_file(
481
            self.settings[ConfigParameters.TABPY_PWD_FILE]
482
        )
483
484 1
        if succeeded and len(self.credentials) == 0:
485 1
            logger.error("No credentials found")
486 1
            succeeded = False
487
488 1
        return succeeded
489
490 1
    def _handle_configuration_without_authentication(self):
491 1
        std_no_auth_msg = "Password file is not specified: Authentication is not enabled"
492
493 1
        if self.disable_auth_warning == True:
494 1
            logger.info(std_no_auth_msg)
495 1
            return  
496
497 1
        confirm_no_auth_msg = "\nWARNING: This TabPy server is not currently configured for username/password authentication. "
498
499 1
        if self.settings[SettingsParameters.EvaluateEnabled]:
500 1
            confirm_no_auth_msg += ("This means that, because the TABPY_EVALUATE_ENABLE feature is enabled, there is " 
501
                "the potential that unauthenticated individuals may be able to remotely execute code on this machine. ")
502
503 1
        confirm_no_auth_msg += ("We strongly advise against proceeding without authentication as it poses a significant security risk.\n\n"
504
            "Do you wish to proceed without authentication? (y/N): ")
505
506 1
        confirm_no_auth_input = input(confirm_no_auth_msg)
507
508 1
        if confirm_no_auth_input == 'y':
509 1
            logger.info(std_no_auth_msg)
510
        else:
511
            print("\nAborting start up. To enable authentication for your TabPy server, see "
512
                "https://github.com/tableau/TabPy/blob/master/docs/server-config.md#authentication.")
513
            exit()
514
515 1
    def _get_features(self):
516 1
        features = {}
517
518
        # Check for auth
519 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
520 1
            features["authentication"] = {
521
                "required": True,
522
                "methods": {"basic-auth": {}},
523
            }
524
525 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
526 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
527 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
528 1
        return features
529
530 1
    def _build_tabpy_state(self):
531 1
        pkg_path = os.path.dirname(tabpy.__file__)
532 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
533 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
534 1
        if not os.path.isfile(state_file_path):
535
            state_file_template_path = os.path.join(
536
                pkg_path, "tabpy_server", "state.ini.template"
537
            )
538
            logger.debug(
539
                f"File {state_file_path} not found, creating from "
540
                f"template {state_file_template_path}..."
541
            )
542
            shutil.copy(state_file_template_path, state_file_path)
543
544 1
        logger.info(f"Loading state from state file {state_file_path}")
545 1
        tabpy_state = _get_state_from_file(state_file_dir)
546 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
547
548
549
# Override _read_body to allow content with size exceeding max_body_size
550
# This enables proper handling of 413 errors in base_handler
551 1
def _read_body_allow_max_size(self, code, headers, delegate):
552 1
    if "Content-Length" in headers:
553 1
        content_length = int(headers["Content-Length"])
554 1
        if content_length > self._max_body_size:
555
            return
556 1
    return self.original_read_body(code, headers, delegate)
557
558 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
559
HTTP1Connection._read_body = _read_body_allow_max_size
560