Passed
Push — master ( 5716a6...96aa26 )
by
unknown
12:43 queued 15s
created

TabPyApp._create_tornado_web_app()   B

Complexity

Conditions 4

Size

Total Lines 71
Code Lines 52

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 16
CRAP Score 4.2159

Importance

Changes 0
Metric Value
eloc 52
dl 0
loc 71
ccs 16
cts 21
cp 0.7619
rs 8.5709
c 0
b 0
f 0
cc 4
nop 1
crap 4.2159

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1 1
import concurrent.futures
2 1
import configparser
3 1
import logging
4 1
import multiprocessing
5 1
import os
6 1
import shutil
7 1
import signal
8 1
import sys
9 1
import _thread
10
11 1
import tornado
12 1
from tornado.http1connection import HTTP1Connection
13
14 1
import tabpy
15 1
import tabpy.tabpy_server.app.arrow_server as pa
16 1
from tabpy.tabpy import __version__
17 1
from tabpy.tabpy_server.app.app_parameters import ConfigParameters, SettingsParameters
18 1
from tabpy.tabpy_server.app.util import parse_pwd_file
19 1
from tabpy.tabpy_server.handlers.basic_auth_server_middleware_factory import BasicAuthServerMiddlewareFactory
20 1
from tabpy.tabpy_server.handlers.no_op_auth_handler import NoOpAuthHandler
21 1
from tabpy.tabpy_server.management.state import TabPyState
22 1
from tabpy.tabpy_server.management.util import _get_state_from_file
23 1
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
24 1
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
25 1
from tabpy.tabpy_server.handlers import (
26
    EndpointHandler,
27
    EndpointsHandler,
28
    EvaluationPlaneHandler,
29
    EvaluationPlaneDisabledHandler,
30
    QueryPlaneHandler,
31
    ServiceInfoHandler,
32
    StatusHandler,
33
    UploadDestinationHandler,
34
)
35
36 1
logger = logging.getLogger(__name__)
37
38 1
def _init_asyncio_patch():
39
    """
40
    Select compatible event loop for Tornado 5+.
41
    As of Python 3.8, the default event loop on Windows is `proactor`,
42
    however Tornado requires the old default "selector" event loop.
43
    As Tornado has decided to leave this to users to set, MkDocs needs
44
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
45
    """
46 1
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
47
        import asyncio
48
        try:
49
            from asyncio import WindowsSelectorEventLoopPolicy
50
        except ImportError:
51
            pass  # Can't assign a policy which doesn't exist.
52
        else:
53
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
54
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
55
56
57 1
class TabPyApp:
58
    """
59
    TabPy application class for keeping context like settings, state, etc.
60
    """
61
62 1
    settings = {}
63 1
    subdirectory = ""
64 1
    tabpy_state = None
65 1
    python_service = None
66 1
    credentials = {}
67 1
    arrow_server = None
68 1
    max_request_size = None
69
70 1
    def __init__(self, config_file, disable_auth_warning=True):
71 1
        self.disable_auth_warning = disable_auth_warning
72 1
        if config_file is None:
73 1
            config_file = os.path.join(
74
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
75
            )
76
77 1
        if os.path.isfile(config_file):
78 1
            try:
79 1
                from logging import config
80 1
                config.fileConfig(config_file, disable_existing_loggers=False)
81 1
            except KeyError:
82 1
                logging.basicConfig(level=logging.DEBUG)
83
84 1
        self._parse_config(config_file)
85
86 1
    def _get_tls_certificates(self, config):
87
        tls_certificates = []
88
        cert = config[SettingsParameters.CertificateFile]
89
        key = config[SettingsParameters.KeyFile]
90
        with open(cert, "rb") as cert_file:
91
            tls_cert_chain = cert_file.read()
92
        with open(key, "rb") as key_file:
93
            tls_private_key = key_file.read()
94
        tls_certificates.append((tls_cert_chain, tls_private_key))
95
        return tls_certificates
96
    
97 1
    def _get_arrow_server(self, config):
98
        verify_client = None
99
        tls_certificates = None
