Skip to content

Commit c9cfa4c

Browse files
committed
completed arviz replacement
1 parent 6c2b714 commit c9cfa4c

File tree

1 file changed

+21
-6
lines changed

1 file changed

+21
-6
lines changed

R/dag_numpyro.R

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -333,11 +333,22 @@ from jax.numpy import (exp, log, log1p, expm1, abs, mean,
333333
drawsStatement <- "samples = mcmc.get_samples(group_by_chain=True) # dict: name -> [chains, draws, *axes]"
334334

335335
## Build labeled DataFrame; drop deterministic RVs; 1-based fallback indices
336+
## Column naming: <rv>_<label1>[_<label2>...]; labels: spaces→'.', only [A-Za-z0-9._], collapse repeated dots, no trailing '.'
336337
drawsDFStatement <- paste(
337-
"import numpy as np, pandas as pd",
338+
"import numpy as np, pandas as pd, string",
338339
"# Flatten (chains*draws) and expand RVs using rv_to_axes + axis_labels",
339340
"flat = {name: np.reshape(val, (-1,) + val.shape[2:]) for name, val in list(samples.items())}",
340341
"out = {}",
342+
"allowed = set(string.ascii_letters + string.digits + '._')",
343+
"def sanitize_label(s):",
344+
" s = str(s).strip().replace(' ', '.') # spaces -> dots",
345+
" s = ''.join((ch if ch in allowed else '.') for ch in s)",
346+
" # collapse repeated dots without regex",
347+
" while '..' in s:",
348+
" s = s.replace('..', '.')",
349+
" s = s.strip('.')",
350+
" return s if s else '1'",
351+
"",
341352
"for name, arr in list(flat.items()):",
342353
" # Drop deterministic RVs entirely",
343354
" if name in drop_names:",
@@ -352,24 +363,28 @@ from jax.numpy import (exp, log, log1p, expm1, abs, mean,
352363
" arr2 = arr.reshape(arr.shape[0], int(np.prod(trailing)))",
353364
" for j in range(arr2.shape[1]):",
354365
" idx = np.unravel_index(j, trailing)",
355-
" name_parts = []",
366+
" parts = []",
356367
" for axis, i in enumerate(idx):",
357368
" if axis < len(axes):",
358369
" axis_name = axes[axis]",
359370
" labels = axis_labels.get(axis_name)",
360371
" if labels is not None:",
361372
" labs = np.asarray(labels).astype(str)",
362-
" name_parts.append(labs[i] if i < labs.shape[0] else str(i + 1))", # 1-based fallback
373+
" lab = labs[i] if i < labs.shape[0] else str(i + 1)",
374+
" parts.append(sanitize_label(lab))",
363375
" else:",
364-
" name_parts.append(str(i + 1))", # 1-based fallback when mapped but unlabeled",
376+
" parts.append(str(i + 1))", # 1-based fallback when mapped but unlabeled",
365377
" else:",
366-
" name_parts.append(str(i + 1))", # 1-based fallback for unmapped extra axis",
367-
" col = f'{name}[{','.join(name_parts)}]'",
378+
" parts.append(str(i + 1))", # 1-based fallback for extra axis",
379+
" # No brackets; join labels with '_' so tibble won't append trailing dots",
380+
" col = name if not parts else name + '_' + '_'.join(parts)",
368381
" out[col] = arr2[:, j]",
369382
"drawsDF = pd.DataFrame(out)",
370383
sep = "\n"
371384
)
372385

386+
387+
373388
## Aggregate and execute
374389
codeStatements = c(
375390
importStatements,

0 commit comments

Comments
 (0)