Completed
Branch master (e214b7)
by Philippe
36s
created

src.denoiser.models.MachineLearningModel   A

Complexity

Total Complexity 12

Size/Duplication

Total Lines 92
Duplicated Lines 0 %
Metric Value
dl 0
loc 92
rs 10
wmc 12

4 Methods

Rating   Name   Duplication   Size   Complexity  
A MachineLearningModel.__init__() 0 6 1
B MachineLearningModel.train() 0 43 5
A MachineLearningModel.load() 0 7 1
B MachineLearningModel.correct() 0 28 5
1
"""Models package
2
3
.. Authors:
4
    Philippe Dessauw
5
    [email protected]
6
7
.. Sponsor:
8
    Alden Dima
9
    [email protected]
10
    Information Systems Group
11
    Software and Systems Division
12
    Information Technology Laboratory
13
    National Institute of Standards and Technology
14
    http://www.nist.gov/itl/ssd/is
15
"""
16
from __future__ import division
17
from os import unlink
18
from sklearn.linear_model.stochastic_gradient import SGDClassifier
19
from denoiser.models.inline import Unigrams, Dictionary, Bigrams, AltCaseMap, OcrKeyMap, AnagramMap
20
from denoiser.models.inline.ranking import rate_corrections
21
from denoiser.models.inline.utils import init_correction_map, select_anagrams, select_ocrsims, build_candidates_list, \
22
    correct_case, apply_bigram_boost, select_correction, extract_paragraph_bigrams, select_lower_edit_distance, \
23
    select_best_alphabetical_word
24
from denoiser.models.machine_learning import MachineLearningFeatures, MachineLearningAlgorithm
25
import logging
26
from os.path import exists, join
27
from denoiser.models.indicators.lists import StrongIndicatorList, CleanIndicatorList
28
from apputils.fileop import file_checksum
29
from apputils.pickling import save, load
30
31
32
class AbstractModel(object):
33
    """Abstract model, contains main functions
34
    """
35
36
    def __init__(self, app_config):
37
        self.config = app_config
38
        self.logger = logging.getLogger('local')
39
40
        self.hash_filename = join(app_config["dirs"]["models_root"], app_config["models"]["hashes"])
41
        self.hash_list = []
42
43
        if exists(self.hash_filename):
44
            self.hash_list = load(self.hash_filename)
45
46
    def is_preprocessed(self, filename):
47
        """Determine if the given file has already been preprocessed (its data added to the models)
48
49
        Args:
50
            filename (str): Path of the given file
51
52
        Returns:
53
            int: 0 if not preprocess, 1 otherwise
54
        """
55
        text_id = file_checksum(filename)
56
57
        if text_id not in self.hash_list:
58
            self.hash_list.append(text_id)
59
            save(self.hash_list, self.hash_filename)
60
            return 0
61
62
        return 1
63
64
    def load(self, text_data):
65
        """Load text data to the model
66
67
        Args:
68
            text_data (dict): Text data
69
70
        Raise:
71
            NotImplementedError: Not yet implemented
72
        """
73
        raise NotImplementedError()
74
75
    def correct(self, text_data):
76
        """Save text data to the model
77
78
        Args:
79
            text_data (dict): Text data
80
81
        Raise:
82
            NotImplementedError: Not yet implemented
83
        """
84
        raise NotImplementedError()
85
86
87
class InlineModel(AbstractModel):
88
    """Model for inline data structures
89
    """
90
91
    def __init__(self, app_config):
92
        super(InlineModel, self).__init__(app_config)
93
94
        inline_models_dir = join(
95
            app_config["root"],
96
            app_config["dirs"]["models_root"],
97
            app_config["dirs"]["models"]["inline"]
98
        )
99
        inline_models_key = app_config["models"]["inline"]
100
101
        self.dictionary = Dictionary(join(inline_models_dir, inline_models_key["dictionary"]))
102
103
        self.unigrams = Unigrams(join(inline_models_dir, inline_models_key["unigrams"]))
104
        self.tmp_unigrams_filename = self.unigrams.filename + app_config["exts"]["tmp"]
105
106
        self.bigrams = Bigrams(join(inline_models_dir, inline_models_key["bigrams"]))
107
108
        self.altcase_map = AltCaseMap(join(inline_models_dir, inline_models_key["altcase"]))
109
        self.tmp_altcase_filename = self.altcase_map.filename + app_config["exts"]["tmp"]
110
111
        self.ocrkey_map = OcrKeyMap(join(inline_models_dir, inline_models_key["ocr_keys"]))
112
        self.anagram_map = AnagramMap(join(inline_models_dir, inline_models_key["anagrams"]))
113
114
    def load(self, text_data):
115
        """Load text data to the model
116
117
        Args:
118
            text_data (dict): Text data
119
        """
