Skip to content

Commit 3fce385

Browse files
new version
1 parent 2a8dfc0 commit 3fce385

File tree

8 files changed

+550
-2
lines changed

8 files changed

+550
-2
lines changed

build/lib/cplvm/models/cplvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def target_log_prob_fn(deltax, size_factor_x, size_factor_y, s, zx, zy):
182182
)
183183

184184
else:
185-
185+
# import ipdb; ipdb.set_trace()
186186
def target_log_prob_fn(size_factor_x, size_factor_y, s, zx, zy):
187187
return model.log_prob(
188188
(size_factor_x, size_factor_y, s, zx, zy, X, Y)

cplvm/models/cplvm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def target_log_prob_fn(deltax, size_factor_x, size_factor_y, s, zx, zy):
182182
)
183183

184184
else:
185-
185+
# import ipdb; ipdb.set_trace()
186186
def target_log_prob_fn(size_factor_x, size_factor_y, s, zx, zy):
187187
return model.log_prob(
188188
(size_factor_x, size_factor_y, s, zx, zy, X, Y)
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
library(splatter)
2+
library(scater)
3+
library(magrittr)
4+
library(ggplot2)
5+
set.seed(1)
6+
7+
8+
## ----groups-------------------------------------------------------------------
9+
sim.groups <- splatSimulate(batchCells = 200,
10+
nGenes = 500,
11+
group.prob = c(0.75, 0.25),
12+
method = "groups",
13+
verbose = F)
14+
15+
count_matrix <- sim.groups@assays@data$counts %>% t() %>% as.data.frame()
16+
17+
# Split into responsive and nonresponsive cells
18+
nonresponsive_cells <- count_matrix[which(sim.groups$Group == "Group1"),]
19+
responsive_cells <- count_matrix[which(sim.groups$Group == "Group2"),]
20+
21+
# Split nonresponsive cells into background and foreground
22+
n_nonresposive <- nrow(nonresponsive_cells)
23+
n_bg <- (n_nonresposive / 2) %>% round()
24+
bg_idx <- sample(seq(n_nonresposive), size = n_bg, replace = F)
25+
fg_nonresponsive_idx <- setdiff(seq(n_nonresposive), bg_idx)
26+
stopifnot(length(intersect(bg_idx, fg_nonresponsive_idx)) == 0)
27+
28+
bg_data <- nonresponsive_cells[bg_idx,]
29+
fg_nonresponsive_data <- nonresponsive_cells[fg_nonresponsive_idx,]
30+
fg_data <- rbind(fg_nonresponsive_data, responsive_cells)
31+
fg_labels <- c(rep(0, nrow(fg_nonresponsive_data)), rep(1, nrow(responsive_cells)))
32+
33+
# Save
34+
bg_data %>% write.csv("~/Documents/beehive/cplvm/data/splatter/two_clusters/bg.csv")
35+
fg_data %>% write.csv("~/Documents/beehive/cplvm/data/splatter/two_clusters/fg.csv")
36+
fg_labels %>% data.frame() %>%
37+
set_colnames("fg_label") %>%
38+
write.csv("~/Documents/beehive/cplvm/data/splatter/two_clusters/fg_labels.csv")
39+
40+

dist/cplvm-0.1-py3.7.egg

22 Bytes
Binary file not shown.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
from cplvm import CPLVM
2+
from cplvm import CPLVMLogNormalApprox
3+
4+
import functools
5+
import warnings
6+
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
import seaborn as sns
10+
import pandas as pd
11+
12+
import tensorflow.compat.v2 as tf
13+
import tensorflow_probability as tfp
14+
15+
from tensorflow_probability import distributions as tfd
16+
17+
import matplotlib
18+
import time
19+
20+
font = {"size": 30}
21+
matplotlib.rc("font", **font)
22+
matplotlib.rcParams["text.usetex"] = True
23+
24+
tf.enable_v2_behavior()
25+
26+
warnings.filterwarnings("ignore")
27+
28+
29+
if __name__ == "__main__":
30+
31+
n_genes_list = [10, 50, 100, 500, 1000]
32+
num_datapoints_x, num_datapoints_y = 200, 200
33+
NUM_REPEATS = 2
34+
latent_dim_shared, latent_dim_foreground = 3, 3
35+
36+
times = np.empty((NUM_REPEATS, len(n_genes_list)))
37+
38+
for ii, n_genes in enumerate(n_genes_list):
39+
40+
for jj in range(NUM_REPEATS):
41+
42+
# ------- generate data ---------
43+
44+
cplvm_for_data = CPLVM(
45+
k_shared=latent_dim_shared, k_foreground=latent_dim_foreground
46+
)
47+
48+
concrete_cplvm_model = functools.partial(
49+
cplvm_for_data.model,
50+
data_dim=n_genes,
51+
num_datapoints_x=num_datapoints_x,
52+
num_datapoints_y=num_datapoints_y,
53+
counts_per_cell_X=1,
54+
counts_per_cell_Y=1,
55+
is_H0=False,
56+
)
57+
58+
model = tfd.JointDistributionCoroutineAutoBatched(concrete_cplvm_model)
59+
deltax, sf_x, sf_y, s, zx, zy, w, ty, X_sampled, Y_sampled = model.sample()
60+
X, Y = X_sampled.numpy(), Y_sampled.numpy()
61+
62+
# ------- fit model ---------
63+
64+
t0 = time.time()
65+
66+
cplvm = CPLVM(k_shared=latent_dim_shared, k_foreground=latent_dim_foreground)
67+
approx_model = CPLVMLogNormalApprox(
68+
X, Y, latent_dim_shared, latent_dim_foreground
69+
)
70+
model_fit = cplvm._fit_model_vi(
71+
X, Y, approx_model, compute_size_factors=True, is_H0=False
72+
)
73+
74+
t1 = time.time()
75+
76+
curr_time = t1 - t0
77+
78+
times[jj, ii] = curr_time
79+
80+
81+
times_df = pd.DataFrame(times, columns=n_genes_list)
82+
times_df_melted = pd.melt(times_df)
83+
84+
plt.figure(figsize=(7, 7))
85+
sns.lineplot(data=times_df_melted, x="variable", y="value", ci=95, err_style="bars", color="black")
86+
plt.xlabel("Number of genes")
87+
plt.ylabel("Time (s)")
88+
plt.xscale('log')
89+
plt.tight_layout()
90+
plt.savefig("../out/time_performance_num_genes_cplvm.png")
91+
plt.show()
92+
import ipdb; ipdb.set_trace()
93+
94+
95+
96+
97+
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import matplotlib
2+
from cplvm import CPLVM
3+
from cplvm import CPLVMLogNormalApprox
4+
5+
import functools
6+
import warnings
7+
8+
import matplotlib.pyplot as plt
9+
import numpy as np
10+
import seaborn as sns
11+
import pandas as pd
12+
import os
13+
from scipy.stats import poisson
14+
from scipy.special import logsumexp
15+
16+
17+
import tensorflow.compat.v2 as tf
18+
import tensorflow_probability as tfp
19+
20+
from tensorflow_probability import distributions as tfd
21+
from tensorflow_probability import bijectors as tfb
22+
23+
import matplotlib
24+
import time
25+
26+
font = {"size": 30}
27+
matplotlib.rc("font", **font)
28+
matplotlib.rcParams["text.usetex"] = True
29+
30+
tf.enable_v2_behavior()
31+
32+
warnings.filterwarnings("ignore")
33+
34+
35+
if __name__ == "__main__":
36+
37+
n_samples_per_condition_list = [10, 100, 1000]
38+
n_genes = 200
39+
NUM_REPEATS = 2
40+
latent_dim_shared, latent_dim_foreground = 3, 3
41+
42+
times = np.empty((NUM_REPEATS, len(n_samples_per_condition_list)))
43+
44+
for ii, n_samples in enumerate(n_samples_per_condition_list):
45+
46+
for jj in range(NUM_REPEATS):
47+
48+
num_datapoints_x, num_datapoints_y = n_samples, n_samples
49+
50+
51+
52+
# ------- generate data ---------
53+
54+
cplvm_for_data = CPLVM(
55+
k_shared=latent_dim_shared, k_foreground=latent_dim_foreground
56+
)
57+
58+
concrete_cplvm_model = functools.partial(
59+
cplvm_for_data.model,
60+
data_dim=n_genes,
61+
num_datapoints_x=num_datapoints_x,
62+
num_datapoints_y=num_datapoints_y,
63+
counts_per_cell_X=1,
64+
counts_per_cell_Y=1,
65+
is_H0=False,
66+
)
67+
68+
model = tfd.JointDistributionCoroutineAutoBatched(concrete_cplvm_model)
69+
deltax, sf_x, sf_y, s, zx, zy, w, ty, X_sampled, Y_sampled = model.sample()
70+
X, Y = X_sampled.numpy(), Y_sampled.numpy()
71+
72+
# ------- fit model ---------
73+
74+
t0 = time.time()
75+
76+
cplvm = CPLVM(k_shared=latent_dim_shared, k_foreground=latent_dim_foreground)
77+
approx_model = CPLVMLogNormalApprox(
78+
X, Y, latent_dim_shared, latent_dim_foreground
79+
)
80+
model_fit = cplvm._fit_model_vi(
81+
X, Y, approx_model, compute_size_factors=True, is_H0=False
82+
)
83+
84+
t1 = time.time()
85+
86+
curr_time = t1 - t0
87+
88+
times[jj, ii] = curr_time
89+
90+
91+
times_df = pd.DataFrame(times, columns=n_samples_per_condition_list)
92+
times_df_melted = pd.melt(times_df)
93+
94+
plt.figure(figsize=(7, 7))
95+
sns.lineplot(data=times_df_melted, x="variable", y="value", ci=95, err_style="bars")
96+
plt.xlabel("Number of samples\nin each condition")
97+
plt.ylabel("Time (s)")
98+
plt.tight_layout()
99+
plt.savefig("../out/time_performance_num_samples_cplvm.png")
100+
plt.show()
101+
import ipdb; ipdb.set_trace()
102+
103+
104+
105+
106+

0 commit comments

Comments
 (0)