Skip to content

Commit 5e12c5d

Browse files
madil90wyli
andauthored
40-validation-early-stop (#73)
* Adding MVP example for Validation every epoch with Early Stopping. * Add every_n_epochs and update python example. * Fix flake8 errors. * Update example to work with StatsHandler. Co-authored-by: Wenqi Li <[email protected]>
1 parent d5e610f commit 5e12c5d

File tree

3 files changed

+120
-69
lines changed

3 files changed

+120
-69
lines changed

examples/unet_segmentation_3d.ipynb

Lines changed: 64 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,13 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": 9,
66
"metadata": {},
77
"outputs": [
88
{
99
"name": "stdout",
1010
"output_type": "stream",
11-
"text": [
12-
"MONAI version: 0.0.1\n",
13-
"Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n",
14-
"Numpy version: 1.16.4\n",
15-
"Pytorch version: 1.3.1\n",
16-
"Ignite version: 0.2.1\n"
17-
]
11+
"text": "MONAI version: 0.0.1\nPython version: 3.8.1 (default, Jan 8 2020, 22:29:32) [GCC 7.3.0]\nNumpy version: 1.18.1\nPytorch version: 1.4.0\nIgnite version: 0.3.0\n"
1812
}
1913
],
2014
"source": [
@@ -28,29 +22,32 @@
2822
"import torch\n",
2923
"import torch.nn as nn\n",
3024
"from torch.utils.data import DataLoader\n",
31-
"import monai.data.transforms.compose as transforms\n",
3225
"\n",
3326
"import numpy as np\n",
3427
"import matplotlib.pyplot as plt\n",
3528
"import nibabel as nib\n",
3629
"\n",
37-
"from ignite.engine import Events, create_supervised_trainer\n",
38-
"from ignite.handlers import ModelCheckpoint\n",
30+
"from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator\n",
31+
"from ignite.handlers import ModelCheckpoint, EarlyStopping\n",
3932
"\n",
4033
"# assumes the framework is found here, change as necessary\n",
4134
"sys.path.append(\"..\")\n",
4235
"\n",
36+
"\n",
37+
"import monai.data.transforms.compose as transforms\n",
4338
"from monai import application, data, networks, utils\n",
4439
"from monai.data.readers import NiftiDataset\n",
4540
"from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n",
41+
"from monai.networks.metrics.mean_dice import MeanDice\n",
42+
"from monai.utils.stopperutils import stopping_fn_from_metric\n",
4643
"\n",
4744
"\n",
4845
"application.config.print_config()"
4946
]
5047
},
5148
{
5249
"cell_type": "code",
53-
"execution_count": 2,
50+
"execution_count": 10,
5451
"metadata": {},
5552
"outputs": [],
5653
"source": [
@@ -81,7 +78,7 @@
8178
},
8279
{
8380
"cell_type": "code",
84-
"execution_count": 3,
81+
"execution_count": 11,
8582
"metadata": {},
8683
"outputs": [],
8784
"source": [
@@ -99,15 +96,13 @@
9996
},
10097
{
10198
"cell_type": "code",
102-
"execution_count": 4,
99+
"execution_count": 12,
103100
"metadata": {},
104101
"outputs": [
105102
{
106103
"name": "stdout",
107104
"output_type": "stream",
108-
"text": [
109-
"torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
110-
]
105+
"text": "torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n"
111106
}
112107
],
113108
"source": [
@@ -136,7 +131,7 @@
136131
},
137132
{
138133
"cell_type": "code",
139-
"execution_count": 5,
134+
"execution_count": 13,
140135
"metadata": {},
141136
"outputs": [],
142137
"source": [
@@ -157,46 +152,9 @@
157152
},
158153
{
159154
"cell_type": "code",
160-
"execution_count": 7,
155+
"execution_count": 14,
161156
"metadata": {},
162-
"outputs": [
163-
{
164-
"name": "stdout",
165-
"output_type": "stream",
166-
"text": [
167-
"Epoch 1 Loss: 0.8619852662086487\n",
168-
"Epoch 2 Loss: 0.8307779431343079\n",
169-
"Epoch 3 Loss: 0.8064168691635132\n",
170-
"Epoch 4 Loss: 0.7981672883033752\n",
171-
"Epoch 5 Loss: 0.7950631976127625\n",
172-
"Epoch 6 Loss: 0.7949732542037964\n",
173-
"Epoch 7 Loss: 0.7963427901268005\n",
174-
"Epoch 8 Loss: 0.7939450144767761\n",
175-
"Epoch 9 Loss: 0.7926643490791321\n",
176-
"Epoch 10 Loss: 0.7911991477012634\n",
177-
"Epoch 11 Loss: 0.7886414527893066\n",
178-
"Epoch 12 Loss: 0.7867528796195984\n",
179-
"Epoch 13 Loss: 0.7857398390769958\n",
180-
"Epoch 14 Loss: 0.7833380699157715\n",
181-
"Epoch 15 Loss: 0.7791398763656616\n",
182-
"Epoch 16 Loss: 0.7720394730567932\n",
183-
"Epoch 17 Loss: 0.7671006917953491\n",
184-
"Epoch 18 Loss: 0.7646064758300781\n",
185-
"Epoch 19 Loss: 0.7672612071037292\n",
186-
"Epoch 20 Loss: 0.7600041627883911\n",
187-
"Epoch 21 Loss: 0.7583478689193726\n",
188-
"Epoch 22 Loss: 0.7571365833282471\n",
189-
"Epoch 23 Loss: 0.7545363306999207\n",
190-
"Epoch 24 Loss: 0.7499511241912842\n",
191-
"Epoch 25 Loss: 0.7481640577316284\n",
192-
"Epoch 26 Loss: 0.7469437122344971\n",
193-
"Epoch 27 Loss: 0.7460543513298035\n",
194-
"Epoch 28 Loss: 0.74577796459198\n",
195-
"Epoch 29 Loss: 0.7429620027542114\n",
196-
"Epoch 30 Loss: 0.7424858808517456\n"
197-
]
198-
}
199-
],
157+
"outputs": [],
200158
"source": [
201159
"trainEpochs = 30\n",
202160
"\n",
@@ -218,16 +176,60 @@
218176
"\n",
219177
"\n",
220178
"loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n",
221-
" \n",
179+
"val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n",
180+
" \n"
181+
]
182+
},
183+
{
184+
"cell_type": "code",
185+
"execution_count": 15,
186+
"metadata": {},
187+
"outputs": [],
188+
"source": [
189+
"validation_every_n_epochs = 1\n",
190+
"\n",
191+
"val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}\n",
192+
"evaluator = create_supervised_evaluator(net, val_metrics, device, True,\n",
193+
" output_transform=lambda x, y, y_pred: (y_pred[0], y))\n",
194+
"\n",
195+
"\n",
196+
"early_stopper = EarlyStopping(patience=4, \n",
197+
" score_function=stopping_fn_from_metric('Mean Dice'),\n",
198+
" trainer=trainer)\n",
199+
"evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n",
200+
"\n",
201+
"@evaluator.on(Events.EPOCH_COMPLETED)\n",
202+
"def log_validation_metrics(engine):\n",
203+
" for name, value in engine.state.metrics.items():\n",
204+
" print(\"Validation --\", name, \":\", value)\n",
205+
"\n",
206+
"@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n",
207+
"def run_validation(engine):\n",
208+
" evaluator.run(val_loader)\n",
209+
"\n"
210+
]
211+
},
212+
{
213+
"cell_type": "code",
214+
"execution_count": 16,
215+
"metadata": {},
216+
"outputs": [
217+
{
218+
"name": "stdout",
219+
"output_type": "stream",
220+
"text": "Epoch 1 Loss: 0.8975554704666138\nValidation -- Mean Dice : 0.11846490800380707\nEpoch 2 Loss: 0.8451039791107178\nValidation -- Mean Dice : 0.12091563045978546\nEpoch 3 Loss: 0.9355515241622925\nValidation -- Mean Dice : 0.12139833569526673\nEpoch 4 Loss: 0.843208909034729\nValidation -- Mean Dice : 0.12108306288719177\nEpoch 5 Loss: 0.8225834965705872\nValidation -- Mean Dice : 0.12179622799158096\nEpoch 6 Loss: 0.957372784614563\nValidation -- Mean Dice : 0.12193384170532226\nEpoch 7 Loss: 0.9011092782020569\nValidation -- Mean Dice : 0.1230143740773201\nEpoch 8 Loss: 0.8651387691497803\nValidation -- Mean Dice : 0.1254110112786293\nEpoch 9 Loss: 0.8767974972724915\nValidation -- Mean Dice : 0.12633273899555206\nEpoch 10 Loss: 0.8193061947822571\nValidation -- Mean Dice : 0.12657881826162337\nEpoch 11 Loss: 0.9466649293899536\nValidation -- Mean Dice : 0.12699378579854964\nEpoch 12 Loss: 0.8258659243583679\nValidation -- Mean Dice : 0.12790720015764237\nEpoch 13 Loss: 0.8661612868309021\nValidation -- Mean Dice : 0.12980296313762665\nEpoch 14 Loss: 0.8039132356643677\nValidation -- Mean Dice : 0.1311295285820961\nEpoch 15 Loss: 0.8050084114074707\nValidation -- Mean Dice : 0.13225494623184203\nEpoch 16 Loss: 0.9048625230789185\nValidation -- Mean Dice : 0.1330576255917549\nEpoch 17 Loss: 0.9179995656013489\nValidation -- Mean Dice : 0.13361359685659407\nEpoch 18 Loss: 0.8956605195999146\nValidation -- Mean Dice : 0.13432369381189346\nEpoch 19 Loss: 0.8029189705848694\nValidation -- Mean Dice : 0.13532216250896453\nEpoch 20 Loss: 0.8359838128089905\nValidation -- Mean Dice : 0.13622953295707702\nEpoch 21 Loss: 0.9225850105285645\nValidation -- Mean Dice : 0.13677610754966735\nEpoch 22 Loss: 0.7023072242736816\nValidation -- Mean Dice : 0.13693425357341765\nEpoch 23 Loss: 0.8776397705078125\nValidation -- Mean Dice : 0.13710424304008484\nEpoch 24 Loss: 0.9571539163589478\nValidation -- Mean Dice : 0.1370883911848068\nEpoch 25 Loss: 0.8877002596855164\nValidation -- Mean Dice : 0.13701471388339997\nEpoch 26 Loss: 0.817417562007904\nValidation -- Mean Dice : 0.13696834743022918\nEpoch 27 Loss: 0.8971314430236816\nValidation -- Mean Dice : 0.1371448516845703\nEpoch 28 Loss: 0.9443905353546143\nValidation -- Mean Dice : 0.13739778995513915\nEpoch 29 Loss: 0.7578094005584717\nValidation -- Mean Dice : 0.137495020031929\nEpoch 30 Loss: 0.7037953734397888\nValidation -- Mean Dice : 0.13759489357471466\n"
221+
}
222+
],
223+
"source": [
222224
"state = trainer.run(loader, trainEpochs)"
223225
]
224226
}
225227
],
226228
"metadata": {
227229
"kernelspec": {
228-
"display_name": "Python 3",
230+
"display_name": "Python 3.7.5 64-bit ('pytorch': conda)",
229231
"language": "python",
230-
"name": "python3"
232+
"name": "python37564bitpytorchconda9e7dd2186ac2430b947ee08d8eff35b4"
231233
},
232234
"language_info": {
233235
"codemirror_mode": {
@@ -239,9 +241,9 @@
239241
"name": "python",
240242
"nbconvert_exporter": "python",
241243
"pygments_lexer": "ipython3",
242-
"version": "3.7.3"
244+
"version": "3.8.1-final"
243245
}
244246
},
245247
"nbformat": 4,
246248
"nbformat_minor": 4
247-
}
249+
}

examples/unet_segmentation_3d.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
import torch
2121
import monai.transforms.compose as transforms
2222
from torch.utils.tensorboard import SummaryWriter
23-
from ignite.engine import Events, create_supervised_trainer
24-
from ignite.handlers import ModelCheckpoint
23+
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
24+
from ignite.handlers import ModelCheckpoint, EarlyStopping
2525
from torch.utils.data import DataLoader
2626

2727
import monai
@@ -32,12 +32,14 @@
3232
from monai.handlers.mean_dice import MeanDice
3333
from monai.visualize import img2tensorboard
3434
from monai.data.synthetic import create_test_image_3d
35+
from monai.handlers.utils import stopping_fn_from_metric
3536

3637
# assumes the framework is found here, change as necessary
3738
sys.path.append("..")
3839

3940
config.print_config()
4041

42+
# Create a temporary directory and 50 random image, mask paris
4143
tempdir = tempfile.mkdtemp()
4244

4345
for i in range(50):
@@ -52,18 +54,20 @@
5254
images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
5355
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
5456

57+
# Define transforms for image and segmentation
5558
imtrans = transforms.Compose([Rescale(), AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()])
56-
5759
segtrans = transforms.Compose([AddChannel(), UniformRandomPatch((64, 64, 64)), ToTensor()])
5860

61+
# Define nifti dataset, dataloader.
5962
ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans)
60-
6163
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
6264
im, seg = monai.utils.misc.first(loader)
6365
print(im.shape, seg.shape)
6466

67+
6568
lr = 1e-3
6669

70+
# Create UNet, DiceLoss and Adam optimizer.
6771
net = monai.networks.nets.UNet(
6872
dimensions=3,
6973
in_channels=1,
@@ -78,13 +82,12 @@
7882

7983
train_epochs = 3
8084

81-
85+
# Since network outputs logits and segmentation, we need a custom function.
8286
def _loss_fn(i, j):
8387
return loss(i[0], j)
8488

85-
89+
# Create trainer
8690
device = torch.device("cuda:0")
87-
8891
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
8992
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])
9093

