1
|
|
|
# Copyright 2014 Diamond Light Source Ltd. |
2
|
|
|
# |
3
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
4
|
|
|
# you may not use this file except in compliance with the License. |
5
|
|
|
# You may obtain a copy of the License at |
6
|
|
|
# |
7
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0 |
8
|
|
|
# |
9
|
|
|
# Unless required by applicable law or agreed to in writing, software |
10
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS, |
11
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12
|
|
|
# See the License for the specific language governing permissions and |
13
|
|
|
# limitations under the License. |
14
|
|
|
|
15
|
|
|
""" |
16
|
|
|
.. module:: comparison |
17
|
|
|
:platform: Unix |
18
|
|
|
:synopsis: A plugin to compare two datasets, given as input datasets, and print the RMSD between the two. |
19
|
|
|
The data is unchanged. |
20
|
|
|
|
21
|
|
|
.. moduleauthor:: Jacob Williamson <[email protected]> |
22
|
|
|
""" |
23
|
|
|
|
24
|
|
|
from savu.plugins.utils import register_plugin |
25
|
|
|
from savu.plugins.plugin import Plugin |
26
|
|
|
from savu.plugins.driver.cpu_plugin import CpuPlugin |
27
|
|
|
from savu.core.iterate_plugin_group_utils import enable_iterative_loop, \ |
28
|
|
|
check_if_end_plugin_in_iterate_group, setup_extra_plugin_data_padding |
29
|
|
|
|
30
|
|
|
import numpy as np |
31
|
|
|
|
32
|
|
|
# This decorator is required for the configurator to recognise the plugin |
33
|
|
|
@register_plugin |
34
|
|
|
class Comparison(Plugin, CpuPlugin): |
35
|
|
|
|
36
|
|
|
def __init__(self): |
37
|
|
|
super(Comparison, self).__init__("Comparison") |
38
|
|
|
|
39
|
|
|
def nInput_datasets(self): |
40
|
|
|
return 2 |
41
|
|
|
|
42
|
|
|
|
43
|
|
|
def nOutput_datasets(self): |
44
|
|
|
if check_if_end_plugin_in_iterate_group(self.exp): |
45
|
|
|
return 3 |
46
|
|
|
else: |
47
|
|
|
return 2 |
48
|
|
|
|
49
|
|
|
def nClone_datasets(self): |
50
|
|
|
if check_if_end_plugin_in_iterate_group(self.exp): |
51
|
|
|
return 1 |
52
|
|
|
else: |
53
|
|
|
return 0 |
54
|
|
|
|
55
|
|
|
|
56
|
|
|
@enable_iterative_loop |
57
|
|
|
def setup(self): |
58
|
|
|
# This method is called after the number of in/out datasets associated |
59
|
|
|
# with the plugin has been established. It tells the framework all |
60
|
|
|
# the information it needs to know about the data transport to-and-from |
61
|
|
|
# the plugin. |
62
|
|
|
|
63
|
|
|
# ================== Input and output datasets ========================= |
64
|
|
|
# in_datasets and out_datasets are instances of the Data class. |
65
|
|
|
# in_datasets were either created in the loader or as output from |
66
|
|
|
# previous plugins. out_datasets objects have already been created at |
67
|
|
|
# this point, but they are empty and need to be populated. |
68
|
|
|
|
69
|
|
|
# Get the Data instances associated with this plugin |
70
|
|
|
in_dataset, out_dataset = self.get_datasets() |
71
|
|
|
|
72
|
|
|
# see https://savu.readthedocs.io/en/latest/api/savu.data.data_structures.data_create/ |
73
|
|
|
# for more information on creating datasets. |
74
|
|
|
|
75
|
|
|
# Populate the output dataset(s) |
76
|
|
|
out_dataset[0].create_dataset(in_dataset[0]) |
77
|
|
|
out_dataset[1].create_dataset(in_dataset[1]) |
78
|
|
|
self.rss_list = [] |
79
|
|
|
self.flipped_rss_list = [] |
80
|
|
|
self.data_points_list = [] |
81
|
|
|
self.partial_cc_top = [] |
82
|
|
|
self.partial_cc_bottom = ([], []) |
83
|
|
|
# ================== Input and output plugin datasets ================== |
84
|
|
|
# in_pData and out_pData are instances of the PluginData class. |
85
|
|
|
# All in_datasets and out_datasets above have an in/out_pData object |
86
|
|
|
# attached to them temporarily for the duration of the plugin, |
87
|
|
|
# giving access to additional plugin-specific dataset details. At this |
88
|
|
|
# point they have been created but not yet populated. |
89
|
|
|
|
90
|
|
|
# Get the PluginData instances attached to the Data instances above |
91
|
|
|
in_pData, out_pData = self.get_plugin_datasets() |
92
|
|
|
|
93
|
|
|
# Each plugin dataset must call this method and define the data access |
94
|
|
|
# pattern and number of frames required. |
95
|
|
|
for i in range(len(in_pData)): |
96
|
|
|
in_pData[i].plugin_data_setup(self.parameters['pattern'], 'single') |
97
|
|
|
|
98
|
|
|
# 'single', 'multiple' or an int (should only be used if essential) |
99
|
|
|
out_pData[0].plugin_data_setup(self.parameters['pattern'], 'single') |
100
|
|
|
out_pData[1].plugin_data_setup(self.parameters['pattern'], 'single') |
101
|
|
|
|
102
|
|
|
# All dataset information can be accessed via the Data and PluginData |
103
|
|
|
# instances |
104
|
|
|
|
105
|
|
|
|
106
|
|
|
def pre_process(self): |
107
|
|
|
# This method is called once before any processing has begun. |
108
|
|
|
# Access parameters from the doc string in the parameters dictionary |
109
|
|
|
# e.g. self.parameters['example'] |
110
|
|
|
in_datasets = self.get_in_datasets() |
111
|
|
|
self.names = [in_datasets[0].group_name, in_datasets[1].group_name] |
112
|
|
|
if not self.names[0]: |
113
|
|
|
self.names[0] = "dataset1" |
114
|
|
|
if not self.names[1]: |
115
|
|
|
self.names[1] = "dataset2" |
116
|
|
|
|
117
|
|
|
self.stats = [None, None] |
118
|
|
|
self.ranges = [None, None] |
119
|
|
|
try: |
120
|
|
|
self.stats[0] = self.stats_obj.get_stats_from_dataset(in_datasets[0]) # get stats dictionary |
121
|
|
|
self.ranges[0] = self.stats[0]["max"] - self.stats[0]["min"] |
122
|
|
|
except KeyError: |
123
|
|
|
print(f"Can't find stats metadata in {self.names[0]}, cannot do comparison") |
124
|
|
|
try: |
125
|
|
|
self.stats[1] = self.stats_obj.get_stats_from_dataset(in_datasets[1]) |
126
|
|
|
self.ranges[1] = self.stats[1]["max"] - self.stats[1]["min"] |
127
|
|
|
except KeyError: |
128
|
|
|
print(f"Can't find stats metadata in {self.names[1]}, cannot do comparison") |
129
|
|
|
|
130
|
|
|
def process_frames(self, data): |
131
|
|
|
# This function is called in a loop by the framework until all the |
132
|
|
|
# data has been processed. |
133
|
|
|
|
134
|
|
|
# Each iteration of the loop will receive a list of numpy arrays |
135
|
|
|
# (data) containing nInput_datasets with the data sliced as requested |
136
|
|
|
# in the setup method (SINOGRAM in this case). If 'multiple' or an |
137
|
|
|
# integer number of max_frames are requested the array with have an |
138
|
|
|
# extra dimension. |
139
|
|
|
|
140
|
|
|
# This plugin has one output dataset, so a single numpy array (a |
141
|
|
|
# SINOGRAM in this case) should be returned to the framework. |
142
|
|
|
if data[0].shape == data[1].shape: |
143
|
|
|
if self.stats[0] is not None and self.stats[1] is not None: |
144
|
|
|
scaled_data = [self._scale_data(data[0], self.stats[0]["min"], self.ranges[0]), |
145
|
|
|
self._scale_data(data[1], self.stats[1]["min"], self.ranges[1])] |
146
|
|
|
self.rss_list.append(self.stats_obj.calc_rss(scaled_data[0], scaled_data[1])) |
147
|
|
|
self.data_points_list.append(data[0].size) |
148
|
|
|
flipped_data = 1 - scaled_data[0] |
149
|
|
|
self.flipped_rss_list.append(self.stats_obj.calc_rss(flipped_data, scaled_data[1])) |
150
|
|
|
|
151
|
|
|
self.partial_cc_top.append(np.sum((data[0] - self.stats[0]["mean"]) * (data[1] - self.stats[1]["mean"]))) |
152
|
|
|
self.partial_cc_bottom[0].append(np.sum((data[0] - self.stats[0]["mean"]) ** 2)) |
153
|
|
|
self.partial_cc_bottom[1].append(np.sum((data[1] - self.stats[1]["mean"]) ** 2)) |
154
|
|
|
|
155
|
|
|
else: |
156
|
|
|
print("Arrays different sizes, can't calculated residuals.") |
157
|
|
|
return [data[0], data[1]] |
158
|
|
|
|
159
|
|
|
def _scale_data(self, data, vol_min, vol_range, new_min=0, new_range=1): # scale data slice to be between 0 and 1 |
160
|
|
|
data = data - vol_min |
161
|
|
|
data = data * (new_range/vol_range) |
162
|
|
|
data = data + new_min |
163
|
|
|
return data |
164
|
|
|
|
165
|
|
|
def post_process(self): |
166
|
|
|
if self.stats[0] is not None and self.stats[1] is not None: |
167
|
|
|
total_rss = sum(self.rss_list) |
168
|
|
|
total_data = sum(self.data_points_list) |
169
|
|
|
RMSD = self.stats_obj.rmsd_from_rss(total_rss, total_data) |
170
|
|
|
print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {RMSD}") |
171
|
|
|
|
172
|
|
|
total_flipped_rss = sum(self.flipped_rss_list) |
173
|
|
|
FRMSD = self.stats_obj.rmsd_from_rss(total_flipped_rss, total_data) |
174
|
|
|
print(f"Normalised root mean square deviation between {self.names[0]} and {self.names[1]} is {FRMSD}, \ |
175
|
|
|
when the contrast is flipped") |
176
|
|
|
|
177
|
|
|
PCC = np.sum(self.partial_cc_top) / (np.sqrt(np.sum(self.partial_cc_bottom[0]) * np.sum(self.partial_cc_bottom[1]))) |
178
|
|
|
print(f"Pearson correlation coefficient between {self.names[0]} and {self.names[1]} is {PCC}") |
179
|
|
|
|