| Conditions | 6 |
| Total Lines | 20 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 1 | ||
| Bugs | 0 | Features | 0 |
| 1 | #!/usr/bin/env python |
||
| 7 | def pad_sequence(batch, pad_value=0, output_mask=True, length=None): |
||
| 8 | if length: |
||
| 9 | max_len = length |
||
| 10 | else: |
||
| 11 | max_len = max(map(len, batch)) |
||
| 12 | mask = None |
||
| 13 | if output_mask: |
||
| 14 | mask = [] |
||
| 15 | for i in range(len(batch)): |
||
| 16 | mask.append([1] * len(batch[i]) + [0] * (max_len - len(batch[i]))) |
||
| 17 | mask = np.array(mask, dtype="float32") |
||
| 18 | if length: |
||
| 19 | new_batch = [] |
||
| 20 | for i in range(len(batch)): |
||
| 21 | new_row = list(batch[i]) + [pad_value] * (max_len - len(batch[i])) |
||
| 22 | new_batch.append(new_row) |
||
| 23 | new_batch = np.array(new_batch) |
||
| 24 | else: |
||
| 25 | new_batch = np.array(list(izip(*izip_longest(*batch, fillvalue=pad_value)))) |
||
| 26 | return new_batch, mask |