1 | from abc import ABCMeta, abstractmethod |
||
2 | |||
3 | import theano |
||
4 | from theano import tensor |
||
5 | from six import add_metaclass |
||
6 | |||
7 | from blocks.bricks.base import application, Brick |
||
8 | |||
9 | |||
10 | @add_metaclass(ABCMeta) |
||
11 | class Cost(Brick): |
||
12 | @abstractmethod |
||
13 | @application |
||
14 | def apply(self, *args, **kwargs): |
||
15 | pass |
||
16 | |||
17 | |||
18 | @add_metaclass(ABCMeta) |
||
19 | class CostMatrix(Cost): |
||
20 | """Base class for costs which can be calculated element-wise. |
||
21 | |||
22 | Assumes that the data has format (batch, features). |
||
23 | |||
24 | """ |
||
25 | @application(outputs=["cost"]) |
||
26 | def apply(self, *args, **kwargs): |
||
27 | return self.cost_matrix(*args, **kwargs).sum(axis=1).mean() |
||
28 | |||
29 | @abstractmethod |
||
30 | @application |
||
31 | def cost_matrix(self, *args, **kwargs): |
||
32 | pass |
||
33 | |||
34 | |||
35 | class BinaryCrossEntropy(CostMatrix): |
||
36 | @application |
||
37 | def cost_matrix(self, y, y_hat): |
||
0 ignored issues
–
show
Bug
introduced
by
Loading history...
|
|||
38 | cost = tensor.nnet.binary_crossentropy(y_hat, y) |
||
39 | return cost |
||
40 | |||
41 | |||
42 | class AbsoluteError(CostMatrix): |
||
43 | @application |
||
44 | def cost_matrix(self, y, y_hat): |
||
0 ignored issues
–
show
|
|||
45 | cost = abs(y - y_hat) |
||
46 | return cost |
||
47 | |||
48 | |||
49 | class SquaredError(CostMatrix): |
||
50 | @application |
||
51 | def cost_matrix(self, y, y_hat): |
||
0 ignored issues
–
show
|
|||
52 | cost = tensor.sqr(y - y_hat) |
||
53 | return cost |
||
54 | |||
55 | |||
56 | class CategoricalCrossEntropy(Cost): |
||
57 | @application(outputs=["cost"]) |
||
58 | def apply(self, y, y_hat): |
||
0 ignored issues
–
show
|
|||
59 | cost = tensor.nnet.categorical_crossentropy(y_hat, y).mean() |
||
60 | return cost |
||
61 | |||
62 | |||
63 | class MisclassificationRate(Cost): |
||
64 | """Calculates the misclassification rate for a mini-batch. |
||
65 | |||
66 | Parameters |
||
67 | ---------- |
||
68 | top_k : int, optional |
||
69 | If the ground truth class is within the `top_k` highest |
||
70 | responses for a given example, the model is considered |
||
71 | to have predicted correctly. Default: 1. |
||
72 | |||
73 | Notes |
||
74 | ----- |
||
75 | Ties for `top_k`-th place are broken pessimistically, i.e. |
||
76 | in the (in practice, rare) case that there is a tie for `top_k`-th |
||
77 | highest output for a given example, it is considered an incorrect |
||
78 | prediction. |
||
79 | |||
80 | """ |
||
81 | def __init__(self, top_k=1): |
||
82 | self.top_k = top_k |
||
83 | super(MisclassificationRate, self).__init__() |
||
84 | |||
85 | @application(outputs=["error_rate"]) |
||
86 | def apply(self, y, y_hat): |
||
0 ignored issues
–
show
|
|||
87 | # Support checkpoints that predate self.top_k |
||
88 | top_k = getattr(self, 'top_k', 1) |
||
89 | if top_k == 1: |
||
90 | mistakes = tensor.neq(y, y_hat.argmax(axis=1)) |
||
91 | else: |
||
92 | row_offsets = theano.tensor.arange(0, y_hat.flatten().shape[0], |
||
93 | y_hat.shape[1]) |
||
94 | truth_score = y_hat.flatten()[row_offsets + y] |
||
95 | # We use greater than _or equals_ here so that the model |
||
96 | # _must_ have its guess in the top k, and cannot extend |
||
97 | # its effective "list of predictions" by tying lots of things |
||
98 | # for k-th place. |
||
99 | higher_scoring = tensor.ge(y_hat, truth_score.dimshuffle(0, 'x')) |
||
100 | # Because we used greater-than-or-equal we have to correct for |
||
101 | # counting the true label. |
||
102 | num_higher = higher_scoring.sum(axis=1) - 1 |
||
103 | mistakes = tensor.ge(num_higher, top_k) |
||
104 | return mistakes.mean(dtype=theano.config.floatX) |
||
105 |