batik / utils /data_loading.py
ncoria's picture
add main program files
ed29c11 verified
"""File for loading data into AnimalEditor"""
import io
from random import random
from os.path import splitext
from collections import OrderedDict
import numpy as np
from tempfile import NamedTemporaryFile
from .annot import Annotations
from .behavior import Behaviors
def has_extension(fname:str, extension:str|list[str]) -> bool:
"""
Checks to see if the passed in file name ends with an expected extension.
"""
_, ext = splitext(fname)
if isinstance(extension, str):
return ext == extension
elif isinstance(extension, list):
return ext in extension
def _clean_annotations(annotations):
"""
While reading in behaviors from an .annot file, sometimes channels without normally
callable keys appear (i.e. keys that are strings which name a behavior), thus this
code only accepts keys which are strings.
"""
if not annotations:
raise ValueError('No annotations found.')
clean_annot = OrderedDict()
for channel in annotations.keys():
channel_dict = OrderedDict()
for behavior_name in annotations[channel].keys():
if isinstance(behavior_name, str):
channel_dict.update({behavior_name : annotations[channel][behavior_name]})
clean_annot.update({channel: channel_dict})
return clean_annot
def load_annot_sheet_txt(fname, offset = 0):
"""
Generated a dictionary for retrieving the beginning and end frames of behaviors from
an .annot file.
Note that 0:00:00 is frame 1
Args:
fname - the path to the .annot file to be read (must be Caltech format)\n
offset - a value which offsets the start and end frame of each bout in
the sheet, as well as the absolute start and end frame of the file.
This value is optional, and is set to 0 by default
Returns:
annotations - dictionary of beginning and end frames for behaviors\n
start_time - the frame the movie started at (0:00:00 is 1)\n
end_time - the frame the movie ended at (0:00:00 is 1)\n
sample_rate - the sample rate reported within the file
"""
# from bento for python
behaviors = Behaviors()
annot_sheet = Annotations(behaviors)
annot_sheet.read(fname)
sample_rate = annot_sheet.sample_rate()
annotations = OrderedDict()
for key in annot_sheet.channel_names():
annot_behaviors = OrderedDict()
bout_names = set()
for bout in annot_sheet.channel(key): #._bouts_by_start:
bout_names.add(bout.name())
for name in bout_names:
annot_behaviors.update({name : []})
for bout in annot_sheet.channel(key): #._bouts_by_start:
start_frame = bout.start().frames + offset
end_frame = bout.end().frames + offset
bout_frames = [start_frame, end_frame]
curr_table = annot_behaviors.get(bout.name())
new_table = curr_table.append(bout_frames)
annot_behaviors.update({bout.name : new_table})
for name in bout_names:
curr_table = annot_behaviors.get(name)
beh_array = np.array(curr_table)
annot_behaviors.update({name : beh_array})
annotations.update({key : annot_behaviors})
annotations = _clean_annotations(annotations)
start_time = annot_sheet.start_frame() + offset
end_time = annot_sheet.end_frame() + offset
return annotations, start_time, end_time, sample_rate
def load_multiple_annotations(fnames):
"""
Generates a single dictionary given multiple .annot files.
"""
if not isinstance(fnames, list):
raise TypeError(f'Expected list[str], got {type(fnames)} instead.')
if not fnames:
raise ValueError('No file names passed in.')
if len(fnames) == 1:
return load_annot_sheet_txt(fnames[0])
head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt(fnames[0])
end_frame = head_end_frame
for fname in fnames[1:]:
curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt(fname, end_frame)
end_frame = curr_end_frame
for channel in curr_annot.keys():
if channel not in head_annot:
channel_dict = {}
head_annot.update({channel : channel_dict})
for behavior in curr_annot[channel].keys():
curr_behavior_bout_array = curr_annot[channel][behavior]
if channel in head_annot and behavior in head_annot[channel]:
new_bout_array = np.vstack((head_annot[channel][behavior],
curr_behavior_bout_array))
else:
new_bout_array = curr_behavior_bout_array
head_annot[channel].update({behavior : new_bout_array})
return head_annot, head_start_frame, end_frame, sample_rate
def load_annot_sheet_txt_io(uploaded_file, offset = 0):
"""
Generated a dictionary for retrieving the beginning and end frames of behaviors from
an .annot file.
Note that 0:00:00 is frame 1
Args:
fname - the path to the .annot file to be read (must be Caltech format)\n
offset - a value which offsets the start and end frame of each bout in
the sheet, as well as the absolute start and end frame of the file.
This value is optional, and is set to 0 by default
Returns:
annotations - dictionary of beginning and end frames for behaviors\n
start_time - the frame the movie started at (0:00:00 is 1)\n
end_time - the frame the movie ended at (0:00:00 is 1)\n
sample_rate - the sample rate reported within the file
"""
# from bento for python
behaviors = Behaviors()
annot_sheet = Annotations(behaviors)
annot_sheet.read_io(uploaded_file)
sample_rate = annot_sheet.sample_rate()
annotations = OrderedDict()
for key in annot_sheet.channel_names():
annot_behaviors = OrderedDict()
bout_names = set()
for bout in annot_sheet.channel(key): #._bouts_by_start:
bout_names.add(bout.name())
for name in bout_names:
annot_behaviors.update({name : []})
for bout in annot_sheet.channel(key): #._bouts_by_start:
start_frame = bout.start().frames + offset
end_frame = bout.end().frames + offset
bout_frames = [start_frame, end_frame]
curr_table = annot_behaviors.get(bout.name())
new_table = curr_table.append(bout_frames)
annot_behaviors.update({bout.name : new_table})
for name in bout_names:
curr_table = annot_behaviors.get(name)
beh_array = np.array(curr_table)
annot_behaviors.update({name : beh_array})
annotations.update({key : annot_behaviors})
annotations = _clean_annotations(annotations)
start_time = annot_sheet.start_frame() + offset
end_time = annot_sheet.end_frame() + offset
return annotations, start_time, end_time, sample_rate
def load_multiple_annotations_io(uploaded_files):
"""
Generates a single dictionary given multiple .annot files.
"""
if not isinstance(uploaded_files, list):
raise TypeError(f'Expected list, got {type(uploaded_files)} instead.')
if not uploaded_files:
raise ValueError('No file names passed in.')
if len(uploaded_files) == 1:
return load_annot_sheet_txt_io(uploaded_files[0])
head_annot, head_start_frame, head_end_frame, sample_rate = load_annot_sheet_txt_io(uploaded_files[0])
end_frame = head_end_frame
for uploaded_file in uploaded_files[1:]:
curr_annot, _, curr_end_frame, _ = load_annot_sheet_txt_io(uploaded_file, end_frame)
end_frame = curr_end_frame
for channel in curr_annot.keys():
if channel not in head_annot:
channel_dict = {}
head_annot.update({channel : channel_dict})
for behavior in curr_annot[channel].keys():
curr_behavior_bout_array = curr_annot[channel][behavior]
if channel in head_annot and behavior in head_annot[channel]:
new_bout_array = np.vstack((head_annot[channel][behavior],
curr_behavior_bout_array))
else:
new_bout_array = curr_behavior_bout_array
head_annot[channel].update({behavior : new_bout_array})
return head_annot, head_start_frame, end_frame, sample_rate