Completed
Push — master ( ca4d61...4b03e7 )
by Juho
16s queued 13s
created

annif.parallel.ProjectSuggestMap.suggest_batch()   A

Complexity

Conditions 3

Size

Total Lines 12
Code Lines 10

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 10
nop 2
dl 0
loc 12
rs 9.9
c 0
b 0
f 0
1
"""Parallel processing functionality for Annif"""
2
3
4
import multiprocessing
5
import multiprocessing.dummy
6
from collections import defaultdict
7
8
# Start method for processes created by the multiprocessing module.
9
# A value of None means using the platform-specific default.
10
# Intended to be overridden in unit tests.
11
MP_START_METHOD = None
12
13
14
class BaseWorker:
15
    """Base class for workers that implement tasks executed via
16
    multiprocessing. The init method can be used to store data objects that
17
    are necessary for the operation. They will be stored in a class
18
    attribute that is accessible to the static worker method. The storage
19
    solution is inspired by this blog post:
20
    https://thelaziestprogrammer.com/python/multiprocessing-pool-a-global-solution # noqa
21
    """
22
23
    args = None
24
25
    @classmethod
26
    def init(cls, args):
27
        cls.args = args  # pragma: no cover
28
29
30
class ProjectSuggestMap:
31
    """A utility class that can be used to wrap one or more projects and
32
    provide a mapping method that converts Document objects to suggestions.
33
    Intended to be used with the multiprocessing module."""
34
35
    def __init__(self, registry, project_ids, backend_params, limit, threshold):
36
        self.registry = registry
37
        self.project_ids = project_ids
38
        self.backend_params = backend_params
39
        self.limit = limit
40
        self.threshold = threshold
41
42
    def suggest(self, doc):
43
        filtered_hits = {}
44
        for project_id in self.project_ids:
45
            project = self.registry.get_project(project_id)
46
            hits = project.suggest([doc.text], self.backend_params)[0]
47
            filtered_hits[project_id] = hits.filter(
48
                project.subjects, self.limit, self.threshold
49
            )
50
        return (filtered_hits, doc.subject_set)
51
52
    def suggest_batch(self, batch):
53
        filtered_hit_sets = defaultdict(list)
54
        texts, subject_sets = zip(*[(doc.text, doc.subject_set) for doc in batch])
55
56
        for project_id in self.project_ids:
57
            project = self.registry.get_project(project_id)
58
            hit_sets = project.suggest(texts, self.backend_params)
59
            for hits in hit_sets:
60
                filtered_hit_sets[project_id].append(
61
                    hits.filter(project.subjects, self.limit, self.threshold)
62
                )
63
        return (filtered_hit_sets, subject_sets)
64
65
66
def get_pool(n_jobs):
67
    """return a suitable multiprocessing pool class, and the correct jobs
68
    argument for its constructor, for the given amount of parallel jobs"""
69
70
    ctx = multiprocessing.get_context(MP_START_METHOD)
71
72
    if n_jobs < 1:
73
        n_jobs = None
74
        pool_class = ctx.Pool
75
    elif n_jobs == 1:
76
        # use the dummy wrapper around threading to avoid subprocess overhead
77
        pool_class = multiprocessing.dummy.Pool
78
    else:
79
        pool_class = ctx.Pool
80
81
    return n_jobs, pool_class
82