100
        scheme = "grpc+tcp"
101
        if config[SettingsParameters.TransferProtocol] == "https":
102
            scheme = "grpc+tls"
103
            tls_certificates = self._get_tls_certificates(config)
104
105
        host = "0.0.0.0"
106
        port = config.get(SettingsParameters.ArrowFlightPort)
107
        location = "{}://{}:{}".format(scheme, host, port)
108
109
        auth_middleware = None
110
        if "authentication" in config[SettingsParameters.ApiVersions]["v1"]["features"]:
111
            _, creds = parse_pwd_file(config[ConfigParameters.TABPY_PWD_FILE])
112
            auth_middleware = {
113
                "basic": BasicAuthServerMiddlewareFactory(creds)
114
            }
115
116
        server = pa.FlightServer(host, location,
117
                            tls_certificates=tls_certificates,
118
                            verify_client=verify_client, auth_handler=NoOpAuthHandler(),
119
                            middleware=auth_middleware)
120
        return server
121
122 1
    def run(self):
123
        application = self._create_tornado_web_app()
124
        
125
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
126
127
        protocol = self.settings[SettingsParameters.TransferProtocol]
128
        ssl_options = None
129
        if protocol == "https":
130
            ssl_options = {
131
                "certfile": self.settings[SettingsParameters.CertificateFile],
132
                "keyfile": self.settings[SettingsParameters.KeyFile],
133
            }
134
        elif protocol != "http":
135
            msg = f"Unsupported transfer protocol {protocol}."
136
            logger.critical(msg)
137
            raise RuntimeError(msg)
138
139
        settings = {}
140
        if self.settings[SettingsParameters.GzipEnabled] is True:
141
            settings["decompress_request"] = True
142
143
        application.listen(
144
            self.settings[SettingsParameters.Port],
145
            ssl_options=ssl_options,
146
            max_buffer_size=self.max_request_size,
147
            max_body_size=self.max_request_size,
148
            **settings,
149
        ) 
150
151
        logger.info(
152
            "Web service listening on port "
153
            f"{str(self.settings[SettingsParameters.Port])}"
154
        )
155
156
        if self.settings[SettingsParameters.ArrowEnabled]:
157
            def start_pyarrow():
158
                self.arrow_server = self._get_arrow_server(self.settings)
159
                pa.start(self.arrow_server)
160
161
            try:
162
                _thread.start_new_thread(start_pyarrow, ())
163
            except Exception as e:
164
                logger.critical(f"Failed to start PyArrow server: {e}")
165
166
        tornado.ioloop.IOLoop.instance().start()
167
168 1
    def _create_tornado_web_app(self):
169 1
        class TabPyTornadoApp(tornado.web.Application):
170 1
            is_closing = False
171
172 1
            def signal_handler(self, signal, _):
173
                logger.critical(f"Exiting on signal {signal}...")
174
                self.is_closing = True
175
176 1
            def try_exit(self):
177
                if self.is_closing:
178
                    tornado.ioloop.IOLoop.instance().stop()
179
                    logger.info("Shutting down TabPy...")
180
181 1
        logger.info("Initializing TabPy...")
182 1
        tornado.ioloop.IOLoop.instance().run_sync(
183
            lambda: init_ps_server(self.settings, self.tabpy_state)
184
        )
185 1
        logger.info("Done initializing TabPy.")
186
187 1
        executor = concurrent.futures.ThreadPoolExecutor(
188
            max_workers=multiprocessing.cpu_count()
189
        )
190
191
        # initialize Tornado application
192 1
        _init_asyncio_patch()
