77import torch
88from huggingface_hub import hf_hub_download
99
10- # matplotlib.rcParams['mathtext.fontset'] = 'stix'
11- # matplotlib.rcParams['font.family'] = 'STIXGeneral'
10+ matplotlib .rcParams ['mathtext.fontset' ] = 'stix'
11+ matplotlib .rcParams ['font.family' ] = 'STIXGeneral'
1212matplotlib .rcParams ['font.size' ] = 7
1313import matplotlib .pyplot as plt
1414from Utility .utils import load_json_from_path
1515
16+
1617def compute_loss_for_approximated_embeddings (csv_path , iso_lookup , language_embeddings , weighted_avg = False , min_n_langs = 5 , max_n_langs = 30 , threshold_percentile = 95 , loss_fn = "MSE" ):
1718 df = pd .read_csv (csv_path , sep = "|" )
1819
@@ -23,7 +24,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
2324
2425 features_per_closest_lang = 2
2526 # for combined, df has up to 5 features (if containing individual distances) per closest lang + 1 target lang column
26- if "combined_dist_0" in df .columns :
27+ if "combined_dist_0" in df .columns :
2728 if "map_dist_0" in df .columns :
2829 features_per_closest_lang += 1
2930 if "asp_dist_0" in df .columns :
@@ -77,7 +78,7 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
7778 lang_emb = language_embeddings [iso_lookup [- 1 ][lang ]]
7879 avg_emb += lang_emb
7980 normalization_factor = len (langs )
80- avg_emb /= normalization_factor # normalize
81+ avg_emb /= normalization_factor # normalize
8182 current_loss = loss_fn (avg_emb , y ).item ()
8283 all_losses .append (current_loss )
8384
@@ -111,44 +112,42 @@ def compute_loss_for_approximated_embeddings(csv_path, iso_lookup, language_embe
111112 os .makedirs (OUT_DIR , exist_ok = True )
112113
113114 fig , ax = plt .subplots (figsize = (6 , 4 ))
114- plt .ylabel (f" { args .loss_fn } between Approximated and Real" )
115+ plt .ylabel (args .loss_fn )
115116 for i , csv_path in enumerate (csv_paths ):
116117 print (f"csv_path: { os .path .basename (csv_path )} " )
117118 for condition in weighted :
118- losses = compute_loss_for_approximated_embeddings (csv_path ,
119- iso_lookup ,
120- lang_embs ,
121- condition ,
122- min_n_langs = args .min_n_langs ,
123- max_n_langs = args .max_n_langs ,
124- threshold_percentile = args .threshold_percentile ,
125- loss_fn = args .loss_fn )
119+ losses = compute_loss_for_approximated_embeddings (csv_path ,
120+ iso_lookup ,
121+ lang_embs ,
122+ condition ,
123+ min_n_langs = args .min_n_langs ,
124+ max_n_langs = args .max_n_langs ,
125+ threshold_percentile = args .threshold_percentile ,
126+ loss_fn = args .loss_fn )
126127 print (f"weighted average: { condition } | mean loss: { np .mean (losses )} " )
127128 losses_of_multiple_datasets .append (losses )
128129
129130 bp_dict = ax .boxplot (losses_of_multiple_datasets ,
130- labels = [
131- "Random Neighbors" ,
132- "Nearest according \n to inverse ASPF" ,
133- "Nearest according \n to Map Distance" ,
134- "Nearest according \n to Tree Distance" ,
135- "Nearest according \n to Learned Distance" ,
136- "Actual Nearest\n (Oracle)" ,
137- ],
131+ labels = ["Random" ,
132+ "Inverse ASP" ,
133+ "Map Distance" ,
134+ "Tree Distance" ,
135+ "Learned Distance" ,
136+ "Oracle" ],
138137 patch_artist = True ,
139- boxprops = dict (facecolor = "lightblue" ,
138+ boxprops = dict (facecolor = "lightblue" ,
140139 ),
141- showfliers = False ,
140+ showfliers = False ,
142141 widths = 0.55
143- )
142+ )
144143 # major ticks every 0.1, minor ticks every 0.05, between 0.0 and 0.6
145144 major_ticks = np .arange (0 , 1.0 , 0.1 )
146145 minor_ticks = np .arange (0 , 1.0 , 0.05 )
147146 ax .set_yticks (major_ticks )
148147 ax .set_yticks (minor_ticks , minor = True )
149148 # horizontal grid lines for minor and major ticks
150149 ax .grid (which = 'both' , linestyle = '-' , color = 'lightgray' , linewidth = 0.3 , axis = 'y' )
151- plt .title (f"Using between { args .min_n_langs } and { args .max_n_langs } Nearest Neighbors to approximate an unseen Embedding" )
150+ # plt.title(f"Using between {args.min_n_langs} and {args.max_n_langs} Nearest Neighbors to approximate an unseen Embedding")
152151 plt .xticks (rotation = 45 )
153152 plt .tight_layout ()
154153 plt .show ()
0 commit comments