Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions examples/finetuning/finetune_hela.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import os
import numpy as np

import torch

import torch_em
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model
from micro_sam.util import export_custom_sam_model, get_device
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data


Expand Down Expand Up @@ -82,7 +80,7 @@ def run_training(checkpoint_name, model_type, train_instance_segmentation):
batch_size = 1 # the training batch size
patch_shape = (1, 512, 512) # the size of patches for training
n_objects_per_batch = 25 # the number of objects per batch that will be sampled
device = torch.device("cuda") # the device used for training
device = get_device() # the device used for training

# Get the dataloaders.
train_loader = get_dataloader("train", patch_shape, batch_size, train_instance_segmentation)
Expand Down
7 changes: 4 additions & 3 deletions notebooks/sam_finetuning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@
"from torch_em.data import MinInstanceSampler\n",
"from torch_em.util.util import get_random_colors\n",
"\n",
"from micro_sam.util import get_device\n",
"import micro_sam.training as sam_training\n",
"from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data\n",
"from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation"
Expand Down Expand Up @@ -645,7 +646,7 @@
"source": [
"# All hyperparameters for training.\n",
"n_objects_per_batch = 5 # the number of objects per batch that will be sampled\n",
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\" # checks if cuda (NVIDIA GPU) or mps (Apple GPU) is available and sets the device accordingly, fallback to cpu if neither is available\n",
"device = get_device()\n",
"n_epochs = 3 # how long we train (in epochs)\n",
"\n",
"# The model_type determines which base model is used to initialize the weights that are finetuned.\n",
Expand Down Expand Up @@ -1063,7 +1064,7 @@
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "super",
"language": "python",
"name": "python3"
},
Expand All @@ -1077,7 +1078,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
"version": "3.11.12"
}
},
"nbformat": 4,
Expand Down
Loading