hyperactive._registry._lookup.all_objects()   C
last analyzed

Complexity

Conditions 9

Size

Total Lines 188
Code Lines 54

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 54
dl 0
loc 188
rs 6.1721
c 0
b 0
f 0
cc 9
nop 7

How to fix   Long Method   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

1
"""Registry lookup methods.
2
3
This module exports the following methods for registry lookup:
4
5
all_objects(object_types, filter_tags)
6
    lookup and filtering of objects
7
"""
8
9
# based on the sktime module of same name
10
11
__author__ = ["fkiraly"]
12
# all_objects is based on the sklearn utility all_estimators
13
14
from inspect import isclass
15
from pathlib import Path
16
17
from skbase.lookup import all_objects as _all_objects
18
19
20
def all_objects(
21
    object_types=None,
22
    filter_tags=None,
23
    exclude_objects=None,
24
    return_names=True,
25
    as_dataframe=False,
26
    return_tags=None,
27
    suppress_import_stdout=True,
28
):
29
    """Get a list of all objects from hyperactive.
30
31
    This function crawls the module and gets all classes that inherit
32
    from skbase compatible base classes.
33
34
    Not included are: the base classes themselves, classes defined in test
35
    modules.
36
37
    Parameters
38
    ----------
39
    object_types: str, list of str, optional (default=None)
40
        Which kind of objects should be returned.
41
42
        * if None, no filter is applied and all objects are returned.
43
        * if str or list of str, strings define scitypes specified in search
44
          only objects that are of (at least) one of the scitypes are returned
45
46
    return_names: bool, optional (default=True)
47
48
        * if True, estimator class name is included in the ``all_objects``
49
          return in the order: name, estimator class, optional tags, either as
50
          a tuple or as pandas.DataFrame columns
51
        * if False, estimator class name is removed from the ``all_objects`` return.
52
53
    filter_tags: dict of (str or list of str or re.Pattern), optional (default=None)
54
        For a list of valid tag strings, use the registry.all_tags utility.
55
56
        ``filter_tags`` subsets the returned objects as follows:
57
58
        * each key/value pair is statement in "and"/conjunction
59
        * key is tag name to sub-set on
60
        * value str or list of string are tag values
61
        * condition is "key must be equal to value, or in set(value)"
62
63
        In detail, he return will be filtered to keep exactly the classes
64
        where tags satisfy all the filter conditions specified by ``filter_tags``.
65
        Filter conditions are as follows, for ``tag_name: search_value`` pairs in
66
        the ``filter_tags`` dict, applied to a class ``klass``:
67
68
        - If ``klass`` does not have a tag with name ``tag_name``, it is excluded.
69
          Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``.
70
        - If ``search_value`` is a string, and ``tag_value`` is a string,
71
          the filter condition is that ``search_value`` must match the tag value.
72
        - If ``search_value`` is a string, and ``tag_value`` is a list,
73
          the filter condition is that ``search_value`` is contained in ``tag_value``.
74
        - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string,
75
          the filter condition is that ``search_value.fullmatch(tag_value)``
76
          is true, i.e., the regex matches the tag value.
77
        - If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list,
78
          the filter condition is that at least one element of ``tag_value``
79
          matches the regex.
80
        - If ``search_value`` is iterable, then the filter condition is that
81
          at least one element of ``search_value`` satisfies the above conditions,
82
          applied to ``tag_value``.
83
84
        Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0.
85
86
    exclude_objects: str, list of str, optional (default=None)
87
        Names of objects to exclude.
88
89
    as_dataframe: bool, optional (default=False)
90
91
        * True: ``all_objects`` will return a ``pandas.DataFrame`` with named
92
          columns for all of the attributes being returned.
93
        * False: ``all_objects`` will return a list (either a list of
94
          objects or a list of tuples, see Returns)
95
96
    return_tags: str or list of str, optional (default=None)
97
        Names of tags to fetch and return each estimator's value of.
98
        For a list of valid tag strings, use the ``registry.all_tags`` utility.
99
        if str or list of str,
100
        the tag values named in return_tags will be fetched for each
101
        estimator and will be appended as either columns or tuple entries.
102
103
    suppress_import_stdout : bool, optional. Default=True
104
        whether to suppress stdout printout upon import.
105
106
    Returns
107
    -------
108
    all_objects will return one of the following:
109
110
        1. list of objects, if ``return_names=False``, and ``return_tags`` is None
111
112
        2. list of tuples (optional estimator name, class, optional estimator
113
        tags), if ``return_names=True`` or ``return_tags`` is not ``None``.
114
115
        3. ``pandas.DataFrame`` if ``as_dataframe = True``
116
117
        if list of objects:
118
            entries are objects matching the query,
119
            in alphabetical order of estimator name
120
121
        if list of tuples:
122
            list of (optional estimator name, estimator, optional estimator
123
            tags) matching the query, in alphabetical order of estimator name,
124
            where
125
            ``name`` is the estimator name as string, and is an
126
            optional return
127
            ``estimator`` is the actual estimator
128
            ``tags`` are the estimator's values for each tag in return_tags
129
            and is an optional return.
130
131
        if ``DataFrame``:
132
            column names represent the attributes contained in each column.
133
            "objects" will be the name of the column of objects, "names"
134
            will be the name of the column of estimator class names and the string(s)
135
            passed in return_tags will serve as column names for all columns of
136
            tags that were optionally requested.
137
138
    Examples
139
    --------
140
    >>> from hyperactive._registry import all_objects
141
    >>> # return a complete list of objects as pd.Dataframe
142
    >>> all_objects(as_dataframe=True)  # doctest: +SKIP
143
144
    References
145
    ----------
146
    Adapted version of sktime's ``all_estimators``,
147
    which is an evolution of scikit-learn's ``all_estimators``
148
    """
