Passed
Push — master ( 53016e...38fde4 )
by Oleksandr
02:52
created

TabPyApp._set_parameter()   B

Complexity

Conditions 7

Size

Total Lines 27
Code Lines 17

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 17
dl 0
loc 27
rs 8
c 0
b 0
f 0
cc 7
nop 5
1
import concurrent.futures
2
import configparser
3
import logging
4
from logging import config
5
import multiprocessing
6
import os
7
import shutil
8
import signal
9
import sys
10
import tabpy.tabpy_server
11
from tabpy.tabpy import __version__
12
from tabpy.tabpy_server.app.ConfigParameters import ConfigParameters
13
from tabpy.tabpy_server.app.SettingsParameters import SettingsParameters
14
from tabpy.tabpy_server.app.util import parse_pwd_file
15
from tabpy.tabpy_server.management.state import TabPyState
16
from tabpy.tabpy_server.management.util import _get_state_from_file
17
from tabpy.tabpy_server.psws.callbacks import init_model_evaluator, init_ps_server
18
from tabpy.tabpy_server.psws.python_service import PythonService, PythonServiceHandler
19
from tabpy.tabpy_server.handlers import (
20
    EndpointHandler,
21
    EndpointsHandler,
22
    EvaluationPlaneHandler,
23
    QueryPlaneHandler,
24
    ServiceInfoHandler,
25
    StatusHandler,
26
    UploadDestinationHandler,
27
)
28
import tornado
29
30
31
logger = logging.getLogger(__name__)
32
33
34
def _init_asyncio_patch():
35
    """
36
    Select compatible event loop for Tornado 5+.
37
    As of Python 3.8, the default event loop on Windows is `proactor`,
38
    however Tornado requires the old default "selector" event loop.
39
    As Tornado has decided to leave this to users to set, MkDocs needs
40
    to set it. See https://github.com/tornadoweb/tornado/issues/2608.
41
    """
42
    if sys.platform.startswith("win") and sys.version_info >= (3, 8):
43
        import asyncio
44
        try:
45
            from asyncio import WindowsSelectorEventLoopPolicy
46
        except ImportError:
47
            pass  # Can't assign a policy which doesn't exist.
48
        else:
49
            if not isinstance(asyncio.get_event_loop_policy(), WindowsSelectorEventLoopPolicy):
50
                asyncio.set_event_loop_policy(WindowsSelectorEventLoopPolicy())
51
52
53
class TabPyApp:
54
    """
55
    TabPy application class for keeping context like settings, state, etc.
56
    """
57
58
    settings = {}
59
    subdirectory = ""
60
    tabpy_state = None
61
    python_service = None
62
    credentials = {}
63
64
    def __init__(self, config_file=None):
65
        if config_file is None:
66
            config_file = os.path.join(
67
                os.path.dirname(__file__), os.path.pardir, "common", "default.conf"
68
            )
69
70
        if os.path.isfile(config_file):
71
            try:
72
                logging.config.fileConfig(config_file, disable_existing_loggers=False)
73
            except KeyError:
74
                logging.basicConfig(level=logging.DEBUG)
75
76
        self._parse_config(config_file)
77
78
    def run(self):
79
        application = self._create_tornado_web_app()
80
        max_request_size = (
81
            int(self.settings[SettingsParameters.MaxRequestSizeInMb]) * 1024 * 1024
82
        )
83
        logger.info(f"Setting max request size to {max_request_size} bytes")
84
85
        init_model_evaluator(self.settings, self.tabpy_state, self.python_service)
86
87
        protocol = self.settings[SettingsParameters.TransferProtocol]
88
        ssl_options = None
89
        if protocol == "https":
90
            ssl_options = {
91
                "certfile": self.settings[SettingsParameters.CertificateFile],
92
                "keyfile": self.settings[SettingsParameters.KeyFile],
93
            }
94
        elif protocol != "http":
95
            msg = f"Unsupported transfer protocol {protocol}."
96
            logger.critical(msg)
97
            raise RuntimeError(msg)
98
99
        application.listen(
100
            self.settings[SettingsParameters.Port],
101
            ssl_options=ssl_options,
102
            max_buffer_size=max_request_size,
103
            max_body_size=max_request_size,
104
        )
105
106
        logger.info(
107
            "Web service listening on port "
108
            f"{str(self.settings[SettingsParameters.Port])}"
109
        )
