1 | from datetime import datetime, timedelta |
||
2 | from typing import ( |
||
3 | Any, |
||
4 | Dict, |
||
5 | Type, |
||
6 | ) |
||
7 | |||
8 | import numpy |
||
9 | from typish import ClsFunction |
||
10 | |||
11 | from nptyping.functions._py_type import py_type |
||
12 | from nptyping.types._bool import Bool |
||
13 | from nptyping.types._complex import Complex128 |
||
14 | from nptyping.types._datetime64 import Datetime64 |
||
15 | from nptyping.types._ndarray import NDArray |
||
16 | from nptyping.types._nptype import NPType |
||
17 | from nptyping.types._number import ( |
||
18 | DEFAULT_FLOAT_BITS, |
||
19 | DEFAULT_INT_BITS, |
||
20 | Float, |
||
21 | Int, |
||
22 | Number, |
||
23 | UInt, |
||
24 | ) |
||
25 | from nptyping.types._object import Object |
||
26 | from nptyping.types._structured_type import StructuredType, is_structured_type |
||
27 | from nptyping.types._subarray_type import SubArrayType, is_subarray_type |
||
28 | from nptyping.types._timedelta64 import Timedelta64 |
||
29 | from nptyping.types._unicode import Unicode |
||
30 | |||
31 | |||
32 | def get_type(obj: Any) -> Type['NPType']: |
||
33 | """ |
||
34 | Return the nptyping type of the given obj. The given obj can be a numpy |
||
35 | ndarray, a dtype or a Python type. If no corresponding nptyping type |
||
36 | can be determined, a TypeError is raised. |
||
37 | :param obj: the object for which an nptyping type is to be returned. |
||
38 | :return: a subclass of NPType. |
||
39 | """ |
||
40 | return ClsFunction(_delegates)(obj) |
||
41 | |||
42 | |||
43 | def _get_type_type(type_: type) -> Type['NPType']: |
||
44 | # Return the nptyping type of a type. |
||
45 | for super_type, delegate in _delegates: |
||
46 | if issubclass(type_, super_type): |
||
47 | break |
||
48 | return delegate(type_) |
||
0 ignored issues
–
show
introduced
by
Loading history...
|
|||
49 | |||
50 | |||
51 | View Code Duplication | def _get_type_dtype(dtype: numpy.dtype) -> Type['NPType']: |
|
0 ignored issues
–
show
|
|||
52 | # Return the nptyping type of a numpy dtype. |
||
53 | if is_subarray_type(dtype): |
||
54 | return get_subarray_type(dtype) |
||
55 | if is_structured_type(dtype): |
||
56 | return get_structured_type(dtype) |
||
57 | np_type_per_py_type = { |
||
58 | type: _get_type_type, |
||
59 | bool: get_type_bool, |
||
60 | int: get_type_int, |
||
61 | float: get_type_float, |
||
62 | str: get_type_str, |
||
63 | complex: get_type_complex, |
||
64 | datetime: get_type_datetime64, |
||
65 | timedelta: get_type_timedelta64, |
||
66 | object: lambda _: Object, |
||
67 | } |
||
68 | return np_type_per_py_type[(py_type(dtype))](dtype) |
||
69 | |||
70 | |||
71 | def _get_type_arrary(arr: numpy.ndarray) -> Type['NPType']: |
||
72 | # Return the nptyping type of a numpy array. |
||
73 | type_ = get_type(arr.dtype) |
||
74 | return NDArray[arr.shape, type_] |
||
75 | |||
76 | |||
77 | View Code Duplication | def _get_type_of_number( |
|
0 ignored issues
–
show
|
|||
78 | cls: Type['Number'], |
||
79 | obj: Any, |
||
80 | bits_per_type: Dict[type, int]) -> Type[Number]: |
||
81 | # Return the nptyping Number type of the given obj using cls and |
||
82 | # bits_per_type. |
||
83 | bits = (bits_per_type.get(obj) |
||
84 | or bits_per_type.get(getattr(obj, 'type', None)) |
||
85 | or bits_per_type.get(type(obj))) |
||
86 | |||
87 | if not bits: |
||
88 | raise TypeError('Unsupported type {} for {}' |
||
89 | .format(type(obj).__name__, cls)) |
||
90 | |||
91 | return cls[bits] |
||
92 | |||
93 | |||
94 | # Library private. |
||
95 | def get_type_bool(_: Any) -> Type[Bool]: |
||
96 | """ |
||
97 | Return the NPType that corresponds to obj. |
||
98 | :param _: a bool compatible object. |
||
99 | :return: a Bool type. |
||
100 | """ |
||
101 | return Bool |
||
102 | |||
103 | |||
104 | # Library private. |
||
105 | View Code Duplication | def get_type_str(obj: Any) -> Type[Unicode]: |
|
0 ignored issues
–
show
|
|||
106 | """ |
||
107 | Return the NPType that corresponds to obj. |
||
108 | :param obj: a string compatible object. |
||
109 | :return: a Unicode type. |
||
110 | """ |
||
111 | if isinstance(obj, numpy.dtype): |
||
112 | return Unicode[obj.itemsize / 4] |
||
113 | if obj == str: |
||
114 | return Unicode |
||
115 | if not isinstance(obj, str): |
||
116 | raise TypeError('Unsupported type {}'.format(type(obj))) |
||
117 | return Unicode[len(obj)] |
||
118 | |||
119 | |||
120 | # Library private. |
||
121 | def get_type_int(obj: Any) -> Type[Int]: |
||
122 | """ |
||
123 | Return the NPType that corresponds to obj. |
||
124 | :param obj: an int compatible object. |
||
125 | :return: a Int type. |
||
126 | """ |
||
127 | return _get_type_of_number(Int, obj, { |
||
128 | numpy.int8: 8, |
||
129 | numpy.int16: 16, |
||
130 | numpy.int32: 32, |
||
131 | numpy.int64: 64, |
||
132 | int: DEFAULT_INT_BITS, |
||
133 | }) |
||
134 | |||
135 | |||
136 | # Library private. |
||
137 | def get_type_uint(obj: Any) -> Type[UInt]: |
||
138 | """ |
||
139 | Return the NPType that corresponds to obj. |
||
140 | :param obj: an uint compatible object. |
||
141 | :return: an UInt type. |
||
142 | """ |
||
143 | return _get_type_of_number(UInt, obj, { |
||
144 | numpy.uint8: 8, |
||
145 | numpy.uint16: 16, |
||
146 | numpy.uint32: 32, |
||
147 | numpy.uint64: 64, |
||
148 | int: DEFAULT_INT_BITS, |
||
149 | }) |
||
150 | |||
151 | |||
152 | # Library private. |
||
153 | def get_type_float(obj: Any) -> Type[Float]: |
||
154 | """ |
||
155 | Return the NPType that corresponds to obj. |
||
156 | :param obj: a float compatible object. |
||
157 | :return: a Float type. |
||
158 | """ |
||
159 | return _get_type_of_number(Float, obj, { |
||
160 | numpy.float16: 16, |
||
161 | numpy.float32: 32, |
||
162 | numpy.float64: 64, |
||
163 | float: DEFAULT_FLOAT_BITS, |
||
164 | }) |
||
165 | |||
166 | |||
167 | # Library private. |
||
168 | def get_type_datetime64(_: Any) -> Type[Datetime64]: |
||
169 | """ |
||
170 | Return the NPType that corresponds to obj. |
||
171 | :param _: a datetime compatible object. |
||
172 | :return: a Datetime64 type. |
||
173 | """ |
||
174 | return Datetime64 |
||
175 | |||
176 | |||
177 | # Library private. |
||
178 | def get_type_timedelta64(_: Any) -> Type[Timedelta64]: |
||
179 | """ |
||
180 | Return the NPType that corresponds to obj. |
||
181 | :param _: a timedelta compatible object. |
||
182 | :return: a Timedelta64 type. |
||
183 | """ |
||
184 | return Timedelta64 |
||
185 | |||
186 | |||
187 | # Library private. |
||
188 | def get_type_complex(_: Any) -> Type[Complex128]: |
||
189 | """ |
||
190 | Return the NPType that corresponds to obj. |
||
191 | :param _: a complex128 compatible object. |
||
192 | :return: a Complex128 type. |
||
193 | """ |
||
194 | return Complex128 |
||
195 | |||
196 | |||
197 | # Library private. |
||
198 | def get_structured_type(dtype: numpy.dtype) -> Type[StructuredType]: |
||
199 | """ |
||
200 | Return the NPType that corresponds to dtype of a structured array. |
||
201 | :param dtype: a dtype of a structured NumPy array |
||
202 | :return: a StructuredType type. |
||
203 | """ |
||
204 | return StructuredType[dtype] |
||
205 | |||
206 | |||
207 | # Library private. |
||
208 | def get_subarray_type(dtype: numpy.dtype) -> Type[SubArrayType]: |
||
209 | """ |
||
210 | Return the NPType that corresponds to dtype of a subarray. |
||
211 | :param dtype: a dtype of a NumPy subarray |
||
212 | :return: a SubArrayType type. |
||
213 | """ |
||
214 | return SubArrayType[dtype] |
||
215 | |||
216 | |||
217 | _delegates = [ |
||
218 | (NPType, lambda x: x), |
||
219 | (type, _get_type_type), |
||
220 | (bool, get_type_bool), |
||
221 | (int, get_type_int), |
||
222 | (float, get_type_float), |
||
223 | (str, get_type_str), |
||
224 | (complex, get_type_complex), |
||
225 | (datetime, get_type_datetime64), |
||
226 | (timedelta, get_type_timedelta64), |
||
227 | (numpy.datetime64, get_type_datetime64), |
||
228 | (numpy.timedelta64, get_type_timedelta64), |
||
229 | (numpy.signedinteger, get_type_int), |
||
230 | (numpy.unsignedinteger, get_type_uint), |
||
231 | (numpy.floating, get_type_float), |
||
232 | (numpy.bool_, get_type_bool), |
||
233 | (numpy.dtype, _get_type_dtype), |
||
234 | (numpy.ndarray, _get_type_arrary), |
||
235 | (object, lambda _: Object), |
||
236 | ] |
||
237 |