Completed
Push — master ( 394368...090fba )
by Raphael
01:33
created

finish_scan()   F

Complexity

Conditions 81

Size

Total Lines 409

Duplication

Lines 0
Ratio 0 %

Importance

Changes 1
Bugs 0 Features 0
Metric Value
cc 81
c 1
b 0
f 0
dl 0
loc 409
rs 2

1 Method

Rating   Name   Duplication   Size   Complexity  
B remove_dimensions() 0 14 5

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like finish_scan() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
#!/usr/bin/env python
2
# -*- coding: utf-8 -*-
3
4
5
"""
6
The code in this file is from the scan function in theano.
7
Never modify this file directly.
8
"""
9
10
11
import logging
12
import numpy
13
import warnings
14
15
from theano.compat import ifilter, izip
16
from six import iteritems, integer_types
17
from six.moves import xrange
18
from theano.compile import SharedVariable, function
19
from theano import compile
20
from theano import gof
21
from theano.tensor import opt
22
from theano import tensor
23
from theano import config
24
from theano.updates import OrderedUpdates
25
from theano.compile import ops
26
from theano.compat import OrderedDict
27
28
29
from theano.scan_module import scan_op
30
from theano.scan_module import scan_utils
31
from theano.scan_module.scan_utils import safe_new, traverse
32
33
# Logging function for sending warning or info
34
_logger = logging.getLogger('scan_dummy_args')
35
36
37
38
def get_dummy_args(sequences=None,
39
         outputs_info=None,
40
         non_sequences=None,
41
         n_steps=None,
42
         truncate_gradient=-1,
43
         go_backwards=False,
44
         mode=None,
45
         name=None,
46
         profile=False,
47
         allow_gc=None,
48
         strict=False):
49
    ################################################################## P1>
50
    # check if inputs are just single variables instead of lists
51
    def wrap_into_list(x):
52
        """
53
        Wrap the input into a list if it is not already a list.
54
55
        """
56
        if x is None:
57
            return []
58
        elif not isinstance(x, (list, tuple)):
59
            return [x]
60
        else:
61
            return list(x)
62
63
    seqs = wrap_into_list(sequences)
64
    outs_info = wrap_into_list(outputs_info)
65
66
    # Make sure we get rid of numpy arrays or ints or anything like that
67
    # passed as inputs to scan
68
    non_seqs = []
69
    for elem in wrap_into_list(non_sequences):
70
        if not isinstance(elem, gof.Variable):
71
            non_seqs.append(tensor.as_tensor_variable(elem))
72
        else:
73
            non_seqs.append(elem)
74
75
    # If we provided a known number of steps ( before compilation)
76
    # and if that number is 1 or -1, then we can skip the Scan Op,
77
    # and just apply the inner function once
78
    # To do that we check here to see the nature of n_steps
79
    n_fixed_steps = None
80
81
    if isinstance(n_steps, (float, integer_types)):
82
        n_fixed_steps = int(n_steps)
83
    else:
84
        try:
85
            n_fixed_steps = opt.get_scalar_constant_value(n_steps)
86
        except tensor.basic.NotScalarConstantError:
87
            n_fixed_steps = None
88
89
    # Check n_steps is an int
90
    if (hasattr(n_steps, 'dtype') and
91
        str(n_steps.dtype)[:3] not in ('uin', 'int')):
92
        raise ValueError(' n_steps must be an int. dtype provided '
93
                         'is %s' % n_steps.dtype)
94
95
    # compute number of sequences and number of outputs
96
    n_seqs = len(seqs)
97
    n_outs = len(outs_info)
98
99
    return_steps = OrderedDict()
100
    # wrap sequences in a dictionary if they are not already dictionaries
101
    for i in xrange(n_seqs):
102
        if not isinstance(seqs[i], dict):
103
            seqs[i] = OrderedDict([('input', seqs[i]), ('taps', [0])])
104
        elif seqs[i].get('taps', None) is not None:
105
            seqs[i]['taps'] = wrap_into_list(seqs[i]['taps'])
106
        elif seqs[i].get('taps', None) is None:
107
            # seqs dictionary does not have the ``taps`` key
108
            seqs[i]['taps'] = [0]
109
110
    # wrap outputs info in a dictionary if they are not already in one
111
    for i in xrange(n_outs):
112
        if outs_info[i] is not None:
113
            if isinstance(outs_info[i], dict):
114
                # DEPRECATED :
115
                if outs_info[i].get('return_steps', None) is not None:
116
                    raise ValueError(
117
                            "Using `return_steps` has been deprecated. "
118
                            "Simply select the entries you need using a "
119
                            "subtensor. Scan will optimize memory "
120
                            "consumption, so do not worry about that.")
