Test Failed
Pull Request — master (#878)
by
unknown
04:38
created

Comparison.post_process()   A

Complexity

Conditions 3

Size

Total Lines 14
Code Lines 11

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
cc 3
eloc 11
nop 1
dl 0
loc 14
rs 9.85
c 0
b 0
f 0
1
# Copyright 2014 Diamond Light Source Ltd.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
15
"""
16
.. module:: comparison
17
   :platform: Unix
18
   :synopsis: A plugin to compare two datasets, given as input datasets, and print the RMSD between the two.
19
              The data is unchanged.
20
21
.. moduleauthor:: Jacob Williamson <[email protected]>
22
"""
23
24
from savu.plugins.utils import register_plugin
25
from savu.plugins.plugin import Plugin
26
from savu.plugins.driver.cpu_plugin import CpuPlugin
27
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \
28
    check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding
29
30
import numpy as np
31
32
# This decorator is required for the configurator to recognise the plugin
33
@register_plugin
34
class Comparison(Plugin, CpuPlugin):
35
36
    def __init__(self):
37
        super(Comparison, self).__init__("Comparison")
38
39
    def nInput_datasets(self):
40
        return 2
41
42
43
    def nOutput_datasets(self):
44
        if check_if_end_plugin_in_iterate_group(self.exp):
45
            return 3
46
        else:
47
            return 2
48
49
    def nClone_datasets(self):
50
        if check_if_end_plugin_in_iterate_group(self.exp):
51
            return 1
52
        else:
53
            return 0
54
55
56
    @enable_iterative_loop
57
    def setup(self):
58
        # This method is called after the number of in/out datasets associated
59
        # with the plugin has been established.  It tells the framework all
60
        # the information it needs to know about the data transport to-and-from
61
        # the plugin.
62
63
        # ================== Input and output datasets =========================
64
        # in_datasets and out_datasets are instances of the Data class.
65
        # in_datasets were either created in the loader or as output from
66
        # previous plugins.  out_datasets objects have already been created at
67
        # this point, but they are empty and need to be populated.
68
69
        # Get the Data instances associated with this plugin
70
        in_dataset, out_dataset = self.get_datasets()
71
72
        # see https://savu.readthedocs.io/en/latest/api/savu.data.data_structures.data_create/
73
        # for more information on creating datasets.
74
75
        # Populate the output dataset(s)
76
        out_dataset[0].create_dataset(in_dataset[0])
77
        out_dataset[1].create_dataset(in_dataset[1])
78
        self.rss_list = []
79
        self.flipped_rss_list = []
80
        self.data_points_list = []
81
        self.partial_cc_top = []
82
        self.partial_cc_bottom = ([], [])
83
        # ================== Input and output plugin datasets ==================
84
        # in_pData and out_pData are instances of the PluginData class.
85
        # All in_datasets and out_datasets above have an in/out_pData object
86
        # attached to them temporarily for the duration of the plugin,
87
        # giving access to additional plugin-specific dataset details. At this
88
        # point they have been created but not yet populated.
89
90
        # Get the PluginData instances attached to the Data instances above
91
        in_pData, out_pData = self.get_plugin_datasets()
92
93
        # Each plugin dataset must call this method and define the data access
94
        # pattern and number of frames required.
95
        for i in range(len(in_pData)):
96
            in_pData[i].plugin_data_setup(self.parameters['pattern'], 'single')
97
98
        # 'single', 'multiple' or an int (should only be used if essential)
99
        out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single')
100
        out_pData[1].plugin_data_setup(self.parameters['pattern'], 'single')
101
102
        # All dataset information can be accessed via the Data and PluginData
103
        # instances
104
105
106
    def pre_process(self):
107
        # This method is called once before any processing has begun.
108
        # Access parameters from the doc string in the parameters dictionary
109
        # e.g. self.parameters['example']
110
        in_datasets = self.get_in_datasets()
111
        self.names = [in_datasets[0].group_name, in_datasets[1].group_name]
112
        if not self.names[0]:
113
            self.names[0] = "dataset1"
114
        if not self.names[1]:
115
            self.names[1] = "dataset2"
116
117
        self.stats = [None, None]
118
        self.ranges = [None, None]
119
        try:
120
            self.stats[0] = self.stats_obj.get_stats_from_dataset(in_datasets[0])  # get stats dictionary
121
            self.ranges[0] = self.stats[0]["max"] - self.stats[0]["min"]
122
        except KeyError:
123
            print(f"Can't find stats metadata in {self.names[0]}, cannot do comparison")
124
        try:
125
            self.stats[1] = self.stats_obj.get_stats_from_dataset(in_datasets[1])
126
            self.ranges[1] = self.stats[1]["max"] - self.stats[1]["min"]
127
        except KeyError:
128
            print(f"Can't find stats metadata in {self.names[1]}, cannot do comparison")
129
130
    def process_frames(self, data):
131
        # This function is called in a loop by the framework until all the
132
        # data has been processed.
133
134
        # Each iteration of the loop will receive a list of numpy arrays
135
        # (data) containing nInput_datasets with the data sliced as requested
136
        # in the setup method (SINOGRAM in this case).  If 'multiple' or an
137
        # integer number of max_frames are requested the array with have an
138
        # extra dimension.
139
140
        # This plugin has one output dataset, so a single numpy array (a
141
        # SINOGRAM in this case) should be returned to the framework.
142
        if data[0].shape == data[1].shape:
143
            if self.stats[0] is not None and self.stats[1] is not None:
144
                scaled_data = [self._scale_data(data[0], self.stats[0]["min"], self.ranges[0]),
145
                               self._scale_data(data[1], self.stats[1]["min"], self.ranges[1])]
146
                self.rss_list.append(self.stats_obj.calc_rss(scaled_data[0], scaled_data[1]))
147
                self.data_points_list.append(data[0].size)
148
                flipped_data = 1 - scaled_data[0]
149
                self.flipped_rss_list.append(self.stats_obj.calc_rss(flipped_data, scaled_data[1]))
150
151
                self.partial_cc_top.append(np.sum((data[0] - self.stats[0]["mean"]) * (data[1] - self.stats[1]["mean"])))
152
                self.partial_cc_bottom[0].append(np.sum((data[0] - self.stats[0]["mean"]) ** 2))
153
                self.partial_cc_bottom[1].append(np.sum((data[1] - self.stats[1]["mean"]) ** 2))
154
155
        else:
156
            print("Arrays different sizes, can't calculated residuals.")
157
        return [data[0], data[1]]
158
159
    def _scale_data(self, data, vol_min, vol_range, new_min=0, new_range=1):  # scale data slice to be between 0 and 1
160
        data = data - vol_min
161
        data = data * (new_range/vol_range)
162
        data = data + new_min
163
        return data
164
165
    def post_process(self):
166
        if self.stats[0] is not None and self.stats[1] is not None:
167
            total_rss = sum(self.rss_list)
168
            total_data = sum(self.data_points_list)
169
            RMSD = self.stats_obj.rmsd_from_rss(total_rss, total_data)
170
            print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {RMSD}")
171
172
            total_flipped_rss = sum(self.flipped_rss_list)
173
            FRMSD = self.stats_obj.rmsd_from_rss(total_flipped_rss, total_data)
174
            print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {FRMSD}, \
175
                  when the contrast is flipped")
176
177
            PCC = np.sum(self.partial_cc_top) / (np.sqrt(np.sum(self.partial_cc_bottom[0]) * np.sum(self.partial_cc_bottom[1])))
178
            print(f"Pearson correlation coefficient between {self.names[0]} and {self.names[1]} is {PCC}")
179