Skip to content

Commit 1c15523

Browse files
Merge branch 'poutyne'
2 parents 0a74b78 + de9b0a5 commit 1c15523

File tree

9 files changed

+99
-24
lines changed

9 files changed

+99
-24
lines changed

README.md

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Every deep learning project has at least three mains steps:
5757
## Project
5858
One good idea is to store all the paths at an interesting location, e.g. the dataset folder, in a shared class that be accessed by anyone in the folder. You should never hardcode any paths and always define them once and import them. So, if you later change your structure you will only have to modify one file.
5959
If we have a look at `Project.py` we can see how we defined the `data_dir` and the `checkpoint_dir` once for all. We are using the 'new' [Path](https://docs.python.org/3/library/pathlib.html) APIs that support different OS out of the box, and also make it easier to join and concatenate paths.
60-
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/develop/images/Project.png)
60+
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/master/images/Project.png)
6161
For example, if we want to know the data location we can :
6262
```python3
6363
from Project import Project
@@ -70,13 +70,16 @@ In our example, we directly used `ImageDataset` from `torchvision` but we includ
7070
### Transformation
7171
You usually have to do some preprocessing on the data, e.g. resize the images and apply data augmentation. All your transformation should go inside `.data.trasformation`. In our template, we included a wrapper for
7272
[imgaug](https://imgaug.readthedocs.io/en/latest/)
73-
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/develop/images/transformation.png)
73+
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/master/images/transformation.png)
7474
### Dataloaders
7575
As you know, you have to create a `Dataloader` to feed your data into the model. In the `data.__init__.py` file we expose a very simple function `get_dataloaders` to automatically configure the *train, val and test* data loaders using few parameters
76-
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/develop/images/data.png)
76+
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/master/images/data.png)
7777
## Losses
7878
Sometimes you may need to define your custom losses, you can include them in the `./losses` package. For example
79-
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/develop/images/losses.png)
79+
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/master/images/losses.png)
80+
## Metrics
81+
Sometimes you may need to define your custom metrics. For example
82+
![alt](https://raw.githubusercontent.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/master/images/metrics.png)
8083
## Logging
8184
We included python [logging](https://docs.python.org/3/library/logging.html) module. You can import and use it by:
8285

@@ -88,27 +91,27 @@ logger.info('print() is for noobs')
8891
## Models
8992
All your models go inside `models`, in our case, we have a very basic cnn and we override the `resnet18` function to provide a frozen model to finetune.
9093

91-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/resnet.png?raw=true)
94+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/resnet.png?raw=true)
9295
## Train/Evaluation
9396
In our case we kept things simple, all the training and evaluation logic is inside `.main.py` where we used [poutyne](https://pypi.org/project/Poutyne/) as the main library. We already defined a useful list of callbacks:
9497
- learning rate scheduler
9598
- auto-save of the best model
9699
- early stopping
97100
Usually, this is all you need!
98-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/main.png?raw=true)
101+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/main.png?raw=true)
99102
### Callbacks
100103
You may need to create custom callbacks, with [poutyne](https://pypi.org/project/Poutyne/) is very easy since it support Keras-like API. You custom callbacks should go inside `./callbacks`. For example, we have created one to update Comet every epoch.
101-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/CometCallback.png?raw=true)
104+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/CometCallback.png?raw=true)
102105

103106
### Track your experiment
104107
We are using [comet](https://www.comet.ml/) to automatically track our models' results. This is what comet's board looks like after a few models run.
105-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/comet.jpg?raw=true)
108+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/comet.jpg?raw=true)
106109
Running `main.py` produces the following output:
107-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/output.jpg?raw=true)
110+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/output.jpg?raw=true)
108111
## Utils
109112
We also created different utilities function to plot booth dataset and dataloader. They are in `utils.py`. For example, calling `show_dl` on our train and val dataset produces the following outputs.
110-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/Figure_1.png?raw=true)
111-
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/develop/images/Figure_2.png?raw=true)
113+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/Figure_1.png?raw=true)
114+
![alt](https://github.com/FrancescoSaverioZuppichini/PyTorch-Deep-Learning-Skeletron/blob/master/images/Figure_2.png?raw=true)
112115
As you can see data-augmentation is correctly applied on the train set
113116
## Conclusions
114117
I hope you found some useful information and hopefully it this template will help you on your next amazing project :)

data/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ def get_dataloaders(
1010
train_transform=None,
1111
val_transform=None,
1212
split=(0.5, 0.5),
13-
batch_size=32):
13+
batch_size=32,
14+
*args, **kwargs):
1415
"""
1516
This function returns the train, val and test dataloaders.
1617
"""
@@ -27,8 +28,8 @@ def get_dataloaders(
2728
val_ds, test_ds = random_split(val_ds, lengths.tolist())
2829
logging.info(f'Train samples={len(train_ds)}, Validation samples={len(val_ds)}, Test samples={len(test_ds)}')
2930

30-
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)
31-
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
32-
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=4)
31+
train_dl = DataLoader(train_ds, batch_sAize=batch_size, shuffle=True, *args, **kwargs)
32+
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, *args, **kwargs)
33+
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False, *args, **kwargs)
3334

3435
return train_dl, val_dl, test_dl

images/data.png

-57.8 KB
Loading

images/main.png

30.7 KB
Loading

images/metrics.png

178 KB
Loading

losses/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
2-
# define custom losses
32

3+
# define custom losses
44
def my_loss(output, target):
55
loss = torch.mean((output - target) ** 2)
66
return loss

main.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,60 @@
1212
from callbacks import CometCallback
1313
from logger import logging
1414

15-
project = Project()
16-
# our hyperparameters
17-
params = {
18-
'lr': 0.001,
19-
'batch_size': 32,
20-
'model': 'resnet18-finetune'
21-
}
22-
logging.info(f'Using device={device} 🚀')
15+
if __name__ == '__main__':
16+
project = Project()
17+
# our hyperparameters
18+
params = {
19+
'lr': 0.001,
20+
'batch_size': 32,
21+
'epochs': 10,
22+
'model': 'resnet18-finetune'
23+
}
24+
logging.info(f'Using device={device} 🚀')
25+
# everything starts with the data
26+
train_dl, val_dl, test_dl = get_dataloaders(
27+
project.data_dir / "train",
28+
project.data_dir / "val",
29+
val_transform=val_transform,
30+
train_transform=train_transform,
31+
batch_size=params['batch_size'],
32+
pin_memory=True,
33+
num_workers=4,
34+
)
35+
# is always good practice to visualise some of the train and val images to be sure data-aug
36+
# is applied properly
37+
show_dl(train_dl)
38+
show_dl(test_dl)
39+
# define our comet experiment
40+
experiment = Experiment(api_key="YOU_KEY",
41+
project_name="dl-pytorch-template", workspace="francescosaveriozuppichini")
42+
experiment.log_parameters(params)
43+
# create our special resnet18
44+
cnn = resnet18(2).to(device)
45+
# print the model summary to show useful information
46+
logging.info(summary(cnn, (3, 224, 244)))
47+
# define custom optimizer and instantiace the trainer `Model`
48+
optimizer = optim.Adam(cnn.parameters(), lr=params['lr'])
49+
model = Model(cnn, optimizer, "cross_entropy",
50+
batch_metrics=["accuracy"]).to(device)
51+
# usually you want to reduce the lr on plateau and store the best model
52+
callbacks = [
53+
ReduceLROnPlateau(monitor="val_acc", patience=5, verbose=True),
54+
ModelCheckpoint(str(project.checkpoint_dir /
55+
f"{time.time()}-model.pt"), save_best_only="True", verbose=True),
56+
EarlyStopping(monitor="val_acc", patience=10, mode='max'),
57+
CometCallback(experiment)
58+
]
59+
model.fit_generator(
60+
train_dl,
61+
val_dl,
62+
epochs=params['epochs'],
63+
callbacks=callbacks,
64+
)
65+
# get the results on the test set
66+
loss, test_acc = model.evaluate_generator(test_dl)
67+
logging.info(f'test_acc=({test_acc})')
68+
experiment.log_metric('test_acc', test_acc)
2369

2470
if __name__ == '__main__':
2571
# everything starts with the data

metrics/__init__.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from poutyne.framework.metrics import EpochMetric
2+
3+
# define a custom metric as a function
4+
def my_metric(y_true, y_pred):
5+
pass
6+
7+
# or as a class when we need to accumulate
8+
class MyEpochMetric(EpochMetric):
9+
def forward(self, y_pred, y_true):
10+
"""
11+
To define the behavior of the metric when called.
12+
Args:
13+
y_pred: The prediction of the model.
14+
y_true: Target to evaluate the model.
15+
"""
16+
pass
17+
18+
def get_metric(self):
19+
"""
20+
Compute and return the metric.S
21+
"""
22+
pass

secrets.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
{
2+
"COMET_API_KEY" : "8THqoAxomFyzBgzkStlY95MOf"
3+
}

0 commit comments

Comments
 (0)