NaNs when computing the gradient #742
Replies: 2 comments 12 replies
-
|
Dear Hannah, thanks for reaching out! A quick question upfront: I am quite surprised about the issue even for the single-compartment neuron. To me, this hints towards an instability in your channels: Did you implement them yourself, or did you use channels that are available in Finally, two things:
Michael |
Beta Was this translation helpful? Give feedback.
-
|
I had a look, and I think I found the issue. It is indeed in the sodium channel. TL;DR: Please modify the sodium channel as follows: class Na_Berger(Channel):
def __init__(self, name = None):
...
self.channel_states = {"h": 0.146, "m": 0.0}
...
def update_states(self, states, dt, v, params):
...
return {"h": new_h, "m": update_minf(v)}
def compute_current(self, states, v, params):
hs = states["h"]
ms = states["m"]
...
return currentThe reason that the channel broke is that the previous code updated the gating variable for Please let me know if this solves the issue for you! Appendix: Further info and reproducibilityFor posterity, here is the code I used to reproduce and solve the issue: from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")
import jax
import jax.numpy as jnp
import numpy as np
import jaxley as jx
from jaxley.channels import Leak
from jaxley.optimize.transforms import SigmoidTransform
from Berger import Shab, Na_Berger, Calcium
cell = jx.Cell()
cell.stimulate(jx.step_current(10.0, 80.0, 0.01, 0.025, 100.0))
cell.insert(Na_Berger())
cell.insert(Shab())
cell.insert(Calcium())
cell.insert(Leak())
cell.record("v")
def simulate(params):
pstate = cell.data_set("Na_Berger_gNa", params[0], None)
pstate = cell.data_set("Shab_gShab", params[1], pstate)
pstate = cell.data_set("Calcium_gCa", params[2], pstate)
pstate = cell.data_set("Leak_gLeak", params[3], pstate)
return jx.integrate(cell, param_state=pstate)
bounds = {}
bounds["Na_Berger_gNa"] = [0, 1.0]
bounds["Shab_gShab"] = [0, 1.0]
bounds["Calcium_gCa"] = [0, 0.1]
bounds["Leak_gLeak"] = [0, 0.002]
lower_bounds = jnp.asarray(list(bounds.values()))[:, 0]
upper_bounds = jnp.asarray(list(bounds.values()))[:, 1]
transform = SigmoidTransform(
lower=lower_bounds,
upper=upper_bounds,
)
def loss_fn(opt_params):
params = transform.forward(opt_params)
v = simulate(params)
return jnp.mean(v)
grad_fn = jax.jit(jax.value_and_grad(loss_fn))
def sample_randomly():
return jnp.asarray(np.random.rand(len(upper_bounds)) * (upper_bounds - lower_bounds) + lower_bounds)
for i in range(200):
_ = np.random.seed(i)
initial_params = sample_randomly()
opt_params = transform.inverse(initial_params)
l, gradient = grad_fn(opt_params)
if i % 100 == 0:
print("i", i)
assert ~np.any(np.isnan(gradient))For example, the loss function becomes NaN for: initial_params = jnp.asarray(
[0.03210001, 0.75575518, 0.0087802 , 0.00138649]
) |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Dear Jaxley Team,
I am trying to fit a multicompartment model with a somewhat complex morphology (~1200 branches) and 4 channels (Na, K, Calcium, Leak) to a synthetic voltage trace. I've been roughly following your tutorial on fitting the L5PC. Here is the voltage trace, I am trying to optimise. I'm using four different stimuli strengths as input and for each stimulus, I am computing the mean and standard deviation in each time window (shaded in blue in the figure below).
However, I am running into a lot of problems when trying to compute the gradient, where the jax.value_and_grad() function returns NaNs.
I had similar issues when trying to optimise a single compartment model but managed to solve the issue by reducing the bounds and adding more stimuli to the loss function. In the multicompartment model, I have so far not been successful with this approach.
When reading the Jaxley publication, I noticed that, different from the bioRxiv version, you are now using dynamic time warping to fit the Allen Cell models, rather than the mean, standard deviation approach. Why did you decide to change the loss function and would that be something that you would also recommend in my case?
Thanks for your help!
Beta Was this translation helpful? Give feedback.
All reactions