1
|
|
|
"""Registry lookup methods. |
2
|
|
|
|
3
|
|
|
This module exports the following methods for registry lookup: |
4
|
|
|
|
5
|
|
|
all_objects(object_types, filter_tags) |
6
|
|
|
lookup and filtering of objects |
7
|
|
|
""" |
8
|
|
|
|
9
|
|
|
# based on the sktime module of same name |
10
|
|
|
|
11
|
|
|
__author__ = ["fkiraly"] |
12
|
|
|
# all_objects is based on the sklearn utility all_estimators |
13
|
|
|
|
14
|
|
|
from inspect import isclass |
15
|
|
|
from pathlib import Path |
16
|
|
|
|
17
|
|
|
from skbase.lookup import all_objects as _all_objects |
18
|
|
|
|
19
|
|
|
|
20
|
|
|
def all_objects( |
21
|
|
|
object_types=None, |
22
|
|
|
filter_tags=None, |
23
|
|
|
exclude_objects=None, |
24
|
|
|
return_names=True, |
25
|
|
|
as_dataframe=False, |
26
|
|
|
return_tags=None, |
27
|
|
|
suppress_import_stdout=True, |
28
|
|
|
): |
29
|
|
|
"""Get a list of all objects from hyperactive. |
30
|
|
|
|
31
|
|
|
This function crawls the module and gets all classes that inherit |
32
|
|
|
from skbase compatible base classes. |
33
|
|
|
|
34
|
|
|
Not included are: the base classes themselves, classes defined in test |
35
|
|
|
modules. |
36
|
|
|
|
37
|
|
|
Parameters |
38
|
|
|
---------- |
39
|
|
|
object_types: str, list of str, optional (default=None) |
40
|
|
|
Which kind of objects should be returned. |
41
|
|
|
|
42
|
|
|
* if None, no filter is applied and all objects are returned. |
43
|
|
|
* if str or list of str, strings define scitypes specified in search |
44
|
|
|
only objects that are of (at least) one of the scitypes are returned |
45
|
|
|
|
46
|
|
|
return_names: bool, optional (default=True) |
47
|
|
|
|
48
|
|
|
* if True, estimator class name is included in the ``all_objects`` |
49
|
|
|
return in the order: name, estimator class, optional tags, either as |
50
|
|
|
a tuple or as pandas.DataFrame columns |
51
|
|
|
* if False, estimator class name is removed from the ``all_objects`` return. |
52
|
|
|
|
53
|
|
|
filter_tags: dict of (str or list of str or re.Pattern), optional (default=None) |
54
|
|
|
For a list of valid tag strings, use the registry.all_tags utility. |
55
|
|
|
|
56
|
|
|
``filter_tags`` subsets the returned objects as follows: |
57
|
|
|
|
58
|
|
|
* each key/value pair is statement in "and"/conjunction |
59
|
|
|
* key is tag name to sub-set on |
60
|
|
|
* value str or list of string are tag values |
61
|
|
|
* condition is "key must be equal to value, or in set(value)" |
62
|
|
|
|
63
|
|
|
In detail, he return will be filtered to keep exactly the classes |
64
|
|
|
where tags satisfy all the filter conditions specified by ``filter_tags``. |
65
|
|
|
Filter conditions are as follows, for ``tag_name: search_value`` pairs in |
66
|
|
|
the ``filter_tags`` dict, applied to a class ``klass``: |
67
|
|
|
|
68
|
|
|
- If ``klass`` does not have a tag with name ``tag_name``, it is excluded. |
69
|
|
|
Otherwise, let ``tag_value`` be the value of the tag with name ``tag_name``. |
70
|
|
|
- If ``search_value`` is a string, and ``tag_value`` is a string, |
71
|
|
|
the filter condition is that ``search_value`` must match the tag value. |
72
|
|
|
- If ``search_value`` is a string, and ``tag_value`` is a list, |
73
|
|
|
the filter condition is that ``search_value`` is contained in ``tag_value``. |
74
|
|
|
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a string, |
75
|
|
|
the filter condition is that ``search_value.fullmatch(tag_value)`` |
76
|
|
|
is true, i.e., the regex matches the tag value. |
77
|
|
|
- If ``search_value`` is a ``re.Pattern``, and ``tag_value`` is a list, |
78
|
|
|
the filter condition is that at least one element of ``tag_value`` |
79
|
|
|
matches the regex. |
80
|
|
|
- If ``search_value`` is iterable, then the filter condition is that |
81
|
|
|
at least one element of ``search_value`` satisfies the above conditions, |
82
|
|
|
applied to ``tag_value``. |
83
|
|
|
|
84
|
|
|
Note: ``re.Pattern`` is supported only from ``scikit-base`` version 0.8.0. |
85
|
|
|
|
86
|
|
|
exclude_objects: str, list of str, optional (default=None) |
87
|
|
|
Names of objects to exclude. |
88
|
|
|
|
89
|
|
|
as_dataframe: bool, optional (default=False) |
90
|
|
|
|
91
|
|
|
* True: ``all_objects`` will return a ``pandas.DataFrame`` with named |
92
|
|
|
columns for all of the attributes being returned. |
93
|
|
|
* False: ``all_objects`` will return a list (either a list of |
94
|
|
|
objects or a list of tuples, see Returns) |
95
|
|
|
|
96
|
|
|
return_tags: str or list of str, optional (default=None) |
97
|
|
|
Names of tags to fetch and return each estimator's value of. |
98
|
|
|
For a list of valid tag strings, use the ``registry.all_tags`` utility. |
99
|
|
|
if str or list of str, |
100
|
|
|
the tag values named in return_tags will be fetched for each |
101
|
|
|
estimator and will be appended as either columns or tuple entries. |
102
|
|
|
|
103
|
|
|
suppress_import_stdout : bool, optional. Default=True |
104
|
|
|
whether to suppress stdout printout upon import. |
105
|
|
|
|
106
|
|
|
Returns |
107
|
|
|
------- |
108
|
|
|
all_objects will return one of the following: |
109
|
|
|
|
110
|
|
|
1. list of objects, if ``return_names=False``, and ``return_tags`` is None |
111
|
|
|
|
112
|
|
|
2. list of tuples (optional estimator name, class, optional estimator |
113
|
|
|
tags), if ``return_names=True`` or ``return_tags`` is not ``None``. |
114
|
|
|
|
115
|
|
|
3. ``pandas.DataFrame`` if ``as_dataframe = True`` |
116
|
|
|
|
117
|
|
|
if list of objects: |
118
|
|
|
entries are objects matching the query, |
119
|
|
|
in alphabetical order of estimator name |
120
|
|
|
|
121
|
|
|
if list of tuples: |
122
|
|
|
list of (optional estimator name, estimator, optional estimator |
123
|
|
|
tags) matching the query, in alphabetical order of estimator name, |
124
|
|
|
where |
125
|
|
|
``name`` is the estimator name as string, and is an |
126
|
|
|
optional return |
127
|
|
|
``estimator`` is the actual estimator |
128
|
|
|
``tags`` are the estimator's values for each tag in return_tags |
129
|
|
|
and is an optional return. |
130
|
|
|
|
131
|
|
|
if ``DataFrame``: |
132
|
|
|
column names represent the attributes contained in each column. |
133
|
|
|
"objects" will be the name of the column of objects, "names" |
134
|
|
|
will be the name of the column of estimator class names and the string(s) |
135
|
|
|
passed in return_tags will serve as column names for all columns of |
136
|
|
|
tags that were optionally requested. |
137
|
|
|
|
138
|
|
|
Examples |
139
|
|
|
-------- |
140
|
|
|
>>> from hyperactive._registry import all_objects |
141
|
|
|
>>> # return a complete list of objects as pd.Dataframe |
142
|
|
|
>>> all_objects(as_dataframe=True) # doctest: +SKIP |
143
|
|
|
|
144
|
|
|
References |
145
|
|
|
---------- |
146
|
|
|
Adapted version of sktime's ``all_estimators``, |
147
|
|
|
which is an evolution of scikit-learn's ``all_estimators`` |
148
|
|
|
""" |
149
|
|
|
MODULES_TO_IGNORE = ( |
150
|
|
|
"tests", |
151
|
|
|
"setup", |
152
|
|
|
"contrib", |
153
|
|
|
"utils", |
154
|
|
|
"all", |
155
|
|
|
) |
156
|
|
|
|
157
|
|
|
result = [] |
158
|
|
|
ROOT = str(Path(__file__).parent.parent) # package root directory |
159
|
|
|
|
160
|
|
|
def _coerce_to_str(obj): |
161
|
|
|
if isinstance(obj, (list, tuple)): |
162
|
|
|
return [_coerce_to_str(o) for o in obj] |
163
|
|
|
if isclass(obj): |
164
|
|
|
obj = obj.get_tag("object_type") |
165
|
|
|
return obj |
166
|
|
|
|
167
|
|
|
def _coerce_to_list_of_str(obj): |
168
|
|
|
obj = _coerce_to_str(obj) |
169
|
|
|
if isinstance(obj, str): |
170
|
|
|
return [obj] |
171
|
|
|
return obj |
172
|
|
|
|
173
|
|
|
if object_types is not None: |
174
|
|
|
object_types = _coerce_to_list_of_str(object_types) |
175
|
|
|
object_types = list(set(object_types)) |
176
|
|
|
|
177
|
|
|
if object_types is not None: |
178
|
|
|
if filter_tags is None: |
179
|
|
|
filter_tags = {} |
180
|
|
|
elif isinstance(filter_tags, str): |
181
|
|
|
filter_tags = {filter_tags: True} |
182
|
|
|
else: |
183
|
|
|
filter_tags = filter_tags.copy() |
184
|
|
|
|
185
|
|
|
if "object_type" in filter_tags: |
186
|
|
|
obj_field = filter_tags["object_type"] |
187
|
|
|
obj_field = _coerce_to_list_of_str(obj_field) |
188
|
|
|
obj_field = obj_field + object_types |
189
|
|
|
else: |
190
|
|
|
obj_field = object_types |
191
|
|
|
|
192
|
|
|
filter_tags["object_type"] = obj_field |
193
|
|
|
|
194
|
|
|
result = _all_objects( |
195
|
|
|
object_types=None, |
196
|
|
|
filter_tags=filter_tags, |
197
|
|
|
exclude_objects=exclude_objects, |
198
|
|
|
return_names=return_names, |
199
|
|
|
as_dataframe=as_dataframe, |
200
|
|
|
return_tags=return_tags, |
201
|
|
|
suppress_import_stdout=suppress_import_stdout, |
202
|
|
|
package_name="hyperactive", |
203
|
|
|
path=ROOT, |
204
|
|
|
modules_to_ignore=MODULES_TO_IGNORE, |
205
|
|
|
) |
206
|
|
|
|
207
|
|
|
return result |
208
|
|
|
|
209
|
|
|
|
210
|
|
|
def _check_list_of_str_or_error(arg_to_check, arg_name): |
211
|
|
|
"""Check that certain arguments are str or list of str. |
212
|
|
|
|
213
|
|
|
Parameters |
214
|
|
|
---------- |
215
|
|
|
arg_to_check: argument we are testing the type of |
216
|
|
|
arg_name: str, |
217
|
|
|
name of the argument we are testing, will be added to the error if |
218
|
|
|
``arg_to_check`` is not a str or a list of str |
219
|
|
|
|
220
|
|
|
Returns |
221
|
|
|
------- |
222
|
|
|
arg_to_check: list of str, |
223
|
|
|
if arg_to_check was originally a str it converts it into a list of str |
224
|
|
|
so that it can be iterated over. |
225
|
|
|
|
226
|
|
|
Raises |
227
|
|
|
------ |
228
|
|
|
TypeError if arg_to_check is not a str or list of str |
229
|
|
|
""" |
230
|
|
|
# check that return_tags has the right type: |
231
|
|
|
if isinstance(arg_to_check, str): |
232
|
|
|
arg_to_check = [arg_to_check] |
233
|
|
|
if not isinstance(arg_to_check, list) or not all( |
234
|
|
|
isinstance(value, str) for value in arg_to_check |
235
|
|
|
): |
236
|
|
|
raise TypeError( |
237
|
|
|
f"Error in all_objects! Argument {arg_name} must be either\ |
238
|
|
|
a str or list of str" |
239
|
|
|
) |
240
|
|
|
return arg_to_check |
241
|
|
|
|