121
                # END
122
123
            if not isinstance(outs_info[i], dict):
124
                # by default any output has a tap value of -1
125
                outs_info[i] = OrderedDict([('initial', outs_info[i]), ('taps', [-1])])
126
            elif (outs_info[i].get('initial', None) is None and
127
                    outs_info[i].get('taps', None) is not None):
128
                # ^ no initial state but taps provided
129
                raise ValueError(('If you are using slices of an output '
130
                                  'you need to provide a initial state '
131
                                  'for it'), outs_info[i])
132
            elif (outs_info[i].get('initial', None) is not None and
133
                  outs_info[i].get('taps', None) is None):
134
                # ^ initial state but taps not provided
135
                if 'taps' in outs_info[i]:
136
                    # ^ explicitly provided a None for taps
137
                    _logger.warning('Output %s ( index %d) has a initial '
138
                            'state but taps is explicitly set to None ',
139
                             getattr(outs_info[i]['initial'], 'name', 'None'),
140
                             i)
141
                outs_info[i]['taps'] = [-1]
142
        else:
143
            # if a None is provided as the output info we replace it
144
            # with an empty OrdereDict() to simplify handling
145
            outs_info[i] = OrderedDict()
146
147
    ##
148
    # Step 2. Generate inputs and outputs of the inner functions
149
    # for compiling a dummy function (Iteration #1)
150
    ##
151
152
    # create theano inputs for the recursive function
153
    # note : this is a first batch of possible inputs that will
154
    #        be compiled in a dummy function; we used this dummy
155
    #        function to detect shared variables and their updates
156
    #        and to construct a new and complete list of inputs and
157
    #        outputs
158
159
    n_seqs = 0
160
    scan_seqs = []     # Variables passed as inputs to the scan op
161
    inner_seqs = []    # Variables passed as inputs to the inner function
162
    inner_slices = []  # Actual slices if scan is removed from the picture
163
    # go through sequences picking up time slices as needed
164
    for i, seq in enumerate(seqs):
165
        # Note that you can have something like no taps for
166
        # a sequence, though is highly unlikely in practice
167
        if 'taps' in seq:
168
            # go through the indicated slice
169
            mintap = numpy.min(seq['taps'])
170
            maxtap = numpy.max(seq['taps'])
171
            for k in seq['taps']:
172
                # create one slice of the input
173
                # Later on, if we decide not to use scan because we are
174
                # going for just one step, it makes things easier if we
175
                # compute the correct outputs here. This way we can use
176
                # the output of the lambda expression directly to replace
177
                # the output of scan.
178
179
                # If not we need to use copies, that will be replaced at
180
                # each frame by the corresponding slice
181
                actual_slice = seq['input'][k - mintap]
182
                _seq_val = tensor.as_tensor_variable(seq['input'])
183
                _seq_val_slice = _seq_val[k - mintap]
184
                nw_slice = _seq_val_slice.type()
185
186
                # Try to transfer test_value to the new variable
187
                if config.compute_test_value != 'off':
188
                    try:
189
                        nw_slice.tag.test_value = gof.Op._get_test_value(
190
                            _seq_val_slice)
191
                    except AttributeError as e:
192
                        if config.compute_test_value != 'ignore':
193
                            # No need to print a warning or raise an error now,
194
                            # it will be done when fn will be called.
195
                            _logger.info(('Cannot compute test value for '
196
                                'the inner function of scan, input value '
197
                                'missing %s'), e)
198
199
                # Add names to slices for debugging and pretty printing ..
200
                # that is if the input already has a name
201
                if getattr(seq['input'], 'name', None) is not None:
202
                    if k > 0:
203
                        nw_name = seq['input'].name + '[t+%d]' % k
204
                    elif k == 0:
205
                        nw_name = seq['input'].name + '[t]'
206
                    else:
207
                        nw_name = seq['input'].name + '[t%d]' % k
208
                    nw_slice.name = nw_name
209
210
                # We cut the sequence such that seq[i] to correspond to
211
                # seq[i-k]. For the purposes of cutting the sequences, we
212
                # need to pretend tap 0 is used to avoid cutting the sequences
213
                # too long if the taps are all lower or all higher than 0.
214
                maxtap_proxy = max(maxtap, 0)
215
                mintap_proxy = min(mintap, 0)
216
                start = (k - mintap_proxy)
217
                if k == maxtap_proxy:
218
                    nw_seq = seq['input'][start:]
219
                else:
220
                    end = -(maxtap_proxy - k)
221
                    nw_seq = seq['input'][start:end]
222
223
                if go_backwards:
224
                    nw_seq = nw_seq[::-1]
