| Conditions | 3 |
| Total Lines | 89 |
| Code Lines | 29 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 0 | ||
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
| 1 | """A small personal package created to store code and data I often reuse. |
||
| 74 | def train_test_split(*arrays, test_pct: float = 0.15, val_set: bool = False, val_pct: float = 0.15) -> Tuple[np.ndarray]: |
||
| 75 | """Splits arrays into train & test sets. |
||
| 76 | |||
| 77 | Splits arrays into train, test, and (optionally) validation sets using the supplied percentages. |
||
| 78 | |||
| 79 | :param *arrays: An arbitrary number of sequences to be split |
||
| 80 | into train, test, and (optionally) validation sets. Must |
||
| 81 | have at least one array. |
||
| 82 | :param test_pct: Float in the range ``[0,1]``. Percent of total |
||
| 83 | ``n`` values to include in test set. |
||
| 84 | |||
| 85 | The train set will have `1.0 - test_pct` pct of |
||
| 86 | values (or `1.0 - test_pct - val_pct` pct of values |
||
| 87 | if `val_set == True`). |
||
| 88 | |||
| 89 | :param val_set: Whether or not to return a validation set, |
||
| 90 | in addition to a test set. |
||
| 91 | |||
| 92 | :param val_pct: `float` in the range ``[0,1]``. Percent |
||
| 93 | of total n values to include in test set. |
||
| 94 | |||
| 95 | Ignored if ``val_set == False``. |
||
| 96 | |||
| 97 | The train set will have ``1.0 - test_pct - val_pct`` |
||
| 98 | pct of values. |
||
| 99 | |||
| 100 | :returns: splits tuple of numpy arrays. Input arrays |
||
| 101 | split into train, test, val sets. |
||
| 102 | |||
| 103 | If ``val_set == False``, ``len(splits) == 2 * len(arrays)``, |
||
| 104 | or if ``val_set == True``, ``len(splits) == 3 * len(arrays)``. |
||
| 105 | |||
| 106 | Example: |
||
| 107 | >>> x = np.arange(10) |
||
| 108 | >>> train_test_split(x) |
||
| 109 | (array([3, 9, 4, 2, 1, 0, 7, 5, 8]), array([6])) |
||
| 110 | |||
| 111 | >>> x = np.arange(10) |
||
| 112 | >>> y = x[::-1] |
||
| 113 | >>> x_train, x_test, y_train, y_test = train_test_split(x,y) |
||
| 114 | >>> x_train, x_test, y_train, y_test |
||
| 115 | (array([1, 3, 5, 8, 4, 7, 6, 9]), |
||
| 116 | array([0, 2]), |
||
| 117 | array([8, 6, 4, 1, 5, 2, 3, 0]), |
||
| 118 | array([9, 7])) |
||
| 119 | |||
| 120 | >>> train_test_split(x,test_pct=0.3,val_set=True,val_pct=0.2) |
||
| 121 | (array([0, 9, 5, 7, 6, 2, 8]), |
||
| 122 | array([1, 3, 4]), |
||
| 123 | array([3, 4])) |
||
| 124 | |||
| 125 | """ |
||
| 126 | # Perform input checks |
||
| 127 | assert arrays, "No arrays supplied" |
||
| 128 | lens = [len(a) for a in arrays] |
||
| 129 | assert len(set(lens)) == 1, "arrays have varying lengths" |
||
| 130 | assert lens[0] > 0, "supplied arrays have `len == 0`" |
||
| 131 | if val_set: |
||
| 132 | assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`" |
||
| 133 | assert 0.0 <= val_pct <= 1.0, "`val_pct` must be in the range `0.0 <= val_pct <= 1.0`" |
||
| 134 | assert test_pct + val_pct <= 1.0, "Can't have `test_pc + val_pct >= 1.0`" |
||
| 135 | else: |
||
| 136 | assert 0.0 <= test_pct <= 1.0, "`test_pct` must be in the range `0.0 <= test_pct <= 1.0`" |
||
| 137 | assert test_pct <= 1.0, "Can't have `test_pc >= 1.0`" |
||
| 138 | # Calculate lengths |
||
| 139 | n = lens[0] |
||
| 140 | n_test = int(n * test_pct) |
||
| 141 | # Shuffle the indexes |
||
| 142 | indexes = np.arange(n) |
||
| 143 | np.random.shuffle(indexes) |
||
| 144 | # Split the data |
||
| 145 | if val_set: |
||
| 146 | n_val = int(n * val_pct) |
||
| 147 | n_train = n - n_test - n_val |
||
| 148 | splits = ( |
||
| 149 | ( |
||
| 150 | a[indexes[:n_train]], |
||
| 151 | a[indexes[n_train:n_train+n_test]], |
||
| 152 | a[indexes[-n_val:]] |
||
| 153 | ) |
||
| 154 | for a in map(np.asarray,arrays) |
||
| 155 | ) |
||
| 156 | else: |
||
| 157 | n_train = n - n_test |
||
| 158 | splits = ( |
||
| 159 | (a[indexes[:n_train]], a[indexes[n_train:]]) |
||
| 160 | for a in map(np.asarray,arrays) |
||
| 161 | ) |
||
| 162 | return tuple(it.chain(*splits)) |
||
| 163 | |||
| 237 |