Conditions | 81 |
Total Lines | 409 |
Lines | 0 |
Ratio | 0 % |
Changes | 1 | ||
Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like finish_scan() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
1 | #!/usr/bin/env python |
||
452 | def finish_scan(fn_outputs, local_vars): |
||
453 | |||
454 | n_fixed_steps = local_vars["n_fixed_steps"] |
||
455 | return_steps = local_vars["return_steps"] |
||
456 | non_seqs = local_vars["non_seqs"] |
||
457 | dummy_args = local_vars["dummy_args"] |
||
458 | args = local_vars["args"] |
||
459 | outs_info = local_vars["outs_info"] |
||
460 | n_outs = local_vars["n_outs"] |
||
461 | mit_sot_inner_outputs = local_vars["mit_sot_inner_outputs"] |
||
462 | sit_sot_inner_outputs = local_vars["sit_sot_inner_outputs"] |
||
463 | sit_sot_scan_inputs = local_vars["sit_sot_scan_inputs"] |
||
464 | sit_sot_inner_inputs = local_vars["sit_sot_inner_inputs"] |
||
465 | actual_n_steps = local_vars["actual_n_steps"] |
||
466 | sit_sot_rightOrder = local_vars["sit_sot_rightOrder"] |
||
467 | strict = local_vars["strict"] |
||
468 | non_sequences = local_vars["non_sequences"] |
||
469 | inner_seqs = local_vars["inner_seqs"] |
||
470 | mit_mot_inner_inputs = local_vars["mit_mot_inner_inputs"] |
||
471 | mit_sot_inner_inputs = local_vars["mit_sot_inner_inputs"] |
||
472 | mit_mot_inner_outputs = local_vars["mit_mot_inner_outputs"] |
||
473 | mit_sot_tap_array = local_vars["mit_sot_tap_array"] |
||
474 | allow_gc = local_vars["allow_gc"] |
||
475 | n_seqs = local_vars["n_seqs"] |
||
476 | n_mit_mot_outs = local_vars["n_mit_mot_outs"] |
||
477 | mit_mot_out_slices = local_vars["mit_mot_out_slices"] |
||
478 | truncate_gradient = local_vars["truncate_gradient"] |
||
479 | name = local_vars["name"] |
||
480 | mode = local_vars["mode"] |
||
481 | profile = local_vars["profile"] |
||
482 | scan_seqs = local_vars["scan_seqs"] |
||
483 | mit_mot_scan_inputs = local_vars["mit_mot_scan_inputs"] |
||
484 | mit_sot_scan_inputs = local_vars["mit_sot_scan_inputs"] |
||
485 | n_mit_mot = local_vars["n_mit_mot"] |
||
486 | mit_sot_return_steps = local_vars["mit_sot_return_steps"] |
||
487 | n_mit_sot = local_vars["n_mit_sot"] |
||
488 | sit_sot_return_steps = local_vars["sit_sot_return_steps"] |
||
489 | mit_sot_rightOrder = local_vars["mit_sot_rightOrder"] |
||
490 | |||
491 | condition, outputs, updates = scan_utils.get_updates_and_outputs(fn_outputs) |
||
492 | ################################################################## P2> |
||
493 | if condition is not None: |
||
494 | as_while = True |
||
495 | else: |
||
496 | as_while = False |
||
497 | ## |
||
498 | # Step 3. Check if we actually need scan and remove it if we don't |
||
499 | ## |
||
500 | |||
501 | if n_fixed_steps in [1, -1]: |
||
502 | # We do not need to use the scan op anymore, so we can just return |
||
503 | # the outputs and updates we have |
||
504 | if condition is not None: |
||
505 | _logger.warning(('When the number of steps is fixed and equal ' |
||
506 | 'to 1, the provided stopping condition, ', |
||
507 | str(condition), ' is ignored')) |
||
508 | |||
509 | for pos, inner_out in enumerate(outputs): |
||
510 | # we need to see if we need to pad our sequences with an |
||
511 | # unbroadcastable dimension; case example : we return an |
||
512 | # output for which we want all intermediate. If n_steps is 1 |
||
513 | # then, if we return the output as given by the innner function |
||
514 | # this will represent only a slice and it will have one |
||
515 | # dimension less. |
||
516 | if (isinstance(inner_out.type, tensor.TensorType) and |
||
517 | return_steps.get(pos, 0) != 1): |
||
518 | outputs[pos] = tensor.unbroadcast( |
||
519 | tensor.shape_padleft(inner_out), 0) |
||
520 | if len(outputs) == 1: |
||
521 | outputs = outputs[0] |
||
522 | |||
523 | return (outputs, updates) |
||
524 | |||
525 | ## |
||
526 | # Step 4. Compile the dummy function |
||
527 | ## |
||
528 | |||
529 | # We can now compile a dummy function just to see what shared variable |
||
530 | # we have and what are their update rules (note that the user has |
||
531 | # the option not to pass the shared variable to scan, so we need to |
||
532 | # pick them manually and add them to scan) |
||
533 | # make the compilation as fast as possible by not applying any |
||
534 | # optimization or conversion to C [ note this region is not important |
||
535 | # for performance so we can do stuff as unoptimal as we wish ] |
||
536 | |||
537 | # extract still missing inputs (there still might be so) and add them |
||
538 | # as non sequences at the end of our args |
||
539 | fake_nonseqs = [x.type() for x in non_seqs] |
||
540 | fake_outputs = scan_utils.clone(outputs, |
||
541 | replace=OrderedDict(izip(non_seqs, |
||
542 | fake_nonseqs))) |
||
543 | all_inputs = ifilter( |
||
544 | lambda x: (isinstance(x, gof.Variable) and |
||
545 | not isinstance(x, SharedVariable) and |
||
546 | not isinstance(x, gof.Constant)), |
||
547 | gof.graph.inputs(fake_outputs)) |
||
548 | extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] |
||
549 | non_seqs += extra_inputs |
||
550 | # Note we do not use all_inputs directly since the order of variables |
||
551 | # in args is quite important |
||
552 | dummy_args += extra_inputs |
||
553 | |||
554 | dummy_outs = outputs |
||
555 | if condition is not None: |
||
556 | dummy_outs.append(condition) |
||
557 | dummy_f = function(dummy_args, |
||
558 | dummy_outs, |
||
559 | updates=updates, |
||
560 | mode=compile.mode.Mode(linker='py', |
||
561 | optimizer=None), |
||
562 | on_unused_input='ignore', |
||
563 | profile=False) |
||
564 | |||
565 | ## |
||
566 | # Step 5. Re-arange inputs of scan into a more strict order |
||
567 | ## |
||
568 | |||
569 | # Step 5.0 Check the outputs of the dummy function to see if they |
||
570 | # match with user provided data |
||
571 | |||
572 | # if the number of outputs to the function does not match the number of |
||
573 | # assumed outputs until now (provided by the user) there can be |
||
574 | # only one explanation: No information is provided for any of the |
||
575 | # outputs (i.e. we are dealing with a map) |
||
576 | tmp_dummy_f_outs = len(dummy_f.maker.outputs) |
||
577 | if as_while: |
||
578 | tmp_dummy_f_outs -= 1 |
||
579 | if not (tmp_dummy_f_outs == n_outs or outs_info == []): |
||
580 | raise ValueError('Please provide None as outputs_info for ' |
||
581 | 'any output that does not feed back into ' |
||
582 | 'scan (i.e. it behaves like a map) ') |
||
583 | |||
584 | if outs_info == []: |
||
585 | n_outs = len(dummy_f.maker.outputs) |
||
586 | if as_while: |
||
587 | n_outs = n_outs - 1 |
||
588 | outs_info = [OrderedDict() for x in xrange(n_outs)] |
||
589 | |||
590 | # Step 5.1 Outputs with taps different then -1 |
||
591 | |||
592 | for i, out in enumerate(outs_info): |
||
593 | if 'taps' in out and out['taps'] != [-1]: |
||
594 | mit_sot_inner_outputs.append(outputs[i]) |
||
595 | |||
596 | # Step 5.2 Outputs with tap equal to -1 |
||
597 | for i, out in enumerate(outs_info): |
||
598 | if 'taps' in out and out['taps'] == [-1]: |
||
599 | sit_sot_inner_outputs.append(outputs[i]) |
||
600 | |||
601 | # Step 5.3 Outputs that correspond to update rules of shared variables |
||
602 | givens = OrderedDict() |
||
603 | n_shared_outs = 0 |
||
604 | shared_scan_inputs = [] |
||
605 | shared_inner_inputs = [] |
||
606 | shared_inner_outputs = [] |
||
607 | sit_sot_shared = [] |
||
608 | for input in dummy_f.maker.expanded_inputs: |
||
609 | if isinstance(input.variable, SharedVariable) and input.update: |
||
610 | new_var = safe_new(input.variable) |
||
611 | if getattr(input.variable, 'name', None) is not None: |
||
612 | new_var.name = input.variable.name + '_copy' |
||
613 | if isinstance(new_var.type, ops.expandable_types): |
||
614 | sit_sot_inner_inputs.append(new_var) |
||
615 | sit_sot_scan_inputs.append( |
||
616 | scan_utils.expand_empty( |
||
617 | tensor.unbroadcast( |
||
618 | tensor.shape_padleft(input.variable), 0), |
||
619 | actual_n_steps)) |
||
620 | tensor_update = tensor.as_tensor_variable(input.update) |
||
621 | sit_sot_inner_outputs.append(tensor_update) |
||
622 | # Not that pos is not a negative index. The sign of pos is used |
||
623 | # as a flag to indicate if this output should be part of the |
||
624 | # update rules or part of the standard outputs of scan. |
||
625 | # If `pos` is positive than it corresponds to the standard |
||
626 | # outputs of scan and it refers to output of index `pos`. If `pos` |
||
627 | # is negative that it corresponds to update rules of scan and it |
||
628 | # refers to update rule of index -1 - `pos`. |
||
629 | sit_sot_rightOrder.append(-1 - len(sit_sot_shared)) |
||
630 | sit_sot_shared.append(input.variable) |
||
631 | givens[input.variable] = new_var |
||
632 | |||
633 | else: |
||
634 | shared_inner_inputs.append(new_var) |
||
635 | shared_scan_inputs.append(input.variable) |
||
636 | shared_inner_outputs.append(input.update) |
||
637 | givens[input.variable] = new_var |
||
638 | n_shared_outs += 1 |
||
639 | n_sit_sot = len(sit_sot_inner_inputs) |
||
640 | # Step 5.4 Outputs with no taps used in the input |
||
641 | n_nit_sot = 0 |
||
642 | nit_sot_inner_outputs = [] |
||
643 | nit_sot_return_steps = OrderedDict() |
||
644 | nit_sot_rightOrder = [] |
||
645 | for i, out in enumerate(outs_info): |
||
646 | if not 'taps' in out: |
||
647 | nit_sot_inner_outputs.append(outputs[i]) |
||
648 | if i in return_steps: |
||
649 | nit_sot_return_steps[n_nit_sot] = return_steps[i] |
||
650 | nit_sot_rightOrder.append(i) |
||
651 | n_nit_sot += 1 |
||
652 | |||
653 | # Step 5.5 all other arguments including extra inputs |
||
654 | other_scan_args = [] |
||
655 | other_inner_args = [] |
||
656 | |||
657 | other_scan_args += [arg for arg in non_seqs |
||
658 | if (not isinstance(arg, SharedVariable) and |
||
659 | not isinstance(arg, tensor.Constant))] |
||
660 | |||
661 | # Step 5.6 all shared variables with no update rules |
||
662 | other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs |
||
663 | if (not isinstance(arg, SharedVariable) and |
||
664 | not isinstance(arg, tensor.Constant))] |
||
665 | |||
666 | givens.update(OrderedDict(izip(other_scan_args, other_inner_args))) |
||
667 | |||
668 | if strict: |
||
669 | non_seqs_set = set(non_sequences if non_sequences is not None else []) |
||
670 | |||
671 | other_shared_scan_args = [arg.variable for arg |
||
672 | in dummy_f.maker.expanded_inputs |
||
673 | if (isinstance(arg.variable, SharedVariable) and |
||
674 | not arg.update and |
||
675 | arg.variable in non_seqs_set)] |
||
676 | other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg |
||
677 | in dummy_f.maker.expanded_inputs |
||
678 | if (isinstance(arg.variable, SharedVariable) and |
||
679 | not arg.update and |
||
680 | arg.variable in non_seqs_set)] |
||
681 | else: |
||
682 | other_shared_scan_args = [arg.variable for arg |
||
683 | in dummy_f.maker.expanded_inputs |
||
684 | if (isinstance(arg.variable, SharedVariable) and |
||
685 | not arg.update)] |
||
686 | other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg |
||
687 | in dummy_f.maker.expanded_inputs |
||
688 | if (isinstance(arg.variable, SharedVariable) and |
||
689 | not arg.update)] |
||
690 | givens.update(OrderedDict(izip(other_shared_scan_args, |
||
691 | other_shared_inner_args))) |
||
692 | |||
693 | ## |
||
694 | # Step 6. Re-order the outputs and clone them replacing things |
||
695 | # using the givens |
||
696 | ## |
||
697 | inner_inputs = (inner_seqs + |
||
698 | mit_mot_inner_inputs + |
||
699 | mit_sot_inner_inputs + |
||
700 | sit_sot_inner_inputs + |
||
701 | shared_inner_inputs + |
||
702 | other_shared_inner_args + |
||
703 | other_inner_args) |
||
704 | |||
705 | inner_outs = (mit_mot_inner_outputs + |
||
706 | mit_sot_inner_outputs + |
||
707 | sit_sot_inner_outputs + |
||
708 | nit_sot_inner_outputs + |
||
709 | shared_inner_outputs) |
||
710 | if condition is not None: |
||
711 | inner_outs.append(condition) |
||
712 | # Cuda and Gpuarray are imported here, instead of being imported on top of |
||
713 | # the file because that would force on the user some dependencies that we |
||
714 | # might do not want to. Currently we are working on removing the |
||
715 | # dependencies on sandbox code completeley. |
||
716 | from theano.sandbox import cuda, gpuarray |
||
717 | if cuda.cuda_available or gpuarray.pygpu_activated: |
||
718 | # very often we end up in this situation when we want to |
||
719 | # replace w with w_copy, where w is a GPU variable |
||
720 | # and w_copy is TensorType. This is caused because shared |
||
721 | # variables are put on GPU right aways >:| , |
||
722 | new_givens = OrderedDict() |
||
723 | |||
724 | for w, w_copy in iteritems(givens): |
||
725 | if ((isinstance(w.type, cuda.CudaNdarrayType) or |
||
726 | isinstance(w.type, gpuarray.GpuArrayType)) and |
||
727 | isinstance(w_copy.type, tensor.TensorType)): |
||
728 | for o in inner_outs: |
||
729 | new_givens = traverse(o, w, w_copy, new_givens) |
||
730 | else: |
||
731 | new_givens[w] = w_copy |
||
732 | else: |
||
733 | new_givens = givens |
||
734 | |||
735 | new_outs = scan_utils.clone(inner_outs, replace=new_givens) |
||
736 | |||
737 | ## |
||
738 | # Step 7. Create the Scan Op |
||
739 | ## |
||
740 | |||
741 | tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] |
||
742 | if allow_gc is None: |
||
743 | allow_gc = config.scan.allow_gc |
||
744 | info = OrderedDict() |
||
745 | |||
746 | info['tap_array'] = tap_array |
||
747 | info['n_seqs'] = n_seqs |
||
748 | info['n_mit_mot'] = n_mit_mot |
||
749 | info['n_mit_mot_outs'] = n_mit_mot_outs |
||
750 | info['mit_mot_out_slices'] = mit_mot_out_slices |
||
751 | info['n_mit_sot'] = n_mit_sot |
||
752 | info['n_sit_sot'] = n_sit_sot |
||
753 | info['n_shared_outs'] = n_shared_outs |
||
754 | info['n_nit_sot'] = n_nit_sot |
||
755 | info['truncate_gradient'] = truncate_gradient |
||
756 | info['name'] = name |
||
757 | info['mode'] = mode |
||
758 | info['destroy_map'] = OrderedDict() |
||
759 | info['gpu'] = False |
||
760 | info['as_while'] = as_while |
||
761 | info['profile'] = profile |
||
762 | info['allow_gc'] = allow_gc |
||
763 | info['strict'] = strict |
||
764 | |||
765 | local_op = scan_op.Scan(inner_inputs, new_outs, info) |
||
766 | |||
767 | ## |
||
768 | # Step 8. Compute the outputs using the scan op |
||
769 | ## |
||
770 | _scan_inputs = (scan_seqs + |
||
771 | mit_mot_scan_inputs + |
||
772 | mit_sot_scan_inputs + |
||
773 | sit_sot_scan_inputs + |
||
774 | shared_scan_inputs + |
||
775 | [actual_n_steps for x in xrange(n_nit_sot)] + |
||
776 | other_shared_scan_args + |
||
777 | other_scan_args) |
||
778 | |||
779 | scan_inputs = [] |
||
780 | for arg in [actual_n_steps] + _scan_inputs: |
||
781 | try: |
||
782 | arg = tensor.as_tensor_variable(arg) |
||
783 | except TypeError: |
||
784 | # This happens for Random States for e.g. but it is a good way |
||
785 | # to make sure no input is a cuda ndarrays |
||
786 | pass |
||
787 | scan_inputs += [arg] |
||
788 | scan_outs = local_op(*scan_inputs) |
||
789 | if type(scan_outs) not in (list, tuple): |
||
790 | scan_outs = [scan_outs] |
||
791 | ## |
||
792 | # Step 9. Figure out which outs are update rules for shared variables |
||
793 | # and so on ... |
||
794 | ## |
||
795 | |||
796 | update_map = OrderedUpdates() |
||
797 | |||
798 | def remove_dimensions(outs, steps_return, offsets=None): |
||
799 | out_ls = [] |
||
800 | for idx, out in enumerate(outs): |
||
801 | if idx in steps_return: |
||
802 | if steps_return[idx] > 1: |
||
803 | out_ls.append(out[-steps_return[idx]:]) |
||
804 | else: |
||
805 | out_ls.append(out[-1]) |
||
806 | else: |
||
807 | if offsets is None: |
||
808 | out_ls.append(out) |
||
809 | else: |
||
810 | out_ls.append(out[offsets[idx]:]) |
||
811 | return out_ls |
||
812 | |||
813 | offset = n_mit_mot |
||
814 | offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] |
||
815 | mit_sot_outs = remove_dimensions( |
||
816 | scan_outs[offset:offset + n_mit_sot], |
||
817 | mit_sot_return_steps, |
||
818 | offsets) |
||
819 | |||
820 | offset += n_mit_sot |
||
821 | offsets = [1 for x in xrange(n_sit_sot)] |
||
822 | sit_sot_outs = remove_dimensions( |
||
823 | scan_outs[offset:offset + n_sit_sot], |
||
824 | sit_sot_return_steps, |
||
825 | offsets) |
||
826 | |||
827 | offset += n_sit_sot |
||
828 | nit_sot_outs = remove_dimensions( |
||
829 | scan_outs[offset:offset + n_nit_sot], |
||
830 | nit_sot_return_steps) |
||
831 | |||
832 | offset += n_nit_sot |
||
833 | for idx, update_rule in enumerate( |
||
834 | scan_outs[offset:offset + n_shared_outs]): |
||
835 | update_map[shared_scan_inputs[idx]] = update_rule |
||
836 | |||
837 | _scan_out_list = (mit_sot_outs + |
||
838 | sit_sot_outs + |
||
839 | nit_sot_outs) |
||
840 | # Step 10. I need to reorder the outputs to be in the order expected by |
||
841 | # the user |
||
842 | rightOrder = (mit_sot_rightOrder + |
||
843 | sit_sot_rightOrder + |
||
844 | nit_sot_rightOrder) |
||
845 | scan_out_list = [None] * len(rightOrder) |
||
846 | for idx, pos in enumerate(rightOrder): |
||
847 | if pos >= 0: |
||
848 | scan_out_list[pos] = _scan_out_list[idx] |
||
849 | else: |
||
850 | # Not that pos is not a negative index. The sign of pos is used |
||
851 | # as a flag to indicate if this output should be part of the |
||
852 | # update rules or part of the standard outputs of scan. |
||
853 | # If `pos` is positive than it corresponds to the standard |
||
854 | # outputs of scan and it refers to output of index `pos`. If `pos` |
||
855 | # is negative that it corresponds to update rules of scan and it |
||
856 | # refers to update rule of index -1 - `pos`. |
||
857 | update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] |
||
858 | scan_out_list = [x for x in scan_out_list if x is not None] |
||
859 | ################################################################## P2< |
||
860 | return (scan_out_list, update_map) |