120
        if self.is_preprocessed(text_data.filename) != 0:
121
            self.logger.debug(text_data.filename+" already loaded: skipping it.")
122
            return
123
124
        tmp_u = Unigrams(self.tmp_unigrams_filename)
125
        word_list = tmp_u.append_data(text_data)
126
127
        self.bigrams.append_data(word_list)
128
129
        tmp_ac = AltCaseMap(self.tmp_altcase_filename)
130
        tmp_ac.append_data(tmp_u.raw_unigrams)
131
132
        tmp_u.generate_low_case(tmp_ac.altcase_map)
133
134
        self.ocrkey_map.append_data(tmp_u.raw_unigrams)
135
136
        # Updating data
137
        self.unigrams.raw_unigrams += tmp_u.raw_unigrams
138
        self.unigrams.ngrams += tmp_u.ngrams
139
        self.unigrams.prune(0.7)
140
        self.unigrams.save()
141
142
        combine_struct = {key: set() for key in tmp_ac.altcase_map.keys() + self.altcase_map.altcase_map.keys()}
143
        for key, value in tmp_ac.altcase_map.items() + self.altcase_map.altcase_map.items():
144
            combine_struct[key] = combine_struct[key].union(value)
145
146
        self.altcase_map.altcase_map = combine_struct
147
        self.altcase_map.prune(self.unigrams.ngrams_pruned)
148
        self.altcase_map.save()
149
150
        unlink(self.tmp_unigrams_filename)
151
        unlink(self.tmp_altcase_filename)
152
153
        self.anagram_map.append_data(self.bigrams.ngrams_pruned, self.unigrams.ngrams_pruned)
154
        self.dictionary.append_data(self.unigrams.ngrams_pruned)
155
156
        self.logger.info(text_data.filename+"'s datastructures loaded")
157
158
    def correct(self, text_data):
159
        """Correct text data
160
161
        Args:
162
            text_data (dict): Text data
163
        """
164
        correction_data = self.correction_data()
165
166
        for paragraph in text_data.text:
167
            for line in paragraph:
168
                for token in line.tokens:
169
                    token[2] = init_correction_map(token[1], correction_data["dictionary"])
170
171
                    # Skip some correction steps if the token is too short, in the dictionary or already identified as
172
                    # garbage
173
                    if not token[2] is None and len(token[2]) == 0:
174
                        anagrams = select_anagrams(token[1], correction_data)
175
                        ocr_sims = select_ocrsims(token[1], correction_data)
176
177
                        token[2] = build_candidates_list(token[1], anagrams, ocr_sims, correction_data)
178
                        token[2] = correct_case(token[1], token[2], correction_data)
179
180
                        token[2] = rate_corrections(token[2])
181
182
                        if len(token[2]) == 0:  # No correction has been found
183
                            token[2] = None
184
185
            # Applying the bigram boost to the tokens
186
            bigrams = extract_paragraph_bigrams(paragraph)
187
            apply_bigram_boost(paragraph, bigrams, correction_data["occurence_map"])
188
189
            # Select the appropriate correction
190
            for line in paragraph:
191
                for token in line.tokens:
192
                    token[2] = select_correction(token[1], token[2])
193
194
                    if token[2] is not None and len(token[2]) > 1:
195
                        tkn_list = [tkn for tkn, sc in token[2].items() if sc == max(token[2].values())]
196
197
                        if len(tkn_list) != 1:
198
                            tkn_list = select_lower_edit_distance(token[1], {tkn: token[2][tkn] for tkn in tkn_list})
199
200
                        if len(tkn_list) != 1:
201
                            tkn_list = [select_best_alphabetical_word(token[1], tkn_list)]
202
203
                        token[2] = {tkn: token[2][tkn] for tkn in tkn_list}
204
205
    def correction_data(self):
206
        """Get the correction data
207
208
        Returns:
209
            dict: Correction data
210
        """
211
        return {
212
            "occurence_map": self.unigrams.ngrams + self.bigrams.ngrams,
213
            "altcase": self.altcase_map.altcase_map,
214
            "ocrkeys": self.ocrkey_map.ocrkey_map,
215
            "anagrams": self.anagram_map.anagram_hashmap,
216
            "alphabet": self.anagram_map.anagram_alphabet,
217
            "dictionary": self.dictionary.dictionary
218
        }
219
220
221
class IndicatorModel(AbstractModel):
222
    """Model for garbage strings indicators
223
    """
224
225
    def __init__(self, app_config):
226
        super(IndicatorModel, self).__init__(app_config)
227
228
        self.model = {
229
            "strong": StrongIndicatorList(),
230
            "clean": CleanIndicatorList()
231
        }
232
233
    def load(self, text_data):
