Conditions | 3 |
Total Lines | 89 |
Code Lines | 29 |
Lines | 0 |
Ratio | 0 % |
Changes | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
1 | """A small personal package created to store code and data I often reuse. |
||
74 | def train_test_split(*arrays, test_pct: float = 0.15, val_set: bool = False, val_pct: float = 0.15) -> Tuple[np.ndarray]: |
||
75 | """Splits arrays into train & test sets. |
||
76 | |||
77 | Splits arrays into train, test, and (optionally) validation sets using the supplied percentages. |
||
78 | |||
79 | :param *arrays: An arbitrary number of sequences to be split |
||
80 | into train, test, and (optionally) validation sets. Must |
||
81 | have at least one array. |
||
82 | :param test_pct: Float in the range ``[0,1]``. Percent of total |
||
83 | ``n`` values to include in test set. |
||
84 | |||
85 | The train set will have `1.0 - test_pct` pct of |
||
86 | values (or `1.0 - test_pct - val_pct` pct of values |
||
87 | if `val_set == True`). |
||
88 | |||
89 | :param val_set: Whether or not to return a validation set, |
||
90 | in addition to a test set. |
||
91 | |||
92 | :param val_pct: `float` in the range ``[0,1]``. Percent |
||
93 | of total n values to include in test set. |
||
94 | |||
95 | Ignored if ``val_set == False``. |
||
96 | |||
97 | The train set will have ``1.0 - test_pct - val_pct`` |
||
98 | pct of values. |
||
99 | |||
100 | :returns: splits tuple of numpy arrays. Input arrays |
||
101 | split into train, test, val sets. |
||
102 | |||
103 | If ``val_set == False``, ``len(splits) == 2 * len(arrays)``, |
||
104 | or if ``val_set == True``, ``len(splits) == 3 * len(arrays)``. |
||
105 | |||
106 | Example: |
||
107 | >>> x = np.arange(10) |
||
108 | >>> train_test_split(x) |
||
109 | (array([3, 9, 4, 2, 1, 0, 7, 5, 8]), array([6])) |
||
110 | |||
111 | >>> x = np.arange(10) |
||
112 | >>> y = x[::-1] |
||
113 | >>> x_train, x_test, y_train, y_test = train_test_split(x,y) |
||
114 | >>> x_train, x_test, y_train, y_test |
||
115 | (array([1, 3, 5, 8, 4, 7, 6, 9]), |
||
116 | array([0, 2]), |
||
117 | array([8, 6, 4, 1, 5, 2, 3, 0]), |
||
118 | array([9, 7])) |
||
119 | |||
120 | >>> train_test_split(x,test_pct=0.3,val_set=True,val_pct=0.2) |
||
121 | (array([0, 9, 5, 7, 6, 2, 8]), |
||
122 | array([1, 3, 4]), |
||
123 | array([3, 4])) |
||
124 | |||
125 | """ |
||
126 | # Perform input checks |
||
127 | assert arrays, "No arrays supplied" |
||
128 | lens = [len(a) for a in arrays] |
||
129 | assert len(set(lens)) == 1, "arrays have varying lengths" |
||
130 | assert lens[0] > 0, "supplied arrays have `len == 0`" |
||
131 | if val_set: |
||
132 | assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`" |
||
133 | assert 0.0 <= val_pct <= 1.0, "`val_pct` must be in the range `0.0 <= val_pct <= 1.0`" |
||
134 | assert test_pct + val_pct <= 1.0, "Can't have `test_pc + val_pct >= 1.0`" |
||
135 | else: |
||
136 | assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`" |
||
137 | assert test_pct <= 1.0, "Can't have `test_pc >= 1.0`" |
||
138 | # Calculate lengths |
||
139 | n = lens[0] |
||
140 | n_test = int(n * test_pct) |
||
141 | # Shuffle the indexes |
||
142 | indexes = np.arange(n) |
||
143 | np.random.shuffle(indexes) |
||
144 | # Split the data |
||
145 | if val_set: |
||
146 | n_val = int(n * val_pct) |
||
147 | n_train = n - n_test - n_val |
||
148 | splits = ( |
||
149 | ( |
||
150 | a[indexes[:n_train]], |
||
151 | a[indexes[n_train:n_train+n_test]], |
||
152 | a[indexes[-n_val:]] |
||
153 | ) |
||
154 | for a in map(np.asarray,arrays) |
||
155 | ) |
||
156 | else: |
||
157 | n_train = n - n_test |
||
158 | splits = ( |
||
159 | (a[indexes[:n_train]], a[indexes[n_train:]]) |
||
160 | for a in map(np.asarray,arrays) |
||
161 | ) |
||
162 | return tuple(it.chain(*splits)) |
||
163 | |||
237 |