diff --git a/pcpostprocess/detect_ramp_bounds.py b/pcpostprocess/detect_ramp_bounds.py index a924ba99..72f81b5c 100644 --- a/pcpostprocess/detect_ramp_bounds.py +++ b/pcpostprocess/detect_ramp_bounds.py @@ -1,23 +1,23 @@ import numpy as np -def detect_ramp_bounds(times, voltage_sections, ramp_no=0): +def detect_ramp_bounds(times, voltage_sections, ramp_index=0): """ - Extract the the times at the start and end of the nth ramp in the protocol. + Extract the timepoint indices at the start and end of the nth ramp in the protocol. @param times: np.array containing the time at which each sample was taken @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end) - @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp + @param ramp_index: the index of the ramp to select. Defaults to 0 - the first ramp - @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp + @returns istart, iend: the start and end timepoint indices for the specified ramp """ ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend in voltage_sections if vstart != vend] try: - ramp = ramps[ramp_no] + ramp = ramps[ramp_index] except IndexError: - print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no})," + print(f"Requested {ramp_index+1}th ramp (ramp_index={ramp_index})," " but there are only {len(ramps)} ramps") tstart, tend = ramp[:2] diff --git a/pcpostprocess/infer_reversal.py b/pcpostprocess/infer_reversal.py index fd5b892b..26fcf285 100644 --- a/pcpostprocess/infer_reversal.py +++ b/pcpostprocess/infer_reversal.py @@ -33,10 +33,7 @@ def infer_reversal_potential(current, times, voltage_segments, voltages, """ # Get ramp bounds, assuming final ramp is the reversal ramp - tstart, tend = detect_ramp_bounds(times, voltage_segments, -1) - - istart = np.argmax(times > tstart) - iend = np.argmax(times > tend) + istart, iend = detect_ramp_bounds(times, voltage_segments, -1) current = current[istart:iend] voltages = voltages[istart:iend] diff --git a/tests/test_infer_reversal.py b/tests/test_infer_reversal.py new file mode 100755 index 00000000..75070e9d --- /dev/null +++ b/tests/test_infer_reversal.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +import os +import unittest + +from syncropatch_export.trace import Trace + +from pcpostprocess import infer_reversal, leak_correct +from pcpostprocess.detect_ramp_bounds import detect_ramp_bounds + + +class TestInferReversal(unittest.TestCase): + def setUp(self): + test_data_dir = os.path.join('tests', 'test_data', '13112023_MW2_FF', + "staircaseramp (2)_2kHz_15.01.07") + json_file = "staircaseramp (2)_2kHz_15.01.07.json" + + self.output_dir = os.path.join('test_output', self.__class__.__name__) + + if not os.path.exists(self.output_dir): + os.makedirs(self.output_dir) + + self.test_trace = Trace(test_data_dir, json_file) + + # get currents and QC from trace object + self.currents = self.test_trace.get_all_traces(leakcorrect=False) + self.currents['times'] = self.test_trace.get_times() + self.currents['voltages'] = self.test_trace.get_voltage() + + self.protocol_desc = self.test_trace.get_voltage_protocol().get_all_sections() + self.leak_ramp_bound_indices = detect_ramp_bounds(self.currents['times'], + self.protocol_desc, + ramp_index=0) + + self.voltages = self.test_trace.get_voltage() + self.correct_Erev = -89.57184330525791438049054704606533050537109375 + + def test_plot_leak_fit(self): + well = "A03" + sweep = 0 + + voltage = self.test_trace.get_voltage() + times = self.test_trace.get_times() + + current = self.test_trace.get_trace_sweeps(sweeps=[sweep])[well][0, :] + params, Ileak = leak_correct.fit_linear_leak(current, voltage, times, + *self.leak_ramp_bound_indices, + output_dir=self.output_dir, + save_fname=f"{well}_sweep{sweep}_leak_correction") + + I_corrected = current - Ileak + + E_rev = infer_reversal.infer_reversal_potential( + I_corrected, times, self.protocol_desc, + self.voltages, + output_path=os.path.join(self.output_dir, + f"{well}_staircase"), + known_Erev=self.correct_Erev) + self.assertLess(abs(E_rev - self.correct_Erev), 1e-5) + + +if __name__ == "__main__": + pass diff --git a/tests/test_leak_correct.py b/tests/test_leak_correct.py index a764b206..4b4ed399 100755 --- a/tests/test_leak_correct.py +++ b/tests/test_leak_correct.py @@ -14,8 +14,7 @@ def setUp(self): "staircaseramp (2)_2kHz_15.01.07") json_file = "staircaseramp (2)_2kHz_15.01.07.json" - self.output_dir = os.path.join("test_output", - self.__class__.__name__) + self.output_dir = os.path.join('test_output', self.__class__.__name__) os.makedirs(self.output_dir, exist_ok=True) @@ -31,7 +30,7 @@ def setUp(self): # Find first times ahead of these times voltage_protocol = self.test_trace.get_voltage_protocol().get_all_sections() times = self.currents['times'].flatten() - self.ramp_bound_indices = detect_ramp_bounds(times, voltage_protocol, ramp_no=0) + self.ramp_bound_indices = detect_ramp_bounds(times, voltage_protocol, ramp_index=0) def test_plot_leak_fit(self): well = 'A01'