225
226
                scan_seqs.append(nw_seq)
227
                inner_seqs.append(nw_slice)
228
                inner_slices.append(actual_slice)
229
                n_seqs += 1
230
231
    # Since we've added all sequences now we need to level them up based on
232
    # n_steps or their different shapes
233
    lengths_vec = []
234
    for seq in scan_seqs:
235
        lengths_vec.append(seq.shape[0])
236
237
    if not scan_utils.isNaN_or_Inf_or_None(n_steps):
238
        # ^ N_steps should also be considered
239
        lengths_vec.append(tensor.as_tensor(n_steps))
240
241
    if len(lengths_vec) == 0:
242
        # ^ No information about the number of steps
243
        raise ValueError('No information about the number of steps '
244
                         'provided. Either provide a value for '
245
                         'n_steps argument of scan or provide an input '
246
                         'sequence')
247
248
    # If the user has provided the number of steps, do that regardless ( and
249
    # raise an error if the sequences are not long enough )
250
    if scan_utils.isNaN_or_Inf_or_None(n_steps):
251
        actual_n_steps = lengths_vec[0]
252
        for contestant in lengths_vec[1:]:
253
            actual_n_steps = tensor.minimum(actual_n_steps, contestant)
254
    else:
255
        actual_n_steps = tensor.as_tensor(n_steps)
256
257
    # Add names -- it helps a lot when debugging
258
259
    for (nw_seq, seq) in zip(scan_seqs, seqs):
260
        if getattr(seq['input'], 'name', None) is not None:
261
            nw_seq.name = seq['input'].name + '[%d:]' % k
262
263
    scan_seqs = [seq[:actual_n_steps] for seq in scan_seqs]
264
    # Conventions :
265
    #   mit_mot = multiple input taps, multiple output taps ( only provided
266
    #             by the gradient function )
267
    #   mit_sot = multiple input taps, single output tap (t + 0)
268
    #   sit_sot = single input tap, single output tap (t + 0)
269
    #   nit_sot = no input tap, single output tap (t + 0)
270
271
    # MIT_MOT -- not provided by the user only by the grad function
272
    n_mit_mot = 0
273
    n_mit_mot_outs = 0
274
    mit_mot_scan_inputs = []
275
    mit_mot_inner_inputs = []
276
    mit_mot_inner_outputs = []
277
    mit_mot_out_slices = []
278
    mit_mot_rightOrder = []
279
280
    # SIT_SOT -- provided by the user
281
    n_mit_sot = 0
282
    mit_sot_scan_inputs = []
283
    mit_sot_inner_inputs = []
284
    mit_sot_inner_slices = []
285
    mit_sot_inner_outputs = []
286
    mit_sot_return_steps = OrderedDict()
287
    mit_sot_tap_array = []
288
    mit_sot_rightOrder = []
289
290
    n_sit_sot = 0
291
    sit_sot_scan_inputs = []
292
    sit_sot_inner_inputs = []
293
    sit_sot_inner_slices = []
294
    sit_sot_inner_outputs = []
295
    sit_sot_return_steps = OrderedDict()
296
    sit_sot_rightOrder = []
297
298
    # go through outputs picking up time slices as needed
299
    for i, init_out in enumerate(outs_info):
300
        # Note that our convention dictates that if an output uses
301
        # just the previous time step, as a initial state we will only
302
        # provide a tensor of the same dimension as one time step; This
303
        # makes code much cleaner for those who do not use taps. Otherwise
304
        # they would always had to shape_padleft the initial state ..
305
        # which is ugly
306
        if init_out.get('taps', None) == [-1]:
307
308
            actual_arg = init_out['initial']
309
            if not isinstance(actual_arg, tensor.Variable):
310
                actual_arg = tensor.as_tensor_variable(actual_arg)
311
            arg = safe_new(actual_arg)
312
            if isinstance(arg, tensor.Constant):
313
                # safe new returns a clone of the constants, but that is not
314
                # what we need for initial states
315
                arg = arg.type()
316
317
            # Try to transfer test_value to the new variable
318
            if config.compute_test_value != 'off':
319
                try:
320
                    arg.tag.test_value = gof.Op._get_test_value(actual_arg)
321
                except AttributeError as e:
322
                    if config.compute_test_value != 'ignore':
323
                        # No need to print a warning or raise an error now,
324
                        # it will be done when fn will be called.
325
                        _logger.info(('Cannot compute test value for the '
326
                            'inner function of scan, input value missing %s'),
327
                                     e)
328
329
            if getattr(init_out['initial'], 'name', None) is not None:
330
                arg.name = init_out['initial'].name + '[t-1]'
