Skip to content

Commit 4b26dfa

Browse files
authored
Merge pull request #135 from sarthakpati/134_regression_bugfix
simple bugfix for #134
2 parents 66288bf + 28399f5 commit 4b26dfa

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

GANDLF/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
import torch.nn as nn
1212
import torchio
13+
from GANDLF.models.modelBase import get_final_layer
1314

1415

1516
def resample_image(
@@ -155,7 +156,7 @@ def reverse_one_hot(predmask_array, class_list):
155156
final_mask = 0
156157
special_cases_to_check = ["||"]
157158
special_case_detected = False
158-
max = 0
159+
max_current = 0
159160

160161
for _class in class_list:
161162
for case in special_cases_to_check:
@@ -166,8 +167,8 @@ def reverse_one_hot(predmask_array, class_list):
166167
case
167168
) # if present, then split the sub-class
168169
for i in class_split: # find the max for computation later on
169-
if int(i) > max:
170-
max = int(i)
170+
if int(i) > max_current:
171+
max_current = int(i)
171172

172173
if special_case_detected:
173174
start_idx = 0
@@ -183,7 +184,7 @@ def reverse_one_hot(predmask_array, class_list):
183184
predmask_array[0, :, :, :], dtype=int
184185
) # predmask_array[i,:,:,:].long()
185186
# temp_sum = torch.sum(output)
186-
# output_2 = (max - torch.sum(output)) % max
187+
# output_2 = (max_current - torch.sum(output)) % max_current
187188
# test_2 = 1
188189
else:
189190
for idx, _class in enumerate(class_list):
@@ -335,7 +336,7 @@ def populate_header_in_parameters(parameters, headers):
335336
if len(headers["predictionHeaders"]) > 0:
336337
parameters["model"]["num_classes"] = len(headers["predictionHeaders"])
337338
is_regression, _, _ = find_problem_type(
338-
parameters["headers"], parameters["model"]["final_layer"]
339+
parameters["headers"], get_final_layer(parameters["model"]["final_layer"])
339340
)
340341

341342
# if the problem type is classification/segmentation, ensure the number of classes are picked from the configuration

GANDLF/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/usr/bin/env python3
22
# -*- coding: UTF-8 -*-
3-
__version__ = "0.0.10-dev"
3+
__version__ = "0.0.10"

testing/config_classification.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ model:
1717
- 1
1818
- 2
1919
dimension: 2
20-
final_layer: None
20+
final_layer: softmax
2121
nested_training:
2222
testing: -2
2323
validation: -2

0 commit comments

Comments
 (0)