Passed
Pull Request — master (#511)
by Osma
01:38
created

annif.parallel   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 66
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 38
dl 0
loc 66
rs 10
c 0
b 0
f 0
wmc 7

3 Methods

Rating   Name   Duplication   Size   Complexity  
A ProjectSuggestMap.__init__() 0 12 1
A ProjectSuggestMap.suggest() 0 8 2
A BaseWorker.init() 0 3 1

1 Function

Rating   Name   Duplication   Size   Complexity  
A get_pool() 0 14 3
1
"""Parallel processing functionality for Annif"""
2
3
4
import multiprocessing
5
import multiprocessing.dummy
6
7
8
class BaseWorker:
9
    """Base class for workers that implement tasks executed via
10
    multiprocessing. The init method can be used to store data objects that
11
    are necessary for the operation. They will be stored in a class
12
    attribute that is accessible to the static worker method. The storage
13
    solution is inspired by this blog post:
14
    https://thelaziestprogrammer.com/python/multiprocessing-pool-a-global-solution # noqa
15
    """
16
17
    args = None
18
19
    @classmethod
20
    def init(cls, args):
21
        cls.args = args
22
23
24
class ProjectSuggestMap:
25
    """A utility class that can be used to wrap one or more projects and
26
    provide a mapping method that converts Document objects to suggestions.
27
    Intended to be used with the multiprocessing module."""
28
29
    def __init__(
30
            self,
31
            registry,
32
            project_ids,
33
            backend_params,
34
            limit,
35
            threshold):
36
        self.registry = registry
37
        self.project_ids = project_ids
38
        self.backend_params = backend_params
39
        self.limit = limit
40
        self.threshold = threshold
41
42
    def suggest(self, doc):
43
        filtered_hits = {}
44
        for project_id in self.project_ids:
45
            project = self.registry.get_project(project_id)
46
            hits = project.suggest(doc.text, self.backend_params)
47
            filtered_hits[project_id] = hits.filter(
48
                project.subjects, self.limit, self.threshold)
49
        return (filtered_hits, doc.uris, doc.labels)
50
51
52
def get_pool(n_jobs):
53
    """return a suitable multiprocessing pool class, and the correct jobs
54
    argument for its constructor, for the given amount of parallel jobs"""
55
56
    if n_jobs < 1:
57
        n_jobs = None
58
        pool_class = multiprocessing.Pool
59
    elif n_jobs == 1:
60
        # use the dummy wrapper around threading to avoid subprocess overhead
61
        pool_class = multiprocessing.dummy.Pool
62
    else:
63
        pool_class = multiprocessing.Pool
64
65
    return n_jobs, pool_class
66