331
332
            # We need now to allocate space for storing the output and copy
333
            # the initial state over. We do this using the expand function
334
            # defined in scan utils
335
            sit_sot_scan_inputs.append(
336
                scan_utils.expand_empty(
337
                    tensor.unbroadcast(
338
                        tensor.shape_padleft(actual_arg), 0),
339
                    actual_n_steps
340
                ))
341
342
            sit_sot_inner_slices.append(actual_arg)
343
            if i in return_steps:
344
                sit_sot_return_steps[n_sit_sot] = return_steps[i]
345
            sit_sot_inner_inputs.append(arg)
346
            sit_sot_rightOrder.append(i)
347
            n_sit_sot += 1
348
349
        elif init_out.get('taps', None):
350
351
            if numpy.any(numpy.array(init_out.get('taps', [])) > 0):
352
                # Make sure we do not have requests for future values of a
353
                # sequence we can not provide such values
354
                raise ValueError('Can not use future taps of outputs',
355
                                    init_out)
356
            # go through the taps
357
            mintap = abs(numpy.min(init_out['taps']))
358
            mit_sot_tap_array.append(init_out['taps'])
359
            idx_offset = abs(numpy.min(init_out['taps']))
360
            # Sequence
361
            mit_sot_scan_inputs.append(
362
                scan_utils.expand_empty(init_out['initial'][:mintap],
363
                                        actual_n_steps))
364
365
            if i in return_steps:
366
                mit_sot_return_steps[n_mit_sot] = return_steps[i]
367
            mit_sot_rightOrder.append(i)
368
            n_mit_sot += 1
369
            for k in init_out['taps']:
370
                # create a new slice
371
                actual_nw_slice = init_out['initial'][k + mintap]
372
                _init_out_var = tensor.as_tensor_variable(init_out['initial'])
373
                _init_out_var_slice = _init_out_var[k + mintap]
374
                nw_slice = _init_out_var_slice.type()
375
376
                # Try to transfer test_value to the new variable
377
                if config.compute_test_value != 'off':
378
                    try:
379
                        nw_slice.tag.test_value = gof.Op._get_test_value(
380
                            _init_out_var_slice)
381
                    except AttributeError as e:
382
                        if config.compute_test_value != 'ignore':
383
                            # No need to print a warning or raise an error now,
384
                            # it will be done when fn will be called.
385
                            _logger.info(('Cannot compute test value for '
386
                                'the inner function of scan, input value '
387
                                'missing. %s'), e)
388
389
                # give it a name or debugging and pretty printing
390
                if getattr(init_out['initial'], 'name', None) is not None:
391
                    if k > 0:
392
                        nw_slice.name = (init_out['initial'].name +
393
                                            '[t+%d]' % k)
394
                    elif k == 0:
395
                        nw_slice.name = init_out['initial'].name + '[t]'
396
                    else:
397
                        nw_slice.name = (init_out['initial'].name +
398
                                            '[t%d]' % k)
399
                mit_sot_inner_inputs.append(nw_slice)
400
                mit_sot_inner_slices.append(actual_nw_slice)
401
        # NOTE: there is another case, in which we do not want to provide
402
        #      any previous value of the output to the inner function (i.e.
403
        #      a map); in that case we do not have to do anything ..
404
405
    # Re-order args
406
    max_mit_sot = numpy.max([-1] + mit_sot_rightOrder) + 1
407
    max_sit_sot = numpy.max([-1] + sit_sot_rightOrder) + 1
408
    n_elems = numpy.max([max_mit_sot, max_sit_sot])
409
    _ordered_args = [[] for x in xrange(n_elems)]
410
    offset = 0
411
    for idx in xrange(n_mit_sot):
412
        n_inputs = len(mit_sot_tap_array[idx])
413
        if n_fixed_steps in [1, -1]:
414
            _ordered_args[mit_sot_rightOrder[idx]] = \
415
                            mit_sot_inner_slices[offset:offset + n_inputs]
416
        else:
417
            _ordered_args[mit_sot_rightOrder[idx]] = \
418
                            mit_sot_inner_inputs[offset:offset + n_inputs]
419
        offset += n_inputs
420
421
    for idx in xrange(n_sit_sot):
422
        if n_fixed_steps in [1, -1]:
423
            _ordered_args[sit_sot_rightOrder[idx]] = \
424
                                        [sit_sot_inner_slices[idx]]
425
        else:
426
            _ordered_args[sit_sot_rightOrder[idx]] = \
427
                                        [sit_sot_inner_inputs[idx]]
428
429
    ordered_args = []
430
    for ls in _ordered_args:
431
        ordered_args += ls
432
    if n_fixed_steps in [1, -1]:
