Conditions | 25 |
Total Lines | 74 |
Code Lines | 50 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like annif.eval.EvaluationBatch._evaluate_samples() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | """Evaluation metrics for Annif""" |
||
110 | def _evaluate_samples( |
||
111 | self, |
||
112 | y_true: csr_array, |
||
113 | y_pred: csr_array, |
||
114 | metrics: Iterable[str] = [], |
||
115 | ) -> dict[str, float]: |
||
116 | y_pred_binary = y_pred > 0.0 |
||
117 | |||
118 | # define the available metrics as lazy lambda functions |
||
119 | # so we can execute only the ones actually requested |
||
120 | all_metrics = { |
||
121 | "Precision (doc avg)": lambda: precision_score( |
||
122 | y_true, y_pred_binary, average="samples" |
||
123 | ), |
||
124 | "Recall (doc avg)": lambda: recall_score( |
||
125 | y_true, y_pred_binary, average="samples" |
||
126 | ), |
||
127 | "F1 score (doc avg)": lambda: f1_score( |
||
128 | y_true, y_pred_binary, average="samples" |
||
129 | ), |
||
130 | "Precision (subj avg)": lambda: precision_score( |
||
131 | y_true, y_pred_binary, average="macro" |
||
132 | ), |
||
133 | "Recall (subj avg)": lambda: recall_score( |
||
134 | y_true, y_pred_binary, average="macro" |
||
135 | ), |
||
136 | "F1 score (subj avg)": lambda: f1_score( |
||
137 | y_true, y_pred_binary, average="macro" |
||
138 | ), |
||
139 | "Precision (weighted subj avg)": lambda: precision_score( |
||
140 | y_true, y_pred_binary, average="weighted" |
||
141 | ), |
||
142 | "Recall (weighted subj avg)": lambda: recall_score( |
||
143 | y_true, y_pred_binary, average="weighted" |
||
144 | ), |
||
145 | "F1 score (weighted subj avg)": lambda: f1_score( |
||
146 | y_true, y_pred_binary, average="weighted" |
||
147 | ), |
||
148 | "Precision (microavg)": lambda: precision_score( |
||
149 | y_true, y_pred_binary, average="micro" |
||
150 | ), |
||
151 | "Recall (microavg)": lambda: recall_score( |
||
152 | y_true, y_pred_binary, average="micro" |
||
153 | ), |
||
154 | "F1 score (microavg)": lambda: f1_score( |
||
155 | y_true, y_pred_binary, average="micro" |
||
156 | ), |
||
157 | "F1@5": lambda: f1_score( |
||
158 | y_true, filter_suggestion(y_pred, 5) > 0.0, average="samples" |
||
159 | ), |
||
160 | "NDCG": lambda: ndcg_score(y_true, y_pred), |
||
161 | "NDCG@5": lambda: ndcg_score(y_true, y_pred, limit=5), |
||
162 | "NDCG@10": lambda: ndcg_score(y_true, y_pred, limit=10), |
||
163 | "Precision@1": lambda: precision_score( |
||
164 | y_true, filter_suggestion(y_pred, 1) > 0.0, average="samples" |
||
165 | ), |
||
166 | "Precision@3": lambda: precision_score( |
||
167 | y_true, filter_suggestion(y_pred, 3) > 0.0, average="samples" |
||
168 | ), |
||
169 | "Precision@5": lambda: precision_score( |
||
170 | y_true, filter_suggestion(y_pred, 5) > 0.0, average="samples" |
||
171 | ), |
||
172 | "True positives": lambda: true_positives(y_true, y_pred_binary), |
||
173 | "False positives": lambda: false_positives(y_true, y_pred_binary), |
||
174 | "False negatives": lambda: false_negatives(y_true, y_pred_binary), |
||
175 | } |
||
176 | |||
177 | if not metrics: |
||
178 | metrics = all_metrics.keys() |
||
179 | |||
180 | with warnings.catch_warnings(): |
||
181 | warnings.simplefilter("ignore") |
||
182 | |||
183 | return {metric: all_metrics[metric]() for metric in metrics} |
||
184 | |||
270 |