apoor.ibuff()   B
last analyzed

Complexity

Conditions 6

Size

Total Lines 43
Code Lines 13

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 13
dl 0
loc 43
rs 8.6666
c 0
b 0
f 0
cc 6
nop 2
1
"""A small personal package created to store code and data I often reuse. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
2
3
I'll continue to update it with useful functions that I find myself reusing. The `apoor.data` module has some common datasets and functions for reading them in as pandas DataFrames.
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (181/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
4
"""
5
6
# Version string
7
__version__ = "1.3.2"
8
9
10
11
import itertools as it
12
import numpy as np
13
from . import data
14
15
from typing import Any, List, Callable, Tuple, Iterable
0 ignored issues
show
introduced by
standard import "from typing import Any, List, Callable, Tuple, Iterable" should be placed before "import numpy as np"
Loading history...
16
17
18
def fdir(o: Any) -> List[str]:
0 ignored issues
show
Coding Style Naming introduced by
Argument name "o" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
19
    """Filtered dir(). Same as builtin dir()
20
    function without private attributes.
21
22
    :param o: Object being inspected
23
    :return: "Public attributes" of o
24
    """
25
    return [a for a in dir(o) if a[0] != "_"]
26
27
def set_seed(n:int):
0 ignored issues
show
Coding Style introduced by
Exactly one space required after :
Loading history...
Coding Style Naming introduced by
Argument name "n" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
28
    """Sets numpy's random seed.
29
30
    :param n: The value used to set numpy's random seed.
31
    :type n: int
32
    """
33
    np.random.seed(n)
