Skip to content

Extract mlapdv #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: nateTaskChanges
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 59 additions & 33 deletions iblrig_custom_tasks/nate_optoBiasedChoiceWorld/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
Additionally the state machine is modified to add output TTLs for optogenetic stimulation
"""
import logging
import time
import sys
from argparse import ArgumentTypeError
from pathlib import Path
from typing import Literal
import warnings

import numpy as np
import yaml
Expand All @@ -20,14 +22,18 @@
from importlib import reload
import random

import sys
sys.path.append('C:\zapit-tcp-bridge\python')
import Python_TCP_Utils as ptu
from TCPclient import TCPclient
ZAPIT_PYTHON = r'C:\zapit-tcp-bridge\python'

num_cond = 52 #will need to change later - is there a function to automatically detect this?>
try:
assert Path(ZAPIT_PYTHON).exists()
sys.path.append(ZAPIT_PYTHON)
import Python_TCP_Utils as ptu
from TCPclient import TCPclient
except (AssertionError, ModuleNotFoundError):
warnings.warn(
'Please clone https://github.com/Zapit-Optostim/zapit-tcp-bridge to '
f'{Path(ZAPIT_PYTHON).parents[1]}', RuntimeWarning)

stim_location_history = []

log = logging.getLogger('iblrig.task')

Expand Down Expand Up @@ -74,9 +80,6 @@ class Session(BiasedChoiceWorldSession):
protocol_name = 'nate_optoBiasedChoiceWorld'
extractor_tasks = ['TrialRegisterRaw', 'ChoiceWorldTrials', 'TrainingStatus']




def __init__(
self,
*args,
Expand All @@ -98,58 +101,70 @@ def __init__(
p=[1 - probability_opto_stim, probability_opto_stim],
size=NTRIALS_INIT,
).astype(bool)
self.trials_table['laser_location_idx'] = np.zeros(NTRIALS_INIT, dtype=int)

def draw_next_trial_info(self, **kwargs):
"""Draw next trial variables.

This is called by the `next_trial` method before updating the Bpod state machine. This
subclass method generates the stimulation index which is sent to Zapit when arming the
laser on stimulation trials.
"""
if self.trials_table.at[self.trial_num, 'opto_stimulation']:
N = int(self.task_params.get('NUM_OPTO_COND', 52))
self.trials_table.at[self.trial_num, 'laser_location_idx'] = random.randrange(1, N)

def start_hardware(self):


self.client = TCPclient(tcp_port=1488, tcp_ip='127.0.0.1')

self.client.close() # need to ensure is closed first; currently nowhere that this is defined at end of task!
self.client.close() # need to ensure is closed first; currently nowhere that this is defined at end of task!
self.client.connect()
super().start_hardware()
# add the softcodes for the zapit opto stimulation
soft_code_dict = self.bpod.softcodes
soft_code_dict.update({SOFTCODE_STOP_ZAPIT: self.zapit_stop_laser})
soft_code_dict.update({SOFTCODE_FIRE_ZAPIT: self.zapit_fire_laser})
self.bpod.register_softcodes(soft_code_dict)


def zapit_arm_laser(self):
log.warning('Arming laser')
#this is where you define the laser stim (i.e., arm the laser)
# this is where you define the laser stim (i.e., arm the laser)

self.current_location_idx = random.randrange(1,int(num_cond))
current_location_idx = self.trials_table.at[self.trial_num, 'laser_location_idx']

#hZP.send_samples(
# conditionNum=current_location_idx, hardwareTriggered=True, logging=True
#)

zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(trial_state_command = 1,
arg_keys_dict = {'conditionNum_channel': True, 'laser_channel': True,
'hardwareTriggered_channel': True, 'logging_channel': False,
'verbose_channel': False},
arg_values_dict = {'conditionNum': self.current_location_idx, 'laser_ON': True,
'hardwareTriggered_ON': True, 'logging_ON': False,
'verbose_ON': False})
zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(
trial_state_command=1,
arg_keys_dict={'conditionNum_channel': True, 'laser_channel': True,
'hardwareTriggered_channel': True, 'logging_channel': False,
'verbose_channel': False},
arg_values_dict={'conditionNum': current_location_idx, 'laser_ON': True,
'hardwareTriggered_ON': True, 'logging_ON': False,
'verbose_ON': False}
)
response = self.client.send_receive(zapit_byte_tuple)
log.warning(response)
stim_location_history.append(self.current_location_idx)

def zapit_fire_laser(self):
# just logging - actual firing will be triggered by the state machine via TTL
#this really only triggers a ttl and sends a log entry - no need to plug in code here
# this really only triggers a ttl and sends a log entry - no need to plug in code here
log.warning('Firing laser')


def zapit_stop_laser(self):
log.warning('Stopping laser')
zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(trial_state_command = 0,
arg_keys_dict = {'conditionNum_channel': True, 'laser_channel': True,
'hardwareTriggered_channel': True, 'logging_channel': False,
'verbose_channel': False},
arg_values_dict = {'conditionNum': self.current_location_idx, 'laser_ON': True,
'hardwareTriggered_ON': False, 'logging_ON': False,
'verbose_ON': False})
current_location_idx = self.trials_table.at[self.trial_num, 'laser_location_idx']
zapit_byte_tuple, zapit_int_tuple = ptu.gen_Zapit_byte_tuple(
trial_state_command=0,
arg_keys_dict={'conditionNum_channel': True, 'laser_channel': True,
'hardwareTriggered_channel': True, 'logging_channel': False,
'verbose_channel': False},
arg_values_dict={'conditionNum': current_location_idx, 'laser_ON': True,
'hardwareTriggered_ON': False, 'logging_ON': False,
'verbose_ON': False}
)
response = self.client.send_receive(zapit_byte_tuple)

def _instantiate_state_machine(self, trial_number=None):
Expand All @@ -172,6 +187,11 @@ def _instantiate_state_machine(self, trial_number=None):
@staticmethod
def extra_parser():
""":return: argparse.parser()"""
def positive_int(value):
if (value := int(value)) <= 0:
raise ArgumentTypeError(f'"{value}" is an invalid positive int value')
return value

parser = super(Session, Session).extra_parser()
parser.add_argument(
'--probability_opto_stim',
Expand Down Expand Up @@ -208,6 +228,12 @@ def extra_parser():
type=str,
help='list of the state machine states where opto stim should be stopped',
)
parser.add_argument(
'--n_opto_cond',
default=DEFAULTS['NUM_OPTO_COND'],
type=positive_int,
help='the number (N) of preset conditions to draw from, where N > x > 0',
)
return parser


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
- error
- reward
'PROBABILITY_OPTO_STIM': 0.2 # probability of optogenetic stimulation
'NUM_OPTO_COND': 52 # the number (N) of preset conditions to draw from, where N > x > 0
31 changes: 31 additions & 0 deletions projects/nate_optoBiasedChoiceWorld.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
The pipeline task subclasses, OptoTrialsBpod and OptoTrialsNidq, aren't strictly necessary. They simply assert that the
laserStimulation datasets were indeed saved and registered by the Bpod extractor class.
"""
import yaml
import numpy as np
from packaging import version
import ibllib.io.raw_data_loaders as raw
from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes
from ibllib.io.extractors.bpod_trials import BiasedTrials
Expand Down Expand Up @@ -37,6 +39,13 @@ class TrialsOpto(BaseBpodTrialsExtractor):
var_names = BiasedTrials.var_names + ('laser_intervals',)
save_names = BiasedTrials.save_names + ('_ibl_laserStimulation.intervals.npy',)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.proj_version = version.parse(self.settings.get('PROJECT_EXTRACTION_VERSION', '0.0.0'))
if self.proj_version >= version.parse('0.3.0'):
self.var_names = BaseBpodTrialsExtractor.var_names + ('laser_mplapdv',)
self.save_names = BaseBpodTrialsExtractor.var_names + ('_ibl_laserStimulation.mlapdv.npy',)