433
        args = (inner_slices +
434
                ordered_args +
435
                non_seqs)
436
437
    else:
438
        args = (inner_seqs +
439
                ordered_args +
440
                non_seqs)
441
442
    # add only the non-shared variables and non-constants to the arguments of
443
    # the dummy function [ a function should not get shared variables or
444
    # constants as input ]
445
    dummy_args = [arg for arg in args
446
                  if (not isinstance(arg, SharedVariable) and
447
                      not isinstance(arg, tensor.Constant))]
448
    ################################################################## P1<
449
    return dummy_args, locals()
450
451
452
def finish_scan(fn_outputs, local_vars):
453
454
    n_fixed_steps = local_vars["n_fixed_steps"]
455
    return_steps = local_vars["return_steps"]
456
    non_seqs = local_vars["non_seqs"]
457
    dummy_args = local_vars["dummy_args"]
458
    args = local_vars["args"]
459
    outs_info = local_vars["outs_info"]
460
    n_outs = local_vars["n_outs"]
461
    mit_sot_inner_outputs = local_vars["mit_sot_inner_outputs"]
462
    sit_sot_inner_outputs = local_vars["sit_sot_inner_outputs"]
463
    sit_sot_scan_inputs = local_vars["sit_sot_scan_inputs"]
464
    sit_sot_inner_inputs = local_vars["sit_sot_inner_inputs"]
465
    actual_n_steps = local_vars["actual_n_steps"]
466
    sit_sot_rightOrder = local_vars["sit_sot_rightOrder"]
467
    strict = local_vars["strict"]
468
    non_sequences = local_vars["non_sequences"]
469
    inner_seqs = local_vars["inner_seqs"]
470
    mit_mot_inner_inputs = local_vars["mit_mot_inner_inputs"]
471
    mit_sot_inner_inputs = local_vars["mit_sot_inner_inputs"]
472
    mit_mot_inner_outputs = local_vars["mit_mot_inner_outputs"]
473
    mit_sot_tap_array = local_vars["mit_sot_tap_array"]
474
    allow_gc = local_vars["allow_gc"]
475
    n_seqs = local_vars["n_seqs"]
476
    n_mit_mot_outs = local_vars["n_mit_mot_outs"]
477
    mit_mot_out_slices = local_vars["mit_mot_out_slices"]
478
    truncate_gradient = local_vars["truncate_gradient"]
479
    name = local_vars["name"]
480
    mode = local_vars["mode"]
481
    profile = local_vars["profile"]
482
    scan_seqs = local_vars["scan_seqs"]
483
    mit_mot_scan_inputs = local_vars["mit_mot_scan_inputs"]
484
    mit_sot_scan_inputs = local_vars["mit_sot_scan_inputs"]
485
    n_mit_mot = local_vars["n_mit_mot"]
486
    mit_sot_return_steps = local_vars["mit_sot_return_steps"]
487
    n_mit_sot = local_vars["n_mit_sot"]
488
    sit_sot_return_steps = local_vars["sit_sot_return_steps"]
489
    mit_sot_rightOrder = local_vars["mit_sot_rightOrder"]
490
491
    condition, outputs, updates = scan_utils.get_updates_and_outputs(fn_outputs)
492
    ################################################################## P2>
493
    if condition is not None:
494
        as_while = True
495
    else:
496
        as_while = False
497
    ##
498
    # Step 3. Check if we actually need scan and remove it if we don't
499
    ##
500
501
    if n_fixed_steps in [1, -1]:
502
        # We do not need to use the scan op anymore, so we can just return
503
        # the outputs and updates we have
504
        if condition is not None:
505
            _logger.warning(('When the number of steps is fixed and equal '
506
                    'to 1, the provided stopping condition, ',
507
                    str(condition), ' is ignored'))
508
509
        for pos, inner_out in enumerate(outputs):
510
            # we need to see if we need to pad our sequences with an
511
            # unbroadcastable dimension; case example : we return an
512
            # output for which we want all intermediate. If n_steps is 1
513
            # then, if we return the output as given by the innner function
514
            # this will represent only a slice and it will have one
515
            # dimension less.
516
            if (isinstance(inner_out.type, tensor.TensorType) and
517
                return_steps.get(pos, 0) != 1):
518
                outputs[pos] = tensor.unbroadcast(
519
                    tensor.shape_padleft(inner_out), 0)
520
        if len(outputs) == 1:
521
            outputs = outputs[0]
522
523
        return (outputs, updates)
524
525
    ##
526
    # Step 4. Compile the dummy function
527
    ##
528
529
    # We can now compile a dummy function just to see what shared variable
