Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 158 additions & 2 deletions dlc2action/data/input_store.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright 2020-present by A. Mathis Group and contributors. All rights reserved.
#
# This project and all its files are licensed under GNU AGPLv3 or later version.
# This project and all its files are licensed under GNU AGPLv3 or later version.
# A copy is included in dlc2action/LICENSE.AGPL.
#
"""Specific realisations of `dlc2action.data.base_store.InputStore` are defined here."""
Expand Down Expand Up @@ -2480,4 +2480,160 @@ def _open_data(
.values
)
return output, None



class ESKTrackStore(FileInputStore):
"""DLC track data from EPFL Smart Kitchen, allows to choose specific set of keypoints.

Assumes the following file structure:
```
data_path
├── video1DLC1000.pickle
├── video2DLC400.pickle
├── video1_features.npy
└── video2_features.npy
```
Here `data_suffix` is `{'DLC1000.pickle', 'DLC400.pickle'}` and `feature_suffix` (optional) is `'_features.npy'`.

The feature files should to be dictionaries where keys are clip IDs (e.g. animal names) and values are
feature values (arrays of shape `(#frames, #features)`). If the arrays are shaped as `(#features, #frames)`,
set `transpose_features` to `True`.

The files can be saved with `numpy.save()` (with `.npy` extension), `torch.save()` (with `.pt` extension) or
with `pickle.dump()` (with `.pickle` or `.pkl` extension).
"""

def __init__(
self,
keypoint_type: str = "hand",
*args,
**kwargs,
):
"""Initialize a store."""
self.keypoint_type = keypoint_type
self.num_body_kpts = 17
self.num_hand_kpts = 42
self.num_eye_kpts = 10

super().__init__(
*args,
**kwargs,
)

def get_kpt_names(self):

kpt_names = [
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
]

kpt_names = (
kpt_names
+ [
f"hand_{i}"
for i in range(
self.num_body_kpts, self.num_hand_kpts + self.num_body_kpts
)
]
+ [f"eye_gaze_{i}" for i in range(self.num_eye_kpts)]
)
return np.array(kpt_names)

def get_kpt_ind(self, default_num):
"""Get the indices of the keypoints to be used."""
body_ind = list(range(17))
count = len(body_ind)
hands_ind = list(range(count, count + self.num_hand_kpts))
count += len(hands_ind)
eye_ind = list(range(count, count + self.num_eye_kpts))
body_wo_arm_ind = [0, 1, 2, 3, 4, 11, 12, 13, 14, 15, 16]
default_ind = list(range(default_num))
switcher = {
"body": body_ind,
"hands": hands_ind,
"eyes": eye_ind,
"body_hands": body_ind + hands_ind,
"body_eyes": body_ind + eye_ind,
"hands_eyes": hands_ind + eye_ind,
"body_wo_arm": body_wo_arm_ind,
}
return (
switcher.get(self.keypoint_type, default_ind),
not self.keypoint_type in switcher.keys(),
)

def _open_data(
self, filename: str, default_agent_name: str
) -> Tuple[Dict, Optional[Dict]]:
"""Load the keypoints from filename and organize them in a dictionary.

In `data_dictionary`, the keys are clip ids and the values are `pandas` dataframes with two-level indices.
The first level is the frame numbers and the second is the body part names. The dataframes should have from
two to four columns labeled `"x"`, `"y"` and (optionally) `"z"` and `"likelihood"`. Each frame should have
information on all the body parts. You don't have to filter the data in any way or fill the nans, it will
be done automatically.

Parameters
----------
filename : str
path to the pose file
default_agent_name : str
the default agent name

Returns
-------
data dictionary : dict
a dictionary where the keys are clip ids and the values are keypoint dataframes (see above for details)
metadata_dictionary : dict
a dictionary where the keys are clip ids and the values are metadata objects (can be any additional information,
like the annotator tag; for no metadata pass `None`)

"""
if filename.endswith("h5"):
temp = pd.read_hdf(filename)
temp = temp.droplevel("scorer", axis=1)
elif filename.endswith(".csv"):
temp = pd.read_csv(filename, header=[1, 2])
temp.columns.names = ["bodyparts", "coords"]
else:
raise TypeError("Invalid file type, please use .csv or .h5")

if "individuals" not in temp.columns.names:
old_idx = temp.columns.to_frame()
old_idx.insert(0, "individuals", self.default_agent_name)
temp.columns = pd.MultiIndex.from_frame(old_idx)

df = temp.stack(["individuals", "bodyparts"], future_stack=True)
idx = pd.MultiIndex.from_product(
[df.index.levels[0], df.index.levels[1], df.index.levels[2]],
names=df.index.names,
)
df = df.reindex(idx).fillna(value=0)
animals = sorted(list(df.index.levels[1]))
dic = {}
default_num = len(df.index.levels[2])
kpt_ind, is_special = self.get_kpt_ind(default_num)
kpt_names = self.get_kpt_names()
for ind in animals:
coord = df.iloc[df.index.get_level_values(1) == ind].droplevel(1)
coord = coord[["x", "y", "z", "likelihood"]]
if not is_special:
coord = coord.loc[(slice(None), kpt_names[kpt_ind]), :]
dic[ind] = coord

return dic, None
3 changes: 3 additions & 0 deletions dlc2action/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
LoadedFeaturesInputStore,
Numpy3DInputStore,
SIMBAInputStore,
ESKTrackStore,
)
from dlc2action.feature_extraction import HeatmapExtractor, KinematicExtractor
from dlc2action.loss import MS_TCN_Loss
Expand Down Expand Up @@ -85,6 +86,8 @@
"np_3d": Numpy3DInputStore,
"features": LoadedFeaturesInputStore,
"simba": SIMBAInputStore,
"esk_track": ESKTrackStore,

}

annotation_stores = {
Expand Down