Skip to content

Commit 52fbfc9

Browse files
Fixed bug when conditions are bigger than latent dim
1 parent c5ec7a9 commit 52fbfc9

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

scarches/models/trvae/modules.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def forward(self, x: torch.Tensor):
2222
if self.n_cond == 0:
2323
out = self.expr_L(x)
2424
else:
25-
expr, cond = torch.split(x, x.shape[1] - self.n_cond, dim=1)
25+
expr, cond = torch.split(x, [x.shape[1] - self.n_cond, self.n_cond], dim=1)
2626
out = self.expr_L(expr) + self.cond_L(cond)
2727
return out
2828

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import scanpy as sc
2+
import os
3+
import scarches as sca
4+
from scarches.dataset.trvae.data_handling import remove_sparsity
5+
import numpy as np
6+
import time
7+
import matplotlib.pyplot as plt
8+
9+
n_epochs_vae = 500
10+
early_stopping_kwargs = {
11+
"early_stopping_metric": "val_unweighted_loss",
12+
"threshold": 0,
13+
"patience": 20,
14+
"reduce_lr": True,
15+
"lr_patience": 13,
16+
"lr_factor": 0.1,
17+
}
18+
batch_key = "study"
19+
cell_type_key = "cell_type"
20+
21+
adata_all = sc.read(os.path.expanduser(f'~/Documents/benchmarking_datasets/pancreas_normalized.h5ad'))
22+
adata = adata_all.raw.to_adata()
23+
adata = remove_sparsity(adata)
24+
adata_conditions = adata.obs[batch_key].tolist()
25+
26+
trvae = sca.models.TRVAE(
27+
adata=adata,
28+
condition_key=batch_key,
29+
conditions=adata_conditions,
30+
hidden_layer_sizes=[128,128],
31+
)
32+
33+
trvae.train(
34+
n_epochs=n_epochs_vae,
35+
alpha_epoch_anneal=200,
36+
early_stopping_kwargs=early_stopping_kwargs
37+
)

0 commit comments

Comments
 (0)