530
    # we have and what are their update rules (note that the user has
531
    # the option not to pass the shared variable to scan, so we need to
532
    # pick them manually and add them to scan)
533
    # make the compilation as fast as possible by not applying any
534
    # optimization or conversion to C [ note this region is not important
535
    # for performance so we can do stuff as unoptimal as we wish ]
536
537
    # extract still missing inputs (there still might be so) and add them
538
    # as non sequences at the end of our args
539
    fake_nonseqs = [x.type() for x in non_seqs]
540
    fake_outputs = scan_utils.clone(outputs,
541
                                    replace=OrderedDict(izip(non_seqs,
542
                                                             fake_nonseqs)))
543
    all_inputs = ifilter(
544
        lambda x: (isinstance(x, gof.Variable) and
545
                   not isinstance(x, SharedVariable) and
546
                   not isinstance(x, gof.Constant)),
547
        gof.graph.inputs(fake_outputs))
548
    extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs]
549
    non_seqs += extra_inputs
550
    # Note we do not use all_inputs directly since the order of variables
551
    # in args is quite important
552
    dummy_args += extra_inputs
553
554
    dummy_outs = outputs
555
    if condition is not None:
556
        dummy_outs.append(condition)
557
    dummy_f = function(dummy_args,
558
                       dummy_outs,
559
                       updates=updates,
560
                       mode=compile.mode.Mode(linker='py',
561
                                              optimizer=None),
562
                       on_unused_input='ignore',
563
                       profile=False)
564
565
    ##
566
    # Step 5. Re-arange inputs of scan into a more strict order
567
    ##
568
569
    # Step 5.0 Check the outputs of the dummy function to see if they
570
    # match with user provided data
571
572
    # if the number of outputs to the function does not match the number of
573
    # assumed outputs until now (provided by the user) there can be
574
    # only one explanation: No information is provided for any of the
575
    # outputs (i.e. we are dealing with a map)
576
    tmp_dummy_f_outs = len(dummy_f.maker.outputs)
577
    if as_while:
578
        tmp_dummy_f_outs -= 1
579
    if not (tmp_dummy_f_outs == n_outs or outs_info == []):
580
        raise ValueError('Please provide None as outputs_info for '
581
                         'any output that does not feed back into '
582
                         'scan (i.e. it behaves like a map) ')
583
584
    if outs_info == []:
585
        n_outs = len(dummy_f.maker.outputs)
586
        if as_while:
587
            n_outs = n_outs - 1
588
        outs_info = [OrderedDict() for x in xrange(n_outs)]
589
590
    # Step 5.1 Outputs with taps different then -1
591
592
    for i, out in enumerate(outs_info):
593
        if 'taps' in out and out['taps'] != [-1]:
594
            mit_sot_inner_outputs.append(outputs[i])
595
596
    # Step 5.2 Outputs with tap equal to -1
597
    for i, out in enumerate(outs_info):
598
        if 'taps' in out and out['taps'] == [-1]:
599
            sit_sot_inner_outputs.append(outputs[i])
600
601
    # Step 5.3 Outputs that correspond to update rules of shared variables
602
    givens = OrderedDict()
603
    n_shared_outs = 0
604
    shared_scan_inputs = []
605
    shared_inner_inputs = []
606
    shared_inner_outputs = []
607
    sit_sot_shared = []
608
    for input in dummy_f.maker.expanded_inputs:
609
        if isinstance(input.variable, SharedVariable) and input.update:
610
            new_var = safe_new(input.variable)
611
            if getattr(input.variable, 'name', None) is not None:
612
                new_var.name = input.variable.name + '_copy'
613
            if isinstance(new_var.type, ops.expandable_types):
614
                sit_sot_inner_inputs.append(new_var)
615
                sit_sot_scan_inputs.append(
616
                    scan_utils.expand_empty(
617
                        tensor.unbroadcast(
618
                            tensor.shape_padleft(input.variable), 0),
619
                        actual_n_steps))
620
                tensor_update = tensor.as_tensor_variable(input.update)
621
                sit_sot_inner_outputs.append(tensor_update)
622
                # Not that pos is not a negative index. The sign of pos is used
623
                # as a flag to indicate if this output should be part of the
624
                # update rules or part of the standard outputs of scan.
625
                # If `pos` is positive than it corresponds to the standard
626
                # outputs of scan and it refers to output of index `pos`. If `pos`
627
                # is negative that it corresponds to update rules of scan and it
628
                # refers to update rule of index -1 - `pos`.
629
                sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
630
                sit_sot_shared.append(input.variable)
631
                givens[input.variable] = new_var
632
633
            else:
634
                shared_inner_inputs.append(new_var)
