Skip to content

Commit 204206a

Browse files
author
naghipourfar
committed
fixed ZINB loss computation bug, get_latent, and model_path
1 parent 55d421a commit 204206a

File tree

3 files changed

+7
-9
lines changed

3 files changed

+7
-9
lines changed

scarches/models/_losses.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,6 @@ def loss(self, y_true, y_pred, mean=True):
110110
final = tf.divide(tf.reduce_sum(final), nelem)
111111
else:
112112
final = tf.reduce_mean(final)
113-
else:
114-
final = tf.reduce_sum(final)
115113

116114
return final
117115

scarches/models/cvae.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,10 @@ def __init__(self, x_dimension: int, conditions: list, task_name: str = "unknown
130130
}
131131

132132
self.training_kwargs = {
133-
"learning_rate": self.lr,
133+
"lr": self.lr,
134134
"alpha": self.alpha,
135135
"eta": self.eta,
136136
"clip_value": self.clip_value,
137-
"model_path": self.model_base_path,
138137
}
139138

140139
self.init_w = keras.initializers.glorot_normal()
@@ -169,11 +168,10 @@ def update_kwargs(self):
169168
}
170169

171170
self.training_kwargs = {
172-
"learning_rate": self.lr,
171+
"lr": self.lr,
173172
"alpha": self.alpha,
174173
"eta": self.eta,
175174
"clip_value": self.clip_value,
176-
"model_path": self.model_base_path,
177175
}
178176

179177
@classmethod
@@ -422,9 +420,9 @@ def get_z_latent(self, adata, encoder_labels, return_mean=False):
422420

423421
encoder_inputs = [adata.X, encoder_labels]
424422
if return_mean:
425-
latent = self.encoder_model.predict(encoder_inputs)[0]
423+
latent = self.encoder.predict(encoder_inputs)[0]
426424
else:
427-
latent = self.encoder_model.predict(encoder_inputs)[2]
425+
latent = self.encoder.predict(encoder_inputs)[2]
428426

429427
latent = np.nan_to_num(latent, nan=0.0, posinf=0.0, neginf=0.0)
430428

scarches/models/scarcheszinb.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def construct_network(self):
165165

166166
# Building the model via calling it with a random input
167167
input_arr = [tf.random.uniform((1, self.x_dim)), tf.ones((1, self.n_conditions)),
168-
tf.ones((1, self.n_conditions)), tf.ones(1, 1, dtype=tf.float32)]
168+
tf.ones((1, self.n_conditions)), tf.ones(1, dtype=tf.float32)]
169169
self(input_arr)
170170

171171
get_custom_objects().update(self.custom_objects)
@@ -185,6 +185,8 @@ def calc_losses(self, y_true, y_pred, z_mean, z_log_var, disp=None, pi=None):
185185

186186
def forward_with_loss(self, data):
187187
x, y = data
188+
189+
y = y['reconstruction']
188190
y_pred, z_mean, z_log_var, disp, pi = self.call(x)
189191
loss, recon_loss, kl_loss = self.calc_losses(y, y_pred, z_mean, z_log_var, disp, pi)
190192

0 commit comments

Comments
 (0)