|
1
|
|
|
"""Registry lookup methods. |
|
2
|
|
|
|
|
3
|
|
|
This module exports the following methods for registry lookup: |
|
4
|
|
|
|
|
5
|
|
|
all_objects(object_types, filter_tags) |
|
6
|
|
|
lookup and filtering of objects |
|
7
|
|
|
""" |
|
8
|
|
|
|
|
9
|
|
|
# based on the sktime module of same name |
|
10
|
|
|
|
|
11
|
|
|
__author__ = ["fkiraly"] |
|
12
|
|
|
# all_objects is based on the sklearn utility all_estimators |
|
13
|
|
|
|
|
14
|
|
|
from inspect import isclass |
|
15
|
|
|
from pathlib import Path |
|
16
|
|
|
|
|
17
|
|
|
from skbase.lookup import all_objects as _all_objects |
|
18
|
|
|
|
|
19
|
|
|
|
|
20
|
|
|
def all_objects( |
|
21
|
|
|
object_types=None, |
|
22
|
|
|
filter_tags=None, |
|
23
|
|
|
exclude_objects=None, |
|
24
|
|
|
return_names=True, |
|
25
|
|
|
as_dataframe=False, |
|
26
|
|
|
return_tags=None, |
|
27
|
|
|
suppress_import_stdout=True, |
|
28
|
|
|
): |
|
29
|
|
|
"""Get a list of all objects from hyperactive. |
|
30
|
|
|
|
|
31
|
|
|
This function crawls the module and gets all classes that inherit |
|
32
|
|
|
from skbase compatible base classes. |
|
33
|
|
|
|
|
34
|
|
|
Not included are: the base classes themselves, classes defined in test |
|
35
|
|
|
modules. |
|
36
|
|
|
|
|
37
|
|
|
Parameters |
|
38
|
|
|
---------- |
|
39
|
|
|
object_types: str, list of str, optional (default=None) |
|
40
|
|
|
Which kind of objects should be returned. |
|
41
|
|
|
|
|
42
|
|
|
* if None, no filter is applied and all objects are returned. |
|
43
|
|
|
* if str or list of str, strings define scitypes specified in search |
|
44
|
|
|
only objects that are of (at least) one of the scitypes are returned |
|
45
|
|
|
|
|
46
|
|
|
return_names: bool, optional (default=True) |
|
47
|
|
|
|
|
48
|
|
|
* if True, estimator class name is included in the ``all_objects`` |
|
49
|
|
|
return in the order: name, estimator class, optional tags, either as |
|
50
|
|
|
a tuple or as pandas.DataFrame columns |
|
51
|
|
|
* if False, estimator class name is removed from the ``all_objects`` return. |
|
52
|
|
|
|
|
53
|
|
|
filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) |
|
54
|
|
|
For a list of valid tag strings, use the registry.all_tags utility. |
|
55
|
|
|
|
|
56
|
|
|
``filter_tags`` subsets the returned objects as follows: |
|
57
|
|
|
|
|
58
|
|
|
* each key/value pair is statement in "and"/conjunction |
|
59
|
|
|
* key is tag name to sub-set on |
|
60
|
|
|
* value str or list of string are tag values |
|
61
|
|
|
* condition is "key must be equal to value, or in set(value)" |
|
62
|
|
|
|
|
63
|
|
|
In detail, he return will be filtered to keep exactly the classes |
|
64
|
|
|
where tags satisfy all the filter conditions specified by ``filter_tags``. |
|
65
|
|
|
Filter conditions are as follows, for ``tag_name: search_value`` pairs in |
|
66
|
|
|
the ``filter_tags`` dict, applied to a class ``klass``: |
|
67
|
|
|
|
|
68
|
|
|
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded. |
|
69
|
|
|
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. |
|
70
|
|
|
- If ``search_value`` is a string, and ``tag_value`` is a string, |
|
71
|
|
|
the filter condition is that ``search_value`` must match the tag value. |
|
72
|
|
|
- If ``search_value`` is a string, and ``tag_value`` is a list, |
|
73
|
|
|
the filter condition is that ``search_value`` is contained in ``tag_value``. |
|
74
|
|
|
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, |
|
75
|
|
|
the filter condition is that ``search_value.fullmatch(tag_value)`` |
|
76
|
|
|
is true, i.e., the regex matches the tag value. |
|
77
|
|
|
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, |
|
78
|
|
|
the filter condition is that at least one element of ``tag_value`` |
|
79
|
|
|
matches the regex. |
|
80
|
|
|
- If ``search_value`` is iterable, then the filter condition is that |
|
81
|
|
|
at least one element of ``search_value`` satisfies the above conditions, |
|
82
|
|
|
applied to ``tag_value``. |
|
83
|
|
|
|
|
84
|
|
|
Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. |
|
85
|
|
|
|
|
86
|
|
|
exclude_objects: str, list of str, optional (default=None) |
|
87
|
|
|
Names of objects to exclude. |
|
88
|
|
|
|
|
89
|
|
|
as_dataframe: bool, optional (default=False) |
|
90
|
|
|
|
|
91
|
|
|
* True: ``all_objects`` will return a ``pandas.DataFrame`` with named |
|
92
|
|
|
columns for all of the attributes being returned. |
|
93
|
|
|
* False: ``all_objects`` will return a list (either a list of |
|
94
|
|
|
objects or a list of tuples, see Returns) |
|
95
|
|
|
|
|
96
|
|
|
return_tags: str or list of str, optional (default=None) |
|
97
|
|
|
Names of tags to fetch and return each estimator's value of. |
|
98
|
|
|
For a list of valid tag strings, use the ``registry.all_tags`` utility. |
|
99
|
|
|
if str or list of str, |
|
100
|
|
|
the tag values named in return_tags will be fetched for each |
|
101
|
|
|
estimator and will be appended as either columns or tuple entries. |
|
102
|
|
|
|
|
103
|
|
|
suppress_import_stdout : bool, optional. Default=True |
|
104
|
|
|
whether to suppress stdout printout upon import. |
|
105
|
|
|
|
|
106
|
|
|
Returns |
|
107
|
|
|
------- |
|
108
|
|
|
all_objects will return one of the following: |
|
109
|
|
|
|
|
110
|
|
|
1. list of objects, if ``return_names=False``, and ``return_tags`` is None |
|
111
|
|
|
|
|
112
|
|
|
2. list of tuples (optional estimator name, class, optional estimator |
|
113
|
|
|
tags), if ``return_names=True`` or ``return_tags`` is not ``None``. |
|
114
|
|
|
|
|
115
|
|
|
3. ``pandas.DataFrame`` if ``as_dataframe = True`` |
|
116
|
|
|
|
|
117
|
|
|
if list of objects: |
|
118
|
|
|
entries are objects matching the query, |
|
119
|
|
|
in alphabetical order of estimator name |
|
120
|
|
|
|
|
121
|
|
|
if list of tuples: |
|
122
|
|
|
list of (optional estimator name, estimator, optional estimator |
|
123
|
|
|
tags) matching the query, in alphabetical order of estimator name, |
|
124
|
|
|
where |
|
125
|
|
|
``name`` is the estimator name as string, and is an |
|
126
|
|
|
optional return |
|
127
|
|
|
``estimator`` is the actual estimator |
|
128
|
|
|
``tags`` are the estimator's values for each tag in return_tags |
|
129
|
|
|
and is an optional return. |
|
130
|
|
|
|
|
131
|
|
|
if ``DataFrame``: |
|
132
|
|
|
column names represent the attributes contained in each column. |
|
133
|
|
|
"objects" will be the name of the column of objects, "names" |
|
134
|
|
|
will be the name of the column of estimator class names and the string(s) |
|
135
|
|
|
passed in return_tags will serve as column names for all columns of |
|
136
|
|
|
tags that were optionally requested. |
|
137
|
|
|
|
|
138
|
|
|
Examples |
|
139
|
|
|
-------- |
|
140
|
|
|
>>> from hyperactive._registry import all_objects |
|
141
|
|
|
>>> # return a complete list of objects as pd.Dataframe |
|
142
|
|
|
>>> all_objects(as_dataframe=True) # doctest: +SKIP |
|
143
|
|
|
|
|
144
|
|
|
References |
|
145
|
|
|
---------- |
|
146
|
|
|
Adapted version of sktime's ``all_estimators``, |
|
147
|
|
|
which is an evolution of scikit-learn's ``all_estimators`` |
|
148
|
|
|
""" |
|
149
|
|
|
MODULES_TO_IGNORE = ( |
|
150
|
|
|
"tests", |
|
151
|
|
|
"setup", |
|
152
|
|
|
"contrib", |
|
153
|
|
|
"utils", |
|
154
|
|
|
"all", |
|
155
|
|
|
) |
|
156
|
|
|
|
|
157
|
|
|
result = [] |
|
158
|
|
|
ROOT = str(Path(__file__).parent.parent) # package root directory |
|
159
|
|
|
|
|
160
|
|
|
def _coerce_to_str(obj): |
|
161
|
|
|
if isinstance(obj, (list, tuple)): |
|
162
|
|
|
return [_coerce_to_str(o) for o in obj] |
|
163
|
|
|
if isclass(obj): |
|
164
|
|
|
obj = obj.get_tag("object_type") |
|
165
|
|
|
return obj |
|
166
|
|
|
|
|
167
|
|
|
def _coerce_to_list_of_str(obj): |
|
168
|
|
|
obj = _coerce_to_str(obj) |
|
169
|
|
|
if isinstance(obj, str): |
|
170
|
|
|
return [obj] |
|
171
|
|
|
return obj |
|
172
|
|
|
|
|
173
|
|
|
if object_types is not None: |
|
174
|
|
|
object_types = _coerce_to_list_of_str(object_types) |
|
175
|
|
|
object_types = list(set(object_types)) |
|
176
|
|
|
|
|
177
|
|
|
if object_types is not None: |
|
178
|
|
|
if filter_tags is None: |
|
179
|
|
|
filter_tags = {} |
|
180
|
|
|
elif isinstance(filter_tags, str): |
|
181
|
|
|
filter_tags = {filter_tags: True} |
|
182
|
|
|
else: |
|
183
|
|
|
filter_tags = filter_tags.copy() |
|
184
|
|
|
|
|
185
|
|
|
if "object_type" in filter_tags: |
|
186
|
|
|
obj_field = filter_tags["object_type"] |
|
187
|
|
|
obj_field = _coerce_to_list_of_str(obj_field) |
|
188
|
|
|
obj_field = obj_field + object_types |
|
189
|
|
|
else: |
|
190
|
|
|
obj_field = object_types |
|
191
|
|
|
|
|
192
|
|
|
filter_tags["object_type"] = obj_field |
|
193
|
|
|
|
|
194
|
|
|
result = _all_objects( |
|
195
|
|
|
object_types=None, |
|
196
|
|
|
filter_tags=filter_tags, |
|
197
|
|
|
exclude_objects=exclude_objects, |
|
198
|
|
|
return_names=return_names, |
|
199
|
|
|
as_dataframe=as_dataframe, |
|
200
|
|
|
return_tags=return_tags, |
|
201
|
|
|
suppress_import_stdout=suppress_import_stdout, |
|
202
|
|
|
package_name="hyperactive", |
|
203
|
|
|
path=ROOT, |
|
204
|
|
|
modules_to_ignore=MODULES_TO_IGNORE, |
|
205
|
|
|
) |
|
206
|
|
|
|
|
207
|
|
|
return result |
|
208
|
|
|
|
|
209
|
|
|
|
|
210
|
|
|
def _check_list_of_str_or_error(arg_to_check, arg_name): |
|
211
|
|
|
"""Check that certain arguments are str or list of str. |
|
212
|
|
|
|
|
213
|
|
|
Parameters |
|
214
|
|
|
---------- |
|
215
|
|
|
arg_to_check: argument we are testing the type of |
|
216
|
|
|
arg_name: str, |
|
217
|
|
|
name of the argument we are testing, will be added to the error if |
|
218
|
|
|
``arg_to_check`` is not a str or a list of str |
|
219
|
|
|
|
|
220
|
|
|
Returns |
|
221
|
|
|
------- |
|
222
|
|
|
arg_to_check: list of str, |
|
223
|
|
|
if arg_to_check was originally a str it converts it into a list of str |
|
224
|
|
|
so that it can be iterated over. |
|
225
|
|
|
|
|
226
|
|
|
Raises |
|
227
|
|
|
------ |
|
228
|
|
|
TypeError if arg_to_check is not a str or list of str |
|
229
|
|
|
""" |
|
230
|
|
|
# check that return_tags has the right type: |
|
231
|
|
|
if isinstance(arg_to_check, str): |
|
232
|
|
|
arg_to_check = [arg_to_check] |
|
233
|
|
|
if not isinstance(arg_to_check, list) or not all( |
|
234
|
|
|
isinstance(value, str) for value in arg_to_check |
|
235
|
|
|
): |
|
236
|
|
|
raise TypeError( |
|
237
|
|
|
f"Error in all_objects! Argument {arg_name} must be either\ |
|
238
|
|
|
a str or list of str" |
|
239
|
|
|
) |
|
240
|
|
|
return arg_to_check |
|
241
|
|
|
|