Completed
Pull Request — master (#168)
by
unknown
03:32
created

SerModel.encode()   A

Complexity

Conditions 1

Size

Total Lines 5

Duplication

Lines 0
Ratio 0 %

Code Coverage

Tests 0
CRAP Score 2

Importance

Changes 1
Bugs 0 Features 1
Metric Value
cc 1
c 1
b 0
f 1
dl 0
loc 5
ccs 0
cts 4
cp 0
crap 2
rs 9.4285
1
"""
2
 Summary:
3
 Functions to save and store a model. The current keras
4
 function to do this does not work in python3. Therefore, we
5
 implemented our own functions until the keras functionality has matured.
6
 Example function calls in 'Tutorial mcfly on PAMAP2.ipynb'
7
"""
8 1
from keras.models import model_from_json
9 1
import keras
10
11 1
import json
12 1
import numpy as np
13 1
import os
14 1
import uuid
15 1
from collections import namedtuple
16
17
18 1
TrainedModel = namedtuple(
19
        'TrainedModel', ['history', 'model'])
20
21
22 1
def savemodel(model, filepath, modelname):
23
    """ Save model  to json file and weights to npy file
24
25
    Parameters
26
    ----------
27
    model : Keras object
28
        model to save
29
    filepath : str
30
        directory where the data will be stored
31
    modelname : str
32
        name of the model to be used in the filename
33
34
    Returns
35
    ----------
36
    json_path : str
37
        Path to json file with architecture
38
    numpy_path : str
39
        Path to npy file with weights
40
    """
41
    json_string = model.to_json()  # save architecture to json string
42
    json_path = os.path.join(filepath, modelname + '_architecture.json')
43
    with open(json_path, 'w') as outfile:
44
        json.dump(json_string, outfile, sort_keys=True, indent=4,
45
                  ensure_ascii=False)
46
    wweights = model.get_weights()  # get weight from model
47
    numpy_path = os.path.join(filepath, modelname + '_weights')
48
    np.save(numpy_path,
49
            wweights)  # save weights in npy file
50
    return json_path, numpy_path
51
52
53 1
def loadmodel(filepath, modelname):
54
    """ Load model + weights from json + npy file, respectively
55
56
    Parameters
57
    ----------
58
    filepath : str
59
        directory where the data will be stored
60
    modelname : str
61
        name of the model to be used in the filename
62
63
    Returns
64
    ----------
65
    model_repro : Keras object
66
        reproduced model
67
    """
68
    with open(os.path.join(filepath, modelname + '_architecture.json'), 'r') as outfile:
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (88/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
69
        json_string_loaded = json.load(outfile)
70
    model_repro = model_from_json(json_string_loaded)
71
    # wweights2 = model_repro.get_weights()
72
    #  extracting the weights would give us the untrained/default weights
73
    wweights_recovered = np.load(
74
        os.path.join(filepath, modelname + '_weights.npy'))  # load the original weights
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (88/79).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
75
    model_repro.set_weights(wweights_recovered)  # now set the weights
76
    return model_repro
77
78
# If we would use standard Keras function, which stores model and weights
79
# in HDF5 format it would look like code below. However, we did not use this
80
# because
81
# https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model
82
# it is not compatible with default Keras version in python3.
83
# from keras.models import load_model
84
# import h5py
85
# modelh5=models[0]
86
# modelh5.save(resultpath+'mymodel.h5')
87
# del modelh5
88
# modelh5 = load_model(resultpath+'mymodel.h5')
89
90
91 1
try:
92 1
    import noodles
93
    from noodles.serial.numpy import arrays_to_string
94
    from noodles.serial.namedtuple import SerNamedTuple
95
    from pathlib import Path
96
97
98
    class SerModel(noodles.serial.Serialiser):
99
        def __init__(self, tmpdir):
100
            self.tmpdir = Path(tmpdir)
101
            super(SerModel, self).__init__(keras.models.Model)
102
103
        def encode(self, obj, make_rec):
104
            random_filename = self.tmpdir / (str(uuid.uuid4()) + '.hdf5')
105
            obj.save(str(random_filename))
106
            return make_rec({'filename': str(random_filename)},
107
                            files=[str(random_filename)], ref=True)
108
109
        def decode(self, cls, data):
110
            return keras.models.load_model(data['filename'])
111
112
113
    def serial_registry(tmpdir='.'):
114
        return noodles.serial.Registry(
115
            # parent=noodles.serial.pickle() +
116
            parent=noodles.serial.base() + arrays_to_string(),
117
            types={
118
                keras.models.Model: SerModel(tmpdir),
119
                TrainedModel: SerNamedTuple(TrainedModel)
120
            }
121
        )
122
123 1
except ImportError:
124
    pass
125