| 
                    1
                 | 
                                    
                                                     | 
                
                 | 
                import math  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    2
                 | 
                                    
                                                     | 
                
                 | 
                import random  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    3
                 | 
                                    
                                                     | 
                
                 | 
                import logging  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    4
                 | 
                                    
                                                     | 
                
                 | 
                import collections  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    5
                 | 
                                    
                                                     | 
                
                 | 
                import numpy as np  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    6
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    7
                 | 
                                    
                                                     | 
                
                 | 
                logger = logging.getLogger(__name__)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    8
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    9
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    10
                 | 
                                    
                                                     | 
                
                 | 
                class BatchInjector:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    11
                 | 
                                    
                                                     | 
                
                 | 
                    """Retrieving dataset values in batches  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    12
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    13
                 | 
                                    
                                                     | 
                
                 | 
                    Args:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    14
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`numpy.ndarray`): Input feature array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    15
                 | 
                                    
                                                     | 
                
                 | 
                        data_y (:obj:`numpy.ndarray`): Input label array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    16
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Batch size.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    17
                 | 
                                    
                                                     | 
                
                 | 
                        num_batches (:obj:`int`): The number of batches in the input data.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    18
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    19
                 | 
                                    
                                                     | 
                
                 | 
                    Attributes:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    20
                 | 
                                    
                                                     | 
                
                 | 
                        size (:obj:`int`): Number of input vectors.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    21
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Batch size.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    22
                 | 
                                    
                                                     | 
                
                 | 
                        num_batches (:obj:`int`): Number of batches in the input data.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    23
                 | 
                                    
                                                     | 
                
                 | 
                        num_epochs (:obj:`int`): Number of epoch of current iteration.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    24
                 | 
                                    
                                                     | 
                
                 | 
                        cur_batch (:obj:`int`): Current batch index.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    25
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`numpy.ndarray`): Reference to input feature array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    26
                 | 
                                    
                                                     | 
                
                 | 
                        data_y (:obj:`numpy.ndarray`): Reference to input label array.s  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    27
                 | 
                                    
                                                     | 
                
                 | 
                    """  | 
            
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 
                    28
                 | 
                                    
                                                     | 
                
                View Code Duplication | 
                    def __init__(self, data_x, data_y=None, batch_size=-1, num_batches=-1):  | 
            
                            
                    | 
                        
                     | 
                     | 
                     | 
                    
                                                                                                    
                        
                         
                                                                                        
                                                                                     
                     | 
                
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    29
                 | 
                                    
                                                     | 
                
                 | 
                        self.size = data_x.shape[0]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    30
                 | 
                                    
                                                     | 
                
                 | 
                        if 0 < batch_size <= self.size:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    31
                 | 
                                    
                                                     | 
                
                 | 
                            self.batch_size = batch_size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    32
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_batches = math.floor(self.size / self.batch_size)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    33
                 | 
                                    
                                                     | 
                
                 | 
                        elif num_batches > 0:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    34
                 | 
                                    
                                                     | 
                
                 | 
                            self.batch_size = math.floor(self.size / num_batches)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    35
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_batches = num_batches  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    36
                 | 
                                    
                                                     | 
                
                 | 
                        else:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    37
                 | 
                                    
                                                     | 
                
                 | 
                            raise ValueError('Invalid batch_size or num_batches.') | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    38
                 | 
                                    
                                                     | 
                
                 | 
                        self.num_epochs = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    39
                 | 
                                    
                                                     | 
                
                 | 
                        self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    40
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_x = data_x  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    41
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_y = data_y  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    42
                 | 
                                    
                                                     | 
                
                 | 
                        if data_y is not None:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    43
                 | 
                                    
                                                     | 
                
                 | 
                            if self.data_x.shape[0] != self.data_y.shape[0]:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    44
                 | 
                                    
                                                     | 
                
                 | 
                                raise ValueError('data_x, data_y provided have different number of rows.') | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    45
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    46
                 | 
                                    
                                                     | 
                
                 | 
                    def next_batch(self):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    47
                 | 
                                    
                                                     | 
                
                 | 
                        """Get Next Batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    48
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    49
                 | 
                                    
                                                     | 
                
                 | 
                        if self.cur_batch == self.num_batches - 1:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    50
                 | 
                                    
                                                     | 
                
                 | 
                            start = self.batch_size * self.cur_batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    51
                 | 
                                    
                                                     | 
                
                 | 
                            end = self.size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    52
                 | 
                                    
                                                     | 
                
                 | 
                            self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    53
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_epochs += 1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    54
                 | 
                                    
                                                     | 
                
                 | 
                        else:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    55
                 | 
                                    
                                                     | 
                
                 | 
                            start = self.batch_size * self.cur_batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    56
                 | 
                                    
                                                     | 
                
                 | 
                            end = start + self.batch_size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    57
                 | 
                                    
                                                     | 
                
                 | 
                            self.cur_batch += 1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    58
                 | 
                                    
                                                     | 
                
                 | 
                        if self.data_y is None:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    59
                 | 
                                    
                                                     | 
                
                 | 
                            return self.data_x[start:end, :]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    60
                 | 
                                    
                                                     | 
                
                 | 
                        else:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    61
                 | 
                                    
                                                     | 
                
                 | 
                            return self.data_x[start:end, :], self.data_y[start:end, :]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    62
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    63
                 | 
                                    
                                                     | 
                
                 | 
                    def reset(self):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    64
                 | 
                                    
                                                     | 
                
                 | 
                        """Reset all counters  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    65
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    66
                 | 
                                    
                                                     | 
                
                 | 
                        self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    67
                 | 
                                    
                                                     | 
                
                 | 
                        self.num_epochs = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    68
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    69
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    70
                 | 
                                    
                                                     | 
                
                 | 
                class BatchSequenceInjector:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    71
                 | 
                                    
                                                     | 
                
                 | 
                    """Retrieving dataset values in batches and form a sequence of events  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    72
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    73
                 | 
                                    
                                                     | 
                
                 | 
                    Args:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    74
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`numpy.ndarray`): Input feature array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    75
                 | 
                                    
                                                     | 
                
                 | 
                        data_y (:obj:`numpy.ndarray`): Input label array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    76
                 | 
                                    
                                                     | 
                
                 | 
                        seq_len (:obj:`int`): Length of sequence.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    77
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Batch size.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    78
                 | 
                                    
                                                     | 
                
                 | 
                        num_batches (:obj:`int`): The number of batches in the input data.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    79
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    80
                 | 
                                    
                                                     | 
                
                 | 
                    Attributes:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    81
                 | 
                                    
                                                     | 
                
                 | 
                        length (:obj:`int`): Length of sequence.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    82
                 | 
                                    
                                                     | 
                
                 | 
                        size (:obj:`int`): Number of input vectors.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    83
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Batch size.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    84
                 | 
                                    
                                                     | 
                
                 | 
                        num_batches (:obj:`int`): Number of batches in the input data.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    85
                 | 
                                    
                                                     | 
                
                 | 
                        num_epochs (:obj:`int`): Number of epoch of current iteration.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    86
                 | 
                                    
                                                     | 
                
                 | 
                        cur_batch (:obj:`int`): Current batch index.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    87
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`numpy.ndarray`): Reference to input feature array.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    88
                 | 
                                    
                                                     | 
                
                 | 
                        data_y (:obj:`numpy.ndarray`): Reference to input label array.s  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    89
                 | 
                                    
                                                     | 
                
                 | 
                    """  | 
            
            
                                                                                                            
                            
            
                                                                    
                                                                                                        
            
            
                | 
                    90
                 | 
                                    
                                                     | 
                
                View Code Duplication | 
                    def __init__(self, data_x, data_y=None, length=100, batch_size=-1, num_batches=-1, with_seq=False):  | 
            
                            
                    | 
                        
                     | 
                     | 
                     | 
                    
                                                                                                    
                        
                         
                                                                                        
                                                                                     
                     | 
                
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    91
                 | 
                                    
                                                     | 
                
                 | 
                        self.with_seq = with_seq  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    92
                 | 
                                    
                                                     | 
                
                 | 
                        self.length = length  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    93
                 | 
                                    
                                                     | 
                
                 | 
                        self.size = data_x.shape[0] - length  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    94
                 | 
                                    
                                                     | 
                
                 | 
                        if 0 < batch_size <= self.size:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    95
                 | 
                                    
                                                     | 
                
                 | 
                            self.batch_size = batch_size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    96
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_batches = math.floor(self.size / self.batch_size)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    97
                 | 
                                    
                                                     | 
                
                 | 
                        elif num_batches > 0:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    98
                 | 
                                    
                                                     | 
                
                 | 
                            self.batch_size = math.floor(self.size / num_batches)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    99
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_batches = num_batches  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    100
                 | 
                                    
                                                     | 
                
                 | 
                        else:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    101
                 | 
                                    
                                                     | 
                
                 | 
                            raise ValueError('Invalid batch_size or num_batches.') | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    102
                 | 
                                    
                                                     | 
                
                 | 
                        self.num_epochs = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    103
                 | 
                                    
                                                     | 
                
                 | 
                        self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    104
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_x = data_x  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    105
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_y = data_y  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    106
                 | 
                                    
                                                     | 
                
                 | 
                        if data_y is not None:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    107
                 | 
                                    
                                                     | 
                
                 | 
                            if self.data_x.shape[0] != self.data_y.shape[0]:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    108
                 | 
                                    
                                                     | 
                
                 | 
                                raise ValueError('data_x, data_y provided have different number of rows.') | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    109
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    110
                 | 
                                    
                                                     | 
                
                 | 
                    def next_batch(self, skip=1):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    111
                 | 
                                    
                                                     | 
                
                 | 
                        """Get Next Batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    112
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    113
                 | 
                                    
                                                     | 
                
                 | 
                        self.cur_batch += skip-1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    114
                 | 
                                    
                                                     | 
                
                 | 
                        if self.cur_batch > self.num_batches - 1:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    115
                 | 
                                    
                                                     | 
                
                 | 
                            self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    116
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_epochs += 1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    117
                 | 
                                    
                                                     | 
                
                 | 
                        if self.cur_batch == self.num_batches - 1:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    118
                 | 
                                    
                                                     | 
                
                 | 
                            start = self.batch_size * self.cur_batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    119
                 | 
                                    
                                                     | 
                
                 | 
                            end = self.size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    120
                 | 
                                    
                                                     | 
                
                 | 
                            self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    121
                 | 
                                    
                                                     | 
                
                 | 
                            self.num_epochs += 1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    122
                 | 
                                    
                                                     | 
                
                 | 
                        else:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    123
                 | 
                                    
                                                     | 
                
                 | 
                            start = self.batch_size * self.cur_batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    124
                 | 
                                    
                                                     | 
                
                 | 
                            end = start + self.batch_size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    125
                 | 
                                    
                                                     | 
                
                 | 
                            self.cur_batch += 1  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    126
                 | 
                                    
                                                     | 
                
                 | 
                        return self.to_sequence(self.length, self.data_x, self.data_y, start, end, with_seq=self.with_seq)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    127
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    128
                 | 
                                    
                                                     | 
                
                 | 
                    def reset(self):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    129
                 | 
                                    
                                                     | 
                
                 | 
                        """Reset all counters  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    130
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    131
                 | 
                                    
                                                     | 
                
                 | 
                        self.cur_batch = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    132
                 | 
                                    
                                                     | 
                
                 | 
                        self.num_epochs = 0  | 
            
            
                                                                                                            
                                                                
            
                                    
            
            
                | 
                    133
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    134
                 | 
                                    
                                                     | 
                
                 | 
                    @staticmethod  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    135
                 | 
                                    
                                                     | 
                
                 | 
                    def to_sequence(length, x, y=None, start=None, end=None, with_seq=False):  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    136
                 | 
                                    
                                                     | 
                
                 | 
                        """Turn feature array as a sequence array where each new feature contains seq_len number of original features.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    137
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    138
                 | 
                                    
                                                     | 
                
                 | 
                        Args:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    139
                 | 
                                    
                                                     | 
                
                 | 
                            length (:obj:`int`): Length of the sequence.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    140
                 | 
                                    
                                                     | 
                
                 | 
                            x (:obj:`numpy.ndarray`): Feature array, with shape (num_samples, num_features).  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    141
                 | 
                                    
                                                     | 
                
                 | 
                            y (:obj:`numpy.ndarray`): Label array, with shape (num_samples. num_classes).  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    142
                 | 
                                    
                                                     | 
                
                 | 
                            start (:obj:`int`): Start index.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    143
                 | 
                                    
                                                     | 
                
                 | 
                            end (:obj:`int`): End index  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    144
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    145
                 | 
                                    
                                                     | 
                
                 | 
                        Returns:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    146
                 | 
                                    
                                                     | 
                
                 | 
                            (seq_x, seq_y) if y is provided, or seq_x if y is not provided.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    147
                 | 
                                    
                                                     | 
                
                 | 
                            seq_x is a numpy array of shape (num_samples, seq_len, num_features), and seq_y is a numpy array  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    148
                 | 
                                    
                                                     | 
                
                 | 
                            of shape (num_samples, num_classes).  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    149
                 | 
                                    
                                                     | 
                
                 | 
                            num_samples is bounded by the value of start and end.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    150
                 | 
                                    
                                                     | 
                
                 | 
                            If start or end are not specified, the code will use the full data provided, so that the  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    151
                 | 
                                    
                                                     | 
                
                 | 
                            array returned has (num_samples - seq_len) of samples.  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    152
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    153
                 | 
                                    
                                                     | 
                
                 | 
                        if start is None or end is None:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    154
                 | 
                                    
                                                     | 
                
                 | 
                            start = 0  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    155
                 | 
                                    
                                                     | 
                
                 | 
                            end = x.shape[0] - length  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    156
                 | 
                                    
                                                     | 
                
                 | 
                        if (start+length) > x.shape[0] or (end+length) > x.shape[0]:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    157
                 | 
                                    
                                                     | 
                
                 | 
                            logger.error('start/end out of bound.') | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    158
                 | 
                                    
                                                     | 
                
                 | 
                            return None  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    159
                 | 
                                    
                                                     | 
                
                 | 
                        batch_x = np.zeros((end - start, length, x.shape[1]), np.float32)  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    160
                 | 
                                    
                                                     | 
                
                 | 
                        for i in range(start, end):  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    161
                 | 
                                    
                                                     | 
                
                 | 
                            batch_x[i-start, :, :] = x[i:i + length, :]  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    162
                 | 
                                    
                                                     | 
                
                 | 
                        return_tuple = tuple([batch_x])  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    163
                 | 
                                    
                                                     | 
                
                 | 
                        if y is not None:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    164
                 | 
                                    
                                                     | 
                
                 | 
                            batch_y = np.zeros((end - start, length, y.shape[1]), np.float32)  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    165
                 | 
                                    
                                                     | 
                
                 | 
                            for i in range(start, end):  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    166
                 | 
                                    
                                                     | 
                
                 | 
                                batch_y[i-start, :, :] = y[i:i + length, :]  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    167
                 | 
                                    
                                                     | 
                
                 | 
                            return_tuple += tuple([batch_y])  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    168
                 | 
                                    
                                                     | 
                
                 | 
                        if with_seq:  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    169
                 | 
                                    
                                                     | 
                
                 | 
                            seq_ar = np.zeros((end - start,), np.float32)  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    170
                 | 
                                    
                                                     | 
                
                 | 
                            seq_ar[:] = length  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    171
                 | 
                                    
                                                     | 
                
                 | 
                            return_tuple += tuple([seq_ar])  | 
            
            
                                                                        
                            
            
                                    
            
            
                | 
                    172
                 | 
                                    
                                                     | 
                
                 | 
                        return return_tuple  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    173
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    174
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    175
                 | 
                                    
                                                     | 
                
                 | 
                class SkipGramInjector:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    176
                 | 
                                    
                                                     | 
                
                 | 
                    """Skip-Gram Batch Injector  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    177
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    178
                 | 
                                    
                                                     | 
                
                 | 
                    It generates a k-skip-2-gram sets based on input sequence  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    179
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    180
                 | 
                                    
                                                     | 
                
                 | 
                    Args:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    181
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`np.ndarray`): 1D array of integer index.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    182
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Size of each batch to be generated.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    183
                 | 
                                    
                                                     | 
                
                 | 
                        num_skips (:obj:`int`): How many times to re-use an input to generate a label.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    184
                 | 
                                    
                                                     | 
                
                 | 
                        skip_window (:obj:`int`): How many items to consider left or right.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    185
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    186
                 | 
                                    
                                                     | 
                
                 | 
                    Attributes:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    187
                 | 
                                    
                                                     | 
                
                 | 
                        data_x (:obj:`np.ndarray`): 1D array of integer index.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    188
                 | 
                                    
                                                     | 
                
                 | 
                        batch_size (:obj:`int`): Size of each batch to be generated.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    189
                 | 
                                    
                                                     | 
                
                 | 
                        num_skips (:obj:`int`): How many times to re-use an input to generate a label.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    190
                 | 
                                    
                                                     | 
                
                 | 
                        skip_window (:obj:`int`): How many items to consider left or right.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    191
                 | 
                                    
                                                     | 
                
                 | 
                        data_index (:obj:`int`): Current index used to generate next batch.  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    192
                 | 
                                    
                                                     | 
                
                 | 
                    """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    193
                 | 
                                    
                                                     | 
                
                 | 
                    def __init__(self, data_x, batch_size, num_skips, skip_window):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    194
                 | 
                                    
                                                     | 
                
                 | 
                        assert batch_size % num_skips == 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    195
                 | 
                                    
                                                     | 
                
                 | 
                        assert num_skips <= 2 * skip_window  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    196
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_x = data_x  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    197
                 | 
                                    
                                                     | 
                
                 | 
                        self.batch_size = batch_size  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    198
                 | 
                                    
                                                     | 
                
                 | 
                        self.num_skips = num_skips  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    199
                 | 
                                    
                                                     | 
                
                 | 
                        self.skip_window = skip_window  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    200
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_index = 0  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    201
                 | 
                                    
                                                     | 
                
                 | 
                 | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    202
                 | 
                                    
                                                     | 
                
                 | 
                    def next_batch(self):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    203
                 | 
                                    
                                                     | 
                
                 | 
                        """Get Next Batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    204
                 | 
                                    
                                                     | 
                
                 | 
                        """  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    205
                 | 
                                    
                                                     | 
                
                 | 
                        # Initialize batch and label array  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    206
                 | 
                                    
                                                     | 
                
                 | 
                        batch = np.ndarray(shape=(self.batch_size), dtype=np.int32)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    207
                 | 
                                    
                                                     | 
                
                 | 
                        labels = np.ndarray(shape=(self.batch_size, 1), dtype=np.int32)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    208
                 | 
                                    
                                                     | 
                
                 | 
                        # span is the size of window we are sampling from  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    209
                 | 
                                    
                                                     | 
                
                 | 
                        span = 2 * self.skip_window + 1  # [ skip_window target skip_window ]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    210
                 | 
                                    
                                                     | 
                
                 | 
                        # Add data in the buffer to a queue  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    211
                 | 
                                    
                                                     | 
                
                 | 
                        buffer = collections.deque(maxlen=span)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    212
                 | 
                                    
                                                     | 
                
                 | 
                        for _ in range(span):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    213
                 | 
                                    
                                                     | 
                
                 | 
                            buffer.append(self.data_x[self.data_index])  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    214
                 | 
                                    
                                                     | 
                
                 | 
                            self.data_index = (self.data_index + 1) % len(self.data_x)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    215
                 | 
                                    
                                                     | 
                
                 | 
                        # Now, populate the k-skip-2-gram data-label pair with random sampling  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    216
                 | 
                                    
                                                     | 
                
                 | 
                        for i in range(self.batch_size // self.num_skips):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    217
                 | 
                                    
                                                     | 
                
                 | 
                            target = self.skip_window  # target label at the center of the buffer  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    218
                 | 
                                    
                                                     | 
                
                 | 
                            targets_to_avoid = [self.skip_window]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    219
                 | 
                                    
                                                     | 
                
                 | 
                            for j in range(self.num_skips):  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    220
                 | 
                                    
                                                     | 
                
                 | 
                                while target in targets_to_avoid:  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    221
                 | 
                                    
                                                     | 
                
                 | 
                                    target = random.randint(0, span - 1)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    222
                 | 
                                    
                                                     | 
                
                 | 
                                targets_to_avoid.append(target)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    223
                 | 
                                    
                                                     | 
                
                 | 
                                batch[i * self.num_skips + j] = buffer[self.skip_window]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    224
                 | 
                                    
                                                     | 
                
                 | 
                                labels[i * self.num_skips + j, 0] = buffer[target]  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    225
                 | 
                                    
                                                     | 
                
                 | 
                            buffer.append(self.data_x[self.data_index])  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    226
                 | 
                                    
                                                     | 
                
                 | 
                            self.data_index = (self.data_index + 1) % len(self.data_x)  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    227
                 | 
                                    
                                                     | 
                
                 | 
                        # Backtrack a little bit to avoid skipping words in the end of a batch  | 
            
            
                                                                                                            
                            
            
                                    
            
            
                | 
                    228
                 | 
                                    
                                                     | 
                
                 | 
                        self.data_index = (self.data_index + len(self.data_x) - span) % len(self.data_x)  | 
            
            
                                                                                                            
                                                                
            
                                    
            
            
                | 
                    229
                 | 
                                    
                                                     | 
                
                 | 
                        return batch, labels  | 
            
            
                                                        
            
                                    
            
            
                | 
                    230
                 | 
                                    
                                                     | 
                
                 | 
                 |