34
35
36
def make_scale(dmin:float,dmax:float,rmin:float,rmax:float,clamp:bool=False) -> Callable[[float],float]:
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (104/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
Coding Style introduced by
Exactly one space required after :
Loading history...
Coding Style introduced by
Exactly one space required after comma
Loading history...
Coding Style introduced by
Exactly one space required around keyword argument assignment
Loading history...
37
    """Scale function factory.
38
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
39
    Creates a scale function to map a number from a domain to a range.
40
41
    :param dmin: Domain's start value
42
    :param dmax: Domain's end value
43
    :param rmin: Range's start value
44
    :param rmax: Range's end value
45
    :param clamp: If the result is outside the range, return 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
46
        clamped value (default: False)
47
    :return: A scale function taking one numeric argument and 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
48
        returns the value mapped from the domain to the range 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
49
        (and clamped if `clamp` flag is set).
50
51
    Examples:
52
        >>> s = make_scale(0,1,0,10)
53
        >>> s(0.1)
54
        1.0
55
56
        >>> s = make_scale(0,10,10,0)
57
        >>> s(1.0)
58
        9.0
59
60
        >>> s = make_scale(0,1,0,1,clamp=True)
61
        >>> s(100)
62
        1.0
63
    """
64
    drange = dmax - dmin
65
    rrange = rmax - rmin
66
    scale_factor = rrange / drange
67
    def scale(n):
0 ignored issues
show
Coding Style Naming introduced by
Argument name "n" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
68
        n_ = (n - dmin) * scale_factor + rmin
0 ignored issues
show
Coding Style Naming introduced by
Variable name "n_" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
69
        if clamp: return min(max(n_,rmin),rmax)
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
unused-code introduced by
Unnecessary "else" after "return"
Loading history...
Coding Style introduced by
More than one statement on a single line
Loading history...
70
        else: return n_ 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
71
    return scale
72
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
73
74
def train_test_split(*arrays, test_pct: float = 0.15, val_set: bool = False, val_pct: float = 0.15) -> Tuple[np.ndarray]:
0 ignored issues
show
Coding Style introduced by
This line is too long as per the coding-style (121/100).

This check looks for lines that are too long. You can specify the maximum line length.

Loading history...
75
    """Splits arrays into train & test sets.
76
77
    Splits arrays into train, test, and (optionally) validation sets using the supplied percentages.
78
79
    :param *arrays: An arbitrary number of sequences to be split
80
        into train, test, and (optionally) validation sets. Must 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
81
        have at least one array.
82
    :param test_pct: Float in the range ``[0,1]``. Percent of total 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
83
        ``n`` values to include in test set.
84
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
85
        The train set will have `1.0 - test_pct` pct of
86
        values (or `1.0 - test_pct - val_pct` pct of values
87
        if `val_set == True`).
88
89
    :param val_set: Whether or not to return a validation set,
90
        in addition to a test set.
91
92
    :param val_pct: `float` in the range ``[0,1]``. Percent 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
93
        of total n values to include in test set.
94
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
95
        Ignored if ``val_set == False``.
96
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
97
        The train set will have ``1.0 - test_pct - val_pct`` 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
98
        pct of values.
99
100
    :returns: splits tuple of numpy arrays. Input arrays 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
101
        split into train, test, val sets.
102
        
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
103
        If ``val_set == False``, ``len(splits) == 2 * len(arrays)``,
104
        or if ``val_set == True``, ``len(splits) == 3 * len(arrays)``.
105
106
    Example:
107
        >>> x = np.arange(10)
108
        >>> train_test_split(x)
109
        (array([3, 9, 4, 2, 1, 0, 7, 5, 8]), array([6]))
110
111
        >>> x = np.arange(10)
112
        >>> y = x[::-1]
113
        >>> x_train, x_test, y_train, y_test = train_test_split(x,y)
114
        >>> x_train, x_test, y_train, y_test
115
        (array([1, 3, 5, 8, 4, 7, 6, 9]),
116
         array([0, 2]),
117
         array([8, 6, 4, 1, 5, 2, 3, 0]),
118
         array([9, 7]))
119
120
        >>> train_test_split(x,test_pct=0.3,val_set=True,val_pct=0.2)
121
        (array([0, 9, 5, 7, 6, 2, 8]), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
122
         array([1, 3, 4]), 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
123
         array([3, 4]))
124
125
    """
126
    # Perform input checks
127
    assert arrays, "No arrays supplied"
128
    lens = [len(a) for a in arrays]
129
    assert len(set(lens)) == 1, "arrays have varying lengths"
130
    assert lens[0] > 0, "supplied arrays have `len == 0`"
131
    if val_set:
132
        assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`"
133
        assert 0.0 <= val_pct <= 1.0, "`val_pct` must be in the range `0.0 <= val_pct <= 1.0`"
134
        assert test_pct + val_pct <= 1.0, "Can't have `test_pc + val_pct >= 1.0`"
135
    else:
136
        assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`"
137
        assert test_pct <= 1.0, "Can't have `test_pc >= 1.0`"
138
    # Calculate lengths
139
    n = lens[0]
0 ignored issues
show
Coding Style Naming introduced by
Variable name "n" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
140
    n_test = int(n * test_pct)
141
    # Shuffle the indexes
142
    indexes = np.arange(n)
143
    np.random.shuffle(indexes)
144
    # Split the data
145
    if val_set:
146
        n_val = int(n * val_pct)
147
        n_train = n - n_test - n_val
148
        splits = (
149
            (
150
                a[indexes[:n_train]], 
0 ignored issues
show
Comprehensibility Best Practice introduced by
The variable a does not seem to be defined.
Loading history...
Coding Style introduced by
Trailing whitespace
Loading history...
151
                a[indexes[n_train:n_train+n_test]], 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
152
                a[indexes[-n_val:]]
153
            )
154
            for a in map(np.asarray,arrays)
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
155
        )
156
    else:
157
        n_train = n - n_test
158
        splits = (
159
            (a[indexes[:n_train]], a[indexes[n_train:]])
160
            for a in map(np.asarray,arrays)
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
161
        )
162
    return tuple(it.chain(*splits))
163
164
165
def to_onehot(y: np.ndarray, num_classes: int = None, dtype="float32") -> np.ndarray:
0 ignored issues
show
Coding Style Naming introduced by
Argument name "y" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
166
    """Expands a 1D categorical vector to
167
    a 2D, onehot-encoded categorical matrix.
168
169
    :param y: 1D categorical vector
170
    :param num_classes: Number of categories in (and width of) 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
171
        the output matrix. 
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
172
173
        If ``num_classes`` is ``None``, setsto ``max(y) + 1``.
174
    :param dtype: Data type of output matrix
175
    :returns: 2D one-hot encoded category matrix
176
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
177
    Examples:
178
        >>> data = np.array([0,2,1,3])
179
        >>> apoor.to_onehot(data)
180
        array([[1., 0., 0., 0.],
181
               [0., 0., 1., 0.],
182
               [0., 1., 0., 0.],
183
               [0., 0., 0., 1.]])
184
185
    """
186
    if num_classes is None:
187
        num_classes = np.max(y) + 1
188
    return np.identity(num_classes,dtype=dtype)[y]
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
189
190
191
def ibuff(itr: Iterable, bsize: int = 1) -> Iterable[List]:
192
    """Creates an iterable that yields elements
193
    from ``itr`` grouped into lists of size ``bsize``.
194
195
    If ``itr`` can't evenly be grouped into lists of size
196
    ``bsize``, the final list will have the remaining
197
    elements.
198
199
    :param itr: The interable to be buffered.
200
    :param bsize: Positive integer, representing the number of
201
        values from ``itr`` to be yielded together.
202
203
        The final list yielded may not be of size ``bsize`` if
204
        ``len(itr)`` doesn't evenly divide into groups of ``bsize``.
205
    :yields: Buffered elements from ``itr``, grouped into lists
206
        of size up to ``bsize``.
207
    :raises TypeError: If ``bsize`` isn't an integer.
208
    :raises ValueError: If ``bsize`` isn't positive.
209
    
0 ignored issues
show
Coding Style introduced by
Trailing whitespace
Loading history...
210
    Examples:
211
        >>> for b in apoor.ibuff(range(10),3):
212
        ...     print(b)
213
        [0, 1, 2]
214
        [3, 4, 5]
215
        [6, 7, 8]
216
        [9]
217
    """
218
    # Perform checks
219
    if not isinstance(bsize,int):
0 ignored issues
show
Coding Style introduced by
Exactly one space required after comma
Loading history...
220
        raise TypeError("bsize needs to be a positive integer.")
221
    if bsize < 1:
222
        raise ValueError("bsize needs to be a positive integer.")
223
    # Initialize the buffer
224
    buff = []
225
    for v in itr:
0 ignored issues
show
Coding Style Naming introduced by
Variable name "v" doesn't conform to snake_case naming style ('([^\\W\\dA-Z][^\\WA-Z]2,|_[^\\WA-Z]*|__[^\\WA-Z\\d_][^\\WA-Z]+__)$' pattern)

This check looks for invalid names for a range of different identifiers.

You can set regular expressions to which the identifiers must conform if the defaults do not match your requirements.

If your project includes a Pylint configuration file, the settings contained in that file take precedence.

To find out more about Pylint, please refer to their site.

Loading history...
226
        if len(buff) < bsize: # If buff not full, append
227
            buff.append(v)
228
        else:                 # Otherwise yield and reinit
229
            yield buff
230
            buff = [v]
231
    # Check if there's anything left in the buffer
232
    if len(buff) > 0:
233
        yield buff
234
235
236
0 ignored issues
show
coding-style introduced by
Trailing newlines
Loading history...
237