Passed
Pull Request — master (#125)
by
unknown
01:28
created

hyperactive.registry._lookup   A

Complexity

Total Complexity 13

Size/Duplication

Total Lines 251
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 67
dl 0
loc 251
rs 10
c 0
b 0
f 0
wmc 13

2 Functions

Rating   Name   Duplication   Size   Complexity  
A _check_list_of_str_or_error() 0 31 4
C all_objects() 0 198 9
1
"""Registry lookup methods.
2
3
This module exports the following methods for registry lookup:
4
5
all_objects(object_types, filter_tags)
6
    lookup and filtering of objects
7
"""
8
9
# based on the sktime module of same name
10
11
__author__ = ["fkiraly"]
12
# all_objects is based on the sklearn utility all_estimators
13
14
from inspect import isclass
15
from pathlib import Path
16
17
from skbase.lookup import all_objects as _all_objects
18
19
20
def all_objects(
21
    object_types=None,
22
    filter_tags=None,
23
    exclude_objects=None,
24
    return_names=True,
25
    as_dataframe=False,
26
    return_tags=None,
27
    suppress_import_stdout=True,
28
):
29
    """Get a list of all objects from hyperactive.
30
31
    This function crawls the module and gets all classes that inherit
32
    from skbase compatible base classes.
33
34
    Not included are: the base classes themselves, classes defined in test
35
    modules.
36
37
    To filter by object type, use the ``object_types`` parameter:
38
39
    * ``"optimizer"``: optimizers, e.g., the Hill Climbing optimizer,
40
      or the Particle Swarm optimizer
41
    * ``"experiment"``: experiments, e.g., the Ackley function, or an ``sklearn``
42
      cross-validation experiment
43
    * if ``None``, no filter is applied and all objects are returned.
44
45
    To filter by tag, use the ``filter_tags`` parameter.
46
47
    Parameters
48
    ----------
49
    object_types: str, list of str, optional (default=None)
50
        Which kind of objects should be returned.
51
52
        * if None, no filter is applied and all objects are returned.
53
        * if str or list of str, strings define scitypes specified in search
54
          only objects that are of (at least) one of the scitypes are returned
55
56
    return_names: bool, optional (default=True)
57
58
        * if True, estimator class name is included in the ``all_objects``
59
          return in the order: name, estimator class, optional tags, either as
60
          a tuple or as pandas.DataFrame columns
61
        * if False, estimator class name is removed from the ``all_objects`` return.
62
63
    filter_tags: dict of (str or list of str or re.Pattern), optional (default=None)
64
        For a list of valid tag strings, use the registry.all_tags utility.
65
66
        ``filter_tags`` subsets the returned objects as follows:
67
68
        * each key/value pair is statement in "and"/conjunction
69
        * key is tag name to sub-set on
70
        * value str or list of string are tag values
71
        * condition is "key must be equal to value, or in set(value)"
72
73
        In detail, he return will be filtered to keep exactly the classes
74
        where tags satisfy all the filter conditions specified by ``filter_tags``.
75
        Filter conditions are as follows, for ``tag_name: search_value`` pairs in
76
        the ``filter_tags`` dict, applied to a class ``klass``:
77
78
        - If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
79
          Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
80
        - If ``search_value`` is a string, and ``tag_value`` is a string,
81
          the filter condition is that ``search_value`` must match the tag value.
82
        - If ``search_value`` is a string, and ``tag_value`` is a list,
83
          the filter condition is that ``search_value`` is contained in ``tag_value``.
84
        - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
85
          the filter condition is that ``search_value.fullmatch(tag_value)``
86
          is true, i.e., the regex matches the tag value.
87
        - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
88
          the filter condition is that at least one element of ``tag_value``
89
          matches the regex.
90
        - If ``search_value`` is iterable, then the filter condition is that
91
          at least one element of ``search_value`` satisfies the above conditions,
92
          applied to ``tag_value``.
93
94
        Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0.
95
96
    exclude_objects: str, list of str, optional (default=None)
97
        Names of objects to exclude.
98
99
    as_dataframe: bool, optional (default=False)
100
101
        * True: ``all_objects`` will return a ``pandas.DataFrame`` with named
102
          columns for all of the attributes being returned.
103
        * False: ``all_objects`` will return a list (either a list of
104
          objects or a list of tuples, see Returns)
105
106
    return_tags: str or list of str, optional (default=None)
107
        Names of tags to fetch and return each estimator's value of.
108
        For a list of valid tag strings, use the ``registry.all_tags`` utility.
109
        if str or list of str,
110
        the tag values named in return_tags will be fetched for each
111
        estimator and will be appended as either columns or tuple entries.
112
113
    suppress_import_stdout : bool, optional. Default=True
114
        whether to suppress stdout printout upon import.
115
116
    Returns
117
    -------
118
    all_objects will return one of the following:
119
120
        1. list of objects, if ``return_names=False``, and ``return_tags`` is None
121
122
        2. list of tuples (optional estimator name, class, optional estimator
123
        tags), if ``return_names=True`` or ``return_tags`` is not ``None``.
124
125
        3. ``pandas.DataFrame`` if ``as_dataframe = True``
126
127
        if list of objects:
128
            entries are objects matching the query,
129
            in alphabetical order of estimator name
130
131
        if list of tuples:
132
            list of (optional estimator name, estimator, optional estimator
133
            tags) matching the query, in alphabetical order of estimator name,
134
            where
135
            ``name`` is the estimator name as string, and is an
136
            optional return
137
            ``estimator`` is the actual estimator
138
            ``tags`` are the estimator's values for each tag in return_tags
139
            and is an optional return.
140
141
        if ``DataFrame``:
142
            column names represent the attributes contained in each column.
143
            "objects" will be the name of the column of objects, "names"
144
            will be the name of the column of estimator class names and the string(s)
145
            passed in return_tags will serve as column names for all columns of
146
            tags that were optionally requested.
147
148
    Examples
149
    --------
150
    >>> from hyperactive.registry import all_objects
151
    >>> # return a complete list of objects as pd.Dataframe
152
    >>> all_objects(as_dataframe=True)  # doctest: +SKIP
153
154
    References
155
    ----------
156
    Adapted version of sktime's ``all_estimators``,
157
    which is an evolution of scikit-learn's ``all_estimators``
158
    """
159
    MODULES_TO_IGNORE = (
160
        "tests",
161
        "setup",
162
        "contrib",
163
        "utils",
164
        "all",
165
    )
166
167
    result = []
168
    ROOT = str(Path(__file__).parent.parent)  # package root directory
169
170
    def _coerce_to_str(obj):
171
        if isinstance(obj, (list, tuple)):
172
            return [_coerce_to_str(o) for o in obj]
173
        if isclass(obj):
174
            obj = obj.get_tag("object_type")
175
        return obj
176
177
    def _coerce_to_list_of_str(obj):
178
        obj = _coerce_to_str(obj)
179
        if isinstance(obj, str):
180
            return [obj]
181
        return obj
182
183
    if object_types is not None:
184
        object_types = _coerce_to_list_of_str(object_types)
185
        object_types = list(set(object_types))
186
187
    if object_types is not None:
188
        if filter_tags is None:
189
            filter_tags = {}
190
        elif isinstance(filter_tags, str):
191
            filter_tags = {filter_tags: True}
192
        else:
193
            filter_tags = filter_tags.copy()
194
195
        if "object_type" in filter_tags:
196
            obj_field = filter_tags["object_type"]
197
            obj_field = _coerce_to_list_of_str(obj_field)
198
            obj_field = obj_field + object_types
199
        else:
200
            obj_field = object_types
201
202
        filter_tags["object_type"] = obj_field
203
204
    result = _all_objects(
205
        object_types=None,
206
        filter_tags=filter_tags,
207
        exclude_objects=exclude_objects,
208
        return_names=return_names,
209
        as_dataframe=as_dataframe,
210
        return_tags=return_tags,
211
        suppress_import_stdout=suppress_import_stdout,
212
        package_name="hyperactive",
213
        path=ROOT,
214
        modules_to_ignore=MODULES_TO_IGNORE,
215
    )
216
217
    return result
218
219
220
def _check_list_of_str_or_error(arg_to_check, arg_name):
221
    """Check that certain arguments are str or list of str.
222
223
    Parameters
224
    ----------
225
    arg_to_check: argument we are testing the type of
226
    arg_name: str,
227
        name of the argument we are testing, will be added to the error if
228
        ``arg_to_check`` is not a str or a list of str
229
230
    Returns
231
    -------
232
    arg_to_check: list of str,
233
        if arg_to_check was originally a str it converts it into a list of str
234
        so that it can be iterated over.
235
236
    Raises
237
    ------
238
    TypeError if arg_to_check is not a str or list of str
239
    """
240
    # check that return_tags has the right type:
241
    if isinstance(arg_to_check, str):
242
        arg_to_check = [arg_to_check]
243
    if not isinstance(arg_to_check, list) or not all(
244
        isinstance(value, str) for value in arg_to_check
245
    ):
246
        raise TypeError(
247
            f"Error in all_objects!  Argument {arg_name} must be either\
248
             a str or list of str"
249
        )
250
    return arg_to_check
251