@@ -133,6 +136,28 @@ def log_training_loss(engine):
133136

134137

135138
loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
139+
val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())
136140
writer = SummaryWriter()
137141

142+
# Define mean dice metric and Evaluator.
143+
validation_every_n_epochs = 1
144+
145+
val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}
146+
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
147+
output_transform=lambda x, y, y_pred: (y_pred[0], y))
148+
149+
val_stats_handler = StatsHandler()
150+
val_stats_handler.attach(evaluator)
151+
152+
# Add early stopping handler to evaluator.
153+
early_stopper = EarlyStopping(patience=4,
154+
score_function=stopping_fn_from_metric('Mean Dice'),
155+
trainer=trainer)
156+
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
157+
158+
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
159+
def run_validation(engine):
160+
evaluator.run(val_loader)
161+
162+
138163
state = trainer.run(loader, train_epochs)

monai/handlers/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
13+
def stopping_fn_from_metric(metric_name):
14+
"""Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name."""
15+
def stopping_fn(engine):
16+
return engine.state.metrics[metric_name]
17+
return stopping_fn
18+
19+
20+
def stopping_fn_from_loss():
21+
"""Returns a stopping function for ignite.handlers.EarlyStopping using the loss value."""
22+
def stopping_fn(engine):
23+
return -engine.state.output
24+
return stopping_fn

0 commit comments

Comments
 (0)