Skip to content

Commit 3e35e5d

Browse files
committed
fix bug in numerical column index for benchmark plot
1 parent 88cb0c1 commit 3e35e5d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

domainlab/utils/generate_benchmark_plots.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def gen_plots(dataframe: pd.DataFrame, output_dir: str, use_param_index: bool):
9090
['param_index','task',' algo',' epos',' te_d',' seed',' params',' acc','precision',...]
9191
"""
9292
os.makedirs(output_dir, exist_ok=True)
93-
obj = dataframe.columns[G_DF_PLOT_COL_METRIC_START:G_DF_PLOT_COL_METRIC_END]
93+
pos_numeric_end = min(G_DF_PLOT_COL_METRIC_END, dataframe.shape[1])
94+
obj = dataframe.columns[G_DF_PLOT_COL_METRIC_START:pos_numeric_end]
9495
# boxplots
9596
for objective in obj:
9697
boxplot(
@@ -267,7 +268,8 @@ def scatterplot_matrix(
267268
but also between the parameter setups
268269
"""
269270
dataframe = dataframe_in.copy()
270-
index = list(range(G_DF_PLOT_COL_METRIC_START, G_DF_PLOT_COL_METRIC_END))
271+
pos_numeric_end = min(G_DF_PLOT_COL_METRIC_END, dataframe.shape[1])
272+
index = list(range(G_DF_PLOT_COL_METRIC_START, pos_numeric_end))
271273
if distinguish_param_setups:
272274
dataframe_ = dataframe.iloc[:, index]
273275
dataframe_.insert(
@@ -280,7 +282,8 @@ def scatterplot_matrix(
280282

281283
g_p = sns.pairplot(data=dataframe_, hue="label", corner=True, kind=kind)
282284
else:
283-
index_ = list(range(G_DF_PLOT_COL_METRIC_START, G_DF_PLOT_COL_METRIC_END))
285+
pos_numeric_end = min(G_DF_PLOT_COL_METRIC_END, dataframe.shape[1])
286+
index_ = list(range(G_DF_PLOT_COL_METRIC_START, pos_numeric_end))
284287
index_.insert(0, G_DF_TASK_COL)
285288
dataframe_ = dataframe.iloc[:, index_]
286289

@@ -417,7 +420,8 @@ def radar_plot(dataframe_in, file=None, distinguish_hyperparam=True):
417420
else:
418421
dataframe.insert(0, "label", dataframe[COLNAME_METHOD])
419422
# we need "G_DF_PLOT_COL_METRIC_START + 1" as we did insert the columns 'label' at index 0
420-
index = list(range(G_DF_PLOT_COL_METRIC_START + 1, G_DF_PLOT_COL_METRIC_END))
423+
pos_numeric_end = min(G_DF_PLOT_COL_METRIC_END, dataframe.shape[1])
424+
index = list(range(G_DF_PLOT_COL_METRIC_START + 1, pos_numeric_end))
421425
num_lines = len(dataframe["label"].unique())
422426
_, axis = plt.subplots(
423427
figsize=(9, 9 + (0.28 * num_lines)), subplot_kw=dict(polar=True)

0 commit comments

Comments
 (0)