@@ -333,11 +333,22 @@ from jax.numpy import (exp, log, log1p, expm1, abs, mean,
333
333
drawsStatement <- " samples = mcmc.get_samples(group_by_chain=True) # dict: name -> [chains, draws, *axes]"
334
334
335
335
# # Build labeled DataFrame; drop deterministic RVs; 1-based fallback indices
336
+ # # Column naming: <rv>_<label1>[_<label2>...]; labels: spaces→'.', only [A-Za-z0-9._], collapse repeated dots, no trailing '.'
336
337
drawsDFStatement <- paste(
337
- " import numpy as np, pandas as pd" ,
338
+ " import numpy as np, pandas as pd, string " ,
338
339
" # Flatten (chains*draws) and expand RVs using rv_to_axes + axis_labels" ,
339
340
" flat = {name: np.reshape(val, (-1,) + val.shape[2:]) for name, val in list(samples.items())}" ,
340
341
" out = {}" ,
342
+ " allowed = set(string.ascii_letters + string.digits + '._')" ,
343
+ " def sanitize_label(s):" ,
344
+ " s = str(s).strip().replace(' ', '.') # spaces -> dots" ,
345
+ " s = ''.join((ch if ch in allowed else '.') for ch in s)" ,
346
+ " # collapse repeated dots without regex" ,
347
+ " while '..' in s:" ,
348
+ " s = s.replace('..', '.')" ,
349
+ " s = s.strip('.')" ,
350
+ " return s if s else '1'" ,
351
+ " " ,
341
352
" for name, arr in list(flat.items()):" ,
342
353
" # Drop deterministic RVs entirely" ,
343
354
" if name in drop_names:" ,
@@ -352,24 +363,28 @@ from jax.numpy import (exp, log, log1p, expm1, abs, mean,
352
363
" arr2 = arr.reshape(arr.shape[0], int(np.prod(trailing)))" ,
353
364
" for j in range(arr2.shape[1]):" ,
354
365
" idx = np.unravel_index(j, trailing)" ,
355
- " name_parts = []" ,
366
+ " parts = []" ,
356
367
" for axis, i in enumerate(idx):" ,
357
368
" if axis < len(axes):" ,
358
369
" axis_name = axes[axis]" ,
359
370
" labels = axis_labels.get(axis_name)" ,
360
371
" if labels is not None:" ,
361
372
" labs = np.asarray(labels).astype(str)" ,
362
- " name_parts.append(labs[i] if i < labs.shape[0] else str(i + 1))" , # 1-based fallback
373
+ " lab = labs[i] if i < labs.shape[0] else str(i + 1)" ,
374
+ " parts.append(sanitize_label(lab))" ,
363
375
" else:" ,
364
- " name_parts .append(str(i + 1))" , # 1-based fallback when mapped but unlabeled",
376
+ " parts .append(str(i + 1))" , # 1-based fallback when mapped but unlabeled",
365
377
" else:" ,
366
- " name_parts.append(str(i + 1))" , # 1-based fallback for unmapped extra axis",
367
- " col = f'{name}[{','.join(name_parts)}]'" ,
378
+ " parts.append(str(i + 1))" , # 1-based fallback for extra axis",
379
+ " # No brackets; join labels with '_' so tibble won't append trailing dots" ,
380
+ " col = name if not parts else name + '_' + '_'.join(parts)" ,
368
381
" out[col] = arr2[:, j]" ,
369
382
" drawsDF = pd.DataFrame(out)" ,
370
383
sep = " \n "
371
384
)
372
385
386
+
387
+
373
388
# # Aggregate and execute
374
389
codeStatements = c(
375
390
importStatements ,
0 commit comments