1 | """Parallel processing functionality for Annif""" |
||
2 | |||
3 | from __future__ import annotations |
||
4 | |||
5 | import multiprocessing |
||
6 | import multiprocessing.dummy |
||
7 | from typing import TYPE_CHECKING, Any |
||
8 | |||
9 | if TYPE_CHECKING: |
||
10 | from collections import defaultdict |
||
11 | from collections.abc import Iterator |
||
12 | from typing import Callable |
||
13 | |||
14 | from annif.corpus import Document, SubjectSet |
||
15 | from annif.registry import AnnifRegistry |
||
16 | from annif.suggestion import SuggestionBatch, SuggestionResult |
||
17 | |||
18 | |||
19 | # Start method for processes created by the multiprocessing module. |
||
20 | # A value of None means using the platform-specific default. |
||
21 | # Intended to be overridden in unit tests. |
||
22 | MP_START_METHOD = None |
||
23 | |||
24 | |||
25 | class BaseWorker: |
||
26 | """Base class for workers that implement tasks executed via |
||
27 | multiprocessing. The init method can be used to store data objects that |
||
28 | are necessary for the operation. They will be stored in a class |
||
29 | attribute that is accessible to the static worker method. The storage |
||
30 | solution is inspired by this blog post: |
||
31 | https://thelaziestprogrammer.com/python/multiprocessing-pool-a-global-solution # noqa |
||
32 | """ |
||
33 | |||
34 | args = None |
||
35 | |||
36 | @classmethod |
||
37 | def init(cls, args) -> None: |
||
38 | cls.args = args # pragma: no cover |
||
39 | |||
40 | |||
41 | class ProjectSuggestMap: |
||
42 | """A utility class that can be used to wrap one or more projects and |
||
43 | provide a mapping method that converts Document objects to suggestions. |
||
44 | Intended to be used with the multiprocessing module.""" |
||
45 | |||
46 | def __init__( |
||
47 | self, |
||
48 | registry: AnnifRegistry, |
||
49 | project_ids: list[str], |
||
50 | backend_params: defaultdict[str, Any] | None, |
||
51 | limit: int | None, |
||
52 | threshold: float, |
||
53 | ) -> None: |
||
54 | self.registry = registry |
||
55 | self.project_ids = project_ids |
||
56 | self.backend_params = backend_params |
||
57 | self.limit = limit |
||
58 | self.threshold = threshold |
||
59 | |||
60 | def suggest(self, doc: Document) -> tuple[dict[str, SuggestionResult], SubjectSet]: |
||
61 | filtered_hits = {} |
||
62 | for project_id in self.project_ids: |
||
63 | project = self.registry.get_project(project_id) |
||
64 | batch = project.suggest([doc], self.backend_params) |
||
65 | filtered_hits[project_id] = batch.filter(self.limit, self.threshold)[0] |
||
66 | return (filtered_hits, doc.subject_set) |
||
67 | |||
68 | def suggest_batch( |
||
69 | self, batch |
||
70 | ) -> tuple[dict[str, SuggestionBatch], Iterator[SubjectSet]]: |
||
71 | filtered_hit_sets = {} |
||
72 | subject_sets = [doc.subject_set for doc in batch] |
||
73 | |||
74 | for project_id in self.project_ids: |
||
75 | project = self.registry.get_project(project_id) |
||
76 | suggestion_batch = project.suggest(batch, self.backend_params) |
||
77 | filtered_hit_sets[project_id] = suggestion_batch.filter( |
||
78 | self.limit, self.threshold |
||
79 | ) |
||
80 | return (filtered_hit_sets, subject_sets) |
||
81 | |||
82 | |||
83 | def get_pool(n_jobs: int) -> tuple[int | None, Callable]: |
||
84 | """return a suitable constructor for multiprocessing pool class, and the correct |
||
85 | jobs argument for it, for the given amount of parallel jobs""" |
||
86 | |||
87 | ctx = multiprocessing.get_context(MP_START_METHOD) |
||
88 | |||
89 | if n_jobs < 1: |
||
90 | n_jobs = None |
||
91 | pool_constructor: Callable = ctx.Pool |
||
0 ignored issues
–
show
introduced
by
![]() |
|||
92 | elif n_jobs == 1: |
||
93 | # use the dummy wrapper around threading to avoid subprocess overhead |
||
94 | pool_constructor = multiprocessing.dummy.Pool |
||
95 | else: |
||
96 | pool_constructor = ctx.Pool |
||
97 | |||
98 | return n_jobs, pool_constructor |
||
99 |