Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions pcpostprocess/detect_ramp_bounds.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parameter name was a bit confusing here

Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
import numpy as np


def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
def detect_ramp_bounds(times, voltage_sections, ramp_index=0):
"""
Extract the the times at the start and end of the nth ramp in the protocol.
Extract the timepoint indices at the start and end of the nth ramp in the protocol.

@param times: np.array containing the time at which each sample was taken
@param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
@param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp
@param ramp_index: the index of the ramp to select. Defaults to 0 - the first ramp

@returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
@returns istart, iend: the start and end timepoint indices for the specified ramp
"""

ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend
in voltage_sections if vstart != vend]
try:
ramp = ramps[ramp_no]
ramp = ramps[ramp_index]
except IndexError:
print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"
print(f"Requested {ramp_index+1}th ramp (ramp_index={ramp_index}),"
" but there are only {len(ramps)} ramps")

tstart, tend = ramp[:2]
Expand Down
5 changes: 1 addition & 4 deletions pcpostprocess/infer_reversal.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These lines were causing infer reversal potential to fail for every single trace

Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ def infer_reversal_potential(current, times, voltage_segments, voltages,
"""

# Get ramp bounds, assuming final ramp is the reversal ramp
tstart, tend = detect_ramp_bounds(times, voltage_segments, -1)

istart = np.argmax(times > tstart)
iend = np.argmax(times > tend)
istart, iend = detect_ramp_bounds(times, voltage_segments, -1)

current = current[istart:iend]
voltages = voltages[istart:iend]
Expand Down
62 changes: 62 additions & 0 deletions tests/test_infer_reversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
import os
import unittest

from syncropatch_export.trace import Trace

from pcpostprocess import infer_reversal, leak_correct
from pcpostprocess.detect_ramp_bounds import detect_ramp_bounds


class TestInferReversal(unittest.TestCase):
def setUp(self):
test_data_dir = os.path.join('tests', 'test_data', '13112023_MW2_FF',
"staircaseramp (2)_2kHz_15.01.07")
json_file = "staircaseramp (2)_2kHz_15.01.07.json"

self.output_dir = os.path.join('test_output', self.__class__.__name__)

if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)

self.test_trace = Trace(test_data_dir, json_file)

# get currents and QC from trace object
self.currents = self.test_trace.get_all_traces(leakcorrect=False)
self.currents['times'] = self.test_trace.get_times()
self.currents['voltages'] = self.test_trace.get_voltage()

self.protocol_desc = self.test_trace.get_voltage_protocol().get_all_sections()
self.leak_ramp_bound_indices = detect_ramp_bounds(self.currents['times'],
self.protocol_desc,
ramp_index=0)

self.voltages = self.test_trace.get_voltage()
self.correct_Erev = -89.57184330525791438049054704606533050537109375

def test_plot_leak_fit(self):
well = "A03"
sweep = 0

voltage = self.test_trace.get_voltage()
times = self.test_trace.get_times()

current = self.test_trace.get_trace_sweeps(sweeps=[sweep])[well][0, :]
params, Ileak = leak_correct.fit_linear_leak(current, voltage, times,
*self.leak_ramp_bound_indices,
output_dir=self.output_dir,
save_fname=f"{well}_sweep{sweep}_leak_correction")

I_corrected = current - Ileak

E_rev = infer_reversal.infer_reversal_potential(
I_corrected, times, self.protocol_desc,
self.voltages,
output_path=os.path.join(self.output_dir,
f"{well}_staircase"),
known_Erev=self.correct_Erev)
self.assertLess(abs(E_rev - self.correct_Erev), 1e-5)


if __name__ == "__main__":
pass
5 changes: 2 additions & 3 deletions tests/test_leak_correct.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def setUp(self):
"staircaseramp (2)_2kHz_15.01.07")
json_file = "staircaseramp (2)_2kHz_15.01.07.json"

self.output_dir = os.path.join("test_output",
self.__class__.__name__)
self.output_dir = os.path.join('test_output', self.__class__.__name__)

os.makedirs(self.output_dir, exist_ok=True)

Expand All @@ -31,7 +30,7 @@ def setUp(self):
# Find first times ahead of these times
voltage_protocol = self.test_trace.get_voltage_protocol().get_all_sections()
times = self.currents['times'].flatten()
self.ramp_bound_indices = detect_ramp_bounds(times, voltage_protocol, ramp_no=0)
self.ramp_bound_indices = detect_ramp_bounds(times, voltage_protocol, ramp_index=0)

def test_plot_leak_fit(self):
well = 'A01'
Expand Down