Conditions | 6 |
Total Lines | 20 |
Lines | 0 |
Ratio | 0 % |
Changes | 1 | ||
Bugs | 0 | Features | 0 |
1 | #!/usr/bin/env python |
||
7 | def pad_sequence(batch, pad_value=0, output_mask=True, length=None): |
||
8 | if length: |
||
9 | max_len = length |
||
10 | else: |
||
11 | max_len = max(map(len, batch)) |
||
12 | mask = None |
||
13 | if output_mask: |
||
14 | mask = [] |
||
15 | for i in range(len(batch)): |
||
16 | mask.append([1] * len(batch[i]) + [0] * (max_len - len(batch[i]))) |
||
17 | mask = np.array(mask, dtype="float32") |
||
18 | if length: |
||
19 | new_batch = [] |
||
20 | for i in range(len(batch)): |
||
21 | new_row = list(batch[i]) + [pad_value] * (max_len - len(batch[i])) |
||
22 | new_batch.append(new_row) |
||
23 | new_batch = np.array(new_batch) |
||
24 | else: |
||
25 | new_batch = np.array(list(izip(*izip_longest(*batch, fillvalue=pad_value)))) |
||
26 | return new_batch, mask |