110
        tornado.ioloop.IOLoop.instance().start()
111
112
    def _create_tornado_web_app(self):
113
        class TabPyTornadoApp(tornado.web.Application):
114
            is_closing = False
115
116
            def signal_handler(self, signal, _):
117
                logger.critical(f"Exiting on signal {signal}...")
118
                self.is_closing = True
119
120
            def try_exit(self):
121
                if self.is_closing:
122
                    tornado.ioloop.IOLoop.instance().stop()
123
                    logger.info("Shutting down TabPy...")
124
125
        logger.info("Initializing TabPy...")
126
        tornado.ioloop.IOLoop.instance().run_sync(
127
            lambda: init_ps_server(self.settings, self.tabpy_state)
128
        )
129
        logger.info("Done initializing TabPy.")
130
131
        executor = concurrent.futures.ThreadPoolExecutor(
132
            max_workers=multiprocessing.cpu_count()
133
        )
134
135
        # initialize Tornado application
136
        _init_asyncio_patch()
137
        application = TabPyTornadoApp(
138
            [
139
                # skip MainHandler to use StaticFileHandler .* page requests and
140
                # default to index.html
141
                # (r"/", MainHandler),
142
                (
143
                    self.subdirectory + r"/query/([^/]+)",
144
                    QueryPlaneHandler,
145
                    dict(app=self),
146
                ),
147
                (self.subdirectory + r"/status", StatusHandler, dict(app=self)),
148
                (self.subdirectory + r"/info", ServiceInfoHandler, dict(app=self)),
149
                (self.subdirectory + r"/endpoints", EndpointsHandler, dict(app=self)),
150
                (
151
                    self.subdirectory + r"/endpoints/([^/]+)?",
152
                    EndpointHandler,
153
                    dict(app=self),
154
                ),
155
                (
156
                    self.subdirectory + r"/evaluate",
157
                    EvaluationPlaneHandler,
158
                    dict(executor=executor, app=self),
159
                ),
160
                (
161
                    self.subdirectory + r"/configurations/endpoint_upload_destination",
162
                    UploadDestinationHandler,
163
                    dict(app=self),
164
                ),
165
                (
166
                    self.subdirectory + r"/(.*)",
167
                    tornado.web.StaticFileHandler,
168
                    dict(
169
                        path=self.settings[SettingsParameters.StaticPath],
170
                        default_filename="index.html",
171
                    ),
172
                ),
173
            ],
174
            debug=False,
175
            **self.settings,
176
        )
177
178
        signal.signal(signal.SIGINT, application.signal_handler)
179
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
180
181
        signal.signal(signal.SIGINT, application.signal_handler)
182
        tornado.ioloop.PeriodicCallback(application.try_exit, 500).start()
183
184
        return application
185
186
    def _set_parameter(self, parser, settings_key, config_key, default_val):
187
        key_is_set = False
188
189
        if (
190
            config_key is not None
191
            and parser.has_section("TabPy")
192
            and parser.has_option("TabPy", config_key)
193
        ):
194
            self.settings[settings_key] = parser.get("TabPy", config_key)
195
            key_is_set = True
196
            logger.debug(
197
                f"Parameter {settings_key} set to "
198
                f'"{self.settings[settings_key]}" '
199
                "from config file or environment variable"
200
            )
201
202
        if not key_is_set and default_val is not None:
203
            self.settings[settings_key] = default_val
204
            key_is_set = True
205
            logger.debug(
206
                f"Parameter {settings_key} set to "
207
                f'"{self.settings[settings_key]}" '
208
                "from default value"
209
            )
210
211
        if not key_is_set:
212
            logger.debug(f"Parameter {settings_key} is not set")
213
214
    def _parse_config(self, config_file):
215
        """Provide consistent mechanism for pulling in configuration.
216
217
        Attempt to retain backward compatibility for
218
        existing implementations by grabbing port
219
        setting from CLI first.
220
221
        Take settings in the following order:
222
223
        1. CLI arguments if present
224
        2. config file
225
        3. OS environment variables (for ease of
226
           setting defaults if not present)
227
        4. current defaults if a setting is not present in any location
228
229
        Additionally provide similar configuration capabilities in between
230
        config file and environment variables.
231
        For consistency use the same variable name in the config file as
232
        in the os environment.
233
        For naming standards use all capitals and start with 'TABPY_'
234
        """