193 1
        application = TabPyTornadoApp(
194
            [
195
                (
196
                    self.subdirectory + r"/query/([^/]+)",
197
                    QueryPlaneHandler,
198
                    dict(app=self),
199
                ),
200
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
201
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
202
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
203
                (
204
                    self.subdirectory + r"/endpoints/([^/]+)?",
205
                    EndpointHandler,
206
                    dict(app=self),
207
                ),
208
                (
209
                    self.subdirectory + r"/evaluate",
210
                    EvaluationPlaneHandler if self.settings[SettingsParameters.EvaluateEnabled]
211
                    else EvaluationPlaneDisabledHandler,
212
                    dict(executor=executor, app=self),
213
                ),
214
                (
215
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
216
                    UploadDestinationHandler,
217
                    dict(app=self),
218
                ),
219
                (
220
                    self.subdirectory + r"/(.*)",
221
                    tornado.web.StaticFileHandler,
222
                    dict(
223
                        path=self.settings[SettingsParameters.StaticPath],
224
                        default_filename="index.html",
225
                    ),
226
                ),
227
            ],
228
            debug=False,
229
            **self.settings,
230
        )
231
232 1
        signal.signal(signal.SIGINT, application.signal_handler)
233 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
234
235 1
        signal.signal(signal.SIGINT, application.signal_handler)
236 1
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
237
238 1
        return application
239
240 1
    def _set_parameter(self, parser, settings_key, config_key, default_val, parse_function):
241 1
        key_is_set = False
242
243 1
        if (
244
            config_key is not None
245
            and parser.has_section("TabPy")
246
            and parser.has_option("TabPy", config_key)
247
        ):
248 1
            if parse_function is None:
249 1
                parse_function = parser.get
250 1
            self.settings[settings_key] = parse_function("TabPy", config_key)
251 1
            key_is_set = True
252 1
            logger.debug(
253
                f"Parameter {settings_key} set to "
254
                f'"{self.settings[settings_key]}" '
255
                "from config file or environment variable"
256
            )
257
258 1
        if not key_is_set and default_val is not None:
259 1
            self.settings[settings_key] = default_val
260 1
            key_is_set = True
261 1
            logger.debug(
262
                f"Parameter {settings_key} set to "
263
                f'"{self.settings[settings_key]}" '
264
                "from default value"
265
            )
266
267 1
        if not key_is_set:
268 1
            logger.debug(f"Parameter {settings_key} is not set")
269
270 1
    def _parse_config(self, config_file):
271
        """Provide consistent mechanism for pulling in configuration.
272
273
        Attempt to retain backward compatibility for
274
        existing implementations by grabbing port
275
        setting from CLI first.
276
277
        Take settings in the following order:
278
279
        1. CLI arguments if present
280
        2. config file
281
        3. OS environment variables (for ease of
282
           setting defaults if not present)
283
        4. current defaults if a setting is not present in any location
284
285
        Additionally provide similar configuration capabilities in between
286
        config file and environment variables.
287
        For consistency use the same variable name in the config file as
288
        in the os environment.
289
        For naming standards use all capitals and start with 'TABPY_'
290
        """
291 1
        self.settings = {}
292 1
        self.subdirectory = ""
293 1
        self.tabpy_state = None
294 1
        self.python_service = None
295 1
        self.credentials = {}
296
297 1
        pkg_path = os.path.dirname(tabpy.__file__)
298
299 1
        parser = configparser.ConfigParser(os.environ)
300 1
        logger.info(f"Parsing config file {config_file}")
301
302 1
        file_exists = False
303 1
        if os.path.isfile(config_file):
304 1
            try:
305 1
                with open(config_file, 'r') as f:
306 1
                    parser.read_string(f.read())
307 1
                    file_exists = True
308 1
            except Exception:
309 1
                pass
310
311 1
        if not file_exists:
312 1
            logger.warning(
313
                f"Unable to open config file {config_file}, "
314
                "using default settings."
315
            )
316
317 1
        settings_parameters = [
318
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004, None),
319
            (SettingsParameters.ServerVersion, None, __version__, None),
320
            (SettingsParameters.EvaluateEnabled, ConfigParameters.TABPY_EVALUATE_ENABLE,
321
             True, parser.getboolean),
322
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT,
323
             30, parser.getfloat),
324
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
325
             os.path.join(pkg_path, "tmp", "query_objects"), None),
326
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
327
             "http", None),
328
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
329
             None, None),
330
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None, None),
331
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
332
             os.path.join(pkg_path, "tabpy_server"), None),
