| Conditions | 81 |
| Total Lines | 409 |
| Lines | 0 |
| Ratio | 0 % |
| Changes | 1 | ||
| Bugs | 0 | Features | 0 |
Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.
For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.
Commonly applied refactorings include:
If many parameters/temporary variables are present:
Complex classes like finish_scan() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.
Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.
| 1 | #!/usr/bin/env python |
||
| 452 | def finish_scan(fn_outputs, local_vars): |
||
| 453 | |||
| 454 | n_fixed_steps = local_vars["n_fixed_steps"] |
||
| 455 | return_steps = local_vars["return_steps"] |
||
| 456 | non_seqs = local_vars["non_seqs"] |
||
| 457 | dummy_args = local_vars["dummy_args"] |
||
| 458 | args = local_vars["args"] |
||
| 459 | outs_info = local_vars["outs_info"] |
||
| 460 | n_outs = local_vars["n_outs"] |
||
| 461 | mit_sot_inner_outputs = local_vars["mit_sot_inner_outputs"] |
||
| 462 | sit_sot_inner_outputs = local_vars["sit_sot_inner_outputs"] |
||
| 463 | sit_sot_scan_inputs = local_vars["sit_sot_scan_inputs"] |
||
| 464 | sit_sot_inner_inputs = local_vars["sit_sot_inner_inputs"] |
||
| 465 | actual_n_steps = local_vars["actual_n_steps"] |
||
| 466 | sit_sot_rightOrder = local_vars["sit_sot_rightOrder"] |
||
| 467 | strict = local_vars["strict"] |
||
| 468 | non_sequences = local_vars["non_sequences"] |
||
| 469 | inner_seqs = local_vars["inner_seqs"] |
||
| 470 | mit_mot_inner_inputs = local_vars["mit_mot_inner_inputs"] |
||
| 471 | mit_sot_inner_inputs = local_vars["mit_sot_inner_inputs"] |
||
| 472 | mit_mot_inner_outputs = local_vars["mit_mot_inner_outputs"] |
||
| 473 | mit_sot_tap_array = local_vars["mit_sot_tap_array"] |
||
| 474 | allow_gc = local_vars["allow_gc"] |
||
| 475 | n_seqs = local_vars["n_seqs"] |
||
| 476 | n_mit_mot_outs = local_vars["n_mit_mot_outs"] |
||
| 477 | mit_mot_out_slices = local_vars["mit_mot_out_slices"] |
||
| 478 | truncate_gradient = local_vars["truncate_gradient"] |
||
| 479 | name = local_vars["name"] |
||
| 480 | mode = local_vars["mode"] |
||
| 481 | profile = local_vars["profile"] |
||
| 482 | scan_seqs = local_vars["scan_seqs"] |
||
| 483 | mit_mot_scan_inputs = local_vars["mit_mot_scan_inputs"] |
||
| 484 | mit_sot_scan_inputs = local_vars["mit_sot_scan_inputs"] |
||
| 485 | n_mit_mot = local_vars["n_mit_mot"] |
||
| 486 | mit_sot_return_steps = local_vars["mit_sot_return_steps"] |
||
| 487 | n_mit_sot = local_vars["n_mit_sot"] |
||
| 488 | sit_sot_return_steps = local_vars["sit_sot_return_steps"] |
||
| 489 | mit_sot_rightOrder = local_vars["mit_sot_rightOrder"] |
||
| 490 | |||
| 491 | condition, outputs, updates = scan_utils.get_updates_and_outputs(fn_outputs) |
||
| 492 | ################################################################## P2> |
||
| 493 | if condition is not None: |
||
| 494 | as_while = True |
||
| 495 | else: |
||
| 496 | as_while = False |
||
| 497 | ## |
||
| 498 | # Step 3. Check if we actually need scan and remove it if we don't |
||
| 499 | ## |
||
| 500 | |||
| 501 | if n_fixed_steps in [1, -1]: |
||
| 502 | # We do not need to use the scan op anymore, so we can just return |
||
| 503 | # the outputs and updates we have |
||
| 504 | if condition is not None: |
||
| 505 | _logger.warning(('When the number of steps is fixed and equal ' |
||
| 506 | 'to 1, the provided stopping condition, ', |
||
| 507 | str(condition), ' is ignored')) |
||
| 508 | |||
| 509 | for pos, inner_out in enumerate(outputs): |
||
| 510 | # we need to see if we need to pad our sequences with an |
||
| 511 | # unbroadcastable dimension; case example : we return an |
||
| 512 | # output for which we want all intermediate. If n_steps is 1 |
||
| 513 | # then, if we return the output as given by the innner function |
||
| 514 | # this will represent only a slice and it will have one |
||
| 515 | # dimension less. |
||
| 516 | if (isinstance(inner_out.type, tensor.TensorType) and |
||
| 517 | return_steps.get(pos, 0) != 1): |
||
| 518 | outputs[pos] = tensor.unbroadcast( |
||
| 519 | tensor.shape_padleft(inner_out), 0) |
||
| 520 | if len(outputs) == 1: |
||
| 521 | outputs = outputs[0] |
||
| 522 | |||
| 523 | return (outputs, updates) |
||
| 524 | |||
| 525 | ## |
||
| 526 | # Step 4. Compile the dummy function |
||
| 527 | ## |
||
| 528 | |||
| 529 | # We can now compile a dummy function just to see what shared variable |
||
| 530 | # we have and what are their update rules (note that the user has |
||
| 531 | # the option not to pass the shared variable to scan, so we need to |
||
| 532 | # pick them manually and add them to scan) |
||
| 533 | # make the compilation as fast as possible by not applying any |
||
| 534 | # optimization or conversion to C [ note this region is not important |
||
| 535 | # for performance so we can do stuff as unoptimal as we wish ] |
||
| 536 | |||
| 537 | # extract still missing inputs (there still might be so) and add them |
||
| 538 | # as non sequences at the end of our args |
||
| 539 | fake_nonseqs = [x.type() for x in non_seqs] |
||
| 540 | fake_outputs = scan_utils.clone(outputs, |
||
| 541 | replace=OrderedDict(izip(non_seqs, |
||
| 542 | fake_nonseqs))) |
||
| 543 | all_inputs = ifilter( |
||
| 544 | lambda x: (isinstance(x, gof.Variable) and |
||
| 545 | not isinstance(x, SharedVariable) and |
||
| 546 | not isinstance(x, gof.Constant)), |
||
| 547 | gof.graph.inputs(fake_outputs)) |
||
| 548 | extra_inputs = [x for x in all_inputs if x not in args + fake_nonseqs] |
||
| 549 | non_seqs += extra_inputs |
||
| 550 | # Note we do not use all_inputs directly since the order of variables |
||
| 551 | # in args is quite important |
||
| 552 | dummy_args += extra_inputs |
||
| 553 | |||
| 554 | dummy_outs = outputs |
||
| 555 | if condition is not None: |
||
| 556 | dummy_outs.append(condition) |
||
| 557 | dummy_f = function(dummy_args, |
||
| 558 | dummy_outs, |
||
| 559 | updates=updates, |
||
| 560 | mode=compile.mode.Mode(linker='py', |
||
| 561 | optimizer=None), |
||
| 562 | on_unused_input='ignore', |
||
| 563 | profile=False) |
||
| 564 | |||
| 565 | ## |
||
| 566 | # Step 5. Re-arange inputs of scan into a more strict order |
||
| 567 | ## |
||
| 568 | |||
| 569 | # Step 5.0 Check the outputs of the dummy function to see if they |
||
| 570 | # match with user provided data |
||
| 571 | |||
| 572 | # if the number of outputs to the function does not match the number of |
||
| 573 | # assumed outputs until now (provided by the user) there can be |
||
| 574 | # only one explanation: No information is provided for any of the |
||
| 575 | # outputs (i.e. we are dealing with a map) |
||
| 576 | tmp_dummy_f_outs = len(dummy_f.maker.outputs) |
||
| 577 | if as_while: |
||
| 578 | tmp_dummy_f_outs -= 1 |
||
| 579 | if not (tmp_dummy_f_outs == n_outs or outs_info == []): |
||
| 580 | raise ValueError('Please provide None as outputs_info for ' |
||
| 581 | 'any output that does not feed back into ' |
||
| 582 | 'scan (i.e. it behaves like a map) ') |
||
| 583 | |||
| 584 | if outs_info == []: |
||
| 585 | n_outs = len(dummy_f.maker.outputs) |
||
| 586 | if as_while: |
||
| 587 | n_outs = n_outs - 1 |
||
| 588 | outs_info = [OrderedDict() for x in xrange(n_outs)] |
||
| 589 | |||
| 590 | # Step 5.1 Outputs with taps different then -1 |
||
| 591 | |||
| 592 | for i, out in enumerate(outs_info): |
||
| 593 | if 'taps' in out and out['taps'] != [-1]: |
||
| 594 | mit_sot_inner_outputs.append(outputs[i]) |
||
| 595 | |||
| 596 | # Step 5.2 Outputs with tap equal to -1 |
||
| 597 | for i, out in enumerate(outs_info): |
||
| 598 | if 'taps' in out and out['taps'] == [-1]: |
||
| 599 | sit_sot_inner_outputs.append(outputs[i]) |
||
| 600 | |||
| 601 | # Step 5.3 Outputs that correspond to update rules of shared variables |
||
| 602 | givens = OrderedDict() |
||
| 603 | n_shared_outs = 0 |
||
| 604 | shared_scan_inputs = [] |
||
| 605 | shared_inner_inputs = [] |
||
| 606 | shared_inner_outputs = [] |
||
| 607 | sit_sot_shared = [] |
||
| 608 | for input in dummy_f.maker.expanded_inputs: |
||
| 609 | if isinstance(input.variable, SharedVariable) and input.update: |
||
| 610 | new_var = safe_new(input.variable) |
||
| 611 | if getattr(input.variable, 'name', None) is not None: |
||
| 612 | new_var.name = input.variable.name + '_copy' |
||
| 613 | if isinstance(new_var.type, ops.expandable_types): |
||
| 614 | sit_sot_inner_inputs.append(new_var) |
||
| 615 | sit_sot_scan_inputs.append( |
||
| 616 | scan_utils.expand_empty( |
||
| 617 | tensor.unbroadcast( |
||
| 618 | tensor.shape_padleft(input.variable), 0), |
||
| 619 | actual_n_steps)) |
||
| 620 | tensor_update = tensor.as_tensor_variable(input.update) |
||
| 621 | sit_sot_inner_outputs.append(tensor_update) |
||
| 622 | # Not that pos is not a negative index. The sign of pos is used |
||
| 623 | # as a flag to indicate if this output should be part of the |
||
| 624 | # update rules or part of the standard outputs of scan. |
||
| 625 | # If `pos` is positive than it corresponds to the standard |
||
| 626 | # outputs of scan and it refers to output of index `pos`. If `pos` |
||
| 627 | # is negative that it corresponds to update rules of scan and it |
||
| 628 | # refers to update rule of index -1 - `pos`. |
||
| 629 | sit_sot_rightOrder.append(-1 - len(sit_sot_shared)) |
||
| 630 | sit_sot_shared.append(input.variable) |
||
| 631 | givens[input.variable] = new_var |
||
| 632 | |||
| 633 | else: |
||
| 634 | shared_inner_inputs.append(new_var) |
||
| 635 | shared_scan_inputs.append(input.variable) |
||
| 636 | shared_inner_outputs.append(input.update) |
||
| 637 | givens[input.variable] = new_var |
||
| 638 | n_shared_outs += 1 |
||
| 639 | n_sit_sot = len(sit_sot_inner_inputs) |
||
| 640 | # Step 5.4 Outputs with no taps used in the input |
||
| 641 | n_nit_sot = 0 |
||
| 642 | nit_sot_inner_outputs = [] |
||
| 643 | nit_sot_return_steps = OrderedDict() |
||
| 644 | nit_sot_rightOrder = [] |
||
| 645 | for i, out in enumerate(outs_info): |
||
| 646 | if not 'taps' in out: |
||
| 647 | nit_sot_inner_outputs.append(outputs[i]) |
||
| 648 | if i in return_steps: |
||
| 649 | nit_sot_return_steps[n_nit_sot] = return_steps[i] |
||
| 650 | nit_sot_rightOrder.append(i) |
||
| 651 | n_nit_sot += 1 |
||
| 652 | |||
| 653 | # Step 5.5 all other arguments including extra inputs |
||
| 654 | other_scan_args = [] |
||
| 655 | other_inner_args = [] |
||
| 656 | |||
| 657 | other_scan_args += [arg for arg in non_seqs |
||
| 658 | if (not isinstance(arg, SharedVariable) and |
||
| 659 | not isinstance(arg, tensor.Constant))] |
||
| 660 | |||
| 661 | # Step 5.6 all shared variables with no update rules |
||
| 662 | other_inner_args += [safe_new(arg, '_copy') for arg in non_seqs |
||
| 663 | if (not isinstance(arg, SharedVariable) and |
||
| 664 | not isinstance(arg, tensor.Constant))] |
||
| 665 | |||
| 666 | givens.update(OrderedDict(izip(other_scan_args, other_inner_args))) |
||
| 667 | |||
| 668 | if strict: |
||
| 669 | non_seqs_set = set(non_sequences if non_sequences is not None else []) |
||
| 670 | |||
| 671 | other_shared_scan_args = [arg.variable for arg |
||
| 672 | in dummy_f.maker.expanded_inputs |
||
| 673 | if (isinstance(arg.variable, SharedVariable) and |
||
| 674 | not arg.update and |
||
| 675 | arg.variable in non_seqs_set)] |
||
| 676 | other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg |
||
| 677 | in dummy_f.maker.expanded_inputs |
||
| 678 | if (isinstance(arg.variable, SharedVariable) and |
||
| 679 | not arg.update and |
||
| 680 | arg.variable in non_seqs_set)] |
||
| 681 | else: |
||
| 682 | other_shared_scan_args = [arg.variable for arg |
||
| 683 | in dummy_f.maker.expanded_inputs |
||
| 684 | if (isinstance(arg.variable, SharedVariable) and |
||
| 685 | not arg.update)] |
||
| 686 | other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg |
||
| 687 | in dummy_f.maker.expanded_inputs |
||
| 688 | if (isinstance(arg.variable, SharedVariable) and |
||
| 689 | not arg.update)] |
||
| 690 | givens.update(OrderedDict(izip(other_shared_scan_args, |
||
| 691 | other_shared_inner_args))) |
||
| 692 | |||
| 693 | ## |
||
| 694 | # Step 6. Re-order the outputs and clone them replacing things |
||
| 695 | # using the givens |
||
| 696 | ## |
||
| 697 | inner_inputs = (inner_seqs + |
||
| 698 | mit_mot_inner_inputs + |
||
| 699 | mit_sot_inner_inputs + |
||
| 700 | sit_sot_inner_inputs + |
||
| 701 | shared_inner_inputs + |
||
| 702 | other_shared_inner_args + |
||
| 703 | other_inner_args) |
||
| 704 | |||
| 705 | inner_outs = (mit_mot_inner_outputs + |
||
| 706 | mit_sot_inner_outputs + |
||
| 707 | sit_sot_inner_outputs + |
||
| 708 | nit_sot_inner_outputs + |
||
| 709 | shared_inner_outputs) |
||
| 710 | if condition is not None: |
||
| 711 | inner_outs.append(condition) |
||
| 712 | # Cuda and Gpuarray are imported here, instead of being imported on top of |
||
| 713 | # the file because that would force on the user some dependencies that we |
||
| 714 | # might do not want to. Currently we are working on removing the |
||
| 715 | # dependencies on sandbox code completeley. |
||
| 716 | from theano.sandbox import cuda, gpuarray |
||
| 717 | if cuda.cuda_available or gpuarray.pygpu_activated: |
||
| 718 | # very often we end up in this situation when we want to |
||
| 719 | # replace w with w_copy, where w is a GPU variable |
||
| 720 | # and w_copy is TensorType. This is caused because shared |
||
| 721 | # variables are put on GPU right aways >:| , |
||
| 722 | new_givens = OrderedDict() |
||
| 723 | |||
| 724 | for w, w_copy in iteritems(givens): |
||
| 725 | if ((isinstance(w.type, cuda.CudaNdarrayType) or |
||
| 726 | isinstance(w.type, gpuarray.GpuArrayType)) and |
||
| 727 | isinstance(w_copy.type, tensor.TensorType)): |
||
| 728 | for o in inner_outs: |
||
| 729 | new_givens = traverse(o, w, w_copy, new_givens) |
||
| 730 | else: |
||
| 731 | new_givens[w] = w_copy |
||
| 732 | else: |
||
| 733 | new_givens = givens |
||
| 734 | |||
| 735 | new_outs = scan_utils.clone(inner_outs, replace=new_givens) |
||
| 736 | |||
| 737 | ## |
||
| 738 | # Step 7. Create the Scan Op |
||
| 739 | ## |
||
| 740 | |||
| 741 | tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)] |
||
| 742 | if allow_gc is None: |
||
| 743 | allow_gc = config.scan.allow_gc |
||
| 744 | info = OrderedDict() |
||
| 745 | |||
| 746 | info['tap_array'] = tap_array |
||
| 747 | info['n_seqs'] = n_seqs |
||
| 748 | info['n_mit_mot'] = n_mit_mot |
||
| 749 | info['n_mit_mot_outs'] = n_mit_mot_outs |
||
| 750 | info['mit_mot_out_slices'] = mit_mot_out_slices |
||
| 751 | info['n_mit_sot'] = n_mit_sot |
||
| 752 | info['n_sit_sot'] = n_sit_sot |
||
| 753 | info['n_shared_outs'] = n_shared_outs |
||
| 754 | info['n_nit_sot'] = n_nit_sot |
||
| 755 | info['truncate_gradient'] = truncate_gradient |
||
| 756 | info['name'] = name |
||
| 757 | info['mode'] = mode |
||
| 758 | info['destroy_map'] = OrderedDict() |
||
| 759 | info['gpu'] = False |
||
| 760 | info['as_while'] = as_while |
||
| 761 | info['profile'] = profile |
||
| 762 | info['allow_gc'] = allow_gc |
||
| 763 | info['strict'] = strict |
||
| 764 | |||
| 765 | local_op = scan_op.Scan(inner_inputs, new_outs, info) |
||
| 766 | |||
| 767 | ## |
||
| 768 | # Step 8. Compute the outputs using the scan op |
||
| 769 | ## |
||
| 770 | _scan_inputs = (scan_seqs + |
||
| 771 | mit_mot_scan_inputs + |
||
| 772 | mit_sot_scan_inputs + |
||
| 773 | sit_sot_scan_inputs + |
||
| 774 | shared_scan_inputs + |
||
| 775 | [actual_n_steps for x in xrange(n_nit_sot)] + |
||
| 776 | other_shared_scan_args + |
||
| 777 | other_scan_args) |
||
| 778 | |||
| 779 | scan_inputs = [] |
||
| 780 | for arg in [actual_n_steps] + _scan_inputs: |
||
| 781 | try: |
||
| 782 | arg = tensor.as_tensor_variable(arg) |
||
| 783 | except TypeError: |
||
| 784 | # This happens for Random States for e.g. but it is a good way |
||
| 785 | # to make sure no input is a cuda ndarrays |
||
| 786 | pass |
||
| 787 | scan_inputs += [arg] |
||
| 788 | scan_outs = local_op(*scan_inputs) |
||
| 789 | if type(scan_outs) not in (list, tuple): |
||
| 790 | scan_outs = [scan_outs] |
||
| 791 | ## |
||
| 792 | # Step 9. Figure out which outs are update rules for shared variables |
||
| 793 | # and so on ... |
||
| 794 | ## |
||
| 795 | |||
| 796 | update_map = OrderedUpdates() |
||
| 797 | |||
| 798 | def remove_dimensions(outs, steps_return, offsets=None): |
||
| 799 | out_ls = [] |
||
| 800 | for idx, out in enumerate(outs): |
||
| 801 | if idx in steps_return: |
||
| 802 | if steps_return[idx] > 1: |
||
| 803 | out_ls.append(out[-steps_return[idx]:]) |
||
| 804 | else: |
||
| 805 | out_ls.append(out[-1]) |
||
| 806 | else: |
||
| 807 | if offsets is None: |
||
| 808 | out_ls.append(out) |
||
| 809 | else: |
||
| 810 | out_ls.append(out[offsets[idx]:]) |
||
| 811 | return out_ls |
||
| 812 | |||
| 813 | offset = n_mit_mot |
||
| 814 | offsets = [abs(numpy.min(x)) for x in mit_sot_tap_array] |
||
| 815 | mit_sot_outs = remove_dimensions( |
||
| 816 | scan_outs[offset:offset + n_mit_sot], |
||
| 817 | mit_sot_return_steps, |
||
| 818 | offsets) |
||
| 819 | |||
| 820 | offset += n_mit_sot |
||
| 821 | offsets = [1 for x in xrange(n_sit_sot)] |
||
| 822 | sit_sot_outs = remove_dimensions( |
||
| 823 | scan_outs[offset:offset + n_sit_sot], |
||
| 824 | sit_sot_return_steps, |
||
| 825 | offsets) |
||
| 826 | |||
| 827 | offset += n_sit_sot |
||
| 828 | nit_sot_outs = remove_dimensions( |
||
| 829 | scan_outs[offset:offset + n_nit_sot], |
||
| 830 | nit_sot_return_steps) |
||
| 831 | |||
| 832 | offset += n_nit_sot |
||
| 833 | for idx, update_rule in enumerate( |
||
| 834 | scan_outs[offset:offset + n_shared_outs]): |
||
| 835 | update_map[shared_scan_inputs[idx]] = update_rule |
||
| 836 | |||
| 837 | _scan_out_list = (mit_sot_outs + |
||
| 838 | sit_sot_outs + |
||
| 839 | nit_sot_outs) |
||
| 840 | # Step 10. I need to reorder the outputs to be in the order expected by |
||
| 841 | # the user |
||
| 842 | rightOrder = (mit_sot_rightOrder + |
||
| 843 | sit_sot_rightOrder + |
||
| 844 | nit_sot_rightOrder) |
||
| 845 | scan_out_list = [None] * len(rightOrder) |
||
| 846 | for idx, pos in enumerate(rightOrder): |
||
| 847 | if pos >= 0: |
||
| 848 | scan_out_list[pos] = _scan_out_list[idx] |
||
| 849 | else: |
||
| 850 | # Not that pos is not a negative index. The sign of pos is used |
||
| 851 | # as a flag to indicate if this output should be part of the |
||
| 852 | # update rules or part of the standard outputs of scan. |
||
| 853 | # If `pos` is positive than it corresponds to the standard |
||
| 854 | # outputs of scan and it refers to output of index `pos`. If `pos` |
||
| 855 | # is negative that it corresponds to update rules of scan and it |
||
| 856 | # refers to update rule of index -1 - `pos`. |
||
| 857 | update_map[sit_sot_shared[abs(pos) - 1]] = _scan_out_list[idx][-1] |
||
| 858 | scan_out_list = [x for x in scan_out_list if x is not None] |
||
| 859 | ################################################################## P2< |
||
| 860 | return (scan_out_list, update_map) |