1
|
|
|
from datetime import datetime, timedelta |
2
|
|
|
from typing import Any, Type, Dict |
3
|
|
|
|
4
|
|
|
import numpy |
5
|
|
|
from typish import ClsFunction |
6
|
|
|
|
7
|
|
|
from nptyping.functions._py_type import py_type |
8
|
|
|
from nptyping.types._bool import Bool |
9
|
|
|
from nptyping.types._datetime64 import Datetime64 |
10
|
|
|
from nptyping.types._ndarray import NDArray |
11
|
|
|
from nptyping.types._nptype import NPType |
12
|
|
|
from nptyping.types._number import ( |
13
|
|
|
Int, |
14
|
|
|
Float, |
15
|
|
|
UInt, |
16
|
|
|
Number, |
17
|
|
|
DEFAULT_INT_BITS, |
18
|
|
|
DEFAULT_FLOAT_BITS |
19
|
|
|
) |
20
|
|
|
from nptyping.types._object import Object |
21
|
|
|
from nptyping.types._timedelta64 import Timedelta64 |
22
|
|
|
from nptyping.types._unicode import Unicode |
23
|
|
|
|
24
|
|
|
|
25
|
|
|
def get_type(obj: Any) -> Type['NPType']: |
26
|
|
|
""" |
27
|
|
|
Return the nptyping type of the given obj. The given obj can be a numpy |
28
|
|
|
ndarray, a dtype or a Python type. If no corresponding nptyping type |
29
|
|
|
can be determined, a TypeError is raised. |
30
|
|
|
:param obj: the object for which an nptyping type is to be returned. |
31
|
|
|
:return: a subclass of NPType. |
32
|
|
|
""" |
33
|
|
|
return ClsFunction(_delegates)(obj) |
34
|
|
|
|
35
|
|
|
|
36
|
|
|
def _get_type_type(type_: type) -> Type['NPType']: |
37
|
|
|
# Return the nptyping type of a type. |
38
|
|
|
for super_type, delegate in _delegates: |
39
|
|
|
if issubclass(type_, super_type): |
40
|
|
|
break |
41
|
|
|
return delegate(type_) |
|
|
|
|
42
|
|
|
|
43
|
|
|
|
44
|
|
View Code Duplication |
def _get_type_dtype(dtype: numpy.dtype) -> Type['NPType']: |
|
|
|
|
45
|
|
|
# Return the nptyping type of a numpy dtype. |
46
|
|
|
np_type_per_py_type = { |
47
|
|
|
type: _get_type_type, |
48
|
|
|
bool: get_type_bool, |
49
|
|
|
int: get_type_int, |
50
|
|
|
float: get_type_float, |
51
|
|
|
str: get_type_str, |
52
|
|
|
datetime: get_type_datetime64, |
53
|
|
|
timedelta: get_type_timedelta64, |
54
|
|
|
object: lambda _: Object, |
55
|
|
|
} |
56
|
|
|
return np_type_per_py_type[(py_type(dtype))](dtype) |
57
|
|
|
|
58
|
|
|
|
59
|
|
|
def _get_type_arrary(arr: numpy.ndarray) -> Type['NPType']: |
60
|
|
|
# Return the nptyping type of a numpy array. |
61
|
|
|
type_ = get_type(arr.dtype) |
62
|
|
|
return NDArray[arr.shape, type_] |
63
|
|
|
|
64
|
|
|
|
65
|
|
View Code Duplication |
def _get_type_of_number( |
|
|
|
|
66
|
|
|
cls: Type['Number'], |
67
|
|
|
obj: Any, |
68
|
|
|
bits_per_type: Dict[type, int]) -> Type[Number]: |
69
|
|
|
# Return the nptyping Number type of the given obj using cls and |
70
|
|
|
# bits_per_type. |
71
|
|
|
bits = (bits_per_type.get(obj) |
72
|
|
|
or bits_per_type.get(getattr(obj, 'type', None)) |
73
|
|
|
or bits_per_type.get(type(obj))) |
74
|
|
|
|
75
|
|
|
if not bits: |
76
|
|
|
raise TypeError('Unsupported type {} for {}' |
77
|
|
|
.format(type(obj).__name__, cls)) |
78
|
|
|
|
79
|
|
|
return cls[bits] |
80
|
|
|
|
81
|
|
|
|
82
|
|
|
# Library private. |
83
|
|
|
def get_type_bool(_: Any) -> Type[Bool]: |
84
|
|
|
""" |
85
|
|
|
Return the NPType that corresponds to obj. |
86
|
|
|
:param _: a bool compatible object. |
87
|
|
|
:return: a Bool type. |
88
|
|
|
""" |
89
|
|
|
return Bool |
90
|
|
|
|
91
|
|
|
|
92
|
|
|
# Library private. |
93
|
|
View Code Duplication |
def get_type_str(obj: Any) -> Type[Unicode]: |
|
|
|
|
94
|
|
|
""" |
95
|
|
|
Return the NPType that corresponds to obj. |
96
|
|
|
:param obj: a string compatible object. |
97
|
|
|
:return: a Unicode type. |
98
|
|
|
""" |
99
|
|
|
if isinstance(obj, numpy.dtype): |
100
|
|
|
return Unicode[obj.itemsize / 4] |
101
|
|
|
if obj == str: |
102
|
|
|
return Unicode |
103
|
|
|
if not isinstance(obj, str): |
104
|
|
|
raise TypeError('Unsupported type {}'.format(type(obj))) |
105
|
|
|
return Unicode[len(obj)] |
106
|
|
|
|
107
|
|
|
|
108
|
|
|
# Library private. |
109
|
|
|
def get_type_int(obj: Any) -> Type[Int]: |
110
|
|
|
""" |
111
|
|
|
Return the NPType that corresponds to obj. |
112
|
|
|
:param obj: an int compatible object. |
113
|
|
|
:return: a Int type. |
114
|
|
|
""" |
115
|
|
|
return _get_type_of_number(Int, obj, { |
116
|
|
|
numpy.int8: 8, |
117
|
|
|
numpy.int16: 16, |
118
|
|
|
numpy.int32: 32, |
119
|
|
|
numpy.int64: 64, |
120
|
|
|
int: DEFAULT_INT_BITS, |
121
|
|
|
}) |
122
|
|
|
|
123
|
|
|
|
124
|
|
|
# Library private. |
125
|
|
|
def get_type_uint(obj: Any) -> Type[UInt]: |
126
|
|
|
""" |
127
|
|
|
Return the NPType that corresponds to obj. |
128
|
|
|
:param obj: an uint compatible object. |
129
|
|
|
:return: an UInt type. |
130
|
|
|
""" |
131
|
|
|
return _get_type_of_number(UInt, obj, { |
132
|
|
|
numpy.uint8: 8, |
133
|
|
|
numpy.uint16: 16, |
134
|
|
|
numpy.uint32: 32, |
135
|
|
|
numpy.uint64: 64, |
136
|
|
|
int: DEFAULT_INT_BITS, |
137
|
|
|
}) |
138
|
|
|
|
139
|
|
|
|
140
|
|
|
# Library private. |
141
|
|
|
def get_type_float(obj: Any) -> Type[Float]: |
142
|
|
|
""" |
143
|
|
|
Return the NPType that corresponds to obj. |
144
|
|
|
:param obj: a float compatible object. |
145
|
|
|
:return: a Float type. |
146
|
|
|
""" |
147
|
|
|
return _get_type_of_number(Float, obj, { |
148
|
|
|
numpy.float16: 16, |
149
|
|
|
numpy.float32: 32, |
150
|
|
|
numpy.float64: 64, |
151
|
|
|
float: DEFAULT_FLOAT_BITS, |
152
|
|
|
}) |
153
|
|
|
|
154
|
|
|
|
155
|
|
|
# Library private. |
156
|
|
|
def get_type_datetime64(_: Any) -> Type[Datetime64]: |
157
|
|
|
""" |
158
|
|
|
Return the NPType that corresponds to obj. |
159
|
|
|
:param _: a datetime compatible object. |
160
|
|
|
:return: a Datetime64 type. |
161
|
|
|
""" |
162
|
|
|
return Datetime64 |
163
|
|
|
|
164
|
|
|
|
165
|
|
|
# Library private. |
166
|
|
|
def get_type_timedelta64(_: Any) -> Type[Timedelta64]: |
167
|
|
|
""" |
168
|
|
|
Return the NPType that corresponds to obj. |
169
|
|
|
:param _: a timedelta compatible object. |
170
|
|
|
:return: a Timedelta64 type. |
171
|
|
|
""" |
172
|
|
|
return Timedelta64 |
173
|
|
|
|
174
|
|
|
|
175
|
|
|
_delegates = [ |
176
|
|
|
(NPType, lambda x: x), |
177
|
|
|
(type, _get_type_type), |
178
|
|
|
(bool, get_type_bool), |
179
|
|
|
(int, get_type_int), |
180
|
|
|
(float, get_type_float), |
181
|
|
|
(str, get_type_str), |
182
|
|
|
(datetime, get_type_datetime64), |
183
|
|
|
(timedelta, get_type_timedelta64), |
184
|
|
|
(numpy.datetime64, get_type_datetime64), |
185
|
|
|
(numpy.timedelta64, get_type_timedelta64), |
186
|
|
|
(numpy.signedinteger, get_type_int), |
187
|
|
|
(numpy.unsignedinteger, get_type_uint), |
188
|
|
|
(numpy.floating, get_type_float), |
189
|
|
|
(numpy.bool_, get_type_bool), |
190
|
|
|
(numpy.dtype, _get_type_dtype), |
191
|
|
|
(numpy.ndarray, _get_type_arrary), |
192
|
|
|
(object, lambda _: Object), |
193
|
|
|
] |
194
|
|
|
|