1010import torch
1111import torch .nn as nn
1212import torchio
13+ from GANDLF .models .modelBase import get_final_layer
1314
1415
1516def resample_image (
@@ -155,7 +156,7 @@ def reverse_one_hot(predmask_array, class_list):
155156 final_mask = 0
156157 special_cases_to_check = ["||" ]
157158 special_case_detected = False
158- max = 0
159+ max_current = 0
159160
160161 for _class in class_list :
161162 for case in special_cases_to_check :
@@ -166,8 +167,8 @@ def reverse_one_hot(predmask_array, class_list):
166167 case
167168 ) # if present, then split the sub-class
168169 for i in class_split : # find the max for computation later on
169- if int (i ) > max :
170- max = int (i )
170+ if int (i ) > max_current :
171+ max_current = int (i )
171172
172173 if special_case_detected :
173174 start_idx = 0
@@ -183,7 +184,7 @@ def reverse_one_hot(predmask_array, class_list):
183184 predmask_array [0 , :, :, :], dtype = int
184185 ) # predmask_array[i,:,:,:].long()
185186 # temp_sum = torch.sum(output)
186- # output_2 = (max - torch.sum(output)) % max
187+ # output_2 = (max_current - torch.sum(output)) % max_current
187188 # test_2 = 1
188189 else :
189190 for idx , _class in enumerate (class_list ):
@@ -335,7 +336,7 @@ def populate_header_in_parameters(parameters, headers):
335336 if len (headers ["predictionHeaders" ]) > 0 :
336337 parameters ["model" ]["num_classes" ] = len (headers ["predictionHeaders" ])
337338 is_regression , _ , _ = find_problem_type (
338- parameters ["headers" ], parameters ["model" ]["final_layer" ]
339+ parameters ["headers" ], get_final_layer ( parameters ["model" ]["final_layer" ])
339340 )
340341
341342 # if the problem type is classification/segmentation, ensure the number of classes are picked from the configuration
0 commit comments