Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
eb2504d
adding global_forager_id as a property of dataObject. this solves kee…
tommybotch Jun 3, 2025
4263d5a
checking that column doesn't already exist
tommybotch Jun 3, 2025
f1ae1db
updating test to consider new global_forager_id column
tommybotch Jun 3, 2025
99863ee
passed all tests
tommybotch Jun 3, 2025
c37c478
updating workflows for action/cache@v4
tommybotch Jun 3, 2025
e016174
Merge branch 'tlb-workflow' into tlb-dev
tommybotch Jun 3, 2025
235e5dc
Merge branch 'main' into tlb-dev
dimkab Jun 5, 2025
479e4e7
Merge branch 'main' into tlb-dev
dimkab Jun 5, 2025
9c00b39
removing erroneous addition to gitignore
tommybotch Jun 7, 2025
f3f5c25
Merge branch 'main' into tlb-dev
dimkab Jun 9, 2025
24776f7
restoring old files for tests -- removing addition of extra column
tommybotch Jun 9, 2025
37bf2c1
updating utils w/ stored mappings and function to apply mappings
tommybotch Jun 9, 2025
2d184c9
updating imports to be inline w lint?
tommybotch Jun 9, 2025
fa5c6b2
forgot to run make statements
tommybotch Jun 9, 2025
6ffe1d5
updating with property of local/global forager id mapping
tommybotch Jun 11, 2025
8341e87
accidentally updated too much
tommybotch Jun 11, 2025
6da6c7f
restoring old dill file
tommybotch Jun 11, 2025
4bcdda9
removing redundancy in code
tommybotch Jun 11, 2025
247780d
adding back in setting type of forager
tommybotch Jun 11, 2025
7d7b740
adding make lint format
tommybotch Jun 11, 2025
c9b2d35
reorganized utils file -- moved local/global mapping step to before s…
tommybotch Jun 24, 2025
2abc286
forgot a self tag
tommybotch Jun 24, 2025
453477d
forgot a self tag again...
tommybotch Jun 24, 2025
643c6e5
fixing needs_forager_id_mapping, and indexing within apply_forager_id…
tommybotch Jun 24, 2025
429bac8
line too long for lint...
tommybotch Jun 24, 2025
ea30c41
updating after linting
tommybotch Jun 24, 2025
8e9facb
type hints and return value consistent
dimkab Jun 24, 2025
56d2a54
updating w/ test and demonstration of index conversion
tommybotch Jun 25, 2025
aa66b6b
Merge branch 'tlb-dev' of https://github.com/BasisResearch/collab-cre…
tommybotch Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions collab/foraging/toolkit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,25 @@ def __init__(
# ensure that forager index is saved as an integer
foragersDF.loc[:, "forager"] = foragersDF.loc[:, "forager"].astype(int)

# Get unique forager IDs from the DataFrame --> we need to store these to be able
# to map back to the original IDs
self._forager_ids = foragersDF.forager.unique()

# Save the original forager IDs and map to consecutive indices if needed
if self.needs_forager_id_mapping:
warnings.warn(
f"""
Original forager indices were converted to consecutive integers starting from 0.
To access the original forager IDs, use the apply_forager_id_mapping() method.
Original IDs were: {self._forager_ids}
"""
)

# By default, convert global to local IDs
foragersDF = self._apply_forager_id_mapping(
foragersDF, local_to_global=False
)

# group dfs by forager index
foragers = [group for _, group in foragersDF.groupby("forager")]
self.num_foragers = len(foragers)
Expand Down Expand Up @@ -119,6 +138,59 @@ def calculate_step_size_max(self):

self.step_size_max = max(step_maxes)

@property
def needs_forager_id_mapping(self) -> bool:
return set(self._forager_ids) != set(range(len(self._forager_ids)))

@property
def local_to_global_map(self) -> dict:
return {
local_id: global_id for local_id, global_id in enumerate(self._forager_ids)
}

@property
def global_to_local_map(self) -> dict:
return {
global_id: local_id for local_id, global_id in enumerate(self._forager_ids)
}

def _apply_forager_id_mapping(
self, foragersDF: pd.DataFrame, local_to_global: bool = False
) -> pd.DataFrame:
"""
Apply forager ID mapping to convert between local and global IDs. Applies
directly to the foragersDF attribute.

Args:
local_to_global: If True, converts from local to global IDs. If False, converts from global to local IDs
"""

# Find current forager IDs and grab mapping
current_ids = set(foragersDF["forager"].unique())

if local_to_global:
mapping = self.local_to_global_map
else:
mapping = self.global_to_local_map

# Check if already mapped
target_ids = set(mapping.values())
if current_ids.issubset(target_ids):
warnings.warn(
"IDs are already in target format. Returning DataFrame unchanged."
)
return foragersDF

# Ensure that all current IDs are in the mapping --> otherwise throw an error
source_ids = set(mapping.keys())
if not current_ids.issubset(source_ids):
unmapped = current_ids - source_ids
raise ValueError(f"Cannot map forager IDs: {unmapped}")

# Apply the mapping to the foragersDF (in place modification)
foragersDF.loc[:, "forager"] = foragersDF["forager"].map(mapping).astype(int)
return foragersDF


def foragers_to_forager_distances(obj: dataObject) -> List[List[pd.DataFrame]]:
"""
Expand Down
181 changes: 181 additions & 0 deletions docs/experimental/collab_tests/forager_index_conversion.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Test of forager index conversation\n",
"\n",
"When forager indices are non-consecutive (i.e., [0, 2]), the dataObject internally converts these indices to consecutive (i.e., [0,1]) and stores the mapping within two variables:\n",
"- local_to_global_map: mapping back to original, non-consecutive indices\n",
"- global_to_local_map: mapping to consecutive indices (applied internally at ```__init__``` when creating dataObject) "
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import pandas as pd\n",
"\n",
"import collab.foraging.toolkit as ftk\n",
"\n",
"smoke_test = \"CI\" in os.environ\n",
"\n",
"num_svi_iters = 400 if not smoke_test else 4\n",
"num_samples = 1000 if not smoke_test else 10"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load data\n",
"\n",
"Setup for dataObject with ground-truth indices"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/tommybotch/Documents/collab-creatures/collab/foraging/toolkit/utils.py:42: UserWarning: \n",
" NaN values in data. The default behavior of predictor/score generating functions is\n",
" to ignore foragers with missing positional data. To modify, see documentation of\n",
" `derive_predictors_and_scores` and `generate_local_windows`\n",
" \n",
" warnings.warn(\n"
]
}
],
"source": [
"# load data\n",
"fish_data = pd.read_csv(\"4wpf_test.csv\")\n",
"gridMin = 0\n",
"gridMax = 300\n",
"grid_size = 50\n",
"\n",
"# scaling and subsampling\n",
"fishDF_scaled = ftk.rescale_to_grid(\n",
" fish_data, size=grid_size, gridMin=gridMin, gridMax=gridMax\n",
")\n",
"# create a test foragers object with 20 frames\n",
"num_frames = 10\n",
"foragers_object = ftk.dataObject(\n",
" fishDF_scaled,\n",
" grid_size=grid_size,\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Modify the original dataframe\n",
"\n",
"We subsample forager indices to test the global (non-consecutive) to local (consecutive) mapping"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/tommybotch/Documents/collab-creatures/collab/foraging/toolkit/utils.py:42: UserWarning: \n",
" NaN values in data. The default behavior of predictor/score generating functions is\n",
" to ignore foragers with missing positional data. To modify, see documentation of\n",
" `derive_predictors_and_scores` and `generate_local_windows`\n",
" \n",
" warnings.warn(\n",
"/Users/tommybotch/Documents/collab-creatures/collab/foraging/toolkit/utils.py:59: UserWarning: \n",
" Original forager indices were converted to consecutive integers starting from 0.\n",
" To access the original forager IDs, use the apply_forager_id_mapping() method.\n",
" Original IDs were: [0 2]\n",
" \n",
" warnings.warn(\n"
]
}
],
"source": [
"fishDF_nonconsecutive = fishDF_scaled[fishDF_scaled[\"forager\"].isin([0, 2])]\n",
"\n",
"# Given non-consecutive indices, forager indices are converted to consecutive integers and a warning is raised\n",
"foragers_nonconsecutive_obj = ftk.dataObject(fishDF_nonconsecutive, grid_size=grid_size)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Demonstrate application of local-global maps\n",
"\n",
"Whether mapping is needed is stored in a read-only property ```needs_forager_id_mapping```"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Original DF needed mapping: False\n",
"Non-consecutive DF needed mapping: True\n",
"Local to global map: {0: 0, 1: 2}\n",
"Global to local map: {0: 0, 2: 1}\n"
]
}
],
"source": [
"# Original object\n",
"print(f\"Original DF needed mapping: {foragers_object.needs_forager_id_mapping}\")\n",
"\n",
"# Non-consecutive object\n",
"print(\n",
" f\"Non-consecutive DF needed mapping: {foragers_nonconsecutive_obj.needs_forager_id_mapping}\"\n",
")\n",
"\n",
"# Maps are accessible as properties\n",
"# Local to global = maps back to original indices\n",
"# Global to local = maps back to consecutive indices\n",
"print(f\"Local to global map: {foragers_nonconsecutive_obj.local_to_global_map}\")\n",
"print(f\"Global to local map: {foragers_nonconsecutive_obj.global_to_local_map}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "collab",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}