1 | import pickle |
||
2 | import os |
||
3 | import numpy as np |
||
4 | from flask_restful import Resource, request |
||
5 | from flask import url_for, current_app |
||
6 | from .errors import NoPredictMethod |
||
0 ignored issues
–
show
Bug
introduced
by
Loading history...
|
|||
7 | import cf_predict |
||
8 | |||
9 | |||
10 | def get_db(): |
||
11 | """Fetch Redis client.""" |
||
12 | return current_app.extensions["redis"] |
||
13 | |||
14 | |||
15 | class Catalogue(Resource): |
||
16 | def get(self): |
||
17 | """Show a catalogue of available endpoints.""" |
||
18 | return { |
||
19 | "predict_url": url_for("api.predict", _external=True), |
||
20 | "api_version": cf_predict.__version__ |
||
21 | } |
||
22 | |||
23 | |||
24 | class Predict(Resource): |
||
25 | def __init__(self): |
||
26 | self.r = get_db() |
||
27 | self.version = os.getenv("MODEL_VERSION") or "latest" |
||
28 | if self.version == "latest": |
||
29 | try: |
||
30 | self.version = self.find_latest_version(self.version) |
||
31 | except (TypeError, ValueError) as e: |
||
32 | current_app.logger.error("No model {} found".format(self.version)) |
||
33 | raise e |
||
34 | try: |
||
35 | self.model = self.load_model(self.version) |
||
36 | except (pickle.UnpicklingError, IOError, AttributeError, EOFError, ImportError, IndexError) as e: |
||
37 | current_app.logger.error("Model {} could not be unpickled".format(self.version)) |
||
38 | raise e |
||
39 | if not hasattr(self.model, 'predict'): |
||
40 | raise NoPredictMethod |
||
41 | |||
42 | def find_latest_version(self, version): |
||
43 | """Find model with the highest version number in Redis.""" |
||
44 | keys = [key.decode("utf-8") for key in self.r.scan_iter()] |
||
45 | latest_version = max(keys) |
||
46 | return latest_version |
||
47 | |||
48 | def load_model(self, version): |
||
49 | """Deserialize and load model.""" |
||
50 | return pickle.loads(self.r.get(version)) |
||
51 | |||
52 | def get(self): |
||
53 | """Get current model version.""" |
||
54 | return {"model_version": self.version} |
||
55 | |||
56 | def post(self): |
||
57 | """Get prediction from model. |
||
58 | |||
59 | Input: Feature array |
||
60 | """ |
||
61 | try: |
||
62 | raw_features = request.get_json()["features"] |
||
63 | except KeyError: |
||
64 | return {"message": "Features not found in {}".format(request.get_json())}, 400 |
||
65 | try: |
||
66 | features = np.array(raw_features) |
||
67 | if len(features.shape) == 1: |
||
68 | features.reshape(1, -1) |
||
69 | prediction = self.model.predict(features) |
||
70 | return { |
||
71 | "model_version": self.version, |
||
72 | "prediction": list(prediction) |
||
73 | } |
||
74 | except ValueError: |
||
75 | return {"message": "Features {} do not match expected input for model version {}".format(raw_features, self.version)}, 400 |
||
76 |