def _extract(self, extractor_classes=None, **kwargs) -> dict:
settings = self.settings.copy()
if 'OPTO_STIM_STATES' in settings:
Expand All @@ -53,12 +62,34 @@ def _extract(self, extractor_classes=None, **kwargs) -> dict:

# Extract laser dataset
laser_intervals = []
location_index = []
for trial in filter(lambda t: t['opto_stimulation'], self.bpod_trials):
location_index.append(trial.get('laser_location_idx', 0))
states = trial['behavior_data']['States timestamps']
# Assumes one of these states per trial: takes the timestamp of the first matching state
start = next((v[0][0] for k, v in states.items() if k in settings['OPTO_TTL_STATES']), np.nan)
stop = next((v[0][0] for k, v in states.items() if k in settings['OPTO_STOP_STATES']), np.nan)
laser_intervals.append((start, stop))
out['laser_intervals'] = np.array(laser_intervals, dtype=np.float64)

# Extract laser coordinates
if self.proj_version >= version.parse('0.3.0'):
location_index = np.fromiter(filter(None, location_index), dtype=int)
assert len(location_index) == out['laser_intervals'].shape[0]
out['laser_mplapdv'] = np.full((out['laser_intervals'].shape[0], 3), np.NaN)
# Load lookup table
try:
zapit_file = next(self.alf_path.glob('zapit_log_*.yml'))
except StopIteration:
raise FileNotFoundError('Failed to load zapit log file.')

with open(zapit_file, 'r') as fp:
zapit = yaml.safe_load(fp)
if any(x['Type'] != 'unilateral_points' for x in (v for k, v in zapit.items() if k.startswith('stimLocations'))):
raise NotImplementedError # TODO verify and document
for i in np.unique(location_index):
location = zapit[f'stimLocations{i:02}']
mlapdv = (location['ML'][0], location['AP'][0], 0.) # TODO ensure len == 3
out['laser_mplapdv'][location_index == i, :] = mlapdv

return {k: out[k] for k in self.var_names} # Ensures all datasets present and ordered
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "project_extraction"
version = "0.4.1"
version = "0.3.0"
description = "Custom extractors for satellite tasks"
dynamic = [ "readme" ]
keywords = [ "IBL", "neuro-science" ]
Expand Down