Skip to content

Commit 335a8ba

Browse files
committed
overhaul boxplots once more
1 parent 38ca97c commit 335a8ba

File tree

1 file changed

+24
-25
lines changed

1 file changed

+24
-25
lines changed

Preprocessing/multilinguality/eval_lang_emb_approximation.py

Lines changed: 24 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
import torch
88
from huggingface_hub import hf_hub_download
99

10-
# matplotlib.rcParams['mathtext.fontset'] = 'stix'
11-
# matplotlib.rcParams['font.family'] = 'STIXGeneral'
10+
matplotlib.rcParams['mathtext.fontset'] = 'stix'
11+
matplotlib.rcParams['font.family'] = 'STIXGeneral'
1212
matplotlib.rcParams['font.size'] = 7
1313
import matplotlib.pyplot as plt
1414
from Utility.utils import load_json_from_path
1515

16+
1617
def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embeddings, weighted_avg=False, min_n_langs=5, max_n_langs=30, threshold_percentile=95, loss_fn="MSE"):
1718
df = pd.read_csv(csv_path, sep="|")
1819

@@ -23,7 +24,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
2324

2425
features_per_closest_lang = 2
2526
# for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
26-
if "combined_dist_0" in df.columns:
27+
if "combined_dist_0" in df.columns:
2728
if "map_dist_0" in df.columns:
2829
features_per_closest_lang += 1
2930
if "asp_dist_0" in df.columns:
@@ -77,7 +78,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
7778
lang_emb = language_embeddings[iso_lookup[-1][lang]]
7879
avg_emb += lang_emb
7980
normalization_factor = len(langs)
80-
avg_emb /= normalization_factor # normalize
81+
avg_emb /= normalization_factor # normalize
8182
current_loss = loss_fn(avg_emb, y).item()
8283
all_losses.append(current_loss)
8384

@@ -111,44 +112,42 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
111112
os.makedirs(OUT_DIR, exist_ok=True)
112113

113114
fig, ax = plt.subplots(figsize=(6, 4))
114-
plt.ylabel(f"{args.loss_fn} between Approximated and Real")
115+
plt.ylabel(args.loss_fn)
115116
for i, csv_path in enumerate(csv_paths):
116117
print(f"csv_path: {os.path.basename(csv_path)}")
117118
for condition in weighted:
118-
losses = compute_loss_for_approximated_embeddings(csv_path,
119-
iso_lookup,
120-
lang_embs,
121-
condition,
122-
min_n_langs=args.min_n_langs,
123-
max_n_langs=args.max_n_langs,
124-
threshold_percentile=args.threshold_percentile,
125-
loss_fn=args.loss_fn)
119+
losses = compute_loss_for_approximated_embeddings(csv_path,
120+
iso_lookup,
121+
lang_embs,
122+
condition,
123+
min_n_langs=args.min_n_langs,
124+
max_n_langs=args.max_n_langs,
125+
threshold_percentile=args.threshold_percentile,
126+
loss_fn=args.loss_fn)
126127
print(f"weighted average: {condition} | mean loss: {np.mean(losses)}")
127128
losses_of_multiple_datasets.append(losses)
128129

129130
bp_dict = ax.boxplot(losses_of_multiple_datasets,
130-
labels =[
131-
"Random Neighbors",
132-
"Nearest according \nto inverse ASPF",
133-
"Nearest according \nto Map Distance",
134-
"Nearest according \nto Tree Distance",
135-
"Nearest according \nto Learned Distance",
136-
"Actual Nearest\n(Oracle)",
137-
],
131+
labels=["Random",
132+
"Inverse ASP",
133+
"Map Distance",
134+
"Tree Distance",
135+
"Learned Distance",
136+
"Oracle"],
138137
patch_artist=True,
139-
boxprops=dict(facecolor = "lightblue",
138+
boxprops=dict(facecolor="lightblue",
140139
),
141-
showfliers=False,
140+
showfliers=False,
142141
widths=0.55
143-
)
142+
)
144143
# major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6
145144
major_ticks = np.arange(0, 1.0, 0.1)
146145
minor_ticks = np.arange(0, 1.0, 0.05)
147146
ax.set_yticks(major_ticks)
148147
ax.set_yticks(minor_ticks, minor=True)
149148
# horizontal grid lines for minor and major ticks
150149
ax.grid(which='both', linestyle='-', color='lightgray', linewidth=0.3, axis='y')
151-
plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding")
150+
# plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding")
152151
plt.xticks(rotation=45)
153152
plt.tight_layout()
154153
plt.show()

0 commit comments

Comments
 (0)