diff --git a/bmtk/analyzer/spike_trains.py b/bmtk/analyzer/spike_trains.py index 58cd9937..aae90e5a 100644 --- a/bmtk/analyzer/spike_trains.py +++ b/bmtk/analyzer/spike_trains.py @@ -331,7 +331,7 @@ def calc_stats(r): return pd.Series(d, index=['count', 'isi']) - spike_counts_df = spike_trains.to_dataframe().groupby(['population', 'node_ids']).apply(calc_stats) + spike_counts_df = spike_trains.to_dataframe().groupby(['population', 'node_ids'])[['timestamps']].apply(calc_stats) spike_counts_df = spike_counts_df.rename({'timestamps': 'counts'}, axis=1) spike_counts_df.index.names = ['population', 'node_id'] @@ -343,7 +343,7 @@ def calc_stats(r): vals_df = pd.merge(nodes_df, spike_counts_df, left_index=True, right_index=True, how='left') vals_df = vals_df.fillna({'count': 0.0, 'firing_rate': 0.0, 'isi': 0.0}) - vals_df = vals_df.groupby(group_by)[['firing_rate', 'count', 'isi']].agg([np.mean, np.std]) + vals_df = vals_df.groupby(group_by)[['firing_rate', 'count', 'isi']].agg(['mean', 'std']) return vals_df else: return spike_counts_df diff --git a/bmtk/utils/brain_observatory/__init__.py b/bmtk/utils/brain_observatory/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bmtk/utils/brain_observatory/brain_observatory_cache.py b/bmtk/utils/brain_observatory/brain_observatory_cache.py new file mode 100644 index 00000000..ac987889 --- /dev/null +++ b/bmtk/utils/brain_observatory/brain_observatory_cache.py @@ -0,0 +1,1161 @@ +import os +import h5py +import numpy as np +import pandas as pd +import six +import dateutil +import re +from pkg_resources import parse_version + +from .cache import Cache, get_default_manifest_file +from .rma_template import RmaTemplate +from . import stimulus_info as si +from .manifest import ManifestBuilder + +class NoEyeTrackingException(Exception): + pass + + +class BrainObservatoryNwbDataSet(object): + PIPELINE_DATASET = 'brain_observatory_pipeline' + SUPPORTED_PIPELINE_VERSION = "3.0" + + FILE_METADATA_MAPPING = { + 'age': 'general/subject/age', + 'sex': 'general/subject/sex', + 'imaging_depth': 'general/optophysiology/imaging_plane_1/imaging depth', + 'targeted_structure': 'general/optophysiology/imaging_plane_1/location', + 'ophys_experiment_id': 'general/session_id', + 'experiment_container_id': 'general/experiment_container_id', + 'device_string': 'general/devices/2-photon microscope', + 'excitation_lambda': 'general/optophysiology/imaging_plane_1/excitation_lambda', + 'indicator': 'general/optophysiology/imaging_plane_1/indicator', + 'fov': 'general/fov', + 'genotype': 'general/subject/genotype', + 'session_start_time': 'session_start_time', + 'session_type': 'general/session_type', + 'specimen_name': 'general/specimen_name', + 'generated_by': 'general/generated_by' + } + + STIMULUS_TABLE_TYPES = { + 'abstract_feature_series': [si.DRIFTING_GRATINGS, si.STATIC_GRATINGS], + 'indexed_time_series': [si.NATURAL_SCENES, si.LOCALLY_SPARSE_NOISE, + si.LOCALLY_SPARSE_NOISE_4DEG, si.LOCALLY_SPARSE_NOISE_8DEG], + 'repeated_indexed_time_series':[si.NATURAL_MOVIE_ONE, si.NATURAL_MOVIE_TWO, si.NATURAL_MOVIE_THREE] + + } + + # this array was moved before file versioning was in place + MOTION_CORRECTION_DATASETS = [ "MotionCorrection/2p_image_series/xy_translations", + "MotionCorrection/2p_image_series/xy_translation" ] + + def __init__(self, nwb_file): + + self.nwb_file = nwb_file + self.pipeline_version = None + + if os.path.exists(self.nwb_file): + meta = self.get_metadata() + if meta and 'pipeline_version' in meta: + pipeline_version_str = meta['pipeline_version'] + self.pipeline_version = parse_version(pipeline_version_str) + + self._stimulus_search = None + + def get_stimulus_epoch_table(self): + '''Returns a pandas dataframe that summarizes the stimulus epoch duration for each acquisition time index in + the experiment + + Parameters + ---------- + None + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + traces: 2D numpy array + Fluorescence traces for each cell + ''' + + + # These are thresholds used by get_epoch_mask_list to set a maximum limit on the delta aqusistion frames to + # count as different trials (rows in the stim table). This helps account for dropped frames, so that they dont + # cause the cutting of an entire experiment into too many stimulus epochs. If these thresholds are too low, + # the assert statment in get_epoch_mask_list will halt execution. In that case, make a bug report!. + threshold_dict = {si.THREE_SESSION_A:32+7, + si.THREE_SESSION_B:15, + si.THREE_SESSION_C:7, + si.THREE_SESSION_C2:7} + + stimulus_table_dict = {} + for stimulus in self.list_stimuli(): + + stimulus_table_dict[stimulus] = self.get_stimulus_table(stimulus) + + if stimulus == si.SPONTANEOUS_ACTIVITY: + stimulus_table_dict[stimulus]['frame'] = 0 + + interval_list = [] + interval_stimulus_dict = {} + for stimulus in self.list_stimuli(): + stimulus_interval_list = get_epoch_mask_list(stimulus_table_dict[stimulus], threshold=threshold_dict.get(self.get_session_type(), None)) + for stimulus_interval in stimulus_interval_list: + interval_stimulus_dict[stimulus_interval] = stimulus + interval_list += stimulus_interval_list + interval_list.sort(key=lambda x: x[0]) + + stimulus_signature_list = ['gap'] + duration_signature_list = [int(interval_list[0][0])] + interval_signature_list = [(0,int(interval_list[0][0]))] + for ii, interval in enumerate(interval_list): + stimulus_signature_list.append(interval_stimulus_dict[interval]) + duration_signature_list.append(int(interval[1] - interval[0])) + interval_signature_list.append((int(interval[0]), int(interval[1]))) + + if ii != len(interval_list)-1: + stimulus_signature_list.append('gap') + duration_signature_list.append((int(interval_list[ii+1][0] - interval_list[ii][1]))) + interval_signature_list.append((int(interval_list[ii][1]), int(interval_list[ii+1][0]))) + + stimulus_signature_list.append('gap') + interval_signature_list.append((int(interval_list[-1][1]), len(self.get_fluorescence_timestamps()))) + duration_signature_list.append(interval_signature_list[-1][1]-interval_signature_list[-1][0]) + + interval_df = pd.DataFrame({'stimulus':stimulus_signature_list, + 'duration':duration_signature_list, + 'interval':interval_signature_list}) + + # Gaps are uninformative; remove them: + interval_df = interval_df[interval_df.stimulus != 'gap'] + interval_df['start'] = [x[0] for x in interval_df['interval'].values] + interval_df['end'] = [x[1] for x in interval_df['interval'].values] + + interval_df.reset_index(inplace=True, drop=True) + interval_df.drop(['interval', 'duration'], axis=1, inplace=True) + return interval_df + + + def get_fluorescence_traces(self, cell_specimen_ids=None): + ''' Returns an array of fluorescence traces for all ROI and + the timestamps for each datapoint + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then all are returned + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + traces: 2D numpy array + Fluorescence traces for each cell + ''' + timestamps = self.get_fluorescence_timestamps() + with h5py.File(self.nwb_file, 'r') as f: + ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1']['data'] + + if cell_specimen_ids is None: + cell_traces = ds[()] + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + cell_traces = ds[inds, :] + + return timestamps, cell_traces + + def get_fluorescence_timestamps(self): + ''' Returns an array of timestamps in seconds for the fluorescence traces ''' + + with h5py.File(self.nwb_file, 'r') as f: + timestamps = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1']['timestamps'][()] + return timestamps + + def get_neuropil_traces(self, cell_specimen_ids=None): + ''' Returns an array of neuropil fluorescence traces for all ROIs + and the timestamps for each datapoint + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then all are returned + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + traces: 2D numpy array + Neuropil fluorescence traces for each cell + ''' + + timestamps = self.get_fluorescence_timestamps() + + with h5py.File(self.nwb_file, 'r') as f: + if self.pipeline_version >= parse_version("2.0"): + ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1_neuropil_response']['data'] + else: + ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1']['neuropil_traces'] + + if cell_specimen_ids is None: + np_traces = ds[()] + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + np_traces = ds[inds, :] + + return timestamps, np_traces + + + def get_neuropil_r(self, cell_specimen_ids=None): + ''' Returns a scalar value of r for neuropil correction of flourescence traces + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then results for all are returned + + Returns + ------- + r: 1D numpy array, len(r)=len(cell_specimen_ids) + Scalar for neuropil subtraction for each cell + ''' + + with h5py.File(self.nwb_file, 'r') as f: + if self.pipeline_version >= parse_version("2.0"): + r_ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1_neuropil_response']['r'] + else: + r_ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1']['r'] + + if cell_specimen_ids is None: + r = r_ds[()] + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + r = r_ds[inds] + + return r + + def get_demixed_traces(self, cell_specimen_ids=None): + ''' Returns an array of demixed fluorescence traces for all ROIs + and the timestamps for each datapoint + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then all are returned + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + traces: 2D numpy array + Demixed fluorescence traces for each cell + ''' + + timestamps = self.get_fluorescence_timestamps() + + with h5py.File(self.nwb_file, 'r') as f: + ds = f['processing'][self.PIPELINE_DATASET][ + 'Fluorescence']['imaging_plane_1_demixed_signal']['data'] + if cell_specimen_ids is None: + traces = ds[()] + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + traces = ds[inds, :] + + return timestamps, traces + + def get_corrected_fluorescence_traces(self, cell_specimen_ids=None): + ''' Returns an array of demixed and neuropil-corrected fluorescence traces + for all ROIs and the timestamps for each datapoint + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then all are returned + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + traces: 2D numpy array + Corrected fluorescence traces for each cell + ''' + + # starting in version 2.0, neuropil correction follows trace demixing + if self.pipeline_version >= parse_version("2.0"): + timestamps, cell_traces = self.get_demixed_traces(cell_specimen_ids) + else: + timestamps, cell_traces = self.get_fluorescence_traces(cell_specimen_ids) + + r = self.get_neuropil_r(cell_specimen_ids) + + _, neuropil_traces = self.get_neuropil_traces(cell_specimen_ids) + + fc = cell_traces - neuropil_traces * r[:, np.newaxis] + + return timestamps, fc + + def get_cell_specimen_indices(self, cell_specimen_ids): + ''' Given a list of cell specimen ids, return their index based on their order in this file. + + Parameters + ---------- + cell_specimen_ids: list of cell specimen ids + + ''' + + all_cell_specimen_ids = list(self.get_cell_specimen_ids()) + + try: + inds = [list(all_cell_specimen_ids).index(i) + for i in cell_specimen_ids] + except ValueError as e: + raise ValueError("Cell specimen not found (%s)" % str(e)) + + return inds + + def get_dff_traces(self, cell_specimen_ids=None): + ''' Returns an array of dF/F traces for all ROIs and + the timestamps for each datapoint + + Parameters + ---------- + cell_specimen_ids: list or array (optional) + List of cell IDs to return data for. If this is None (default) + then all are returned + + Returns + ------- + timestamps: 2D numpy array + Timestamp for each fluorescence sample + + dF/F: 2D numpy array + dF/F values for each cell + ''' + with h5py.File(self.nwb_file, 'r') as f: + dff_ds = f['processing'][self.PIPELINE_DATASET][ + 'DfOverF']['imaging_plane_1'] + + timestamps = dff_ds['timestamps'][()] + + if cell_specimen_ids is None: + cell_traces = dff_ds['data'][()] + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + cell_traces = dff_ds['data'][inds, :] + + return timestamps, cell_traces + + def get_roi_ids(self): + ''' Returns an array of IDs for all ROIs in the file + + Returns + ------- + ROI IDs: list + ''' + with h5py.File(self.nwb_file, 'r') as f: + roi_id = f['processing'][self.PIPELINE_DATASET][ + 'ImageSegmentation']['roi_ids'][()] + return roi_id + + def get_cell_specimen_ids(self): + ''' Returns an array of cell IDs for all cells in the file + + Returns + ------- + cell specimen IDs: list + ''' + with h5py.File(self.nwb_file, 'r') as f: + cell_id = f['processing'][self.PIPELINE_DATASET][ + 'ImageSegmentation']['cell_specimen_ids'][()] + return cell_id + + def get_session_type(self): + ''' Returns the type of experimental session, presently one of the + following: three_session_A, three_session_B, three_session_C + + Returns + ------- + session type: string + ''' + with h5py.File(self.nwb_file, 'r') as f: + session_type = f['general/session_type'][()] + return session_type.decode('utf-8') + + def get_max_projection(self): + '''Returns the maximum projection image for the 2P movie. + + Returns + ------- + max projection: np.ndarray + ''' + + with h5py.File(self.nwb_file, 'r') as f: + max_projection = f['processing'][self.PIPELINE_DATASET]['ImageSegmentation'][ + 'imaging_plane_1']['reference_images']['maximum_intensity_projection_image']['data'][()] + return max_projection + + def list_stimuli(self): + ''' Return a list of the stimuli presented in the experiment. + + Returns + ------- + stimuli: list of strings + ''' + + with h5py.File(self.nwb_file, 'r') as f: + keys = list(f["stimulus/presentation/"].keys()) + return [ k.replace('_stimulus', '') for k in keys ] + + + def _get_master_stimulus_table(self): + ''' Builds a table for all stimuli by concatenating (vertically) the + sub-tables describing presentation of each stimulus + ''' + + epoch_table = self.get_stimulus_epoch_table() + + stimulus_table_dict = {} + for stimulus in self.list_stimuli(): + stimulus_table_dict[stimulus] = self.get_stimulus_table(stimulus) + + table_list = [] + for stimulus in self.list_stimuli(): + curr_stimtable = stimulus_table_dict[stimulus] + + for _, row in epoch_table[epoch_table['stimulus'] == stimulus].iterrows(): + + epoch_start_ind, epoch_end_ind = row['start'], row['end'] + curr_subtable = curr_stimtable[(epoch_start_ind <= curr_stimtable['start']) & + (curr_stimtable['end'] <= epoch_end_ind)].copy() + curr_subtable['stimulus'] = stimulus + table_list.append(curr_subtable) + + new_table = pd.concat(table_list, sort=True) + new_table.reset_index(drop=True, inplace=True) + + return new_table + + + def get_stimulus_table(self, stimulus_name): + ''' Return a stimulus table given a stimulus name + + Notes + ----- + For more information, see: + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf + + ''' + + if stimulus_name == 'master': + return self._get_master_stimulus_table() + + with h5py.File(self.nwb_file, 'r') as nwb_file: + + stimulus_group = _find_stimulus_presentation_group(nwb_file, stimulus_name) + + if stimulus_name in self.STIMULUS_TABLE_TYPES['abstract_feature_series']: + datasets = h5_utilities.load_datasets_by_relnames( + ['data', 'features', 'frame_duration'], nwb_file, stimulus_group) + return _make_abstract_feature_series_stimulus_table( + datasets['data'], h5_utilities.decode_bytes(datasets['features']), datasets['frame_duration']) + + if stimulus_name in self.STIMULUS_TABLE_TYPES['indexed_time_series']: + datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) + return _make_indexed_time_series_stimulus_table(datasets['data'], datasets['frame_duration']) + + if stimulus_name in self.STIMULUS_TABLE_TYPES['repeated_indexed_time_series']: + datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) + return _make_repeated_indexed_time_series_stimulus_table(datasets['data'], datasets['frame_duration']) + + if stimulus_name == 'spontaneous': + datasets = h5_utilities.load_datasets_by_relnames(['data', 'frame_duration'], nwb_file, stimulus_group) + return _make_spontaneous_activity_stimulus_table(datasets['data'], datasets['frame_duration']) + + raise IOError("Could not find a stimulus table named '%s'" % stimulus_name) + + + # @memoize + def get_stimulus_template(self, stimulus_name): + ''' Return an array of the stimulus template for the specified stimulus. + + Parameters + ---------- + stimulus_name: string + Must be one of the strings returned by list_stimuli(). + + Returns + ------- + stimulus table: pd.DataFrame + ''' + stim_name = stimulus_name + "_image_stack" + with h5py.File(self.nwb_file, 'r') as f: + image_stack = f['stimulus']['templates'][stim_name]['data'][()] + return image_stack + + def get_locally_sparse_noise_stimulus_template(self, + stimulus, + mask_off_screen=True): + ''' Return an array of the stimulus template for the specified stimulus. + + Parameters + ---------- + stimulus: string + Which locally sparse noise stimulus to retrieve. Must be one of: + stimulus_info.LOCALLY_SPARSE_NOISE + stimulus_info.LOCALLY_SPARSE_NOISE_4DEG + stimulus_info.LOCALLY_SPARSE_NOISE_8DEG + + mask_off_screen: boolean + Set off-screen regions of the stimulus to LocallySparseNoise.LSN_OFF_SCREEN. + + Returns + ------- + tuple: (template, off-screen mask) + ''' + + if stimulus not in si.LOCALLY_SPARSE_NOISE_DIMENSIONS: + raise KeyError("%s is not a known locally sparse noise stimulus" % stimulus) + + template = self.get_stimulus_template(stimulus) + + # build mapping from template coordinates to display coordinates + template_shape = si.LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus] + template_shape = [ template_shape[1], template_shape[0] ] + + template_display_shape = (1260, 720) + display_shape = (1920, 1200) + + scale = [ + float(template_shape[0]) / float(template_display_shape[0]), + float(template_shape[1]) / float(template_display_shape[1]) + ] + offset = [ + -(display_shape[0] - template_display_shape[0]) * 0.5, + -(display_shape[1] - template_display_shape[1]) * 0.5 + ] + + x, y = np.meshgrid(np.arange(display_shape[0]), np.arange( + display_shape[1]), indexing='ij') + template_display_coords = np.array([(x + offset[0]) * scale[0] - 0.5, + (y + offset[1]) * scale[1] - 0.5], + dtype=float) + template_display_coords = np.rint(template_display_coords).astype(int) + + # build mask + template_mask, template_frac = si_mask_stimulus_template( + template_display_coords, template_shape) + + if mask_off_screen: + template[:, ~template_mask.T] = LocallySparseNoise.LSN_OFF_SCREEN + + return template, template_mask.T + + def get_roi_mask_array(self, cell_specimen_ids=None): + ''' Return a numpy array containing all of the ROI masks for requested cells. + If cell_specimen_ids is omitted, return all masks. + + Parameters + ---------- + cell_specimen_ids: list + List of cell specimen ids. Default None. + + Returns + ------- + np.ndarray: NxWxH array, where N is number of cells + ''' + + roi_masks = self.get_roi_mask(cell_specimen_ids) + + if len(roi_masks) == 0: + raise IOError("no masks found for given cell specimen ids") + + roi_arr = roi.create_roi_mask_array(roi_masks) + + return roi_arr + + def get_roi_mask(self, cell_specimen_ids=None): + ''' Returns an array of all the ROI masks + + Parameters + ---------- + cell specimen IDs: list or array (optional) + List of cell IDs to return traces for. If this is None (default) + then all are returned + + Returns + ------- + List of ROI_Mask objects + ''' + + with h5py.File(self.nwb_file, 'r') as f: + mask_loc = f['processing'][self.PIPELINE_DATASET][ + 'ImageSegmentation']['imaging_plane_1'] + roi_list = f['processing'][self.PIPELINE_DATASET][ + 'ImageSegmentation']['imaging_plane_1']['roi_list'][()] + + inds = None + if cell_specimen_ids is None: + inds = range(self.number_of_cells) + else: + inds = self.get_cell_specimen_indices(cell_specimen_ids) + + roi_array = [] + for i in inds: + v = roi_list[i] + roi_mask = mask_loc[v]["img_mask"][()] + m = roi.create_roi_mask(roi_mask.shape[1], roi_mask.shape[0], + [0, 0, 0, 0], roi_mask=roi_mask, label=v) + roi_array.append(m) + + return roi_array + + @property + def number_of_cells(self): + '''Number of cells in the experiment''' + + # Replace here is there is a better way to get this info: + return len(self.get_cell_specimen_ids()) + + + def get_metadata(self): + ''' Returns a dictionary of meta data associated with each + experiment, including Cre line, specimen number, + visual area imaged, imaging depth + + Returns + ------- + metadata: dictionary + ''' + + meta = {} + + with h5py.File(self.nwb_file, 'r') as f: + for memory_key, disk_key in BrainObservatoryNwbDataSet.FILE_METADATA_MAPPING.items(): + try: + v = f[disk_key][()] + + # convert numpy strings to python strings + if v.dtype.type is np.bytes_: + if len(v.shape) == 0: + v = v.decode('UTF-8') + elif len(v.shape) == 1: + v = [ s.decode('UTF-8') for s in v ] + else: + raise Exception("Unrecognized metadata formatting for field %s" % disk_key) + + meta[memory_key] = v + except KeyError as e: + logging.warning("could not find key %s", disk_key) + + # extract cre line from genotype string + genotype = meta.get('genotype') + meta['cre_line'] = meta['genotype'].split(';')[0] if genotype else None + + imaging_depth = meta.pop('imaging_depth', None) + meta['imaging_depth_um'] = int(imaging_depth.split()[0]) if imaging_depth else None + + ophys_experiment_id = meta.get('ophys_experiment_id') + meta['ophys_experiment_id'] = int(ophys_experiment_id) if ophys_experiment_id else None + + experiment_container_id = meta.get('experiment_container_id') + meta['experiment_container_id'] = int(experiment_container_id) if experiment_container_id else None + + # convert start time to a date object + session_start_time = meta.get('session_start_time') + if isinstance( session_start_time, six.string_types ): + meta['session_start_time'] = dateutil.parser.parse(session_start_time) + + age = meta.pop('age', None) + if age: + # parse the age in days + m = re.match("(.*?) days", age) + if m: + meta['age_days'] = int(m.groups()[0]) + else: + raise IOError("Could not parse age.") + + + # parse the device string (ugly, sorry) + device_string = meta.pop('device_string', None) + if device_string: + m = re.match("(.*?)\.\s(.*?)\sPlease*", device_string) + if m: + device, device_name = m.groups() + meta['device'] = device + meta['device_name'] = device_name + else: + raise IOError("Could not parse device string.") + + # file version + generated_by = meta.pop('generated_by', None) + version = generated_by[-1] if generated_by else "0.9" + meta["pipeline_version"] = version + + return meta + + def get_running_speed(self): + ''' Returns the mouse running speed in cm/s + ''' + with h5py.File(self.nwb_file, 'r') as f: + dx_ds = f['processing'][self.PIPELINE_DATASET][ + 'BehavioralTimeSeries']['running_speed'] + dxcm = dx_ds['data'][()] + dxtime = dx_ds['timestamps'][()] + + timestamps = self.get_fluorescence_timestamps() + + # v0.9 stored this as an Nx1 array instead of a flat 1-d array + if len(dxcm.shape) == 2: + dxcm = dxcm[:, 0] + + dxcm, dxtime = align_running_speed(dxcm, dxtime, timestamps) + + return dxcm, dxtime + + def get_pupil_location(self, as_spherical=True): + '''Returns the x, y pupil location. + + Parameters + ---------- + as_spherical : bool + Whether to return the location as spherical (default) or + not. If true, the result is altitude and azimuth in + degrees, otherwise it is x, y in centimeters. (0,0) is + the center of the monitor. + + Returns + ------- + (timestamps, location) + Timestamps is an (Nx1) array of timestamps in seconds. + Location is an (Nx2) array of spatial location. + ''' + if as_spherical: + location_key = "pupil_location_spherical" + else: + location_key = "pupil_location" + try: + with h5py.File(self.nwb_file, 'r') as f: + eye_tracking = f['processing'][self.PIPELINE_DATASET][ + 'EyeTracking'][location_key] + pupil_location = eye_tracking['data'][()] + pupil_times = eye_tracking['timestamps'][()] + except KeyError: + raise NoEyeTrackingException("No eye tracking for this experiment.") + + return pupil_times, pupil_location + + def get_pupil_size(self): + '''Returns the pupil area in pixels. + + Returns + ------- + (timestamps, areas) + Timestamps is an (Nx1) array of timestamps in seconds. + Areas is an (Nx1) array of pupil areas in pixels. + ''' + try: + with h5py.File(self.nwb_file, 'r') as f: + pupil_tracking = f['processing'][self.PIPELINE_DATASET][ + 'PupilTracking']['pupil_size'] + pupil_size = pupil_tracking['data'][()] + pupil_times = pupil_tracking['timestamps'][()] + except KeyError: + raise NoEyeTrackingException("No pupil tracking for this experiment.") + + return pupil_times, pupil_size + + def get_motion_correction(self): + ''' Returns a Panda DataFrame containing the x- and y- translation of each image used for image alignment + ''' + + motion_correction = None + with h5py.File(self.nwb_file, 'r') as f: + pipeline_ds = f['processing'][self.PIPELINE_DATASET] + + # pipeline 0.9 stores this in xy_translations + # pipeline 1.0 stores this in xy_translation + for mc_ds_name in self.MOTION_CORRECTION_DATASETS: + try: + mc_ds = pipeline_ds[mc_ds_name] + + motion_log = mc_ds['data'][()] + motion_time = mc_ds['timestamps'][()] + motion_names = mc_ds['feature_description'][()] + + motion_correction = pd.DataFrame(motion_log, columns=motion_names) + motion_correction['timestamp'] = motion_time + + # break out if we found it + break + except KeyError as e: + pass + + if motion_correction is None: + raise KeyError("Could not find motion correction data.") + + # Python3 compatibility: + rename_dict = {} + for c in motion_correction.columns: + if not isinstance(c, str): + rename_dict[c] = c.decode("utf-8") + motion_correction.rename(columns=rename_dict, inplace=True) + + return motion_correction + + def save_analysis_dataframes(self, *tables): + store = pd.HDFStore(self.nwb_file, mode='a') + for k, v in tables: + store.put('analysis/%s' % (k), v) + store.close() + + def save_analysis_arrays(self, *datasets): + with h5py.File(self.nwb_file, 'a') as f: + for k, v in datasets: + if k in f['analysis']: + del f['analysis'][k] + f.create_dataset('analysis/%s' % k, data=v) + + @property + def stimulus_search(self): + + if self._stimulus_search is None: + self._stimulus_search = si.StimulusSearch(self) + return self._stimulus_search + + def get_stimulus(self, frame_ind): + + search_result = self.stimulus_search.search(frame_ind) + + if search_result is None or search_result[2]['stimulus'] == si.SPONTANEOUS_ACTIVITY: + return None, None + + else: + + curr_stimulus = search_result[2]['stimulus'] + if curr_stimulus in si.LOCALLY_SPARSE_NOISE_STIMULUS_TYPES + si.NATURAL_MOVIE_STIMULUS_TYPES + [si.NATURAL_SCENES]: + curr_frame = search_result[2]['frame'] + return search_result, self.get_stimulus_template(curr_stimulus)[int(curr_frame), :, :] + elif curr_stimulus == si.STATIC_GRATINGS or curr_stimulus == si.DRIFTING_GRATINGS: + return search_result, None + + + + +class BrainObservatoryApi(RmaTemplate): + + OPHYS_EVENTS_FILE_TYPE = "ObservatoryEventsFile" + NWB_FILE_TYPE = "NWBOphys" + OPHYS_ANALYSIS_FILE_TYPE = "OphysExperimentCellRoiMetricsFile" + + rma_templates = { + "brain_observatory_queries": [ + { + "name": "list_isi_experiments", + "description": "see name", + "model": "IsiExperiment", + "num_rows": "all", + "count": False, + "criteria_params": [], + }, + { + "name": "isi_experiment_by_ids", + "description": "see name", + "model": "IsiExperiment", + "criteria": "[id$in{{ isi_experiment_ids }}]", + "include": "experiment_container(ophys_experiments,targeted_structure)", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["isi_experiment_ids"], + }, + { + "name": "ophys_experiment_by_ids", + "description": "see name", + "model": "OphysExperiment", + "criteria": "{% if ophys_experiment_ids is defined %}[id$in{{ ophys_experiment_ids }}]{%endif%}", # noqa e501 + "include": "experiment_container,well_known_files(well_known_file_type),targeted_structure,specimen(donor(age,transgenic_lines))", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["ophys_experiment_ids"], + }, + { + "name": "ophys_experiment_data", + "description": "see name", + "model": "WellKnownFile", + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + % NWB_FILE_TYPE, + "num_rows": "all", + "count": False, + "criteria_params": ["ophys_experiment_id"], + }, + { + "name": "ophys_analysis_file", + "description": "see name", + "model": "WellKnownFile", + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + % OPHYS_ANALYSIS_FILE_TYPE, + "num_rows": "all", + "count": False, + "criteria_params": ["ophys_experiment_id"], + }, + { + "name": "ophys_events_file", + "description": "see name", + "model": "WellKnownFile", + "criteria": "[attachable_id$eq{{ ophys_experiment_id }}],well_known_file_type[name$eq%s]" # noqa e501 + % OPHYS_EVENTS_FILE_TYPE, + "num_rows": "all", + "count": False, + "criteria_params": ["ophys_experiment_id"], + }, + { + "name": "column_definitions", + "description": "see name", + "model": "ApiColumnDefinition", + "criteria": "[api_class_name$eq{{ api_class_name }}]", + "num_rows": "all", + "count": False, + "criteria_params": ["api_class_name"], + }, + { + "name": "column_definition_class_names", + "description": "see name", + "model": "ApiColumnDefinition", + "only": ["api_class_name"], + "num_rows": "all", + "count": False, + }, + { + "name": "stimulus_mapping", + "description": "see name", + "model": "ApiCamStimulusMapping", + "criteria": "{% if stimulus_mapping_ids is defined %}[id$in{{ stimulus_mapping_ids }}]{%endif%}", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["stimulus_mapping_ids"], + }, + { + "name": "experiment_container", + "description": "see name", + "model": "ExperimentContainer", + "criteria": "{% if experiment_container_ids is defined %}[id$in{{ experiment_container_ids }}]{%endif%}", # noqa e501 + "include": "ophys_experiments,isi_experiment,specimen(donor(conditions,age,transgenic_lines)),targeted_structure", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["experiment_container_ids"], + }, + { + "name": "experiment_container_metric", + "description": "see name", + "model": "ApiCamExperimentContainerMetric", + "criteria": "{% if experiment_container_metric_ids is defined %}[id$in{{ experiment_container_metric_ids }}]{%endif%}", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["experiment_container_metric_ids"], + }, + { + "name": "cell_metric", + "description": "see name", + "model": "ApiCamCellMetric", + "criteria": "{% if cell_specimen_ids is defined %}[cell_specimen_id$in{{ cell_specimen_ids }}]{%endif%}", # noqa e501 + "criteria_params": ["cell_specimen_ids"], + }, + { + "name": "cell_specimen_id_mapping_table", + "description": "see name", + "model": "WellKnownFile", + "criteria": "[id$eq{{ mapping_table_id }}],well_known_file_type[name$eqOphysCellSpecimenIdMapping]", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["mapping_table_id"], + }, + { + "name": "eye_gaze_mapping_file", + "description": "h5 file containing mouse eye gaze mapped onto screen coordinates (as well as pupil and eye sizes)", # noqa e501 + "model": "WellKnownFile", + "criteria": "[attachable_id$eq{{ ophys_session_id }}],well_known_file_type[name$eqEyeDlcScreenMapping]", # noqa e501 + "num_rows": "all", + "count": False, + "criteria_params": ["ophys_session_id"], + }, + # NOTE: 'all_eye_mapping_files' query is for facilitating an ugly + # hack to get around lack of relationship between experiment id + # and session id in current warehouse. This should be removed when + # the relationship is added. + { + "name": "all_eye_mapping_files", + "description": "Get a list of dictionaries for all eye mapping wkfs", # noqa e501 + "model": "WellKnownFile", + "criteria": "well_known_file_type[name$eqEyeDlcScreenMapping]", + "num_rows": "all", + "count": False, + }, + ] + } + + def __init__(self, base_uri=None, datacube_uri=None): + super(BrainObservatoryApi, self).__init__( + base_uri, query_manifest=BrainObservatoryApi.rma_templates + ) + + self.datacube_uri = datacube_uri + + def save_ophys_experiment_data(self, ophys_experiment_id, file_name): + data = self.template_query( + "brain_observatory_queries", + "ophys_experiment_data", + ophys_experiment_id=ophys_experiment_id, + ) + + try: + file_url = data[0]["download_link"] + except Exception: + raise Exception( + "ophys experiment %d has no data file" % ophys_experiment_id + ) + + # self._log.warning( + # "Downloading ophys_experiment %d NWB. This can take some time." + # % ophys_experiment_id + # ) + + self.retrieve_file_over_http(self.api_url + file_url, file_name) + + + +class BrainObservatoryCache(Cache): + EXPERIMENT_CONTAINERS_KEY = "EXPERIMENT_CONTAINERS" + EXPERIMENTS_KEY = "EXPERIMENTS" + CELL_SPECIMENS_KEY = "CELL_SPECIMENS" + EXPERIMENT_DATA_KEY = "EXPERIMENT_DATA" + ANALYSIS_DATA_KEY = "ANALYSIS_DATA" + EVENTS_DATA_KEY = "EVENTS_DATA" + STIMULUS_MAPPINGS_KEY = "STIMULUS_MAPPINGS" + EYE_GAZE_DATA_KEY = "EYE_GAZE_DATA" + MANIFEST_VERSION = "1.3" + + def __init__(self, cache=True, manifest_file=None, base_uri=None, api=None): + + if manifest_file is None: + manifest_file = get_default_manifest_file("brain_observatory") + + super(BrainObservatoryCache, self).__init__( + manifest=manifest_file, cache=cache, version=self.MANIFEST_VERSION + ) + + if api is None: + self.api = BrainObservatoryApi(base_uri=base_uri) + else: + self.api = api + + + def get_ophys_experiment_data(self, ophys_experiment_id, file_name=None): + """Download the NWB file for an ophys_experiment (if it hasn't + already been + downloaded) and return a data accessor object. + + Parameters + ---------- + file_name: string + File name to save/read the data set. If file_name is None, + the file_name will be pulled out of the manifest. If caching + is disabled, no file will be saved. Default is None. + + ophys_experiment_id: integer + id of the ophys_experiment to retrieve + + Returns + ------- + BrainObservatoryNwbDataSet + """ + file_name = self.get_cache_path( + file_name, self.EXPERIMENT_DATA_KEY, ophys_experiment_id + ) + + self.api.save_ophys_experiment_data( + ophys_experiment_id, file_name + ) + + return BrainObservatoryNwbDataSet(file_name) + + def build_manifest(self, file_name): + """ + Construct a manifest for this Cache class and save it in a file. + + Parameters + ---------- + + file_name: string + File location to save the manifest. + + """ + + mb = ManifestBuilder() + mb.set_version(self.MANIFEST_VERSION) + mb.add_path("BASEDIR", ".") + mb.add_path( + self.EXPERIMENT_CONTAINERS_KEY, + "experiment_containers.json", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.EXPERIMENTS_KEY, + "ophys_experiments.json", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.EXPERIMENT_DATA_KEY, + "ophys_experiment_data/%d.nwb", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.ANALYSIS_DATA_KEY, + "ophys_experiment_analysis/%d_%s_analysis.h5", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.EVENTS_DATA_KEY, + "ophys_experiment_events/%d_events.npz", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.CELL_SPECIMENS_KEY, + "cell_specimens.json", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.STIMULUS_MAPPINGS_KEY, + "stimulus_mappings.json", + typename="file", + parent_key="BASEDIR", + ) + mb.add_path( + self.EYE_GAZE_DATA_KEY, + "ophys_eye_gaze_mapping/%d_eyetracking_dlc_to_screen_mapping.h5", + typename="file", + parent_key="BASEDIR", + ) + + mb.write_json_file(file_name) diff --git a/bmtk/utils/brain_observatory/cache.py b/bmtk/utils/brain_observatory/cache.py new file mode 100644 index 00000000..5abf9eb0 --- /dev/null +++ b/bmtk/utils/brain_observatory/cache.py @@ -0,0 +1,102 @@ +import os +import json + +from .manifest import Manifest, ManifestBuilder, ManifestVersionError + + +class Cache: + def __init__(self, + manifest=None, + cache=True, + version=None, + **kwargs): + self.cache = cache + if version is None and hasattr(self, 'MANIFEST_VERSION'): + version = self.MANIFEST_VERSION + self.load_manifest(manifest, version) + + def load_manifest(self, file_name, version=None): + if file_name is not None: + if not os.path.exists(file_name): + + # make the directory if it doesn't exist already + dirname = os.path.dirname(file_name) + if dirname: + Manifest.safe_mkdir(dirname) + + self.build_manifest(file_name) + + try: + with open(file_name, "rb") as f: + json_string = f.read().decode("utf-8") + if len(json_string) == 0: + json_string = "{}" + json_obj = json.loads(json_string) + + self.manifest = Manifest( + json_obj['manifest'], + os.path.dirname(file_name), + version=version) + except ManifestVersionError as e: + if e.outdated is True: + intro = "is out of date" + elif e.outdated is False: + intro = "was made with a newer version of the AllenSDK" + elif e.outdated is None: + intro = "version did not match the expected version" + + ref_url = "https://github.com/alleninstitute/allensdk/wiki" + raise ManifestVersionError(("Your manifest file (%s) %s" + + " (its version is '%s', but" + + " version '%s' is expected). " + + " Please remove this file" + + " and it will be regenerated for" + + " you the next time you" + + " instantiate this class." + + " WARNING: There may be new data" + + " files available that replace" + + " the ones you already have" + + " downloaded. Read the notes" + + " for this release for more" + + " details on what has changed" + + " (%s).") % + (file_name, intro, + e.found_version, e.version, + ref_url), + e.version, e.found_version) + + self.manifest_path = file_name + + else: + self.manifest = None + + + def build_manifest(self, file_name): + manifest_builder = ManifestBuilder() + manifest_builder.set_version(self.MANIFEST_VERSION) + manifest_builder = self.add_manifest_paths(manifest_builder) + manifest_builder.write_json_file(file_name) + + + def add_manifest_paths(self, manifest_builder): + manifest_builder.add_path('BASEDIR', '.') + if hasattr(self, 'MANIFEST_CONFIG'): + for key, config in self.MANIFEST_CONFIG.items(): + manifest_builder.add_path(key, **config) + return manifest_builder + + + def get_cache_path(self, file_name, manifest_key, *args): + if self.cache: + if file_name: + return file_name + elif self.manifest: + return self.manifest.get_path(manifest_key, *args) + + return None + +def get_default_manifest_file(cache_name): + return os.environ.get( + '{}_MANIFEST'.format(cache_name.upper()), + '{}/manifest.json'.format(cache_name.lower()) + ) \ No newline at end of file diff --git a/bmtk/utils/brain_observatory/ecephys/__init__.py b/bmtk/utils/brain_observatory/ecephys/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/bmtk/utils/brain_observatory/ecephys/ecephys_project_cache.py b/bmtk/utils/brain_observatory/ecephys/ecephys_project_cache.py new file mode 100644 index 00000000..d613a341 --- /dev/null +++ b/bmtk/utils/brain_observatory/ecephys/ecephys_project_cache.py @@ -0,0 +1,325 @@ +import pandas as pd +import functools +from pathlib import Path + +from ..cache import Cache +from ..utils import write_from_stream +from ..rma_engine import RmaEngine + + +class EcephysSession: + def __init__(self, nwb_path): + self.nwb_path = nwb_path + + +class EcephysProjectWarehouseApi: + def __init__(self, rma_engine=None): + if rma_engine is None: + rma_engine = RmaEngine(scheme="http", host="api.brain-map.org") + self.rma_engine = rma_engine + + @classmethod + def default(cls, asynchronous=False, **rma_kwargs): + _rma_kwargs = {"scheme": "http", "host": "api.brain-map.org"} + _rma_kwargs.update(rma_kwargs) + + engine_cls = AsyncRmaEngine if asynchronous else RmaEngine + return cls(engine_cls(**_rma_kwargs)) + + def get_session_data(self, session_id, **kwargs): + query = "criteria=model::WellKnownFile" \ + ",rma::criteria,well_known_file_type[name$eq'EcephysNwb']" \ + "[attachable_type$eq'EcephysSession']" \ + f"[attachable_id$eq{session_id}]" + well_known_files = self.rma_engine.get_rma_tabular(query) + + if well_known_files.shape[0] != 1: + raise ValueError( + f"expected exactly 1 nwb file for session {session_id}, found: {well_known_files}" # noqa: E501 + ) + + download_link = well_known_files.iloc[0]["download_link"] + return self.rma_engine.stream(download_link) + + + def get_sessions( + self, session_ids=None, has_eye_tracking=None, stimulus_names=None + ): + response = build_and_execute( + ( + "{% import 'rma_macros' as rm %}" + "{% import 'macros' as m %}" + "criteria=model::EcephysSession" + r"{{rm.optional_contains('id',session_ids)}}" + r"{%if has_eye_tracking is not none%}[fail_eye_tracking$eq{{m.str(not has_eye_tracking).lower()}}]{%endif%}" # noqa: E501 + r"{{rm.optional_contains('stimulus_name',stimulus_names,True)}}" # noqa: E501 + ",rma::include,specimen(donor(age))" + ",well_known_files(well_known_file_type)" + ), + base=rma_macros(), + engine=self.rma_engine.get_rma_tabular, + session_ids=session_ids, + has_eye_tracking=has_eye_tracking, + stimulus_names=stimulus_names, + ) + + response.set_index("id", inplace=True) + + age_in_days = [] + sex = [] + genotype = [] + has_nwb = [] + + for idx, row in response.iterrows(): + age_in_days.append(row["specimen"]["donor"]["age"]["days"]) + sex.append(row["specimen"]["donor"]["sex"]) + + gt = row["specimen"]["donor"]["full_genotype"] + if gt is None: + gt = "wt" + genotype.append(gt) + + current_has_nwb = False + for wkf in row["well_known_files"]: + if wkf["well_known_file_type"]["name"] == "EcephysNwb": + current_has_nwb = True + has_nwb.append(current_has_nwb) + + response["age_in_days"] = age_in_days + response["sex"] = sex + response["genotype"] = genotype + response["has_nwb"] = has_nwb + + response.drop( + columns=["specimen", "fail_eye_tracking", "well_known_files"], + inplace=True, + ) + response.rename( + columns={"stimulus_name": "session_type"}, inplace=True + ) + + return response + + def get_probes(self, probe_ids=None, session_ids=None): + raise NotImplementedError() + + + def get_channels(self, channel_ids=None, probe_ids=None): + raise NotImplementedError() + + + def get_rig_metadata(self): + raise NotImplementedError() + + + def get_units(self, unit_ids=None, channel_ids=None, probe_ids=None, session_ids=None, *a, **k): + raise NotImplementedError() + + + def get_unit_analysis_metrics(self, unit_ids=None, ecephys_session_ids=None, session_types=None): + raise NotImplementedError() + + + def get_probe_lfp_data(self, probe_id): + raise NotImplementedError() + + + + + +class EcephysProjectCache(Cache): + SESSIONS_KEY = 'sessions' + PROBES_KEY = 'probes' + CHANNELS_KEY = 'channels' + UNITS_KEY = 'units' + + SESSION_DIR_KEY = 'session_data' + SESSION_NWB_KEY = 'session_nwb' + PROBE_LFP_NWB_KEY = "probe_lfp_nwb" + + NATURAL_MOVIE_DIR_KEY = "movie_dir" + NATURAL_MOVIE_KEY = "natural_movie" + + NATURAL_SCENE_DIR_KEY = "natural_scene_dir" + NATURAL_SCENE_KEY = "natural_scene" + + SESSION_ANALYSIS_METRICS_KEY = "session_analysis_metrics" + TYPEWISE_ANALYSIS_METRICS_KEY = "typewise_analysis_metrics" + + MANIFEST_VERSION = '0.3.0' + + SUPPRESS_FROM_PROBES = ( + "air_channel_index", "surface_channel_index", + "date_of_acquisition", "published_at", "specimen_id", "session_type", "isi_experiment_id", "age_in_days", + "sex", "genotype", "has_nwb", "lfp_temporal_subsampling_factor" + ) + + def __init__( + self, + fetch_api=None, + fetch_tries=2, + stream_writer=None, + manifest=None, + version=None, + cache=True): + + manifest_ = manifest or "ecephys_project_manifest.json" + version_ = version or self.MANIFEST_VERSION + + super(EcephysProjectCache, self).__init__(manifest=manifest_, + version=version_, + cache=cache) + self.fetch_api = (EcephysProjectWarehouseApi.default() + if fetch_api is None else fetch_api) + self.fetch_tries = fetch_tries + self.stream_writer = (stream_writer + or self.fetch_api.rma_engine.write_bytes) + if stream_writer is not None: + self.stream_writer = stream_writer + else: + if hasattr(self.fetch_api, "rma_engine"): # EcephysProjectWarehouseApi # noqa + self.stream_writer = self.fetch_api.rma_engine.write_bytes + # TODO: Make these names consistent in the different fetch apis + elif hasattr(self.fetch_api, "app_engine"): # EcephysProjectLimsApi # noqa + self.stream_writer = self.fetch_api.app_engine.write_bytes + else: + raise ValueError( + "Must either set value for `stream_writer`, or use a " + "`fetch_api` with an rma_engine or app_engine attribute " + "that implements `write_bytes`. See `HttpEngine` and " + "`AsyncHttpEngine` from " + "allensdk.brain_observatory.ecephys.ecephys_project_api." + "http_engine for examples.") + + + def add_manifest_paths(self, manifest_builder): + manifest_builder = super(EcephysProjectCache, self).add_manifest_paths(manifest_builder) + + manifest_builder.add_path( + self.SESSIONS_KEY, 'sessions.csv', parent_key='BASEDIR', typename='file' + ) + + manifest_builder.add_path( + self.PROBES_KEY, 'probes.csv', parent_key='BASEDIR', typename='file' + ) + + manifest_builder.add_path( + self.CHANNELS_KEY, 'channels.csv', parent_key='BASEDIR', typename='file' + ) + + manifest_builder.add_path( + self.UNITS_KEY, 'units.csv', parent_key='BASEDIR', typename='file' + ) + + manifest_builder.add_path( + self.SESSION_DIR_KEY, 'session_%d', parent_key='BASEDIR', typename='dir' + ) + + manifest_builder.add_path( + self.SESSION_NWB_KEY, 'session_%d.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' + ) + + manifest_builder.add_path( + self.SESSION_ANALYSIS_METRICS_KEY, 'session_%d_analysis_metrics.csv', parent_key=self.SESSION_DIR_KEY, typename='file' + ) + + manifest_builder.add_path( + self.PROBE_LFP_NWB_KEY, 'probe_%d_lfp.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' + ) + + manifest_builder.add_path( + self.NATURAL_MOVIE_DIR_KEY, "natural_movie_templates", parent_key="BASEDIR", typename="dir" + ) + + manifest_builder.add_path( + self.TYPEWISE_ANALYSIS_METRICS_KEY, "%s_analysis_metrics.csv", parent_key='BASEDIR', typename="file" + ) + + manifest_builder.add_path( + self.NATURAL_MOVIE_KEY, "natural_movie_%d.h5", parent_key=self.NATURAL_MOVIE_DIR_KEY, typename="file" + ) + + manifest_builder.add_path( + self.NATURAL_SCENE_DIR_KEY, "natural_scene_templates", parent_key="BASEDIR", typename="dir" + ) + + manifest_builder.add_path( + self.NATURAL_SCENE_KEY, "natural_scene_%d.tiff", parent_key=self.NATURAL_SCENE_DIR_KEY, typename="file" + ) + + return manifest_builder + + @classmethod + def from_warehouse(cls, + scheme=None, + host=None, + asynchronous=False, + manifest=None, + version=None, + cache=True, + fetch_tries=2, + timeout=1200): + if scheme and host: + app_kwargs = {"scheme": scheme, "host": host, + "asynchronous": asynchronous} + else: + app_kwargs = {"asynchronous": asynchronous} + app_kwargs['timeout'] = timeout + return cls._from_http_source_default( + EcephysProjectWarehouseApi, app_kwargs, manifest=manifest, + version=version, cache=cache, fetch_tries=fetch_tries + ) + + @classmethod + def _from_http_source_default(cls, fetch_api_cls, fetch_api_kwargs, **kwargs): + fetch_api_kwargs = { + "asynchronous": True + } if fetch_api_kwargs is None else fetch_api_kwargs + + if kwargs.get("stream_writer") is None: + if fetch_api_kwargs.get("asynchronous", True): + kwargs["stream_writer"] = write_bytes_from_coroutine + else: + kwargs["stream_writer"] = write_from_stream + + return cls( + fetch_api=fetch_api_cls.default(**fetch_api_kwargs), + **kwargs + ) + + def get_session_data(self, session_id, force_overwrite=False): + """ Obtain an EcephysSession object containing detailed data for a single session + """ + + path = self.get_cache_path(None, self.SESSION_NWB_KEY, session_id, session_id) + fetch = functools.partial(self.fetch_api.get_session_data, session_id) + write = self.stream_writer + self._fetch_cached_session(path, fetch, write, force_overwrite) + return EcephysSession(path) + + + def _fetch_cached_session(self, path, fetch, write, force_overwrite=False): + path = path if isinstance(path, Path) else Path(path) + if not force_overwrite and path.exists(): + return + + path.parent.mkdir(parents=True, exist_ok=True) + + data = fetch() + write(path, data) + + + def get_channels(self, suppress=None): + raise NotImplementedError() + + + def get_probes(self, suppress=None): + raise NotImplementedError() + + + def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): + raise NotImplementedError() + + + def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): + raise NotImplementedError() diff --git a/bmtk/utils/brain_observatory/manifest.py b/bmtk/utils/brain_observatory/manifest.py new file mode 100644 index 00000000..37132ea7 --- /dev/null +++ b/bmtk/utils/brain_observatory/manifest.py @@ -0,0 +1,342 @@ +import os +import sys +import errno +import json +from pathlib import Path + +from .utils import json_handler + + + +class ManifestVersionError(Exception): + @property + def outdated(self): + try: + return self.found_version < self.version + except TypeError: + return + + def __init__(self, message, version, found_version): + super(ManifestVersionError, self).__init__(message) + self.found_version = found_version + self.version = version + + +class ManifestBuilder(object): + df_columns = ['key', 'parent_key', 'spec', 'type', 'format'] + + def __init__(self): + self.path_info = [] + self.sections = {} + + def set_version(self, value): + self.path_info.append({'type': Manifest.VERSION, 'value': value}) + + def add_path(self, key, spec, + typename='dir', + parent_key=None, + format=None): + entry = { + 'key': key, + 'type': typename, + 'spec': spec} + + if format is not None: + entry['format'] = format + + if parent_key is not None: + entry['parent_key'] = parent_key + + self.path_info.append(entry) + + def write_json_file(self, path, overwrite=False): + mode = 'wb' + + if overwrite is True: + mode = 'wb+' + + json_string = self.write_json_string() + + with open(path, mode) as f: + try: + f.write(json_string) # Python 2.7 + except TypeError: + f.write(bytes(json_string, 'utf-8')) # Python 3 + + def write_json_string(self): + config = self.get_config() + + return json.dumps( + config, + indent=2, + # ignore_nan=True, + default=json_handler, + # iterable_as_array=True, + ) + + + + + def get_config(self): + wrapper = {"manifest": self.path_info} + for section in self.sections.values(): + wrapper.update(section) + + return wrapper + + + +class Manifest(object): + DIR = 'dir' + FILE = 'file' + DIRNAME = 'dir_name' + VERSION = 'manifest_version' + + def __init__(self, config=None, relative_base_dir='.', version=None): + self.path_info = {} + self.relative_base_dir = relative_base_dir + + if config is not None: + self.load_config(config, version=version) + + def load_config(self, config, version=None): + ''' Load paths into the manifest from an Allen SDK config section. + + Parameters + ---------- + config : Config + Manifest section of an Allen SDK config. + ''' + found_version = None + for path_info in config: + path_type = path_info['type'] + path_format = None + if 'format' in path_info: + path_format = path_info['format'] + + if path_type == 'file': + try: + parent_key = path_info['parent_key'] + except: + parent_key = None + + self.add_file(path_info['key'], + path_info['spec'], + parent_key, + path_format) + elif path_type == 'dir': + try: + parent_key = path_info['parent_key'] + except: + parent_key = None + + spec = path_info['spec'] + absolute = False + if spec[0] == '/': + absolute = True + self.add_path(path_info['key'], + path_info['spec'], + path_type, + absolute, + path_format, + parent_key) + + elif path_type == self.VERSION: + found_version = path_info['value'] + else: + Manifest.log.warning("Unknown path type in manifest: %s" % + (path_type)) + + + if found_version != version: + raise ManifestVersionError("", version, found_version) + self.version = version + + def add_path(self, key, path, path_type=DIR, + absolute=True, path_format=None, parent_key=None): + '''Insert a new entry. + + Parameters + ---------- + key : string + Identifier for referencing the entry. + path : string + Specification for a path using %s, %d style substitution. + path_type : string enumeration + 'dir' (default) or 'file' + absolute : boolean + Is the spec relative to the process current directory. + path_format : string, optional + Indicate a known file type for further parsing. + parent_key : string + Refer to another entry. + ''' + if parent_key: + path_args = [] + + try: + parent_path = self.path_info[parent_key]['spec'] + path_args.append(parent_path) + except: + Manifest.log.error( + "cannot resolve directory key %s" % (parent_key)) + raise + path_args.extend(path.split('/')) + path = os.path.join(*path_args) + + # TODO: relative paths need to be considered better + if absolute is True: + path = os.path.abspath(path) + else: + path = os.path.abspath(os.path.join(self.relative_base_dir, path)) + + if path_type == Manifest.DIRNAME: + path = os.path.dirname(path) + + self.path_info[key] = {'type': path_type, + 'spec': path} + + if path_type == Manifest.FILE and path_format is not None: + self.path_info[key]['format'] = path_format + + def add_file(self, + file_key, + file_name, + dir_key=None, + path_format=None): + '''Insert a new file entry. + + Parameters + ---------- + file_key : string + Reference to the entry. + file_name : string + Subtitutions of the %s, %d style allowed. + dir_key : string + Reference to the parent directory entry. + path_format : string, optional + File type for further parsing. + ''' + path_args = [] + + if dir_key: + try: + dir_path = self.path_info[dir_key]['spec'] + path_args.append(dir_path) + except: + Manifest.log.error( + "cannot resolve directory key %s" % (dir_key)) + raise + elif not file_name.startswith('/'): + path_args.append(os.curdir) + else: + path_args.append(os.path.sep) + + path_args.extend(file_name.split('/')) + file_path = os.path.join(*path_args) + + self.path_info[file_key] = {'type': Manifest.FILE, + 'spec': file_path} + + if path_format: + self.path_info[file_key]['format'] = path_format + + @classmethod + def safe_mkdir(cls, directory): + '''Create path if not already there. + + Parameters + ---------- + directory : string + create it if it doesn't exist + + Returns + ------- + leftmost : string + most rootward directory created + + ''' + + parts = Path(directory).parts + sub_paths = [Path(parts[0])] + for part in parts[1:]: + sub_paths.append(sub_paths[-1] / part) + + leftmost = None + for sub_path in sub_paths: + if not sub_path.exists(): + leftmost = str(sub_path) + + try: + os.makedirs(directory) + except OSError as e: + if ((sys.platform == "darwin") and (e.errno == errno.EISDIR) and \ + (e.filename == "/")): + # undocumented behavior of mkdir on OSX where for / it raises + # EISDIR and not EEXIST + # https://bugs.python.org/issue24231 (old but still holds true) + pass + elif sys.platform == "win32" and e.errno == errno.EACCES: + root_path = os.path.abspath(os.sep) + if e.filename == root_path or \ + e.filename == root_path.replace("\\", "/"): + # When attempting to os.makedirs the root drive letter on + # Windows, EACCES is raised, not EEXIST + pass + else: + raise + elif e.errno == errno.EEXIST: + pass + else: + raise + + return leftmost + + @classmethod + def safe_make_parent_dirs(cls, file_name): + ''' Create a parent directories for file. + + Parameters + ---------- + file_name : string + + Returns + ------- + leftmost : string + most rootward directory created + + ''' + + dirname = os.path.dirname(file_name) + + # do nothing if there are no parent directories + if not dirname: + return + + return Manifest.safe_mkdir(dirname) + + + def get_path(self, path_key, *args): + '''Retrieve an entry with substitutions. + + Parameters + ---------- + path_key : string + Refer to the entry to retrieve. + args : any types, optional + arguments to be substituted into the path spec for %s, %d, etc. + + Returns + ------- + string + Path with parent structure and substitutions applied. + ''' + path_spec = self.path_info[path_key]['spec'] + + if args is not None and len(args) != 0: + path = path_spec % args + else: + path = path_spec + + return path + diff --git a/bmtk/utils/brain_observatory/rma_engine.py b/bmtk/utils/brain_observatory/rma_engine.py new file mode 100644 index 00000000..8a88b2dd --- /dev/null +++ b/bmtk/utils/brain_observatory/rma_engine.py @@ -0,0 +1,136 @@ +import pandas as pd +import time +import requests + +from .utils import infer_column_types + +try: + from tqdm import tqdm +except ImportError: + from .utils import FakeTqdm as tqdm + + + +DEFAULT_TIMEOUT = 20 * 60 # seconds +DEFAULT_CHUNKSIZE = 1024 * 10 # bytes + + +class HttpEngine: + def __init__( + self, + scheme: str, + host: str, + timeout: float = DEFAULT_TIMEOUT, + chunksize: int = DEFAULT_CHUNKSIZE, + **kwargs + ): + self.scheme = scheme + self.host = host + self.timeout = timeout + self.chunksize = chunksize + + + def stream(self, route): + """ Makes an http request and returns an iterator over the response. + + Parameters + ---------- + route : + the http route (under this object's host) to request against. + + """ + + url = self._build_url(route) + + start_time = time.perf_counter() + response = requests.get(url, stream=True) + response_b = None + if "Content-length" in response.headers: + response_b = float(response.headers["Content-length"]) + + size_message = f"{response_b / 1024 ** 2:3.3f}MiB" if response_b is not None else "potentially large" + # logging.warning(f"downloading a {size_message} file from {url}") + progress = tqdm(unit="B", total=response_b, unit_scale=True, desc="Downloading") + + for chunk in response.iter_content(self.chunksize): + if chunk: # filter out keep-alive new chunks + progress.update(len(chunk)) + yield chunk + + elapsed = time.perf_counter() - start_time + if elapsed > self.timeout: + raise requests.Timeout(f"Download took {elapsed} seconds, but timeout was set to {self.timeout}") + + def _build_url(self, route): + return f"{self.scheme}://{self.host}/{route}" + + + +class RmaEngine(HttpEngine): + def __init__( + self, + scheme, + host, + rma_prefix: str = "api/v2/data", + rma_format: str = "json", + page_size: int = 5000, + **kwargs + ): + super(RmaEngine, self).__init__(scheme, host, **kwargs) + self.rma_prefix = rma_prefix + self.rma_format = rma_format + self.page_size = page_size + + @property + def format_query_string(self): + return f"query.{self.rma_format}" + + def add_page_params(self, url, start, count=None): + if count is None: + count = self.page_size + return f"{url},rma::options[start_row$eq{start}][num_rows$eq{count}][order$eq'id']" + + + def get_rma(self, query: str): + """ Makes a paging rma query + + Parameters + ---------- + query : + The RMA query parameters + + """ + url = f"{self.scheme}://{self.host}/{self.rma_prefix}/{self.format_query_string}?{query}" + # logging.debug(url) + + start_row = 0 + total_rows = None + + start_time = time.time() + while total_rows is None or start_row < total_rows: + current_url = self.add_page_params(url, start_row) + response_json = requests.get(current_url).json() + if not response_json["success"]: + raise Exception(response_json["msg"]) + + start_row += response_json["num_rows"] + if total_rows is None: + total_rows = response_json["total_rows"] + + # logging.debug(f"downloaded {start_row} of {total_rows} records ({time.time() - start_time:.3f} seconds)") + yield response_json["msg"] + + + def get_rma_list(self, query): + response = [] + for chunk in self.get_rma(query): + response.extend(chunk) + return response + + def get_rma_tabular(self, query, try_infer_dtypes=True): + response = pd.DataFrame(self.get_rma_list(query)) + + if try_infer_dtypes: + response = infer_column_types(response) + + return response \ No newline at end of file diff --git a/bmtk/utils/brain_observatory/rma_template.py b/bmtk/utils/brain_observatory/rma_template.py new file mode 100644 index 00000000..9877b567 --- /dev/null +++ b/bmtk/utils/brain_observatory/rma_template.py @@ -0,0 +1,1048 @@ +import os +import warnings +import pandas as pd +import requests +from contextlib import closing +import urllib +import json +from pathlib import Path + +from jinja2 import Template + + +class Api(object): + # _log = logging.getLogger('allensdk.api.api') + # _file_download_log = logging.getLogger('allensdk.api.api.retrieve_file_over_http') + default_api_url = 'http://api.brain-map.org' + download_url = 'http://download.alleninstitute.org' + + def __init__(self, api_base_url_string=None): + if api_base_url_string is None: + api_base_url_string = Api.default_api_url + + self.set_api_urls(api_base_url_string) + self.default_working_directory = os.getcwd() + + def set_api_urls(self, api_base_url_string): + '''Set the internal RMA and well known file download endpoint urls + based on a api server endpoint. + + Parameters + ---------- + api_base_url_string : string + url of the api to point to + ''' + self.api_url = api_base_url_string + + # http://help.brain-map.org/display/api/Downloading+a+WellKnownFile + self.well_known_file_endpoint = api_base_url_string + \ + '/api/v2/well_known_file_download' + + # http://help.brain-map.org/display/api/Downloading+3-D+Expression+Grid+Data + self.grid_data_endpoint = api_base_url_string + '/grid_data' + + # http://help.brain-map.org/display/api/Downloading+and+Displaying+SVG + self.svg_endpoint = api_base_url_string + '/api/v2/svg' + self.svg_download_endpoint = api_base_url_string + '/api/v2/svg_download' + + # http://help.brain-map.org/display/api/Downloading+an+Ontology%27s+Structure+Graph + self.structure_graph_endpoint = api_base_url_string + \ + '/api/v2/structure_graph_download' + + # http://help.brain-map.org/display/api/Searching+a+Specimen+or+Structure+Tree + self.tree_search_endpoint = api_base_url_string + '/api/v2/tree_search' + + # http://help.brain-map.org/display/api/Searching+Annotated+SectionDataSets + self.annotated_section_data_sets_endpoint = api_base_url_string + \ + '/api/v2/annotated_section_data_sets' + self.compound_annotated_section_data_sets_endpoint = api_base_url_string + \ + '/api/v2/compound_annotated_section_data_sets' + + # http://help.brain-map.org/display/api/Image-to-Image+Synchronization#Image-to-ImageSynchronization-ImagetoImage + self.image_to_atlas_endpoint = api_base_url_string + '/api/v2/image_to_atlas' + self.image_to_image_endpoint = api_base_url_string + '/api/v2/image_to_image' + self.image_to_image_2d_endpoint = api_base_url_string + '/api/v2/image_to_image_2d' + self.reference_to_image_endpoint = api_base_url_string + '/api/v2/reference_to_image' + self.image_to_reference_endpoint = api_base_url_string + '/api/v2/image_to_reference' + self.structure_to_image_endpoint = api_base_url_string + '/api/v2/structure_to_image' + + # http://help.brain-map.org/display/mouseconnectivity/API + self.section_image_download_endpoint = api_base_url_string + \ + '/api/v2/section_image_download' + self.atlas_image_download_endpoint = api_base_url_string + \ + '/api/v2/atlas_image_download' + self.projection_image_download_endpoint = api_base_url_string + \ + '/api/v2/projection_image_download' + self.image_download_endpoint = api_base_url_string + \ + '/api/v2/image_download' + self.informatics_archive_endpoint = Api.download_url + '/informatics-archive' + + self.rma_endpoint = api_base_url_string + '/api/v2/data' + + def set_default_working_directory(self, working_directory): + '''Set the working directory where files will be saved. + + Parameters + ---------- + working_directory : string + the absolute path string of the working directory. + ''' + self.default_working_directory = working_directory + + def read_data(self, parsed_json): + '''Return the message data from the parsed query. + + Parameters + ---------- + parsed_json : dict + A python structure corresponding to the JSON data returned from the API. + + Notes + ----- + See `API Response Formats - Response Envelope `_ + for additional documentation. + ''' + return parsed_json['msg'] + + def json_msg_query(self, url, dataframe=False): + ''' Common case where the url is fully constructed + and the response data is stored in the 'msg' field. + + Parameters + ---------- + url : string + Where to get the data in json form + dataframe : boolean + True converts to a pandas dataframe, False (default) doesn't + + Returns + ------- + dict or DataFrame + returned data; type depends on dataframe option + ''' + + data = self.do_query(lambda *a, **k: url, + self.read_data) + + if dataframe is True: + warnings.warn("dataframe argument is deprecated", DeprecationWarning) + data = pd.DataFrame(data) + + return data + + def do_query(self, url_builder_fn, json_traversal_fn, *args, **kwargs): + '''Bundle an query url construction function + with a corresponding response json traversal function. + + Parameters + ---------- + url_builder_fn : function + A function that takes parameters and returns an rma url. + json_traversal_fn : function + A function that takes a json-parsed python data structure and returns data from it. + post : boolean, optional kwarg + True does an HTTP POST, False (default) does a GET + args : arguments + Arguments to be passed to the url builder function. + kwargs : keyword arguments + Keyword arguments to be passed to the rma builder function. + + Returns + ------- + any type + The data extracted from the json response. + + Examples + -------- + `A simple Api subclass example + `_. + ''' + api_url = url_builder_fn(*args, **kwargs) + + post = kwargs.get('post', False) + + json_parsed_data = self.retrieve_parsed_json_over_http(api_url, post) + + return json_traversal_fn(json_parsed_data) + + def do_rma_query(self, rma_builder_fn, json_traversal_fn, *args, **kwargs): + '''Bundle an RMA query url construction function + with a corresponding response json traversal function. + + ..note:: Deprecated in AllenSDK 0.9.2 + `do_rma_query` will be removed in AllenSDK 1.0, it is replaced by + `do_query` because the latter is more general. + + Parameters + ---------- + rma_builder_fn : function + A function that takes parameters and returns an rma url. + json_traversal_fn : function + A function that takes a json-parsed python data structure and returns data from it. + args : arguments + Arguments to be passed to the rma builder function. + kwargs : keyword arguments + Keyword arguments to be passed to the rma builder function. + + Returns + ------- + any type + The data extracted from the json response. + + Examples + -------- + `A simple Api subclass example + `_. + ''' + return self.do_query(rma_builder_fn, json_traversal_fn, *args, **kwargs) + + def load_api_schema(self): + '''Download the RMA schema from the current RMA endpoint + + Returns + ------- + dict + the parsed json schema message + + Notes + ----- + This information and other + `Allen Brain Atlas Data Portal Data Model `_ + documentation is also available as a + `Class Hierarchy `_ + and `Class List `_. + + ''' + schema_url = self.rma_endpoint + '/enumerate.json' + json_parsed_schema_data = self.retrieve_parsed_json_over_http( + schema_url) + + return json_parsed_schema_data + + def construct_well_known_file_download_url(self, well_known_file_id): + '''Join data api endpoint and id. + + Parameters + ---------- + well_known_file_id : integer or string representing an integer + well known file id + + Returns + ------- + string + the well-known-file download url for the current api api server + + See Also + -------- + retrieve_file_over_http: Can be used to retrieve the file from the url. + ''' + return self.well_known_file_endpoint + '/' + str(well_known_file_id) + + def cleanup_truncated_file(self, file_path): + '''Helper for removing files. + + Parameters + ---------- + file_path : string + Absolute path including the file name to remove.''' + try: + os.remove(file_path) + except OSError as e: + warnings(f'{e}') + + def retrieve_file_over_http(self, url, file_path, zipped=False): + '''Get a file from the data api and save it. + + Parameters + ---------- + url : string + Url[1]_ from which to get the file. + file_path : string + Absolute path including the file name to save. + zipped : bool, optional + If true, assume that the response is a zipped directory and attempt + to extract contained files into the directory containing file_path. + Default is False. + + See Also + -------- + construct_well_known_file_download_url: Can be used to construct the url. + + References + ---------- + .. [1] Allen Brain Atlas Data Portal: `Downloading a WellKnownFile `_. + ''' + + # self._file_download_log.info("Downloading URL: %s", url) + + try: + if zipped: + stream_zip_directory_over_http(url, os.path.dirname(file_path)) + else: + stream_file_over_http(url, file_path) + + except Exception as e: + # self._file_download_log.error("Couldn't retrieve file %s from %s" % (file_path, url)) + # self.cleanup_truncated_file(file_path) + raise e + + + def retrieve_parsed_json_over_http(self, url, post=False): + '''Get the document and put it in a Python data structure + + Parameters + ---------- + url : string + Full API query url. + post : boolean + True does an HTTP POST, False (default) encodes the URL and does a GET + + Returns + ------- + dict + Result document as parsed by the JSON library. + ''' + # self._log.info("Downloading URL: %s", url) + + if post is False: + url = requests.utils.quote(url, ';/?:@&=+$,') + response = urllib.request.urlopen(url) + json_string = response.read().decode("utf-8") + data = json.loads(json_string) + else: + data = json_utilities.read_url_post(url) + + return data + + def retrieve_xml_over_http(self, url): + '''Get the document and put it in a Python data structure + + Parameters + ---------- + url : string + Full API query url. + + Returns + ------- + string + Unparsed xml string. + ''' + self._log.info("Downloading URL: %s", url) + + response = requests.get(url) + + return response.content + + +def stream_zip_directory_over_http(url, directory, members=None, timeout=(9.05, 31.1)): + ''' Supply an http get request and stream the response to a file. + + Parameters + ---------- + url : str + Send the request to this url + directory : str + Extract the response to this directory + members : list of str, optional + Extract only these files + timeout : float or tuple of float, optional + Specify a timeout for the request. If a tuple, specify seperate connect + and read timeouts. + + ''' + + buf = io.BytesIO() + + with closing( requests.get(url, stream=True, timeout=timeout) ) as request: + stream.stream_response_to_file( request, buf ) + + zipper = zipfile.ZipFile(buf) + zipper.extractall(path=directory, members=members) + zipper.close() + + +def stream_file_over_http(url, file_path, timeout=(9.05, 31.1)): + ''' Supply an http get request and stream the response to a file. + + Parameters + ---------- + url : str + Send the request to this url + file_path : str + Stream the response to this path + timeout : float or tuple of float, optional + Specify a timeout for the request. If a tuple, specify seperate connect + and read timeouts. + + ''' + Path(file_path).parent.mkdir(parents=True, exist_ok=True) + with closing(requests.get(url, stream=True, timeout=timeout)) as response: + + response.raise_for_status() + with open(file_path, 'wb') as fil: + for chunk in response.iter_content(chunk_size=8192): # Adjust chunk_size as needed + if chunk: # Filter out keep-alive new chunks + fil.write(chunk) + # stream.stream_response_to_file(response, path=fil) + + + +class RmaApi(Api): + ''' + See: `RESTful Model Access (RMA) `_ + ''' + MODEL = 'model::' + PIPE = 'pipe::' + SERVICE = 'service::' + CRITERIA = 'rma::criteria' + INCLUDE = 'rma::include' + OPTIONS = 'rma::options' + ORDER = 'order' + NUM_ROWS = 'num_rows' + ALL = 'all' + START_ROW = 'start_row' + COUNT = 'count' + ONLY = 'only' + EXCEPT = 'except' + EXCPT = 'excpt' + TABULAR = 'tabular' + DEBUG = 'debug' + PREVIEW = 'preview' + TRUE = 'true' + FALSE = 'false' + IS = '$is' + EQ = '$eq' + + def __init__(self, base_uri=None): + super(RmaApi, self).__init__(base_uri) + + def build_query_url(self, + stage_clauses, + fmt='json'): + '''Combine one or more RMA query stages into a single RMA query. + + Parameters + ---------- + stage_clauses : list of strings + subqueries + fmt : string, optional + json (default), xml, or csv + + Returns + ------- + string + complete RMA url + ''' + if not type(stage_clauses) is list: + stage_clauses = [stage_clauses] + + url = ''.join([ + self.rma_endpoint, + '/query.', + fmt, + '?q=', + ','.join(stage_clauses)]) + + return url + + def model_stage(self, + model, + **kwargs): + '''Construct a model stage of an RMA query string. + + Parameters + ---------- + model : string + The top level data type + filters : dict + key, value comparisons applied to the top-level model to narrow the results. + criteria : string + raw RMA criteria clause to choose what object are returned + include : string + raw RMA include clause to return associated objects + only : list of strings, optional + to be joined into an rma::options only filter to limit what data is returned + except : list of strings, optional + to be joined into an rma::options except filter to limit what data is returned + tabular : list of string, optional + return columns as a tabular data structure rather than a nested tree. + count : boolean, optional + False to skip the extra database count query. + debug : string, optional + 'true', 'false' or 'preview' + num_rows : int or string, optional + how many database rows are returned (may not correspond directly to JSON tree structure) + start_row : int or string, optional + which database row is start of returned data (may not correspond directly to JSON tree structure) + + + Notes + ----- + See `RMA Path Syntax `_ + for a brief overview of the normalized RMA syntax. + Normalized RMA syntax differs from the legacy syntax + used in much of the RMA documentation. + Using the &debug=true option with an RMA URL will include debugging information in the + response, including the normalized query. + ''' + clauses = [RmaApi.MODEL + model] + + filters = kwargs.get('filters', None) + + if filters is not None: + clauses.append(self.filters(filters)) + + criteria = kwargs.get('criteria', None) + + if criteria is not None: + clauses.append(',') + clauses.append(RmaApi.CRITERIA) + clauses.append(',') + clauses.extend(criteria) + + include = kwargs.get('include', None) + + if include is not None: + clauses.append(',') + clauses.append(RmaApi.INCLUDE) + clauses.append(',') + clauses.extend(include) + + options_clause = self.options_clause(**kwargs) + + if options_clause != '': + clauses.append(',') + clauses.append(options_clause) + + stage = ''.join(clauses) + + return stage + + def pipe_stage(self, + pipe_name, + parameters): + '''Connect model and service stages via their JSON responses. + + Notes + ----- + See: `Service Pipelines `_ + and + `Connected Services and Pipes `_ + ''' + clauses = [RmaApi.PIPE + pipe_name] + + clauses.append(self.tuple_filters(parameters)) + + stage = ''.join(clauses) + + return stage + + def service_stage(self, + service_name, + parameters=None): + '''Construct an RMA query fragment to send a request to a connected service. + + Parameters + ---------- + service_name : string + Name of a documented connected service. + parameters : dict + key-value pairs as in the online documentation. + + Notes + ----- + See: `Service Pipelines `_ + and + `Connected Services and Pipes `_ + ''' + clauses = [RmaApi.SERVICE + service_name] + + if parameters is not None: + clauses.append(self.tuple_filters(parameters)) + + stage = ''.join(clauses) + + return stage + + def model_query(self, *args, **kwargs): + '''Construct and execute a model stage of an RMA query string. + + Parameters + ---------- + model : string + The top level data type + filters : dict + key, value comparisons applied to the top-level model to narrow the results. + criteria : string + raw RMA criteria clause to choose what object are returned + include : string + raw RMA include clause to return associated objects + only : list of strings, optional + to be joined into an rma::options only filter to limit what data is returned + except : list of strings, optional + to be joined into an rma::options except filter to limit what data is returned + excpt : list of strings, optional + synonym for except parameter to avoid a reserved word conflict. + tabular : list of string, optional + return columns as a tabular data structure rather than a nested tree. + count : boolean, optional + False to skip the extra database count query. + debug : string, optional + 'true', 'false' or 'preview' + num_rows : int or string, optional + how many database rows are returned (may not correspond directly to JSON tree structure) + start_row : int or string, optional + which database row is start of returned data (may not correspond directly to JSON tree structure) + + + Notes + ----- + See `RMA Path Syntax `_ + for a brief overview of the normalized RMA syntax. + Normalized RMA syntax differs from the legacy syntax + used in much of the RMA documentation. + Using the &debug=true option with an RMA URL will include debugging information in the + response, including the normalized query. + ''' + return self.json_msg_query( + self.build_query_url( + self.model_stage(*args, **kwargs))) + + def service_query(self, *args, **kwargs): + '''Construct and Execute a single-stage RMA query + to send a request to a connected service. + + Parameters + ---------- + service_name : string + Name of a documented connected service. + parameters : dict + key-value pairs as in the online documentation. + + Notes + ----- + See: `Service Pipelines `_ + and + `Connected Services and Pipes `_ + ''' + return self.json_msg_query( + self.build_query_url( + self.service_stage(*args, **kwargs))) + + def options_clause(self, **kwargs): + '''build rma:: options clause. + + Parameters + ---------- + only : list of strings, optional + except : list of strings, optional + tabular : list of string, optional + count : boolean, optional + debug : string, optional + 'true', 'false' or 'preview' + num_rows : int or string, optional + start_row : int or string, optional + ''' + clause = '' + options_params = [] + + only = kwargs.get(RmaApi.ONLY, None) + + if only is not None: + options_params.append( + self.only_except_tabular_clause(RmaApi.ONLY, + only)) + + # handle alternate 'except' spelling to avoid reserved word conflict + excpt = kwargs.get(RmaApi.EXCEPT, None) + excpt2 = kwargs.get(RmaApi.EXCPT, None) + + if excpt is not None and excpt2 is not None: + warnings.warn('excpt and except options should not be used together', + Warning) + elif excpt2 is not None: + excpt = excpt2 + + if excpt is not None: + options_params.append( + self.only_except_tabular_clause(RmaApi.EXCEPT, + excpt)) + + tabular = kwargs.get(RmaApi.TABULAR, None) + + if tabular is not None: + options_params.append( + self.only_except_tabular_clause(RmaApi.TABULAR, + tabular)) + + num_rows = kwargs.get(RmaApi.NUM_ROWS, None) + + if num_rows is not None: + if num_rows == RmaApi.ALL: + options_params.append("[%s$eq'all']" % (RmaApi.NUM_ROWS)) + else: + options_params.append('[%s$eq%d]' % (RmaApi.NUM_ROWS, + num_rows)) + + start_row = kwargs.get(RmaApi.START_ROW, None) + + if start_row is not None: + options_params.append('[%s$eq%d]' % (RmaApi.START_ROW, + start_row)) + + order = kwargs.get(RmaApi.ORDER, None) + + if order is not None: + options_params.append(self.order_clause(order)) + + debug = kwargs.get(RmaApi.DEBUG, None) + + if debug is not None: + options_params.append(self.debug_clause(debug)) + + cnt = kwargs.get(RmaApi.COUNT, None) + + if cnt is not None: + if cnt is True or cnt == 'true': + options_params.append('[%s$eq%s]' % (RmaApi.COUNT, + RmaApi.TRUE)) + elif cnt is False or cnt == 'false': + options_params.append('[%s$eq%s]' % (RmaApi.COUNT, + RmaApi.FALSE)) + else: + pass + + if len(options_params) > 0: + clause = RmaApi.OPTIONS + ''.join(options_params) + + return clause + + def only_except_tabular_clause(self, filter_type, attribute_list): + '''Construct a clause to filter which attributes are returned + for use in an rma::options clause. + + Parameters + ---------- + filter_type : string + 'only', 'except', or 'tabular' + attribute_list : list of strings + for example ['acronym', 'products.name', 'structure.id'] + + Returns + ------- + clause : string + The query clause for inclusion in an RMA query URL. + + Notes + ----- + The title of tabular columns can be set by adding '+as+' + to the attribute. + The tabular filter type requests a response that is row-oriented + rather than a nested structure. + Because of this, the tabular option can mask the lazy query behavior + of an rma::include clause. + The tabular option does not mask the inner-join behavior of an rma::include + clause. + The tabular filter is required for .csv format RMA requests. + ''' + clause = '' + + if attribute_list is not None: + clause = '[%s$eq%s]' % (filter_type, + ','.join(attribute_list)) + + return clause + + def order_clause(self, order_list=None): + '''Construct a debug clause for use in an rma::options clause. + + Parameters + ---------- + order_list : list of strings + for example ['acronym', 'products.name+asc', 'structure.id+desc'] + + Returns + ------- + clause : string + The query clause for inclusion in an RMA query URL. + + Notes + ----- + Optionally adding '+asc' (default) or '+desc' after an attribute + will change the sort order. + ''' + clause = '' + + if order_list is not None: + clause = '[order$eq%s]' % (','.join(order_list)) + + return clause + + def debug_clause(self, debug_value=None): + '''Construct a debug clause for use in an rma::options clause. + Parameters + ---------- + debug_value : string or boolean + True, False, None (default) or 'preview' + + Returns + ------- + clause : string + The query clause for inclusion in an RMA query URL. + + Notes + ----- + True will request debugging information in the response. + False will request no debugging information. + None will return an empty clause. + 'preview' will request debugging information without the query being run. + + ''' + clause = '' + + if debug_value is None: + clause = '' + if debug_value is True or debug_value == 'true': + clause = '[debug$eqtrue]' + elif debug_value is False or debug_value == 'false': + clause = '[debug$eqfalse]' + elif debug_value == 'preview': + clause = "[debug$eq'preview']" + + return clause + + # TODO: deprecate for something that can preserve order + def filters(self, filters): + '''serialize RMA query filter clauses. + + Parameters + ---------- + filters : dict + keys and values for narrowing a query. + + Returns + ------- + string + filter clause for an RMA query string. + ''' + filters_builder = [] + + for (key, value) in filters.items(): + filters_builder.append(self.filter(key, value)) + + return ''.join(filters_builder) + + # TODO: this needs to be more rigorous. + def tuple_filters(self, filters): + '''Construct an RMA filter clause. + + Notes + ----- + + See `RMA Path Syntax - Square Brackets for Filters <http://help.brain-map.org/display/api/RMA+Path+Syntax#RMAPathSyntax-SquareBracketsforFilters>`_ for additional documentation. + ''' + filters_builder = [] + + for filt in sorted(filters): + if filt[-1] is None: + continue + if len(filt) == 2: + val = filt[1] + if type(val) is list: + val_array = [] + for v in val: + if type(v) is str: + val_array.append(v) + else: + val_array.append(str(v)) + val = ','.join(val_array) + filters_builder.append("[%s$eq%s]" % (filt[0], val)) + elif type(val) is int: + filters_builder.append("[%s$eq%d]" % (filt[0], val)) + elif type(val) is bool: + if val: + filters_builder.append("[%s$eqtrue]" % (filt[0])) + else: + filters_builder.append("[%s$eqfalse]" % (filt[0])) + elif type(val) is str: + filters_builder.append("[%s$eq%s]" % (filt[0], filt[1])) + elif len(filt) == 3: + filters_builder.append("[%s%s%s]" % (filt[0], + filt[1], + str(filt[2]))) + + return ''.join(filters_builder) + + def quote_string(self, the_string): + '''Wrap a clause in single quotes. + + Parameters + ---------- + the_string : string + a clause to be included in an rma query that needs to be quoted + + Returns + ------- + string + input wrapped in single quotes + ''' + return ''.join(["'", the_string, "'"]) + + def filter(self, key, value): + '''serialize a single RMA query filter clause. + + Parameters + ---------- + key : string + keys for narrowing a query. + value : string + value for narrowing a query. + + Returns + ------- + string + a single filter clause for an RMA query string. + ''' + return "".join(['[', + key, + RmaApi.EQ, + str(value), + ']']) + + def build_schema_query(self, clazz=None, fmt='json'): + '''Build the URL that will fetch the data schema. + + Parameters + ---------- + clazz : string, optional + Name of a specific class or None (default). + fmt : string, optional + json (default) or xml + + Returns + ------- + url : string + The constructed URL + + Notes + ----- + If a class is specified, only the schema information for that class + will be requested, otherwise the url requests the entire schema. + ''' + if clazz is not None: + class_clause = '/' + clazz + else: + class_clause = '' + + url = ''.join([self.rma_endpoint, + class_clause, + '.', + fmt]) + + return url + + def get_schema(self, clazz=None): + '''Retrieve schema information.''' + schema_data = self.do_query(self.build_schema_query, + self.read_data, + clazz) + + return schema_data + + + +class RmaTemplate(RmaApi): + ''' + See: `Atlas Drawings and Ontologies + <http://help.brain-map.org/display/api/Atlas+Drawings+and+Ontologies>`_ + ''' + + def __init__(self, base_uri=None, query_manifest=None): + super(RmaTemplate, self).__init__(base_uri) + self.templates = query_manifest + + def to_filter_rhs(self, rhs): + if type(rhs) == list: + return ','.join(str(r) for r in rhs) + + return rhs + + def template_query(self, template_name, entry_name, **kwargs): + cb = self.templates[template_name] + templates = [e for e in cb if e['name'] == entry_name] + + if len(templates) > 0: + template = templates[0] + else: + raise Exception('Entry %s not found.' % (entry_name)) + + query_args = {'model': template['model']} + + if 'criteria' in template: + criteria_template = Template(template['criteria']) + + if 'criteria_params' in template: + criteria_params = {key: self.to_filter_rhs(kwargs.get(key)) + for key in template['criteria_params'] + if key in kwargs and kwargs.get(key) is not None} + else: + criteria_params = {} + + criteria_str = str(criteria_template.render(**criteria_params)) + if criteria_str: + query_args['criteria'] = criteria_str + + if 'include' in template: + include_template = Template(template['include']) + + if 'include_params' in template: + include_params = {key: self.to_filter_rhs(kwargs.get(key)) + for key in template['include_params'] + if key in kwargs and kwargs.get(key) is not None} + else: + include_params = {} + + include_str = str(include_template.render(**include_params)) + if include_str: + query_args['include'] = include_str + + if 'only' in kwargs: + if kwargs.get('only') is not None: + query_args['only'] = [self.quote_string( + ','.join(kwargs.get('only')))] + elif 'only' in template: + query_args['only'] = [ + self.quote_string(','.join(template['only']))] + + if 'except' in kwargs: + if kwargs.get('except') is not None: + query_args['except'] = [self.quote_string( + ','.join(kwargs.get('except')))] + elif 'except' in template: + query_args['except'] = template['except'] + + if 'start_row' in kwargs: + query_args['start_row'] = kwargs.get('start_row') + elif 'start_row' in template: + query_args['start_row'] = template['start_row'] + + if 'num_rows' in kwargs: + query_args['num_rows'] = kwargs.get('num_rows') + elif 'num_rows' in template: + query_args['num_rows'] = template['num_rows'] + + if 'count' in kwargs: + query_args['count'] = kwargs.get('count') + elif 'count' in template: + query_args['count'] = template['count'] + + if 'order' in kwargs: + query_args['order'] = kwargs.get('order') + elif 'order' in template: + query_args['order'] = template['order'] + + query_args.update(kwargs) + + data = self.model_query(**query_args) + + return data diff --git a/bmtk/utils/brain_observatory/stimulus_info.py b/bmtk/utils/brain_observatory/stimulus_info.py new file mode 100755 index 00000000..2f5a145f --- /dev/null +++ b/bmtk/utils/brain_observatory/stimulus_info.py @@ -0,0 +1,1143 @@ +# Allen Institute Software License - This software license is the 2-clause BSD +# license plus a third clause that prohibits redistribution for commercial +# purposes without further permission. +# +# Copyright 2017. Allen Institute. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Redistributions for commercial purposes are not permitted without the +# Allen Institute's written permission. +# For purposes of this license, commercial purposes is the incorporation of the +# Allen Institute's software into anything for which you will charge fees or +# other compensation. Contact terms@alleninstitute.org for commercial licensing +# opportunities. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. +# +import itertools + +import numpy as np +import scipy.ndimage.interpolation as spndi +import six +# from allensdk.api.warehouse_cache.cache import memoize +from PIL import Image + +# some handles for stimulus types +DRIFTING_GRATINGS = "drifting_gratings" +DRIFTING_GRATINGS_SHORT = "dg" +DRIFTING_GRATINGS_COLOR = "#a6cee3" + +STATIC_GRATINGS = "static_gratings" +STATIC_GRATINGS_SHORT = "sg" +STATIC_GRATINGS_COLOR = "#1f78b4" + +NATURAL_MOVIE_ONE = "natural_movie_one" +NATURAL_MOVIE_ONE_SHORT = "nm1" +NATURAL_MOVIE_ONE_COLOR = "#b2df8a" + +NATURAL_MOVIE_TWO = "natural_movie_two" +NATURAL_MOVIE_TWO_SHORT = "nm2" +NATURAL_MOVIE_TWO_COLOR = "#33a02c" + +NATURAL_MOVIE_THREE = "natural_movie_three" +NATURAL_MOVIE_THREE_SHORT = "nm3" +NATURAL_MOVIE_THREE_COLOR = "#fb9a99" + +NATURAL_SCENES = "natural_scenes" +NATURAL_SCENES_SHORT = "ns" +NATURAL_SCENES_COLOR = "#e31a1c" + +# note that this stimulus is equivalent to LOCALLY_SPARSE_NOISE_4DEG in session +# C2 files +LOCALLY_SPARSE_NOISE = "locally_sparse_noise" +LOCALLY_SPARSE_NOISE_SHORT = "lsn" +LOCALLY_SPARSE_NOISE_COLOR = "#fdbf6f" + +LOCALLY_SPARSE_NOISE_4DEG = "locally_sparse_noise_4deg" +LOCALLY_SPARSE_NOISE_4DEG_SHORT = "lsn4" +LOCALLY_SPARSE_NOISE_4DEG_COLOR = "#fdbf6f" + +LOCALLY_SPARSE_NOISE_8DEG = "locally_sparse_noise_8deg" +LOCALLY_SPARSE_NOISE_8DEG_SHORT = "lsn8" +LOCALLY_SPARSE_NOISE_8DEG_COLOR = "#ff7f00" + +SPONTANEOUS_ACTIVITY = "spontaneous" +SPONTANEOUS_ACTIVITY_SHORT = "sp" +SPONTANEOUS_ACTIVITY_COLOR = "#cab2d6" + +# handles for stimulus names +THREE_SESSION_A = "three_session_A" +THREE_SESSION_B = "three_session_B" +THREE_SESSION_C = "three_session_C" +THREE_SESSION_C2 = "three_session_C2" + +SESSION_LIST = [ + THREE_SESSION_A, + THREE_SESSION_B, + THREE_SESSION_C, + THREE_SESSION_C2, +] + +SESSION_STIMULUS_MAP = { + THREE_SESSION_A: [ + DRIFTING_GRATINGS, + NATURAL_MOVIE_ONE, + NATURAL_MOVIE_THREE, + SPONTANEOUS_ACTIVITY, + ], + THREE_SESSION_B: [ + STATIC_GRATINGS, + NATURAL_SCENES, + NATURAL_MOVIE_ONE, + SPONTANEOUS_ACTIVITY, + ], + THREE_SESSION_C: [ + LOCALLY_SPARSE_NOISE, + NATURAL_MOVIE_ONE, + NATURAL_MOVIE_TWO, + SPONTANEOUS_ACTIVITY, + ], + THREE_SESSION_C2: [ + LOCALLY_SPARSE_NOISE_4DEG, + LOCALLY_SPARSE_NOISE_8DEG, + NATURAL_MOVIE_ONE, + NATURAL_MOVIE_TWO, + SPONTANEOUS_ACTIVITY, + ], +} + +LOCALLY_SPARSE_NOISE_STIMULUS_TYPES = [ + LOCALLY_SPARSE_NOISE, + LOCALLY_SPARSE_NOISE_4DEG, + LOCALLY_SPARSE_NOISE_8DEG, +] +NATURAL_MOVIE_STIMULUS_TYPES = [ + NATURAL_MOVIE_ONE, + NATURAL_MOVIE_TWO, + NATURAL_MOVIE_THREE, +] + +LOCALLY_SPARSE_NOISE_DIMENSIONS = { + LOCALLY_SPARSE_NOISE: [16, 28], + LOCALLY_SPARSE_NOISE_4DEG: [16, 28], + LOCALLY_SPARSE_NOISE_8DEG: [8, 14], +} + +LOCALLY_SPARSE_NOISE_PIXELS = { + LOCALLY_SPARSE_NOISE: 45, + LOCALLY_SPARSE_NOISE_4DEG: 45, + LOCALLY_SPARSE_NOISE_8DEG: 90, +} + +NATURAL_SCENES_PIXELS = (918, 1174) +NATURAL_MOVIE_PIXELS = (1080, 1920) +NATURAL_MOVIE_DIMENSIONS = (304, 608) + +MONITOR_DIMENSIONS = (1200, 1920) +MONITOR_DISTANCE = 15 + +STIMULUS_GRAY = 127 +STIMULUS_BITDEPTH = 8 + +# Note: the "8deg" stimulus is actually 9.3 visual degrees on a side +LOCALLY_SPARSE_NOISE_PIXEL_SIZE = { + LOCALLY_SPARSE_NOISE: 4.65, + LOCALLY_SPARSE_NOISE_4DEG: 4.65, + LOCALLY_SPARSE_NOISE_8DEG: 9.3, +} + +RADIANS_TO_DEGREES = 57.2958 + + +def sessions_with_stimulus(stimulus): + """Return the names of the sessions that contain a given stimulus.""" + + sessions = set() + for session, session_stimuli in six.iteritems(SESSION_STIMULUS_MAP): + if stimulus in session_stimuli: + sessions.add(session) + + return sorted(list(sessions)) + + +def stimuli_in_session(session, allow_unknown=True): + """Return a list what stimuli are available in a given session. + + Parameters + ---------- + session: string + Must be one of: [ + stimulus_info.THREE_SESSION_A, + stimulus_info.THREE_SESSION_B, + stimulus_info.THREE_SESSION_C, + stimulus_info.THREE_SESSION_C2 + ] + """ + try: + return SESSION_STIMULUS_MAP[session] + except KeyError as e: + if allow_unknown: + return [] + else: + raise e + + +def all_stimuli(): + """Return a list of all stimuli in the data set""" + return set( + [v for k, vl in six.iteritems(SESSION_STIMULUS_MAP) for v in vl] + ) + + +class BinaryIntervalSearchTree(object): + @staticmethod + def from_df(input_df): + search_list = input_df.to_dict("records") + + new_list = [] + for x in search_list: + if x["start"] == x["end"]: + new_list.append((x["start"], x["end"], x)) + else: + # -.01 prevents endpoint-overlapping intervals; assigns ties to + # intervals that start at requested index + new_list.append((x["start"], x["end"] - 0.01, x)) + return BinaryIntervalSearchTree(new_list) + + def __init__(self, search_list): + """Create a binary tree to search for a point within a list of + intervals. Assumes that the intervals are non-overlapping. If two + intervals share an endpoint, the left-side wins the tie. + + :param search_list: list of interval tuples; in the tuple, first + element is interval start, then interval end (inclusive), then the + return value for the lookup + + Example: + bist = BinaryIntervalSearchTree([(0,.5,'A'), (1,2,'B')]) + print(bist.search(1.5)) + """ + + # Double-check that the list is sorted + search_list = sorted(search_list, key=lambda x: x[0]) + + # Check that the intervals are non-overlapping (except potentially at + # the end point) + for x, y in zip(search_list[:-1], search_list[1:]): + assert x[1] <= y[0] + + self.data = {} + self.add(search_list) + + def add(self, input_list, tmp=None): + if tmp is None: + tmp = [] + + if len(input_list) == 1: + self.data[tuple(tmp)] = input_list[0] + else: + self.add(input_list[: int(len(input_list) / 2)], tmp=tmp + [0]) + self.add(input_list[int(len(input_list) / 2) :], tmp=tmp + [1]) + self.data[tuple(tmp)] = input_list[int(len(input_list) / 2) - 1] + + def search(self, fi, tmp=None): + if tmp is None: + tmp = [] + + if (self.data[tuple(tmp)][0] <= fi) and ( + fi <= self.data[tuple(tmp)][1] + ): + return_val = self.data[tuple(tmp)] + elif fi < self.data[tuple(tmp)][1]: + return_val = self.search(fi, tmp=tmp + [0]) + else: + return_val = self.search(fi, tmp=tmp + [1]) + + assert (return_val[0] <= fi) and (fi <= return_val[1]) + return return_val + + +class StimulusSearch(object): + def __init__(self, nwb_dataset): + self.nwb_data = nwb_dataset + self.epoch_df = nwb_dataset.get_stimulus_epoch_table() + self.master_df = nwb_dataset.get_stimulus_table("master") + self.epoch_bst = BinaryIntervalSearchTree.from_df(self.epoch_df) + self.master_bst = BinaryIntervalSearchTree.from_df(self.master_df) + + def search(self, fi): + try: + # Look in fine-grain tree: + search_result = self.master_bst.search(fi) + return search_result + except KeyError: + # Current frame not found in a fine-grain interval; + # see if it is unregistered to a coarse-grain epoch: + try: + # THis will thow KeyError if not in coarse-grain epoch + self.epoch_bst.search(fi) + + # Frame is in a coarse-grain epoch, but not a fine grain + # interval; look backwards to find most recent find nearest + # matching interval + if fi < self.epoch_df.iloc[0]["start"]: + return None + else: + return self.search(fi - 1) + + except KeyError: + # Frame is unregistered at the coarse level; return None + return None + + +def rotate(X, Y, theta): + x = np.array([X, Y]) + M = np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] + ) + if len(x.shape) in [1, 2]: + assert x.shape[0] == 2 + return M.dot(x) + elif len(x.shape) == 3: + M2 = M[:, :, np.newaxis, np.newaxis] + x2 = x[np.newaxis, :, :] + return (M2 * x2).sum(axis=1) + else: + raise NotImplementedError + + +def get_spatial_grating( + height=None, + aspect_ratio=None, + ori=None, + pix_per_cycle=None, + phase=None, + p2p_amp=2, + baseline=0, +): + aspect_ratio = float(aspect_ratio) + _height_prime = 100 + + sf = 1.0 / (float(pix_per_cycle) / (height / float(_height_prime))) + + # Final height set by zoom below: + y, x = (_height_prime, _height_prime * aspect_ratio) + + theta = ori * np.pi / 180.0 # convert to radians + + ph = phase * np.pi * 2.0 + + X, Y = np.meshgrid(np.arange(x), np.arange(y)) + X = X - x / 2 + Y = Y - y / 2 + Xp, Yp = rotate(X, Y, theta) + + img = np.cos(2.0 * np.pi * Xp * sf + ph) + + return (p2p_amp / 2.0) * spndi.zoom( + img, height / float(_height_prime) + ) + baseline + + +# def grating_to_screen(self, phase, spatial_frequency, orientation, **kwargs): + + +def get_spatio_temporal_grating(t, temporal_frequency=None, **kwargs): + kwargs["phase"] = ( + kwargs.pop("phase", 0) + (float(t) * temporal_frequency) % 1 + ) + + return get_spatial_grating(**kwargs) + + +def map_template_coordinate_to_monitor_coordinate( + template_coord, monitor_shape, template_shape +): + rx, cx = template_coord + n_pixels_r, n_pixels_c = monitor_shape + tr, tc = template_shape + + rx_new = float((n_pixels_r - tr) / 2) + rx + cx_new = float((n_pixels_c - tc) / 2) + cx + + return rx_new, cx_new + + +def map_monitor_coordinate_to_template_coordinate( + monitor_coord, monitor_shape, template_shape +): + rx, cx = monitor_coord + n_pixels_r, n_pixels_c = monitor_shape + tr, tc = template_shape + + rx_new = rx - float((n_pixels_r - tr) / 2) + cx_new = cx - float((n_pixels_c - tc) / 2) + + return rx_new, cx_new + + +def lsn_coordinate_to_monitor_coordinate( + lsn_coordinate, monitor_shape, stimulus_type +): + template_shape = LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus_type] + pixels_per_patch = LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type] + + rx, cx = lsn_coordinate + tr, tc = template_shape + + return map_template_coordinate_to_monitor_coordinate( + (rx * pixels_per_patch, cx * pixels_per_patch), + monitor_shape, + (tr * pixels_per_patch, tc * pixels_per_patch), + ) + + +def monitor_coordinate_to_lsn_coordinate( + monitor_coordinate, monitor_shape, stimulus_type +): + pixels_per_patch = LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type] + tr, tc = LOCALLY_SPARSE_NOISE_DIMENSIONS[stimulus_type] + + rx, cx = map_monitor_coordinate_to_template_coordinate( + monitor_coordinate, + monitor_shape, + (tr * pixels_per_patch, tc * pixels_per_patch), + ) + + return (rx / pixels_per_patch, cx / pixels_per_patch) + + +def natural_scene_coordinate_to_monitor_coordinate( + natural_scene_coordinate, monitor_shape +): + return map_template_coordinate_to_monitor_coordinate( + natural_scene_coordinate, monitor_shape, NATURAL_SCENES_PIXELS + ) + + +def natural_movie_coordinate_to_monitor_coordinate( + natural_movie_coordinate, monitor_shape +): + local_y = ( + 1.0 + * NATURAL_MOVIE_PIXELS[0] + * natural_movie_coordinate[0] + / NATURAL_MOVIE_DIMENSIONS[0] + ) + local_x = ( + 1.0 + * NATURAL_MOVIE_PIXELS[1] + * natural_movie_coordinate[1] + / NATURAL_MOVIE_DIMENSIONS[1] + ) + + return map_template_coordinate_to_monitor_coordinate( + (local_y, local_x), monitor_shape, NATURAL_MOVIE_PIXELS + ) + + +def map_stimulus_coordinate_to_monitor_coordinate( + template_coordinate, monitor_shape, stimulus_type +): + if stimulus_type in LOCALLY_SPARSE_NOISE_STIMULUS_TYPES: + return lsn_coordinate_to_monitor_coordinate( + template_coordinate, monitor_shape, stimulus_type + ) + elif stimulus_type in NATURAL_MOVIE_STIMULUS_TYPES: + return natural_movie_coordinate_to_monitor_coordinate( + template_coordinate, monitor_shape + ) + elif stimulus_type == NATURAL_SCENES: + return natural_scene_coordinate_to_monitor_coordinate( + template_coordinate, monitor_shape + ) + elif stimulus_type in [ + DRIFTING_GRATINGS, + STATIC_GRATINGS, + SPONTANEOUS_ACTIVITY, + ]: + return template_coordinate + else: + raise NotImplementedError # pragma: no cover + + +def monitor_coordinate_to_natural_movie_coordinate( + monitor_coordinate, monitor_shape +): + local_y, local_x = map_monitor_coordinate_to_template_coordinate( + monitor_coordinate, monitor_shape, NATURAL_MOVIE_PIXELS + ) + + return ( + float(NATURAL_MOVIE_DIMENSIONS[0]) * local_y / NATURAL_MOVIE_PIXELS[0], + float(NATURAL_MOVIE_DIMENSIONS[1]) * local_x / NATURAL_MOVIE_PIXELS[1], + ) + + +def map_monitor_coordinate_to_stimulus_coordinate( + monitor_coordinate, monitor_shape, stimulus_type +): + if stimulus_type in LOCALLY_SPARSE_NOISE_STIMULUS_TYPES: + return monitor_coordinate_to_lsn_coordinate( + monitor_coordinate, monitor_shape, stimulus_type + ) + elif stimulus_type == NATURAL_SCENES: + return map_monitor_coordinate_to_template_coordinate( + monitor_coordinate, monitor_shape, NATURAL_SCENES_PIXELS + ) + elif stimulus_type in NATURAL_MOVIE_STIMULUS_TYPES: + return monitor_coordinate_to_natural_movie_coordinate( + monitor_coordinate, monitor_shape + ) + elif stimulus_type in [ + DRIFTING_GRATINGS, + STATIC_GRATINGS, + SPONTANEOUS_ACTIVITY, + ]: + return monitor_coordinate + else: + raise NotImplementedError # pragma: no cover + + +def map_stimulus( + source_stimulus_coordinate, + source_stimulus_type, + target_stimulus_type, + monitor_shape, +): + mc = map_stimulus_coordinate_to_monitor_coordinate( + source_stimulus_coordinate, monitor_shape, source_stimulus_type + ) + return map_monitor_coordinate_to_stimulus_coordinate( + mc, monitor_shape, target_stimulus_type + ) + + +def translate_image_and_fill(img, translation=(0, 0)): + # first coordinate is horizontal, second is vertical + + roll = (int(translation[0]), -int(translation[1])) + + im2 = np.roll(img, roll, (1, 0)) + + if roll[1] >= 0: + im2[: roll[1], :] = STIMULUS_GRAY + else: + im2[roll[1] :, :] = STIMULUS_GRAY + + if roll[0] >= 0: + im2[:, : roll[0]] = STIMULUS_GRAY + else: + im2[:, roll[0] :] = STIMULUS_GRAY + + return im2 + + +class Monitor(object): + def __init__(self, n_pixels_r, n_pixels_c, panel_size, spatial_unit): + self.spatial_unit = spatial_unit + if spatial_unit == "cm": + self.spatial_conversion_factor = 1.0 + else: + raise NotImplementedError # pragma: no cover + + self._panel_size = panel_size + self.n_pixels_r = n_pixels_r + self.n_pixels_c = n_pixels_c + self._mask = None + + @property + def mask(self): + if self._mask is None: + self._mask = self.get_mask() + return self._mask + + @property + def panel_size(self): + return self._panel_size * self.spatial_conversion_factor + + @property + def aspect_ratio(self): + return float(self.n_pixels_c) / self.n_pixels_r + + @property + def height(self): + return self.spatial_conversion_factor * np.sqrt( + self.panel_size**2 / (1 + self.aspect_ratio**2) + ) + + @property + def width(self): + return self.height * self.aspect_ratio + + def set_spatial_unit(self, new_unit): + if new_unit == self.spatial_unit: + pass + elif new_unit == "inch" and self.spatial_unit == "cm": + self.spatial_conversion_factor *= 0.393701 + elif new_unit == "cm" and self.spatial_unit == "inch": + self.spatial_conversion_factor *= 1.0 / 0.393701 + else: + raise NotImplementedError # pragma: no cover + self.spatial_unit = new_unit + + @property + def pixel_size(self): + return float(self.width) / self.n_pixels_c + + def pixels_to_visual_degrees( + self, n, distance_from_monitor, small_angle_approximation=True + ): + if small_angle_approximation: + return ( + n + * self.pixel_size + / distance_from_monitor + * RADIANS_TO_DEGREES + ) # radians to degrees + else: + return ( + 2 + * np.arctan( + n * 1.0 / 2 * self.pixel_size / distance_from_monitor + ) + * RADIANS_TO_DEGREES + ) # radians to degrees + + def visual_degrees_to_pixels( + self, vd, distance_from_monitor, small_angle_approximation=True + ): + if small_angle_approximation: + return vd * ( + distance_from_monitor / self.pixel_size / RADIANS_TO_DEGREES + ) + else: + raise NotImplementedError + + def lsn_image_to_screen( + self, + img, + stimulus_type, + origin="lower", + background_color=STIMULUS_GRAY, + translation=(0, 0), + ): + # assert img.dtype == np.uint8 + + full_image = np.full( + (self.n_pixels_r, self.n_pixels_c), + background_color, + dtype=np.uint8, + ) + + pixels_per_patch = float(LOCALLY_SPARSE_NOISE_PIXELS[stimulus_type]) + target_size = tuple( + int(pixels_per_patch * dimsize) for dimsize in img.shape[::-1] + ) + img_full_res = np.array( + Image.fromarray(img).resize(target_size, 0) + ) # 0 -> nearest neighbor interpolator + + mr, mc = lsn_coordinate_to_monitor_coordinate( + (0, 0), (self.n_pixels_r, self.n_pixels_c), stimulus_type + ) + Mr, Mc = lsn_coordinate_to_monitor_coordinate( + img.shape, (self.n_pixels_r, self.n_pixels_c), stimulus_type + ) + full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img_full_res + + full_image = translate_image_and_fill( + full_image, translation=translation + ) + + if origin == "lower": + return full_image + elif origin == "upper": + return np.flipud(full_image) + else: + raise Exception + + return full_image + + def natural_scene_image_to_screen( + self, img, origin="lower", translation=(0, 0) + ): + full_image = np.full( + (self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8 + ) + mr, mc = natural_scene_coordinate_to_monitor_coordinate( + (0, 0), (self.n_pixels_r, self.n_pixels_c) + ) + Mr, Mc = natural_scene_coordinate_to_monitor_coordinate( + (img.shape[0], img.shape[1]), (self.n_pixels_r, self.n_pixels_c) + ) + full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img + + full_image = translate_image_and_fill( + full_image, translation=translation + ) + + if origin == "lower": + return np.flipud(full_image) + elif origin == "upper": + return full_image + else: + raise Exception + + def natural_movie_image_to_screen( + self, img, origin="lower", translation=(0, 0) + ): + img = np.array( + Image.fromarray(img).resize(NATURAL_MOVIE_PIXELS[::-1], 2) + ).astype( + np.uint8 + ) # 2 -> bilinear interpolator + + assert img.dtype == np.uint8 + + full_image = np.full( + (self.n_pixels_r, self.n_pixels_c), 127, dtype=np.uint8 + ) + mr, mc = map_template_coordinate_to_monitor_coordinate( + (0, 0), (self.n_pixels_r, self.n_pixels_c), NATURAL_MOVIE_PIXELS + ) + Mr, Mc = map_template_coordinate_to_monitor_coordinate( + (img.shape[0], img.shape[1]), + (self.n_pixels_r, self.n_pixels_c), + NATURAL_MOVIE_PIXELS, + ) + + full_image[int(mr) : int(Mr), int(mc) : int(Mc)] = img + + full_image = translate_image_and_fill( + full_image, translation=translation + ) + + if origin == "lower": + return np.flipud(full_image) + elif origin == "upper": + return full_image + else: + raise Exception + + def spatial_frequency_to_pix_per_cycle( + self, spatial_frequency, distance_from_monitor + ): + # How many cycles do I want to see post warp: + number_of_cycles = ( + spatial_frequency + * 2 + * np.degrees(np.arctan(self.width / 2.0 / distance_from_monitor)) + ) + + # How many pixels to I have pre-warp to place my cycles on: + _, m_col = np.where(self.mask != 0) + number_of_pixels = m_col.max() - m_col.min() + + return float(number_of_pixels) / number_of_cycles + + def grating_to_screen( + self, + phase, + spatial_frequency, + orientation, + distance_from_monitor, + p2p_amp=256, + baseline=127, + translation=(0, 0), + ): + pix_per_cycle = self.spatial_frequency_to_pix_per_cycle( + spatial_frequency, distance_from_monitor + ) + + full_image = get_spatial_grating( + height=self.n_pixels_r, + aspect_ratio=self.aspect_ratio, + ori=orientation, + pix_per_cycle=pix_per_cycle, + phase=phase, + p2p_amp=p2p_amp, + baseline=baseline, + ) + + full_image = translate_image_and_fill( + full_image, translation=translation + ) + + return full_image + + def get_mask(self): + mask = make_display_mask( + display_shape=(self.n_pixels_c, self.n_pixels_r) + ).T + assert mask.shape[0] == self.n_pixels_r + assert mask.shape[1] == self.n_pixels_c + + return mask + + def show_image( + self, img, ax=None, show=True, mask=False, warp=False, origin="lower" + ): + import matplotlib.pyplot as plt + + assert img.shape == ( + self.n_pixels_r, + self.n_pixels_c, + ) or img.shape == (self.n_pixels_r, self.n_pixels_c, 4) + + if ax is None: + fig, ax = plt.subplots(1, 1) + + if warp: + img = self.warp_image(img) + + if warp: + assert mask is False + + ax.imshow(img, origin=origin, cmap=plt.cm.gray, interpolation="none") + + if mask: + mask = make_display_mask( + display_shape=(self.n_pixels_c, self.n_pixels_r) + ).T + alpha_mask = np.zeros((mask.shape[0], mask.shape[1], 4)) + alpha_mask[:, :, 2] = 1 - mask + alpha_mask[:, :, 3] = 0.4 + ax.imshow(alpha_mask, origin=origin, interpolation="none") + + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + + if origin == "upper": + ax.set_ylim((img.shape[0], 0)) + elif origin == "lower": + ax.set_ylim((0, img.shape[0])) + else: + raise Exception + ax.set_xlim((0, img.shape[1])) + + if show: + plt.show() + + def map_stimulus( + self, + source_stimulus_coordinate, + source_stimulus_type, + target_stimulus_type, + ): + monitor_shape = (self.n_pixels_r, self.n_pixels_c) + return map_stimulus( + source_stimulus_coordinate, + source_stimulus_type, + target_stimulus_type, + monitor_shape, + ) + + +class ExperimentGeometry(object): + def __init__( + self, distance, mon_height_cm, mon_width_cm, mon_res, eyepoint + ): + self.distance = distance + self.mon_height_cm = mon_height_cm + self.mon_width_cm = mon_width_cm + self.mon_res = mon_res + self.eyepoint = eyepoint + + self._warp_coordinates = None + + @property + def warp_coordinates(self): + if self._warp_coordinates is None: + self._warp_coordinates = self.generate_warp_coordinates() + + return self._warp_coordinates + + def generate_warp_coordinates(self): + display_shape = self.mon_res + x = np.array(range(display_shape[0])) - display_shape[0] / 2 + y = np.array(range(display_shape[1])) - display_shape[1] / 2 + display_coords = np.array(list(itertools.product(y, x))) + + warp_coorinates = warp_stimulus_coords( + display_coords, + distance=self.distance, + mon_height_cm=self.mon_height_cm, + mon_width_cm=self.mon_width_cm, + mon_res=self.mon_res, + eyepoint=self.eyepoint, + ) + + warp_coorinates[:, 0] += display_shape[1] / 2 + warp_coorinates[:, 1] += display_shape[0] / 2 + + return warp_coorinates + + +class BrainObservatoryMonitor(Monitor): + """ + http://help.brain-map.org/display/observatory/Documentation?preview=/10616846/10813485/VisualCoding_VisualStimuli.pdf # noqa: E501 + https://www.cnet.com/products/asus-pa248q/specs/ + """ + + def __init__(self, experiment_geometry=None): + height, width = MONITOR_DIMENSIONS + + super(BrainObservatoryMonitor, self).__init__( + height, width, 61.214, "cm" + ) + + if experiment_geometry is None: + self.experiment_geometry = ExperimentGeometry( + distance=float(MONITOR_DISTANCE), + mon_height_cm=self.height, + mon_width_cm=self.width, + mon_res=(self.n_pixels_c, self.n_pixels_r), + eyepoint=(0.5, 0.5), + ) + else: + self.experiment_geometry = experiment_geometry + + def lsn_image_to_screen(self, img, **kwargs): + if img.shape == tuple( + LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE] + ): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen( + img, LOCALLY_SPARSE_NOISE, **kwargs + ) + elif img.shape == tuple( + LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_4DEG] + ): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen( + img, LOCALLY_SPARSE_NOISE_4DEG, **kwargs + ) + elif img.shape == tuple( + LOCALLY_SPARSE_NOISE_DIMENSIONS[LOCALLY_SPARSE_NOISE_8DEG] + ): + return super(BrainObservatoryMonitor, self).lsn_image_to_screen( + img, LOCALLY_SPARSE_NOISE_8DEG, **kwargs + ) + else: # pragma: no cover + raise RuntimeError # pragma: no cover + + def warp_image(self, img, **kwargs): + assert img.shape == (self.n_pixels_r, self.n_pixels_c) + assert self.spatial_unit == "cm" + + return spndi.map_coordinates( + img, self.experiment_geometry.warp_coordinates.T + ).reshape((self.n_pixels_r, self.n_pixels_c)) + + def grating_to_screen( + self, phase, spatial_frequency, orientation, **kwargs + ): + return super(BrainObservatoryMonitor, self).grating_to_screen( + phase, + spatial_frequency, + orientation, + self.experiment_geometry.distance, + p2p_amp=256, + baseline=127, + **kwargs, + ) + + def pixels_to_visual_degrees(self, n, **kwargs): + return super(BrainObservatoryMonitor, self).pixels_to_visual_degrees( + n, self.experiment_geometry.distance, **kwargs + ) + + def visual_degrees_to_pixels(self, vd, **kwargs): + return super(BrainObservatoryMonitor, self).visual_degrees_to_pixels( + vd, self.experiment_geometry.distance, **kwargs + ) + + +def warp_stimulus_coords( + vertices, + distance=15.0, + mon_height_cm=32.5, + mon_width_cm=51.0, + mon_res=(1920, 1200), + eyepoint=(0.5, 0.5), +): + """ + For a list of screen vertices, provides a corresponding list of texture + coordinates. + + Parameters + ---------- + vertices: numpy.ndarray + [[x0,y0], [x1,y1], ...] A set of vertices to convert to texture + positions. + distance: float + distance from the monitor in cm. + mon_height_cm: float + monitor height in cm + mon_width_cm: float + monitor width in cm + mon_res: tuple + monitor resolution (x,y) + eyepoint: tuple + + Returns + ------- + np.ndarray + x,y coordinates shaped like the input that describe what pixel + coordinates are displayed an the input coordinates after warping the + stimulus. + + """ + + mon_width_cm = float(mon_width_cm) + mon_height_cm = float(mon_height_cm) + distance = float(distance) + mon_res_x, mon_res_y = float(mon_res[0]), float(mon_res[1]) + + vertices = vertices.astype("float") + + # from pixels (-1920/2 -> 1920/2) to stimulus space (-0.5->0.5) + vertices[:, 0] = vertices[:, 0] / mon_res_x + vertices[:, 1] = vertices[:, 1] / mon_res_y + + x = (vertices[:, 0] + 0.5) * mon_width_cm + y = (vertices[:, 1] + 0.5) * mon_height_cm + + xEye = eyepoint[0] * mon_width_cm + yEye = eyepoint[1] * mon_height_cm + + x = x - xEye + y = y - yEye + + r = np.sqrt(np.square(x) + np.square(y) + np.square(distance)) + + azimuth = np.arctan(x / distance) + altitude = np.arcsin(y / r) + + # calculate the texture coordinates + tx = distance * (1 + x / r) - distance + ty = distance * (1 + y / r) - distance + + # prevent div0 + azimuth[azimuth == 0] = np.finfo(np.float32).eps + altitude[altitude == 0] = np.finfo(np.float32).eps + + # the texture coordinates (which are now lying on the sphere) + # need to be remapped back onto the plane of the display. + # This effectively stretches the coordinates away from the eyepoint. + + centralAngle = np.arccos(np.cos(altitude) * np.cos(np.abs(azimuth))) + # distance from eyepoint to texture vertex + arcLength = centralAngle * distance + # remap the texture coordinate + theta = np.arctan2(ty, tx) + tx = arcLength * np.cos(theta) + ty = arcLength * np.sin(theta) + + u_coords = tx / mon_width_cm + v_coords = ty / mon_height_cm + + retCoords = np.column_stack((u_coords, v_coords)) + + # back to pixels + retCoords[:, 0] = retCoords[:, 0] * mon_res_x + retCoords[:, 1] = retCoords[:, 1] * mon_res_y + + return retCoords + + +def make_display_mask(display_shape=(1920, 1200)): + """Build a display-shaped mask that indicates which pixels are on screen + after warping the stimulus. + """ + x = np.array(range(display_shape[0])) - display_shape[0] / 2 + y = np.array(range(display_shape[1])) - display_shape[1] / 2 + display_coords = np.array(list(itertools.product(x, y))) + + warped_coords = warp_stimulus_coords(display_coords).astype(int) + + off_warped_coords = np.array( + [ + warped_coords[:, 0] + display_shape[0] / 2, + warped_coords[:, 1] + display_shape[1] / 2, + ] + ) + + used_coords = set() + for i in range(off_warped_coords.shape[1]): + used_coords.add((off_warped_coords[0, i], off_warped_coords[1, i])) + + used_coords = ( + np.array([x for (x, y) in used_coords]).astype(int), + np.array([y for (x, y) in used_coords]).astype(int), + ) + + mask = np.zeros(display_shape) + + mask[used_coords] = 1 + + return mask + + +def mask_stimulus_template( + template_display_coords, template_shape, display_mask=None, threshold=1.0 +): + """Build a mask for a stimulus template of a given shape and display + coordinates that indicates which part of the template is on screen after + warping. + + Parameters + ---------- + template_display_coords: list + list of (x,y) display coordinates + + template_shape: tuple + (width,height) of the display template + + display_mask: np.ndarray + boolean 2D mask indicating which display coordinates are on screen + after warping. + + threshold: float + Fraction of pixels associated with a template display coordinate that + should remain on screen to count as belonging to the mask. + + Returns + ------- + tuple: (template mask, pixel fraction) + """ + if display_mask is None: + display_mask = make_display_mask() + + frac = np.zeros(template_shape) + mask = np.zeros(template_shape, dtype=bool) + for y in range(template_shape[1]): + for x in range(template_shape[0]): + tdcm = np.where( + (template_display_coords[0, :, :] == x) + & (template_display_coords[1, :, :] == y) + ) + v = display_mask[tdcm] + f = np.sum(v) / len(v) + frac[x, y] = f + mask[x, y] = f >= threshold + + return mask, frac diff --git a/bmtk/utils/brain_observatory/utils.py b/bmtk/utils/brain_observatory/utils.py new file mode 100644 index 00000000..e220261b --- /dev/null +++ b/bmtk/utils/brain_observatory/utils.py @@ -0,0 +1,51 @@ +import ast +import numpy as np + + +class FakeTqdm: + def __init__(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + +def write_from_stream(path: str, stream): + with open(path, "wb") as fil: + for chunk in stream: + fil.write(chunk) + + +def json_handler(obj): + """Used by write_json convert a few non-standard types to things that the + json package can handle.""" + if hasattr(obj, "to_dict"): + return obj.to_dict() + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, np.floating): + return float(obj) + elif isinstance(obj, np.integer): + return int(obj) + elif isinstance(obj, bool) or isinstance(obj, np.bool_): + return bool(obj) + elif hasattr(obj, "isoformat"): + return obj.isoformat() + else: + raise TypeError( + "Object of type %s with value of %s is not JSON serializable" + % (type(obj), repr(obj)) + ) + + +def infer_column_types(dataframe): + dataframe = dataframe.copy() + + for colname in dataframe.columns: + try: + dataframe[colname] = dataframe[colname].apply(ast.literal_eval) + except (ValueError, SyntaxError): + continue + + dataframe = dataframe.infer_objects() + return dataframe \ No newline at end of file diff --git a/bmtk/utils/sonata/population.py b/bmtk/utils/sonata/population.py index e9e5c07c..a439d666 100644 --- a/bmtk/utils/sonata/population.py +++ b/bmtk/utils/sonata/population.py @@ -242,7 +242,7 @@ def to_dataframe(self, index_by_id=True): else: ret_df = pd.DataFrame() for grp_id in self.group_ids: - ret_df = ret_df.append(self.get_group(grp_id).to_dataframe(), sort=False) + ret_df = pd.concat([ret_df, self.get_group(grp_id).to_dataframe()], sort=False) if index_by_id: ret_df = ret_df.set_index('node_id')