635
                shared_scan_inputs.append(input.variable)
636
                shared_inner_outputs.append(input.update)
637
                givens[input.variable] = new_var
638
                n_shared_outs += 1
639
    n_sit_sot = len(sit_sot_inner_inputs)
640
    # Step 5.4 Outputs with no taps used in the input
641
    n_nit_sot = 0
642
    nit_sot_inner_outputs = []
643
    nit_sot_return_steps = OrderedDict()
644
    nit_sot_rightOrder = []
645
    for i, out in enumerate(outs_info):
646
        if not 'taps' in out:
647
            nit_sot_inner_outputs.append(outputs[i])
648
            if i in return_steps:
649
                nit_sot_return_steps[n_nit_sot] = return_steps[i]
650
            nit_sot_rightOrder.append(i)
651
            n_nit_sot += 1
652
653
    # Step 5.5 all other arguments including extra inputs
654
    other_scan_args = []
655
    other_inner_args = []
656
657
    other_scan_args += [arg for arg in non_seqs
658
                        if (not isinstance(arg, SharedVariable) and
659
                            not isinstance(arg, tensor.Constant))]
660
661
    # Step 5.6 all shared variables with no update rules
662
    other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs
663
                         if (not isinstance(arg, SharedVariable) and
664
                             not isinstance(arg, tensor.Constant))]
665
666
    givens.update(OrderedDict(izip(other_scan_args, other_inner_args)))
667
668
    if strict:
669
        non_seqs_set = set(non_sequences if non_sequences is not None else [])
670
671
        other_shared_scan_args = [arg.variable for arg
672
                            in dummy_f.maker.expanded_inputs
673
                            if (isinstance(arg.variable, SharedVariable) and
674
                                not arg.update and
675
                                arg.variable in non_seqs_set)]
676
        other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
677
                            in dummy_f.maker.expanded_inputs
678
                            if (isinstance(arg.variable, SharedVariable) and
679
                                not arg.update and
680
                                arg.variable in non_seqs_set)]
681
    else:
682
        other_shared_scan_args = [arg.variable for arg
683
                            in dummy_f.maker.expanded_inputs
684
                            if (isinstance(arg.variable, SharedVariable) and
685
                                not arg.update)]
686
        other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
687
                            in dummy_f.maker.expanded_inputs
688
                            if (isinstance(arg.variable, SharedVariable) and
689
                                not arg.update)]
690
    givens.update(OrderedDict(izip(other_shared_scan_args,
691
                                   other_shared_inner_args)))
692
693
    ##
694
    # Step 6. Re-order the outputs and clone them replacing things
695
    # using the givens
696
    ##
697
    inner_inputs = (inner_seqs +
698
                    mit_mot_inner_inputs +
699
                    mit_sot_inner_inputs +
700
                    sit_sot_inner_inputs +
701
                    shared_inner_inputs +
702
                    other_shared_inner_args +
703
                    other_inner_args)
704
705
    inner_outs = (mit_mot_inner_outputs +
706
                  mit_sot_inner_outputs +
707
                  sit_sot_inner_outputs +
708
                  nit_sot_inner_outputs +
709
                  shared_inner_outputs)
710
    if condition is not None:
711
        inner_outs.append(condition)
712
    # Cuda and Gpuarray are imported here, instead of being imported on top of
713
    # the file because that would force on the user some dependencies that we
714
    # might do not want to. Currently we are working on removing the
715
    # dependencies on sandbox code completeley.
716
    from theano.sandbox import cuda, gpuarray
717
    if cuda.cuda_available or gpuarray.pygpu_activated:
718
        # very often we end up in this situation when we want to
719
        # replace w with w_copy, where w is a GPU variable
720
        # and w_copy is TensorType. This is caused because shared
721
        # variables are put on GPU right aways >:| ,
722
        new_givens = OrderedDict()
723
724
        for w, w_copy in iteritems(givens):
725
            if ((isinstance(w.type, cuda.CudaNdarrayType) or
726
                 isinstance(w.type, gpuarray.GpuArrayType)) and
727
                isinstance(w_copy.type, tensor.TensorType)):
728
                for o in inner_outs:
729
                    new_givens = traverse(o, w, w_copy, new_givens)
730
            else:
731
                new_givens[w] = w_copy
732
    else:
733
        new_givens = givens
734
735
    new_outs = scan_utils.clone(inner_outs, replace=new_givens)
736
737
    ##
738
    # Step 7. Create the Scan Op
739
    ##
740
741
    tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
742
    if allow_gc is None:
743
        allow_gc = config.scan.allow_gc
744
    info = OrderedDict()
745
746
    info['tap_array'] = tap_array