235
        self.settings = {}
236
        self.subdirectory = ""
237
        self.tabpy_state = None
238
        self.python_service = None
239
        self.credentials = {}
240
241
        pkg_path = os.path.dirname(tabpy.__file__)
242
243
        parser = configparser.ConfigParser(os.environ)
244
245
        if os.path.isfile(config_file):
246
            with open(config_file) as f:
247
                parser.read_string(f.read())
248
        else:
249
            logger.warning(
250
                f"Unable to find config file at {config_file}, "
251
                "using default settings."
252
            )
253
254
        settings_parameters = [
255
            (SettingsParameters.Port, ConfigParameters.TABPY_PORT, 9004),
256
            (SettingsParameters.ServerVersion, None, __version__),
257
            (SettingsParameters.EvaluateTimeout, ConfigParameters.TABPY_EVALUATE_TIMEOUT, 30),
258
            (SettingsParameters.UploadDir, ConfigParameters.TABPY_QUERY_OBJECT_PATH,
259
             os.path.join(pkg_path, "tmp", "query_objects")),
260
            (SettingsParameters.TransferProtocol, ConfigParameters.TABPY_TRANSFER_PROTOCOL,
261
             "http"),
262
            (SettingsParameters.CertificateFile, ConfigParameters.TABPY_CERTIFICATE_FILE,
263
             None),
264
            (SettingsParameters.KeyFile, ConfigParameters.TABPY_KEY_FILE, None),
265
            (SettingsParameters.StateFilePath, ConfigParameters.TABPY_STATE_PATH,
266
             os.path.join(pkg_path, "tabpy_server")),
267
            (SettingsParameters.StaticPath, ConfigParameters.TABPY_STATIC_PATH,
268
             os.path.join(pkg_path, "tabpy_server", "static")),
269
            (ConfigParameters.TABPY_PWD_FILE, ConfigParameters.TABPY_PWD_FILE, None),
270
            (SettingsParameters.LogRequestContext, ConfigParameters.TABPY_LOG_DETAILS,
271
             "false"),
272
            (SettingsParameters.MaxRequestSizeInMb, ConfigParameters.TABPY_MAX_REQUEST_SIZE_MB,
273
             100),
274
        ]
275
276
        for setting, parameter, default_val in settings_parameters:
277
            self._set_parameter(parser, setting, parameter, default_val)
278
279
        try:
280
            self.settings[SettingsParameters.EvaluateTimeout] = float(
281
                self.settings[SettingsParameters.EvaluateTimeout]
282
            )
283
        except ValueError:
284
            logger.warning(
285
                "Evaluate timeout must be a float type. Defaulting "
286
                "to evaluate timeout of 30 seconds."
287
            )
288
            self.settings[SettingsParameters.EvaluateTimeout] = 30
289
290
        if not os.path.exists(self.settings[SettingsParameters.UploadDir]):
291
            os.makedirs(self.settings[SettingsParameters.UploadDir])
292
293
        # set and validate transfer protocol
294
        self.settings[SettingsParameters.TransferProtocol] = self.settings[
295
            SettingsParameters.TransferProtocol
296
        ].lower()
297
298
        self._validate_transfer_protocol_settings()
299
300
        # if state.ini does not exist try and create it - remove
301
        # last dependence on batch/shell script
302
        self.settings[SettingsParameters.StateFilePath] = os.path.realpath(
303
            os.path.normpath(
304
                os.path.expanduser(self.settings[SettingsParameters.StateFilePath])
305
            )
306
        )
307
        state_config, self.tabpy_state = self._build_tabpy_state()
308
309
        self.python_service = PythonServiceHandler(PythonService())
310
        self.settings["compress_response"] = True
311
        self.settings[SettingsParameters.StaticPath] = os.path.abspath(
312
            self.settings[SettingsParameters.StaticPath]
313
        )
314
        logger.debug(
315
            f"Static pages folder set to "
316
            f'"{self.settings[SettingsParameters.StaticPath]}"'
317
        )
318
319
        # Set subdirectory from config if applicable
320
        if state_config.has_option("Service Info", "Subdirectory"):
321
            self.subdirectory = "/" + state_config.get("Service Info", "Subdirectory")
322
323
        # If passwords file specified load credentials
324
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
325
            if not self._parse_pwd_file():
