diff --git a/examples/finetuning/finetune_hela.py b/examples/finetuning/finetune_hela.py index c54d462f9..683150b0a 100644 --- a/examples/finetuning/finetune_hela.py +++ b/examples/finetuning/finetune_hela.py @@ -1,13 +1,11 @@ import os import numpy as np -import torch - import torch_em from torch_em.transform.label import PerObjectDistanceTransform import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model +from micro_sam.util import export_custom_sam_model, get_device from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data @@ -82,7 +80,7 @@ def run_training(checkpoint_name, model_type, train_instance_segmentation): batch_size = 1 # the training batch size patch_shape = (1, 512, 512) # the size of patches for training n_objects_per_batch = 25 # the number of objects per batch that will be sampled - device = torch.device("cuda") # the device used for training + device = get_device() # the device used for training # Get the dataloaders. train_loader = get_dataloader("train", patch_shape, batch_size, train_instance_segmentation) diff --git a/notebooks/sam_finetuning.ipynb b/notebooks/sam_finetuning.ipynb index d1aff68fd..8d4d65cba 100644 --- a/notebooks/sam_finetuning.ipynb +++ b/notebooks/sam_finetuning.ipynb @@ -152,6 +152,7 @@ "from torch_em.data import MinInstanceSampler\n", "from torch_em.util.util import get_random_colors\n", "\n", + "from micro_sam.util import get_device\n", "import micro_sam.training as sam_training\n", "from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data\n", "from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation" @@ -645,7 +646,7 @@ "source": [ "# All hyperparameters for training.\n", "n_objects_per_batch = 5 # the number of objects per batch that will be sampled\n", - "device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\" # checks if cuda (NVIDIA GPU) or mps (Apple GPU) is available and sets the device accordingly, fallback to cpu if neither is available\n", + "device = get_device()\n", "n_epochs = 3 # how long we train (in epochs)\n", "\n", "# The model_type determines which base model is used to initialize the weights that are finetuned.\n", @@ -1063,7 +1064,7 @@ "sourceType": "notebook" }, "kernelspec": { - "display_name": "Python 3", + "display_name": "super", "language": "python", "name": "python3" }, @@ -1077,7 +1078,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.11" + "version": "3.11.12" } }, "nbformat": 4,