wait_for_endpoint_loaded()   B
last analyzed

Complexity

Conditions 6

Size

Total Lines 19
Code Lines 14

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 1
CRAP Score 34.8262

Importance

Changes 0
Metric Value
eloc 14
dl 0
loc 19
ccs 1
cts 14
cp 0.0714
rs 8.6666
c 0
b 0
f 0
cc 6
nop 2
crap 34.8262
1 1
import logging
2 1
from tabpy.tabpy_server.app.app_parameters import SettingsParameters
3 1
from tabpy.tabpy_server.common.messages import (
4
    LoadObject,
5
    DeleteObjects,
6
    ListObjects,
7
    ObjectList,
8
)
9 1
from tabpy.tabpy_server.common.endpoint_file_mgr import cleanup_endpoint_files
10 1
from tabpy.tabpy_server.common.util import format_exception
11 1
from tabpy.tabpy_server.management.state import TabPyState, get_query_object_path
12 1
from tabpy.tabpy_server.management import util
13 1
from time import sleep
14 1
from tornado import gen
15
16
17 1
logger = logging.getLogger(__name__)
18
19
20 1
def wait_for_endpoint_loaded(python_service, object_uri):
21
    """
22
    This method waits for the object to be loaded.
23
    """
24
    logger.info("Waiting for object to be loaded...")
25
    while True:
26
        msg = ListObjects()
27
        list_object_msg = python_service.manage_request(msg)
28
        if not isinstance(list_object_msg, ObjectList):
29
            logger.error(f"Error loading endpoint {object_uri}: {list_object_msg}")
30
            return
31
32
        for (uri, info) in list_object_msg.objects.items():
33
            if uri == object_uri:
34
                if info["status"] != "LoadInProgress":
35
                    logger.info(f'Object load status: {info["status"]}')
36
                    return
37
38
        sleep(0.1)
39
40
41 1
@gen.coroutine
42 1
def init_ps_server(settings, tabpy_state):
43 1
    logger.info("Initializing TabPy Server...")
44 1
    existing_pos = tabpy_state.get_endpoints()
45 1
    for (object_name, obj_info) in existing_pos.items():
46 1
        try:
47 1
            object_version = obj_info["version"]
48 1
            get_query_object_path(
49
                settings[SettingsParameters.StateFilePath], object_name, object_version
50
            )
51
        except Exception as e:
52
            logger.error(
53
                f"Exception encounted when downloading object: {object_name}"
54
                f", error: {e}"
55
            )
56
57
58 1
@gen.coroutine
59 1
def init_model_evaluator(settings, tabpy_state, python_service):
60
    """
61
    This will go through all models that the service currently have and
62
    initialize them.
63
    """
64
    logger.info("Initializing models...")
65
66
    existing_pos = tabpy_state.get_endpoints()
67
68
    for (object_name, obj_info) in existing_pos.items():
69
        object_version = obj_info["version"]
70
        object_type = obj_info["type"]
71
        object_path = get_query_object_path(
72
            settings[SettingsParameters.StateFilePath], object_name, object_version
73
        )
74
75
        logger.info(
76
            f"Load endpoint: {object_name}, "
77
            f"version: {object_version}, "
78
            f"type: {object_type}"
79
        )
80
        if object_type == "alias":
81
            msg = LoadObject(
82
                object_name, obj_info["target"], object_version, False, "alias"
83
            )
84
        else:
85
            local_path = object_path
86
            msg = LoadObject(
87
                object_name, local_path, object_version, False, object_type
88
            )
89
        python_service.manage_request(msg)
90
91
92 1
def _get_latest_service_state(settings, tabpy_state, new_ps_state, python_service):
93
    """
94
    Update the endpoints from the latest remote state file.
95
96
    Returns
97
    --------
98
    (has_changes, endpoint_diff):
99
        has_changes: True or False
100
        endpoint_diff: Summary of what has changed, one entry for each changes
101
    """
102
    # Shortcut when nothing is changed
103
    changes = {"endpoints": {}}
104
105
    # update endpoints
106
    new_endpoints = new_ps_state.get_endpoints()
107
    diff = {}
108
    current_endpoints = python_service.ps.query_objects
109
    for (endpoint_name, endpoint_info) in new_endpoints.items():
110
        existing_endpoint = current_endpoints.get(endpoint_name)
111
        if (existing_endpoint is None) or endpoint_info["version"] != existing_endpoint[
112
            "version"
113
        ]:
114
            # Either a new endpoint or new endpoint version
115
            path_to_new_version = get_query_object_path(
116
                settings[SettingsParameters.StateFilePath],
117
                endpoint_name,
118
                endpoint_info["version"],
119
            )
120
            endpoint_type = endpoint_info.get("type", "model")
121
            diff[endpoint_name] = (
122
                endpoint_type,
123
                endpoint_info["version"],
124
                path_to_new_version,
125
            )
126
127
    # add removed models too
128
    for (endpoint_name, endpoint_info) in current_endpoints.items():
129
        if endpoint_name not in new_endpoints.keys():
130
            endpoint_type = current_endpoints[endpoint_name].get("type", "model")
131
            diff[endpoint_name] = (endpoint_type, None, None)
132
133
    if diff:
134
        changes["endpoints"] = diff
135
136
    return (True, changes)
137
138
139 1
@gen.coroutine
140 1
def on_state_change(
141
    settings, tabpy_state, python_service, logger=logging.getLogger(__name__)
142
):
143
    try:
144
        logger.log(logging.INFO, "Loading state from state file")
145
        config = util._get_state_from_file(
146
            settings[SettingsParameters.StateFilePath], logger=logger
147
        )
148
        new_ps_state = TabPyState(config=config, settings=settings)
149
150
        (has_changes, changes) = _get_latest_service_state(
151
            settings, tabpy_state, new_ps_state, python_service
152
        )
153
        if not has_changes:
154
            logger.info("Nothing changed, return.")
155
            return
156
157
        new_endpoints = new_ps_state.get_endpoints()
158
        for object_name in changes["endpoints"]:
159
            (object_type, object_version, object_path) = changes["endpoints"][
160
                object_name
161
            ]
162
163
            if not object_path and not object_version:  # removal
164
                logger.info(f"Removing object: URI={object_name}")
165
166
                python_service.manage_request(DeleteObjects([object_name]))
167
168
                cleanup_endpoint_files(
169
                    object_name, settings[SettingsParameters.UploadDir], logger=logger
170
                )
171
172
            else:
173
                endpoint_info = new_endpoints[object_name]
174
                is_update = object_version > 1
175
                if object_type == "alias":
176
                    msg = LoadObject(
177
                        object_name,
178
                        endpoint_info["target"],
179
                        object_version,
180
                        is_update,
181
                        "alias",
182
                    )
183
                else:
184
                    local_path = object_path
185
                    msg = LoadObject(
186
                        object_name, local_path, object_version, is_update, object_type
187
                    )
188
189
                python_service.manage_request(msg)
190
                wait_for_endpoint_loaded(python_service, object_name)
191
192
                # cleanup old version of endpoint files
193
                if object_version > 2:
194
                    cleanup_endpoint_files(
195
                        object_name,
196
                        settings[SettingsParameters.UploadDir],
197
                        logger=logger,
198
                        retain_versions=[object_version, object_version - 1],
199
                    )
200
201
    except Exception as e:
202
        err_msg = format_exception(e, "on_state_change")
203
        logger.log(
204
            logging.ERROR, f"Error submitting update model request: error={err_msg}"
205
        )
206