333
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
334
             os.path.join(pkg_path, "tabpy_server", "static"), None),
335
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None, None),
336
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
337
             "false", None),
338
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
339
             100, None),
340
            (SettingsParameters.GzipEnabled, ConfigParameters.TABPY_GZIP_ENABLE,
341
             True, parser.getboolean),
342
            (SettingsParameters.ArrowEnabled, ConfigParameters.TABPY_ARROW_ENABLE, False, parser.getboolean), 
343
            (SettingsParameters.ArrowFlightPort, ConfigParameters.TABPY_ARROWFLIGHT_PORT, 13622, parser.getint),
344
        ]
345
346 1
        for setting, parameter, default_val, parse_function in settings_parameters:
347 1
            self._set_parameter(parser, setting, parameter, default_val, parse_function)
348
349 1
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
350 1
            os.makedirs(self.settings[SettingsParameters.UploadDir])
351
352
        # set and validate transfer protocol
353 1
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
354
            SettingsParameters.TransferProtocol
355
        ].lower()
356
357 1
        self._validate_transfer_protocol_settings()
358
        
359
        # Set max request size in bytes
360 1
        self.max_request_size = (
361
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
362
        )
363 1
        logger.info(f"Setting max request size to {self.max_request_size} bytes")
364
365
        # if state.ini does not exist try and create it - remove
366
        # last dependence on batch/shell script
367 1
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
368
            os.path.normpath(
369
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
370
            )
371
        )
372 1
        state_config, self.tabpy_state = self._build_tabpy_state()
373
374 1
        self.python_service = PythonServiceHandler(PythonService())
375 1
        self.settings["compress_response"] = True
376 1
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
377
            self.settings[SettingsParameters.StaticPath]
378
        )
379 1
        logger.debug(
380
            f"Static pages folder set to "
381
            f'"{self.settings[SettingsParameters.StaticPath]}"'
382
        )
383
384
        # Set subdirectory from config if applicable
385 1
        if state_config.has_option("Service Info", "Subdirectory"):
386 1
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
387
388
        # If passwords file specified load credentials
389 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
390 1
            if not self._parse_pwd_file():
391 1
                msg = (
392
                    "Failed to read passwords file "
393
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
394
                )
395 1
                logger.critical(msg)
396 1
                raise RuntimeError(msg)
397
        else:
398 1
            self._handle_configuration_without_authentication()
399
400 1
        features = self._get_features()
401 1
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
402
403 1
        self.settings[SettingsParameters.LogRequestContext] = (
404
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
405
        )
406 1
        call_context_state = (
407
            "enabled"
408
            if self.settings[SettingsParameters.LogRequestContext]
409
            else "disabled"
410
        )
411 1
        logger.info(f"Call context logging is {call_context_state}")
412
413 1
    def _validate_transfer_protocol_settings(self):
414 1
        if SettingsParameters.TransferProtocol not in self.settings:
415
            msg = "Missing transfer protocol information."
416
            logger.critical(msg)
417
            raise RuntimeError(msg)
418
419 1
        protocol = self.settings[SettingsParameters.TransferProtocol]
420
421 1
        if protocol == "http":
422 1
            return
423
424 1
        if protocol != "https":
425 1
            msg = f"Unsupported transfer protocol: {protocol}"
426 1
            logger.critical(msg)
427 1
            raise RuntimeError(msg)
428
429 1
        self._validate_cert_key_state(
430
            "The parameter(s) {} must be set.",
431
            SettingsParameters.CertificateFile in self.settings,
432
            SettingsParameters.KeyFile in self.settings,
433
        )
434 1
        cert = self.settings[SettingsParameters.CertificateFile]
435
436 1
        self._validate_cert_key_state(
437
            "The parameter(s) {} must point to " "an existing file.",
438
            os.path.isfile(cert),
439
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
440
        )
441 1
        tabpy.tabpy_server.app.util.validate_cert(cert)
442
443 1
    @staticmethod
444
    def _validate_cert_key_state(msg, cert_valid, key_valid):
445 1
        cert_and_key_param = (
446
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
447
            f"{ConfigParameters.TABPY_KEY_FILE}"
448
        )