149
    MODULES_TO_IGNORE = (
150
        "tests",
151
        "setup",
152
        "contrib",
153
        "utils",
154
        "all",
155
    )
156
157
    result = []
158
    ROOT = str(Path(__file__).parent.parent)  # package root directory
159
160
    def _coerce_to_str(obj):
161
        if isinstance(obj, (list, tuple)):
162
            return [_coerce_to_str(o) for o in obj]
163
        if isclass(obj):
164
            obj = obj.get_tag("object_type")
165
        return obj
166
167
    def _coerce_to_list_of_str(obj):
168
        obj = _coerce_to_str(obj)
169
        if isinstance(obj, str):
170
            return [obj]
171
        return obj
172
173
    if object_types is not None:
174
        object_types = _coerce_to_list_of_str(object_types)
175
        object_types = list(set(object_types))
176
177
    if object_types is not None:
178
        if filter_tags is None:
179
            filter_tags = {}
180
        elif isinstance(filter_tags, str):
181
            filter_tags = {filter_tags: True}
182
        else:
183
            filter_tags = filter_tags.copy()
184
185
        if "object_type" in filter_tags:
186
            obj_field = filter_tags["object_type"]
187
            obj_field = _coerce_to_list_of_str(obj_field)
188
            obj_field = obj_field + object_types
189
        else:
190
            obj_field = object_types
191
192
        filter_tags["object_type"] = obj_field
193
194
    result = _all_objects(
195
        object_types=None,
196
        filter_tags=filter_tags,
197
        exclude_objects=exclude_objects,
198
        return_names=return_names,
199
        as_dataframe=as_dataframe,
200
        return_tags=return_tags,
201
        suppress_import_stdout=suppress_import_stdout,
202
        package_name="hyperactive",
203
        path=ROOT,
204
        modules_to_ignore=MODULES_TO_IGNORE,
205
    )
206
207
    return result
208
209
210
def _check_list_of_str_or_error(arg_to_check, arg_name):
211
    """Check that certain arguments are str or list of str.
212
213
    Parameters
214
    ----------
215
    arg_to_check: argument we are testing the type of
216
    arg_name: str,
217
        name of the argument we are testing, will be added to the error if
218
        ``arg_to_check`` is not a str or a list of str
219
220
    Returns
221
    -------
222
    arg_to_check: list of str,
223
        if arg_to_check was originally a str it converts it into a list of str
224
        so that it can be iterated over.
225
226
    Raises
227
    ------
228
    TypeError if arg_to_check is not a str or list of str
229
    """
230
    # check that return_tags has the right type:
231
    if isinstance(arg_to_check, str):
232
        arg_to_check = [arg_to_check]
233
    if not isinstance(arg_to_check, list) or not all(
234
        isinstance(value, str) for value in arg_to_check
235
    ):
236
        raise TypeError(
237
            f"Error in all_objects!  Argument {arg_name} must be either\
238
             a str or list of str"
239
        )
240
    return arg_to_check
241