Test Failed
Pull Request — master (#875)
by Daniil
03:45
created

savu.plugins.alignment.projection_2d_alignment   A

Complexity

Total Complexity 7

Size/Duplication

Total Lines 78
Duplicated Lines 0 %

Importance

Changes 0
Metric Value
eloc 43
dl 0
loc 78
rs 10
c 0
b 0
f 0
wmc 7

7 Methods

Rating   Name   Duplication   Size   Complexity  
A Projection2dAlignment.process_frames() 0 8 1
A Projection2dAlignment.nInput_datasets() 0 2 1
A Projection2dAlignment.post_process() 0 10 1
A Projection2dAlignment.nOutput_datasets() 0 2 1
A Projection2dAlignment.setup() 0 14 1
A Projection2dAlignment.__init__() 0 2 1
A Projection2dAlignment.get_max_frames() 0 2 1
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: projection_2d_alignment
17
   :platform: Unix
18
   :synopsis: calculates horizontal-vertical shift vectors for fixing misaligned projection data
19
20
.. moduleauthor:: Daniil Kazantsev <[email protected]>
21
"""
22
23
from savu.plugins.plugin import Plugin
24
from savu.plugins.driver.cpu_plugin import CpuPlugin
25
from savu.plugins.utils import register_plugin
26
from skimage.registration import phase_cross_correlation
27
28
import numpy as np
29
30
@register_plugin
31
class Projection2dAlignment(Plugin, CpuPlugin):
32
    def __init__(self):
33
        super(Projection2dAlignment, self).__init__('Projection2dAlignment')
34
35
    def setup(self):
36
        in_dataset, out_dataset = self.get_datasets()
37
        in_pData, out_pData = self.get_plugin_datasets()
38
        in_pData[0].plugin_data_setup('PROJECTION', self.get_max_frames())
39
        in_pData[1].plugin_data_setup('PROJECTION', self.get_max_frames())
40
41
        # create a metadata for storing shift vectors
42
        slice_dirs = list(in_dataset[0].get_slice_dimensions())
43
        new_shape = (in_dataset[0].get_shape()[slice_dirs[0]], 2)
44
        out_dataset[0].create_dataset(shape=new_shape,
45
                                      axis_labels=['x.shifts', 'y.shifts'],
46
                                      remove=True)
47
        out_dataset[0].add_pattern("METADATA", core_dims=(1,), slice_dims=(0,))
48
        out_pData[0].plugin_data_setup('METADATA', self.get_max_frames())
49
50
    def process_frames(self, data):
51
        projection = data[0]  # extract a projection
52
        projection_align = data[1]  # extract a projection for alignment
53
54
        # perform alignment
55
        shift, error, diffphase = phase_cross_correlation(
56
                    projection, projection_align, upsample_factor=self.parameters['upsample_factor'])
57
        return shift
58
59
    def post_process(self):
60
        out_data = self.get_out_datasets()[0]
61
        shift_vector = out_data.data[:, :]  # get a shift vector
62
        shift_vector[:, [0, 1]] = shift_vector[:, [1, 0]]  # swap axis in shift vector
63
        # get previous projection shifts first from experimental metadata
64
        shift_vector_prev = self.exp.meta_data.dict['projection_shifts']
65
        shift_vector_prev += shift_vector
66
        self.exp.meta_data.set('projection_shifts', shift_vector_prev.copy())
67
        in_meta_data = self.get_in_meta_data()[0]
68
        in_meta_data.set('projection_shifts', shift_vector_prev.copy())
69
70
    def get_max_frames(self):
71
        return 'single'
72
73
    def nInput_datasets(self):
74
        return 2
75
76
    def nOutput_datasets(self):
77
        return 1
78