deepreg.util.save_array()   F
last analyzed

Complexity

Conditions 16

Size

Total Lines 65
Code Lines 39

Duplication

Lines 0
Ratio 0 %

Importance

Changes 0
Metric Value
eloc 39
dl 0
loc 65
rs 2.4
c 0
b 0
f 0
cc 16
nop 7

How to fix   Long Method    Complexity   

Long Method

Small methods make your code easier to understand, in particular if combined with a good name. Besides, if your method is small, finding a good name is usually much easier.

For example, if you find yourself adding comments to a method's body, this is usually a good sign to extract the commented part to a new method, and use the comment as a starting point when coming up with a good name for this new method.

Commonly applied refactorings include:

Complexity

Complex classes like deepreg.util.save_array() often do a lot of different things. To break such a class down, we need to identify a cohesive component within that class. A common approach to find such a component is to look for fields/methods that share the same prefixes, or suffixes.

Once you have determined the fields that belong together, you can apply the Extract Class refactoring. If the component makes sense as a sub-class, Extract Subclass is also a candidate, and is often faster.

1
import os
2
from datetime import datetime
3
from typing import Optional, Tuple, Union
4
5
import matplotlib.pyplot as plt
6
import nibabel as nib
7
import numpy as np
8
import pandas as pd
9
import tensorflow as tf
10
11
import deepreg.loss.label as label_loss
12
from deepreg import log
13
from deepreg.dataset.load import get_data_loader
14
from deepreg.dataset.loader.interface import DataLoader
15
from deepreg.dataset.loader.util import normalize_array
16
17
logger = log.get(__name__)
18
19
20
def build_dataset(
21
    dataset_config: dict,
22
    preprocess_config: dict,
23
    split: str,
24
    training: bool,
25
    repeat: bool,
26
) -> Tuple[Optional[DataLoader], Optional[tf.data.Dataset], Optional[int]]:
27
    """
28
    Function to prepare dataset for training and validation.
29
    :param dataset_config: configuration for dataset
30
    :param preprocess_config: configuration for preprocess
31
    :param split: train or valid or test
32
    :param training: bool, if true, data augmentation and shuffling will be added
33
    :param repeat: bool, if true, dataset will be repeated,
34
        true for train/valid dataset during model.fit
35
36
    :return:
37
    - (data_loader_train, dataset_train, steps_per_epoch_train)
38
    - (data_loader_val, dataset_val, steps_per_epoch_valid)
39
40
    Cannot move this function into deepreg/dataset/util.py
41
    as we need DataLoader to define the output
42
    """
43
    assert split in ["train", "valid", "test"]
44
    data_loader = get_data_loader(dataset_config, split)
45
    if data_loader is None:
46
        return None, None, None
47
48
    dataset = data_loader.get_dataset_and_preprocess(
49
        training=training, repeat=repeat, **preprocess_config
50
    )
51
    dataset_size = data_loader.num_samples
