|
@@ 168-187 (lines=20) @@
|
| 165 |
|
top_n (:obj:`int`): Back annotate top n probabilities. |
| 166 |
|
""" |
| 167 |
|
# Verify split id first |
| 168 |
|
if split_id == -1: |
| 169 |
|
if type(split_name) is tuple: |
| 170 |
|
time_array = [] |
| 171 |
|
for each_split in split_name: |
| 172 |
|
each_split_id = self.info['split_sets'].index(each_split) |
| 173 |
|
if 0 < each_split_id < len(self.info['split_sets']): |
| 174 |
|
time_array += self.info['split_timearray'][each_split_id] |
| 175 |
|
else: |
| 176 |
|
if split_name in self.info['split_sets']: |
| 177 |
|
split_id = self.info['split_sets'].index(split_name) |
| 178 |
|
if 0 < split_id < len(self.info['split_sets']): |
| 179 |
|
time_array = self.info['split_timearray'][split_id] |
| 180 |
|
else: |
| 181 |
|
logger.error('Failed to find split set with name %s.' % split_name) |
| 182 |
|
return |
| 183 |
|
elif 0 < split_id < len(self.info['split_sets']): |
| 184 |
|
time_array = self.info['split_timearray'][split_id] |
| 185 |
|
else: |
| 186 |
|
logger.error('Split set index %d out of bound.' % split_id) |
| 187 |
|
return |
| 188 |
|
# Check length of prediction and time array |
| 189 |
|
if prediction_proba.shape[0] != len(time_array): |
| 190 |
|
logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' % |
|
@@ 125-144 (lines=20) @@
|
| 122 |
|
split_name (:obj:`str`): The name of the split set to be annotated (required if split_id is not specified). |
| 123 |
|
""" |
| 124 |
|
# Verify split id first |
| 125 |
|
if split_id == -1: |
| 126 |
|
if type(split_name) is tuple: |
| 127 |
|
time_array = [] |
| 128 |
|
for each_split in split_name: |
| 129 |
|
each_split_id = self.info['split_sets'].index(each_split) |
| 130 |
|
if 0 < each_split_id < len(self.info['split_sets']): |
| 131 |
|
time_array += self.info['split_timearray'][each_split_id] |
| 132 |
|
else: |
| 133 |
|
if split_name in self.info['split_sets']: |
| 134 |
|
split_id = self.info['split_sets'].index(split_name) |
| 135 |
|
if 0 < split_id < len(self.info['split_sets']): |
| 136 |
|
time_array = self.info['split_timearray'][split_id] |
| 137 |
|
else: |
| 138 |
|
logger.error('Failed to find split set with name %s.' % split_name) |
| 139 |
|
return |
| 140 |
|
elif 0 < split_id < len(self.info['split_sets']): |
| 141 |
|
time_array = self.info['split_timearray'][split_id] |
| 142 |
|
else: |
| 143 |
|
logger.error('Split set index %d out of bound.' % split_id) |
| 144 |
|
return |
| 145 |
|
# Check length of prediction and time array |
| 146 |
|
if prediction.shape[0] != len(time_array): |
| 147 |
|
logger.error('Prediction size miss-match. There are %d time points with only %d labels given.' % |