747
    info['n_seqs'] = n_seqs
748
    info['n_mit_mot'] = n_mit_mot
749
    info['n_mit_mot_outs'] = n_mit_mot_outs
750
    info['mit_mot_out_slices'] = mit_mot_out_slices
751
    info['n_mit_sot'] = n_mit_sot
752
    info['n_sit_sot'] = n_sit_sot
753
    info['n_shared_outs'] = n_shared_outs
754
    info['n_nit_sot'] = n_nit_sot
755
    info['truncate_gradient'] = truncate_gradient
756
    info['name'] = name
757
    info['mode'] = mode
758
    info['destroy_map'] = OrderedDict()
759
    info['gpu'] = False
760
    info['as_while'] = as_while
761
    info['profile'] = profile
762
    info['allow_gc'] = allow_gc
763
    info['strict'] = strict
764
765
    local_op = scan_op.Scan(inner_inputs, new_outs, info)
766
767
    ##
768
    # Step 8. Compute the outputs using the scan op
769
    ##
770
    _scan_inputs = (scan_seqs +
771
                    mit_mot_scan_inputs +
772
                    mit_sot_scan_inputs +
773
                    sit_sot_scan_inputs +
774
                    shared_scan_inputs +
775
                    [actual_n_steps for x in xrange(n_nit_sot)] +
776
                    other_shared_scan_args +
777
                    other_scan_args)
778
779
    scan_inputs = []
780
    for arg in [actual_n_steps] + _scan_inputs:
781
        try:
782
            arg = tensor.as_tensor_variable(arg)
783
        except TypeError:
784
            # This happens for Random States for e.g. but it is a good way
785
            # to make sure no input is a cuda ndarrays
786
            pass
787
        scan_inputs += [arg]
788
    scan_outs = local_op(*scan_inputs)
789
    if type(scan_outs) not in (list, tuple):
790
        scan_outs = [scan_outs]
791
    ##
792
    # Step 9. Figure out which outs are update rules for shared variables
793
    # and so on ...
794
    ##
795
796
    update_map = OrderedUpdates()
797
798
    def remove_dimensions(outs, steps_return, offsets=None):
799
        out_ls = []
800
        for idx, out in enumerate(outs):
801
            if idx in steps_return:
802
                if steps_return[idx] > 1:
803
                    out_ls.append(out[-steps_return[idx]:])
804
                else:
805
                    out_ls.append(out[-1])
806
            else:
807
                if offsets is None:
808
                    out_ls.append(out)
809
                else:
810
                    out_ls.append(out[offsets[idx]:])
811
        return out_ls
812
813
    offset = n_mit_mot
814
    offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array]
815
    mit_sot_outs = remove_dimensions(
816
        scan_outs[offset:offset + n_mit_sot],
817
        mit_sot_return_steps,
818
        offsets)
819
820
    offset += n_mit_sot
821
    offsets = [1 for x in xrange(n_sit_sot)]
822
    sit_sot_outs = remove_dimensions(
823
        scan_outs[offset:offset + n_sit_sot],
824
        sit_sot_return_steps,
825
        offsets)
826
827
    offset += n_sit_sot
828
    nit_sot_outs = remove_dimensions(
829
        scan_outs[offset:offset + n_nit_sot],
830
        nit_sot_return_steps)
831
832
    offset += n_nit_sot
833
    for idx, update_rule in enumerate(
834
                scan_outs[offset:offset + n_shared_outs]):
835
        update_map[shared_scan_inputs[idx]] = update_rule
836
837
    _scan_out_list = (mit_sot_outs +
838
                      sit_sot_outs +
839
                      nit_sot_outs)
840
    # Step 10. I need to reorder the outputs to be in the order expected by
841
    # the user
842
    rightOrder = (mit_sot_rightOrder +
843
                  sit_sot_rightOrder +
844
                  nit_sot_rightOrder)
845
    scan_out_list = [None] * len(rightOrder)
846
    for idx, pos in enumerate(rightOrder):
847
        if pos >= 0:
848
            scan_out_list[pos] = _scan_out_list[idx]
849
        else:
850
            # Not that pos is not a negative index. The sign of pos is used
851
            # as a flag to indicate if this output should be part of the
852
            # update rules or part of the standard outputs of scan.
853
            # If `pos` is positive than it corresponds to the standard
854
            # outputs of scan and it refers to output of index `pos`. If `pos`
855
            # is negative that it corresponds to update rules of scan and it
856
            # refers to update rule of index -1 - `pos`.
857
            update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1]
858
    scan_out_list = [x for x in scan_out_list if x is not None]
859
    ################################################################## P2<
860
    return (scan_out_list, update_map)