52
    steps_per_epoch = max(dataset_size // preprocess_config["batch_size"], 1)
53
    return data_loader, dataset, steps_per_epoch
54
55
56
def build_log_dir(log_dir: str, exp_name: str) -> str:
57
    """
58
    Build a log directory for the experiment.
59
60
    :param log_dir: path of the log directory.
61
    :param exp_name: name of the experiment.
62
    :return: the path of directory to save logs.
63
    """
64
    log_dir = os.path.join(
65
        os.path.expanduser(log_dir),
66
        datetime.now().strftime("%Y%m%d-%H%M%S") if exp_name == "" else exp_name,
67
    )
68
    if os.path.exists(log_dir):
69
        logger.warning("Log directory %s exists already.", log_dir)
70
    else:
71
        os.makedirs(log_dir)
72
    return log_dir
73
74
75
def save_array(
76
    save_dir: str,
77
    arr: Union[np.ndarray, tf.Tensor],
78
    name: str,
79
    normalize: bool,
80
    save_nifti: bool = True,
81
    save_png: bool = True,
82
    overwrite: bool = True,
83
):
84
    """
85
    :param save_dir: path of the directory to save
86
    :param arr: 3D or 4D array to be saved
87
    :param name: name of the array, e.g. image, label, etc.
88
    :param normalize: true if the array's value has to be normalized when saving pngs,
89
        false means the value is between [0, 1].
90
    :param save_nifti: if true, array will be saved in nifti
91
    :param save_png: if true, array will be saved in png
92
    :param overwrite: if false, will not save the file in case the file exists
93
    """
94
    if isinstance(arr, tf.Tensor):
95
        arr = arr.numpy()
96
    if len(arr.shape) not in [3, 4]:
97
        raise ValueError(f"arr must be 3d or 4d numpy array or tf tensor, got {arr}")
98
    is_4d = len(arr.shape) == 4
99
    if is_4d:
100
        # if 4D array, it must be 3 channels
101
        if arr.shape[3] != 3:
102
            raise ValueError(
103
                f"4d arr must have 3 channels as last dimension, "
104
                f"got arr.shape = {arr.shape}"
105
            )
106
107
    # save in nifti format
108
    if save_nifti:
109
        nifti_file_path = os.path.join(save_dir, name + ".nii.gz")
110
        if overwrite or (not os.path.exists(nifti_file_path)):
111
            # save only if need to overwrite or doesn't exist
112
            os.makedirs(save_dir, exist_ok=True)
113
            # output with Nifti1Image can be loaded by
114
            # - https://www.slicer.org/
115
            # - http://www.itksnap.org/
116
            # - http://ric.uthscsa.edu/mango/
117
            # However, outputs with Nifti2Image couldn't be loaded
118
            nib.save(
119
                img=nib.Nifti1Image(arr, affine=np.eye(4)), filename=nifti_file_path
120
            )
121
122
    # save in png
123
    if save_png:
124
        png_dir = os.path.join(save_dir, name)
125
        dir_existed = os.path.exists(png_dir)
126
        if normalize:
127
            # normalize arr such that it has only values between 0, 1
128
            arr = normalize_array(arr=arr)
129
        for depth_index in range(arr.shape[2]):
130
            png_file_path = os.path.join(png_dir, f"depth{depth_index}_{name}.png")
131
            if overwrite or (not os.path.exists(png_file_path)):
132
                if not dir_existed:
133
                    os.makedirs(png_dir, exist_ok=True)
134
                plt.imsave(
135
                    fname=png_file_path,
136
                    arr=arr[:, :, depth_index, :] if is_4d else arr[:, :, depth_index],
137
                    vmin=0,
138
                    vmax=1,
139
                    cmap="PiYG" if is_4d else "gray",
140
                )
141
142
143
def calculate_metrics(
144
    fixed_image: tf.Tensor,
145
    fixed_label: Optional[tf.Tensor],
146
    pred_fixed_image: Optional[tf.Tensor],
147
    pred_fixed_label: Optional[tf.Tensor],
148
    fixed_grid_ref: tf.Tensor,
149
    sample_index: int,
150
) -> dict:
151
    """
152
    Calculate image/label based metrics.
153
    :param fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
154
    :param fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
155
    :param pred_fixed_image: shape=(batch, f_dim1, f_dim2, f_dim3)
156
    :param pred_fixed_label: shape=(batch, f_dim1, f_dim2, f_dim3) or None
157
    :param fixed_grid_ref: shape=(1, f_dim1, f_dim2, f_dim3, 3)
158
    :param sample_index: int,
159
    :return: dictionary of metrics
160
    """
161
162
    if pred_fixed_image is not None:
163
        y_true = fixed_image[sample_index : (sample_index + 1), :, :, :]
164
        y_pred = pred_fixed_image[sample_index : (sample_index + 1), :, :, :]
165
        y_true = tf.expand_dims(y_true, axis=4)
166
        y_pred = tf.expand_dims(y_pred, axis=4)
167
        ssd = label_loss.SumSquaredDifference()(y_true=y_true, y_pred=y_pred).numpy()
168
    else:
169
        ssd = None
170
171
    if fixed_label is not None and pred_fixed_label is not None:
172
        y_true = fixed_label[sample_index : (sample_index + 1), :, :, :]
173
        y_pred = pred_fixed_label[sample_index : (sample_index + 1), :, :, :]
174
        dice = label_loss.DiceScore(binary=True)(y_true=y_true, y_pred=y_pred).numpy()
175
        tre = label_loss.compute_centroid_distance(
176
            y_true=y_true, y_pred=y_pred, grid=fixed_grid_ref
177
        ).numpy()[0]
178
    else:
179
        dice = None
180
        tre = None
181
182
    return dict(image_ssd=ssd, label_binary_dice=dice, label_tre=tre)
183
184
185
def save_metric_dict(save_dir: str, metrics: list):
186
    """
187
    :param save_dir: directory to save outputs
188
    :param metrics: list of dicts, dict must have key pair_index and label_index
189
    """
190
    os.makedirs(name=save_dir, exist_ok=True)
191
192
    # build dataframe
193
    # column is pair_index, label_index, and metrics
194
    df = pd.DataFrame(metrics)
195
196
    # save overall dataframe
197
    df.to_csv(os.path.join(save_dir, "metrics.csv"), index=False)
198
199
    # calculate mean/median/std per label
200
    df_per_label = df.drop(["pair_index"], axis=1)
201
    df_per_label = df_per_label.fillna(value=np.nan)
202
    df_per_label = df_per_label.groupby(["label_index"])
203
    df_per_label = pd.concat(
204
        [
205
            df_per_label.mean().add_suffix("_mean"),
206
            df_per_label.median().add_suffix("_median"),
207
            df_per_label.std().add_suffix("_std"),
208
        ],
209
        axis=1,
210
        sort=True,
211
    )
212
    df_per_label = df_per_label.reindex(
213
        sorted(df_per_label.columns), axis=1
214
    )  # sort columns
215
    df_per_label.to_csv(
216
        os.path.join(save_dir, "metrics_stats_per_label.csv"), index=True
217
    )
218
219
    # calculate overall mean/median/std
220
    df_all = df.drop(["pair_index", "label_index"], axis=1)
221
    df_all.describe().to_csv(
222
        os.path.join(save_dir, "metrics_stats_overall.csv"), index=True
223
    )
224