234
        """Load text data to the model
235
236
        Args:
237
            text_data (dict): Text data
238
        """
239
        for indicator_list in self.model.values():
240
            indicator_list.set_stats(text_data.stats)
241
242
    def correct(self, text_data):
243
        """Correct text data
244
245
        Args:
246
            text_data (dict): Text data
247
        """
248
        # =======================
249
        # Strong indicators
250
        # =======================
251
        lines = [line for paragraph in text_data.text for line in paragraph
252
                 if line.grade != 0 and self.model["strong"].match(line)]
253
254
        for line in lines:
255
            line.set_garbage()
256
257
        # =======================
258
        # Clean indicators
259
        # =======================
260
        lines = [line for paragraph in text_data.text for line in paragraph
261
                 if line.grade != 0 and self.model["clean"].match(line)]
262
263
        for line in lines:
264
            line.set_clean()
265
266
        # =======================
267
        # Post processing
268
        # =======================
269
        lines = [line for paragraph in text_data.text for line in paragraph]
270
        previous_line = None
271
272
        # Smoothing function
273
        for line in lines:
274
            # Decrease grade if previous line is a garbage string
275
            if previous_line is not None and previous_line.grade == 0 and line.grade != 5:
276
                line.decrease_grade()
277
278
            # Decrease grade of previous line
279
            if line.grade == 0 and previous_line is not None and previous_line.grade != 5:
280
                previous_line.decrease_grade()
281
282
            previous_line = line
283
284
285
class MachineLearningModel(AbstractModel):
286
    """Model storing all machine learning data
287
    """
288
289
    def __init__(self, app_config):
290
        super(MachineLearningModel, self).__init__(app_config)
291
292
        self.model = {
293
            "algo": MachineLearningAlgorithm(),
294
            "features": MachineLearningFeatures()
295
        }
296
297
    def train(self, dataset):
298
        """Train the model with a dataset
299
300
        Args:
301
            dataset (list): List of training files
302
        """
303
        # Get the original training set
304
        training_set = self.model["algo"].training_set
305
306
        # Append the new data to it
307
        for text in dataset:
308
            self.logger.debug("Processing "+text.filename+"...")
309
            unigrams = Unigrams(join(self.config["root"],
310
                                     self.config["dirs"]["models_root"],
311
                                     self.config["dirs"]["models"]["inline"],
312
                                     self.config["models"]["inline"]["unigrams"],))
313
314
            for p in text.text:
315
                for line in p:
316
                    if line.grade % 5 != 0:  # Unclassified lines are useless for the training
317
                        continue
318
319
                    f = MachineLearningFeatures()
320
                    features = f.extract_features(line, unigrams.ngrams, text.stats)
321
                    result = int(line.grade / 5)
322
323
                    training_set["features"].append(features)
324
                    training_set["results"].append(result)
325
326
        self.logger.debug("Saving training set...")
327
        save(training_set, join(self.config["dirs"]["models_root"],
328
                                self.config["dirs"]["models"]["learning"],
329
                                self.config["models"]["learning"]["training_set"]))
330
331
        self.logger.debug("Training model...")
332
        ml_classifier = SGDClassifier(loss="log", class_weight="auto")
333
        self.model["algo"].set_classifier(ml_classifier)
334
        self.model["algo"].set_training_set(training_set["features"], training_set["results"])
335
        self.model["algo"].train()
336
337
        save(self.model["algo"].classifier, join(self.config["dirs"]["models_root"],
338
                                                 self.config["dirs"]["models"]["learning"],
339
                                                 self.config["models"]["learning"]["classifier"]))
340
341
    def load(self, text_data):
342
        """Load text data to the model
343
344
        Args:
345
            text_data (dict): Text data
346
        """
347
        pass
348
349
    def correct(self, text_data):
350
        """Correct text data
351
352
        Args:
353
            text_data (dict): Text data
354
        """
355
        unigrams = Unigrams(join(self.config["root"],
356
                                 self.config["dirs"]["models_root"],
357
                                 self.config["dirs"]["models"]["inline"],
358
                                 self.config["models"]["inline"]["unigrams"],))
359
360
        ml_classifier = load(join(self.config["dirs"]["models_root"],
361
                                  self.config["dirs"]["models"]["learning"],
362
                                  self.config["models"]["learning"]["classifier"]))
363
364
        if ml_classifier is None:
365
            return
366
367
        self.model["algo"].set_classifier(ml_classifier)
368
369
        for paragraph in text_data.text:
370
            for line in paragraph:
371
                if line.grade % 5 == 0:
372
                    continue
373
374
                f = MachineLearningFeatures()
375
                features = f.extract_features(line, unigrams.ngrams, text_data.stats)
376
                line.grade = self.model["algo"].classify(features) * 5
377