326
                msg = (
327
                    "Failed to read passwords file "
328
                    f"{self.settings[ConfigParameters.TABPY_PWD_FILE]}"
329
                )
330
                logger.critical(msg)
331
                raise RuntimeError(msg)
332
        else:
333
            logger.info(
334
                "Password file is not specified: " "Authentication is not enabled"
335
            )
336
337
        features = self._get_features()
338
        self.settings[SettingsParameters.ApiVersions] = {"v1": {"features": features}}
339
340
        self.settings[SettingsParameters.LogRequestContext] = (
341
            self.settings[SettingsParameters.LogRequestContext].lower() != "false"
342
        )
343
        call_context_state = (
344
            "enabled"
345
            if self.settings[SettingsParameters.LogRequestContext]
346
            else "disabled"
347
        )
348
        logger.info(f"Call context logging is {call_context_state}")
349
350
    def _validate_transfer_protocol_settings(self):
351
        if SettingsParameters.TransferProtocol not in self.settings:
352
            msg = "Missing transfer protocol information."
353
            logger.critical(msg)
354
            raise RuntimeError(msg)
355
356
        protocol = self.settings[SettingsParameters.TransferProtocol]
357
358
        if protocol == "http":
359
            return
360
361
        if protocol != "https":
362
            msg = f"Unsupported transfer protocol: {protocol}"
363
            logger.critical(msg)
364
            raise RuntimeError(msg)
365
366
        self._validate_cert_key_state(
367
            "The parameter(s) {} must be set.",
368
            SettingsParameters.CertificateFile in self.settings,
369
            SettingsParameters.KeyFile in self.settings,
370
        )
371
        cert = self.settings[SettingsParameters.CertificateFile]
372
373
        self._validate_cert_key_state(
374
            "The parameter(s) {} must point to " "an existing file.",
375
            os.path.isfile(cert),
376
            os.path.isfile(self.settings[SettingsParameters.KeyFile]),
377
        )
378
        tabpy.tabpy_server.app.util.validate_cert(cert)
379
380
    @staticmethod
381
    def _validate_cert_key_state(msg, cert_valid, key_valid):
382
        cert_and_key_param = (
383
            f"{ConfigParameters.TABPY_CERTIFICATE_FILE} and "
384
            f"{ConfigParameters.TABPY_KEY_FILE}"
385
        )
386
        https_error = "Error using HTTPS: "
387
        err = None
388
        if not cert_valid and not key_valid:
389
            err = https_error + msg.format(cert_and_key_param)
390
        elif not cert_valid:
391
            err = https_error + msg.format(ConfigParameters.TABPY_CERTIFICATE_FILE)
392
        elif not key_valid:
393
            err = https_error + msg.format(ConfigParameters.TABPY_KEY_FILE)
394
395
        if err is not None:
396
            logger.critical(err)
397
            raise RuntimeError(err)
398
399
    def _parse_pwd_file(self):
400
        succeeded, self.credentials = parse_pwd_file(
401
            self.settings[ConfigParameters.TABPY_PWD_FILE]
402
        )
403
404
        if succeeded and len(self.credentials) == 0:
405
            logger.error("No credentials found")
406
            succeeded = False
407
408
        return succeeded
409
410
    def _get_features(self):
411
        features = {}
412
413
        # Check for auth
414
        if ConfigParameters.TABPY_PWD_FILE in self.settings:
415
            features["authentication"] = {
416
                "required": True,
417
                "methods": {"basic-auth": {}},
418
            }
419
420
        return features
421
422
    def _build_tabpy_state(self):
423
        pkg_path = os.path.dirname(tabpy.__file__)
424
        state_file_dir = self.settings[SettingsParameters.StateFilePath]
425
        state_file_path = os.path.join(state_file_dir, "state.ini")
426
        if not os.path.isfile(state_file_path):
427
            state_file_template_path = os.path.join(
428
                pkg_path, "tabpy_server", "state.ini.template"
429
            )
430
            logger.debug(
431
                f"File {state_file_path} not found, creating from "
432
                f"template {state_file_template_path}..."
433
            )
434
            shutil.copy(state_file_template_path, state_file_path)
435
436
        logger.info(f"Loading state from state file {state_file_path}")
437
        tabpy_state = _get_state_from_file(state_file_dir)
438
        return tabpy_state, TabPyState(config=tabpy_state, settings=self.settings)
439