Completed
Push — master ( 48255b...bf2b0c )
by Raphael
01:13
created

pad_sequence()   B

Complexity

Conditions 6

Size

Total Lines 20

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 6
c 1
b 0
f 0
dl 0
loc 20
rs 8
1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
import numpy as np
5
from itertools import izip, izip_longest
6
7
def pad_sequence(batch, pad_value=0, output_mask=True, length=None):
8
    if length:
9
        max_len = length
10
    else:
11
        max_len = max(map(len, batch))
12
    mask = None
13
    if output_mask:
14
        mask = []
15
        for i in range(len(batch)):
16
            mask.append([1] * len(batch[i]) + [0] * (max_len - len(batch[i])))
17
        mask = np.array(mask, dtype="float32")
18
    if length:
19
        new_batch = []
20
        for i in range(len(batch)):
21
            new_row = list(batch[i]) + [pad_value] * (max_len - len(batch[i]))
22
            new_batch.append(new_row)
23
        new_batch = np.array(new_batch)
24
    else:
25
        new_batch = np.array(list(izip(*izip_longest(*batch, fillvalue=pad_value))))
26
    return new_batch, mask