449 1
        https_error = "Error using HTTPS: "
450 1
        err = None
451 1
        if not cert_valid and not key_valid:
452 1
            err = https_error + msg.format(cert_and_key_param)
453 1
        elif not cert_valid:
454 1
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
455 1
        elif not key_valid:
456 1
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
457
458 1
        if err is not None:
459 1
            logger.critical(err)
460 1
            raise RuntimeError(err)
461
462 1
    def _parse_pwd_file(self):
463 1
        succeeded, self.credentials = parse_pwd_file(
464
            self.settings[ConfigParameters.TABPY_PWD_FILE]
465
        )
466
467 1
        if succeeded and len(self.credentials) == 0:
468 1
            logger.error("No credentials found")
469 1
            succeeded = False
470
471 1
        return succeeded
472
473 1
    def _handle_configuration_without_authentication(self):
474 1
        std_no_auth_msg = "Password file is not specified: Authentication is not enabled"
475
476 1
        if self.disable_auth_warning == True:
477 1
            logger.info(std_no_auth_msg)
478 1
            return  
479
480 1
        confirm_no_auth_msg = "\nWARNING: This TabPy server is not currently configured for username/password authentication. "
481
482 1
        if self.settings[SettingsParameters.EvaluateEnabled]:
483 1
            confirm_no_auth_msg += ("This means that, because the TABPY_EVALUATE_ENABLE feature is enabled, there is " 
484
                "the potential that unauthenticated individuals may be able to remotely execute code on this machine. ")
485
486 1
        confirm_no_auth_msg += ("We strongly advise against proceeding without authentication as it poses a significant security risk.\n\n"
487
            "Do you wish to proceed without authentication? (y/N): ")
488
489 1
        confirm_no_auth_input = input(confirm_no_auth_msg)
490
491 1
        if confirm_no_auth_input == 'y':
492 1
            logger.info(std_no_auth_msg)
493
        else:
494
            print("\nAborting start up. To enable authentication for your TabPy server, see "
495
                "https://github.com/tableau/TabPy/blob/master/docs/server-config.md#authentication.")
496
            exit()
497
498 1
    def _get_features(self):
499 1
        features = {}
500
501
        # Check for auth
502 1
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
503 1
            features["authentication"] = {
504
                "required": True,
505
                "methods": {"basic-auth": {}},
506
            }
507
508 1
        features["evaluate_enabled"] = self.settings[SettingsParameters.EvaluateEnabled]
509 1
        features["gzip_enabled"] = self.settings[SettingsParameters.GzipEnabled]
510 1
        features["arrow_enabled"] = self.settings[SettingsParameters.ArrowEnabled]
511 1
        return features
512
513 1
    def _build_tabpy_state(self):
514 1
        pkg_path = os.path.dirname(tabpy.__file__)
515 1
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
516 1
        state_file_path = os.path.join(state_file_dir, "state.ini")
517 1
        if not os.path.isfile(state_file_path):
518
            state_file_template_path = os.path.join(
519
                pkg_path, "tabpy_server", "state.ini.template"
520
            )
521
            logger.debug(
522
                f"File {state_file_path} not found, creating from "
523
                f"template {state_file_template_path}..."
524
            )
525
            shutil.copy(state_file_template_path, state_file_path)
526
527 1
        logger.info(f"Loading state from state file {state_file_path}")
528 1
        tabpy_state = _get_state_from_file(state_file_dir)
529 1
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
530
531
532
# Override _read_body to allow content with size exceeding max_body_size
533
# This enables proper handling of 413 errors in base_handler
534 1
def _read_body_allow_max_size(self, code, headers, delegate):
535 1
    if "Content-Length" in headers:
536 1
        content_length = int(headers["Content-Length"])
537 1
        if content_length > self._max_body_size:
538
            return
539 1
    return self.original_read_body(code, headers, delegate)
540
541 1
HTTP1Connection.original_read_body = HTTP1Connection._read_body
542
HTTP1Connection._read_body = _read_body_allow_max_size
543