Source code for neurotic.datasets.data

# -*- coding: utf-8 -*-
"""
The :mod:`neurotic.datasets.data` module implements a function for loading a
dataset from selected metadata.

.. autofunction:: load_dataset
"""

import inspect
from packaging import version
import numpy as np
import pandas as pd
import quantities as pq
import neo

from ..datasets.metadata import _abs_path
from ..elephant_tools import _butter, _isi, _peak_detection

[docs]def load_dataset(metadata, lazy=False, signal_group_mode='split-all', filter_events_from_epochs=False): """ Load a dataset. ``metadata`` may be a :class:`MetadataSelector <neurotic.datasets.metadata.MetadataSelector>` or a simple dictionary containing the appropriate data. The ``data_file`` in ``metadata`` is read into a Neo :class:`Block <neo.core.Block>` using an automatically detected :mod:`neo.io` class if ``lazy=False`` or a :mod:`neo.rawio` class if ``lazy=True``. Epochs and events loaded from ``annotations_file`` and ``epoch_encoder_file`` and spike trains loaded from ``tridesclous_file`` are added to the Neo Block. If ``lazy=False``, filters given in ``metadata`` are applied to the signals and amplitude discriminators are run to detect spikes. """ # read in the electrophysiology data blk = _read_data_file(metadata, lazy, signal_group_mode) # apply filters to signals if not using lazy loading of signals if not lazy: blk = _apply_filters(metadata, blk) # copy events into epochs and vice versa epochs_from_events = [neo.Epoch(name=ev.name, times=ev.times, labels=ev.labels, durations=np.zeros_like(ev.times)) for ev in blk.segments[0].events] events_from_epochs = [neo.Event(name=ep.name, times=ep.times, labels=ep.labels) for ep in blk.segments[0].epochs] if not filter_events_from_epochs: blk.segments[0].epochs += epochs_from_events blk.segments[0].events += events_from_epochs # read in annotations annotations_dataframe = _read_annotations_file(metadata) blk.segments[0].epochs += _create_neo_epochs_from_dataframe(annotations_dataframe, metadata, _abs_path(metadata, 'annotations_file'), filter_events_from_epochs) blk.segments[0].events += _create_neo_events_from_dataframe(annotations_dataframe, metadata, _abs_path(metadata, 'annotations_file')) # read in epoch encoder file epoch_encoder_dataframe = _read_epoch_encoder_file(metadata) blk.segments[0].epochs += _create_neo_epochs_from_dataframe(epoch_encoder_dataframe, metadata, _abs_path(metadata, 'epoch_encoder_file'), filter_events_from_epochs) blk.segments[0].events += _create_neo_events_from_dataframe(epoch_encoder_dataframe, metadata, _abs_path(metadata, 'epoch_encoder_file')) # classify spikes by amplitude if not using lazy loading of signals if not lazy: blk.segments[0].spiketrains += _run_amplitude_discriminators(metadata, blk) # read in spikes identified by spike sorting using tridesclous t_start = blk.segments[0].analogsignals[0].t_start t_stop = blk.segments[0].analogsignals[0].t_stop sampling_period = blk.segments[0].analogsignals[0].sampling_period spikes_dataframe = _read_spikes_file(metadata, blk) blk.segments[0].spiketrains += _create_neo_spike_trains_from_dataframe(spikes_dataframe, metadata, t_start, t_stop, sampling_period) # identify bursts from spike trains if not using lazy loading of signals if not lazy: blk.segments[0].epochs += _run_burst_detectors(metadata, blk) # alphabetize epoch and event channels by name blk.segments[0].epochs.sort(key=lambda ep: ep.name) blk.segments[0].events.sort(key=lambda ev: ev.name) return blk
def _get_io(metadata): """ Return a :mod:`neo.io` object for reading the ``data_file`` in ``metadata``. An appropriate :mod:`neo.io` class is typically determined automatically from the file extension, but this can be overridden with the optional ``io_class`` metadata parameter. Arbitrary arguments can be passed to the :mod:`neo.io` class using the optional ``io_args`` metadata parameter. """ # prepare arguments for instantiating a Neo IO class if metadata['io_args'] is not None: io_args = metadata['io_args'].copy() if 'sampling_rate' in io_args: # AsciiSignalIO's sampling_rate must be a Quantity io_args['sampling_rate'] *= pq.Hz else: io_args = {} if metadata['io_class'] is None: try: # detect the class automatically using the file extension io = neo.io.get_io(_abs_path(metadata, 'data_file'), **io_args) except IOError as e: if len(e.args) > 0 and type(e.args[0]) is str and e.args[0].startswith('File extension'): # provide a useful error message when format detection fails raise IOError("Could not find an appropriate neo.io class " \ f"for data_file \"{metadata['data_file']}\". " \ "Try specifying one in your metadata using " \ "the io_class parameter.") else: # something else has gone wrong, like the file not being found raise e else: # use a user-specified class io_list = [io.__name__ for io in neo.io.iolist] if metadata['io_class'] not in io_list: raise ValueError(f"specified io_class \"{metadata['io_class']}\" was not found in neo.io.iolist: {io_list}") io_class_index = io_list.index(metadata['io_class']) io_class = neo.io.iolist[io_class_index] io = io_class(_abs_path(metadata, 'data_file'), **io_args) return io def _read_data_file(metadata, lazy=False, signal_group_mode='split-all'): """ Read in the ``data_file`` given in ``metadata`` using a :mod:`neo.io` class. Lazy-loading is used for signals if both ``lazy=True`` and the data file type is supported by a :mod:`neo.rawio` class; otherwise, signals are fully loaded. Lazy-loading is never used for epochs, events, and spike trains contained in the data file; these are always fully loaded. Returns a Neo :class:`Block <neo.core.Block>`. """ # get a Neo IO object appropriate for the data file type io = _get_io(metadata) # force lazy=False if lazy is not supported by the reader class if lazy and not io.support_lazy: lazy = False print(f'NOTE: Not reading signals in lazy mode because Neo\'s {io.__class__.__name__} reader does not support it.') if 'signal_group_mode' in inspect.signature(io.read_block).parameters.keys(): # - signal_group_mode='split-all' is the default because this ensures # every channel gets its own AnalogSignal, which is important for # indexing in EphyviewerConfigurator blk = io.read_block(lazy=lazy, signal_group_mode=signal_group_mode) else: # some IOs do not have signal_group_mode blk = io.read_block(lazy=lazy) # load all objects except analog signals if lazy: if version.parse(neo.__version__) >= version.parse('0.8.0'): # Neo >= 0.8.0 has proxy objects with load method for i in range(len(blk.segments[0].epochs)): epoch = blk.segments[0].epochs[i] if hasattr(epoch, 'load'): blk.segments[0].epochs[i] = epoch.load() for i in range(len(blk.segments[0].events)): event = blk.segments[0].events[i] if hasattr(event, 'load'): blk.segments[0].events[i] = event.load() for i in range(len(blk.segments[0].spiketrains)): spiketrain = blk.segments[0].spiketrains[i] if hasattr(spiketrain, 'load'): blk.segments[0].spiketrains[i] = spiketrain.load() else: # Neo < 0.8.0 does not have proxy objects neorawioclass = neo.rawio.get_rawio_class(_abs_path(metadata, 'data_file')) if neorawioclass is not None: neorawio = neorawioclass(_abs_path(metadata, 'data_file')) neorawio.parse_header() for i in range(len(blk.segments[0].epochs)): epoch = blk.segments[0].epochs[i] channel_index = next((i for i, chan in enumerate(neorawio.header['event_channels']) if chan['name'] == epoch.name and chan['type'] == b'epoch'), None) if channel_index is not None: ep_raw_times, ep_raw_durations, ep_labels = neorawio.get_event_timestamps(event_channel_index=channel_index) ep_times = neorawio.rescale_event_timestamp(ep_raw_times, dtype='float64') ep_durations = neorawio.rescale_epoch_duration(ep_raw_durations, dtype='float64') ep = neo.Epoch(times=ep_times*pq.s, durations=ep_durations*pq.s, labels=ep_labels, name=epoch.name) blk.segments[0].epochs[i] = ep for i in range(len(blk.segments[0].events)): event = blk.segments[0].events[i] channel_index = next((i for i, chan in enumerate(neorawio.header['event_channels']) if chan['name'] == event.name and chan['type'] == b'event'), None) if channel_index is not None: ev_raw_times, _, ev_labels = neorawio.get_event_timestamps(event_channel_index=channel_index) ev_times = neorawio.rescale_event_timestamp(ev_raw_times, dtype='float64') ev = neo.Event(times=ev_times*pq.s, labels=ev_labels, name=event.name) blk.segments[0].events[i] = ev for i in range(len(blk.segments[0].spiketrains)): spiketrain = blk.segments[0].spiketrains[i] channel_index = next((i for i, chan in enumerate(neorawio.header['unit_channels']) if chan['name'] == spiketrain.name), None) if channel_index is not None: st_raw_times = neorawio.get_spike_timestamps(unit_index=channel_index) st_times = neorawio.rescale_spike_timestamp(st_raw_times, dtype='float64') st = neo.SpikeTrain(times=st_times*pq.s, name=st.name) blk.segments[0].spiketrains[i] = st # convert byte labels to Unicode strings for epoch in blk.segments[0].epochs: epoch.labels = epoch.labels.astype('U') for event in blk.segments[0].events: event.labels = event.labels.astype('U') return blk def _read_annotations_file(metadata): """ Read in epochs and events from the ``annotations_file`` in ``metadata`` and return a dataframe. """ if metadata['annotations_file'] is None: return None else: # data types for each column in the file dtypes = { 'Start (s)': float, 'End (s)': float, 'Type': str, 'Label': str, } # parse the file and create a dataframe df = pd.read_csv(_abs_path(metadata, 'annotations_file'), dtype = dtypes) # increment row labels by 2 so they match the source file # which is 1-indexed and has a header df.index += 2 # discard entries with missing or negative start times bad_start = df['Start (s)'].isnull() | (df['Start (s)'] < 0) if bad_start.any(): print('NOTE: These rows will be discarded because their Start times are missing or negative:') print(df[bad_start]) df = df[~bad_start] # discard entries with end time preceding start time bad_end = df['End (s)'] < df['Start (s)'] if bad_end.any(): print('NOTE: These rows will be discarded because their End times precede their Start times:') print(df[bad_end]) df = df[~bad_end] # compute durations df.insert( column = 'Duration (s)', value = df['End (s)'] - df['Start (s)'], loc = 2, # insert after 'End (s)' ) # replace some NaNs df.fillna({ 'Duration (s)': 0, 'Type': 'Other', 'Label': '', }, inplace = True) # sort entries by time df.sort_values([ 'Start (s)', 'Duration (s)', ], inplace = True) # return the dataframe return df def _read_epoch_encoder_file(metadata): """ Read in epochs from the ``epoch_encoder_file`` in ``metadata`` and return a dataframe. """ if metadata['epoch_encoder_file'] is None: return None else: # data types for each column in the file dtypes = { 'Start (s)': float, 'End (s)': float, 'Type': str, } # parse the file and create a dataframe df = pd.read_csv(_abs_path(metadata, 'epoch_encoder_file'), dtype = dtypes) # increment row labels by 2 so they match the source file # which is 1-indexed and has a header df.index += 2 # discard entries with missing or negative start times bad_start = df['Start (s)'].isnull() | (df['Start (s)'] < 0) if bad_start.any(): print('NOTE: These rows will be discarded because their Start times are missing or negative:') print(df[bad_start]) df = df[~bad_start] # discard entries with end time preceding start time bad_end = df['End (s)'] < df['Start (s)'] if bad_end.any(): print('NOTE: These rows will be discarded because their End times precede their Start times:') print(df[bad_end]) df = df[~bad_end] # compute durations df.insert( column = 'Duration (s)', value = df['End (s)'] - df['Start (s)'], loc = 2, # insert after 'End (s)' ) # replace some NaNs df.fillna({ 'Duration (s)': 0, 'Type': 'Other', }, inplace = True) # sort entries by time df.sort_values([ 'Start (s)', 'Duration (s)', ], inplace = True) # add 'Label' column to indicate where these epochs came from df.insert( column = 'Label', value = '(from epoch encoder file)', loc = 4, # insert after 'Type' ) # return the dataframe return df def _read_spikes_file(metadata, blk): """ Read in spikes identified by spike sorting with tridesclous and return a dataframe. """ if metadata['tridesclous_file'] is None or metadata['tridesclous_channels'] is None: return None else: # parse the file and create a dataframe df = pd.read_csv(_abs_path(metadata, 'tridesclous_file'), names = ['index', 'label']) # drop clusters with negative labels df = df[df['label'] >= 0] if metadata['tridesclous_merge']: # merge some clusters and drop all others new_labels = [] for clusters_to_merge in metadata['tridesclous_merge']: new_label = clusters_to_merge[0] new_labels.append(new_label) df.loc[df['label'].isin(clusters_to_merge), 'label'] = new_label df = df[df['label'].isin(new_labels)] # return the dataframe return df def _create_neo_epochs_from_dataframe(dataframe, metadata, file_origin, filter_events_from_epochs=False): """ Convert the contents of a dataframe into Neo :class:`Epochs <neo.core.Epoch>`. """ epochs_list = [] if dataframe is not None: if filter_events_from_epochs: # keep only rows with a positive duration dataframe = dataframe[dataframe['Duration (s)'] > 0] # group epochs by type for type_name, df in dataframe.groupby('Type'): # create a Neo Epoch for each type epoch = neo.Epoch( name = type_name, file_origin = file_origin, times = df['Start (s)'].values * pq.s, durations = df['Duration (s)'].values * pq.s, labels = df['Label'].values, ) epochs_list.append(epoch) # return the list of Neo Epochs return epochs_list def _create_neo_events_from_dataframe(dataframe, metadata, file_origin): """ Convert the contents of a dataframe into Neo :class:`Events <neo.core.Event>`. """ events_list = [] if dataframe is not None: # group events by type for type_name, df in dataframe.groupby('Type'): # create a Neo Event for each type event = neo.Event( name = type_name, file_origin = file_origin, times = df['Start (s)'].values * pq.s, labels = df['Label'].values, ) events_list.append(event) # return the list of Neo Events return events_list def _create_neo_spike_trains_from_dataframe(dataframe, metadata, t_start, t_stop, sampling_period): """ Convert the contents of a dataframe into Neo :class:`SpikeTrains <neo.core.SpikeTrain>`. """ spiketrain_list = [] if dataframe is not None: # group spikes by cluster label for spike_label, df in dataframe.groupby('label'): # look up the channels that this unit was found on channels = metadata['tridesclous_channels'][spike_label] # create a Neo SpikeTrain for each cluster label st = neo.SpikeTrain( name = str(spike_label), file_origin = _abs_path(metadata, 'tridesclous_file'), channels = channels, # custom annotation amplitude = None, # custom annotation times = t_start + sampling_period * df['index'].values, t_start = t_start, t_stop = t_stop, ) spiketrain_list.append(st) return spiketrain_list def _apply_filters(metadata, blk): """ Apply filters specified in ``metadata`` to the signals in ``blk``. """ if metadata['filters'] is not None: signalNameToIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)} for sig_filter in metadata['filters']: index = signalNameToIndex.get(sig_filter['channel'], None) if index is None: print('Warning: skipping filter with channel name {} because channel was not found!'.format(sig_filter['channel'])) else: high = sig_filter.get('highpass', None) low = sig_filter.get('lowpass', None) if high: high *= pq.Hz if low: low *= pq.Hz blk.segments[0].analogsignals[index] = _butter( signal = blk.segments[0].analogsignals[index], highpass_freq = high, lowpass_freq = low, ) return blk def _run_amplitude_discriminators(metadata, blk): """ Run all amplitude discriminators for spike detection given in ``metadata`` on the signals in ``blk``. """ spiketrain_list = [] if metadata['amplitude_discriminators'] is not None: signalNameToIndex = {sig.name:i for i, sig in enumerate(blk.segments[0].analogsignals)} epochs = blk.segments[0].epochs # classify spikes by amplitude for discriminator in metadata['amplitude_discriminators']: index = signalNameToIndex.get(discriminator['channel'], None) if index is None: print('Warning: skipping amplitude discriminator with channel name {} because channel was not found!'.format(discriminator['channel'])) else: sig = blk.segments[0].analogsignals[index] st = _detect_spikes(sig, discriminator, epochs) spiketrain_list.append(st) return spiketrain_list def _detect_spikes(sig, discriminator, epochs): """ Detect spikes in the amplitude window given by ``discriminator`` and optionally filter them by coincidence with epochs of a given name. """ assert sig.name == discriminator['channel'], 'sig name "{}" does not match amplitude discriminator channel "{}"'.format(sig.name, discriminator['channel']) min_threshold = min(discriminator['amplitude']) max_threshold = max(discriminator['amplitude']) if min_threshold >= 0 and max_threshold > 0: sign = 'above' elif min_threshold < 0 and max_threshold <= 0: sign = 'below' else: raise ValueError('amplitude discriminator must have two nonnegative thresholds or two nonpositive thresholds: {}'.format(discriminator)) spikes_crossing_min = _peak_detection(sig, pq.Quantity(min_threshold, discriminator['units']), sign, 'raw') spikes_crossing_max = _peak_detection(sig, pq.Quantity(max_threshold, discriminator['units']), sign, 'raw') if sign == 'above': spikes_between_min_and_max = np.setdiff1d(spikes_crossing_min, spikes_crossing_max) elif sign == 'below': spikes_between_min_and_max = np.setdiff1d(spikes_crossing_max, spikes_crossing_min) else: raise ValueError('sign should be "above" or "below": {}'.format(sign)) st = neo.SpikeTrain( name = discriminator['name'], channels = [discriminator['channel']], # custom annotation amplitude = discriminator['amplitude'], # custom annotation times = spikes_between_min_and_max * pq.s, t_start = sig.t_start, t_stop = sig.t_stop, ) if 'epoch' in discriminator: time_masks = [] if isinstance(discriminator['epoch'], str): # search for matching epochs ep = next((ep for ep in epochs if ep.name == discriminator['epoch']), None) if ep is not None: # select spike times that fall within each epoch for t_start, duration in zip(ep.times, ep.durations): t_stop = t_start + duration time_masks.append((t_start <= st) & (st < t_stop)) else: # no matching epochs found time_masks.append([False] * len(st)) else: # may eventually implement lists of ordered pairs, but # for now raise an error raise ValueError('amplitude discriminator epoch could not be handled: {}'.format(discriminator['epoch'])) # select the subset of spikes that fall within the epoch # windows st = st[np.any(time_masks, axis=0)] return st def _run_burst_detectors(metadata, blk): """ Run all burst detectors given in ``metadata`` on the spike trains in ``blk``. """ burst_list = [] if metadata['burst_detectors'] is not None: spikeTrainNameToIndex = {st.name:i for i, st in enumerate(blk.segments[0].spiketrains)} # detect bursts spikes using frequency thresholds for detector in metadata['burst_detectors']: index = spikeTrainNameToIndex.get(detector['spiketrain'], None) if index is None: print("Warning: skipping burst detector for spike train named " f"\"{detector['spiketrain']}\" because spike train was " "not found!") else: st = blk.segments[0].spiketrains[index] start_freq, stop_freq = detector['thresholds']*pq.Hz burst = _find_bursts(st, start_freq, stop_freq) burst.name = detector.get('name', detector['spiketrain'] + ' burst') burst_list.append(burst) return burst_list def _find_bursts(st, start_freq, stop_freq): """ Find every period of time during which the instantaneous firing frequency (IFF) of the Neo :class:`SpikeTrain <neo.core.SpikeTrain>` ``st`` meets the criteria for bursting. Return the set of bursts as a Neo :class:`Epoch <neo.core.Epoch>`, with ``array_annotations['spikes']`` listing the number of spikes contained in each burst. A burst is defined as a period beginning when the IFF exceeds ``start_freq`` and ending when the IFF subsequently drops below the ``stop_freq``. Note that in general ``stop_freq`` should not exceed ``start_freq``, since otherwise bursts may not be detected. """ isi = _isi(st).rescale('s') iff = 1/isi start_mask = iff > start_freq stop_mask = iff < stop_freq times = [] durations = [] n_spikes = [] scan_index = -1 while scan_index < iff.size: start_index = None stop_index = None start_mask_indexes = np.where(start_mask)[0] start_mask_indexes = start_mask_indexes[start_mask_indexes > scan_index] if start_mask_indexes.size == 0: break start_index = start_mask_indexes[0] # first time that iff rises above start threshold stop_mask_indexes = np.where(stop_mask)[0] stop_mask_indexes = stop_mask_indexes[stop_mask_indexes > start_index] if stop_mask_indexes.size > 0: stop_index = stop_mask_indexes[0] # first time after start that iff drops below stop theshold else: stop_index = -1 # end of spike train (include all spikes after start) times.append(st[start_index].rescale('s').magnitude) durations.append((st[stop_index] - st[start_index]).rescale('s').magnitude) n_spikes.append(stop_index-start_index+1 if stop_index > 0 else st.size-start_index) if stop_index == -1: break else: scan_index = stop_index bursts = neo.Epoch( times = times*pq.s, durations = durations*pq.s, labels = [''] * len(times), array_annotations = {'spikes': n_spikes}, ) return bursts