22 "cells" : [
33 {
44 "cell_type" : " code" ,
5- "execution_count" : 1 ,
5+ "execution_count" : 9 ,
66 "metadata" : {},
77 "outputs" : [
88 {
99 "name" : " stdout" ,
1010 "output_type" : " stream" ,
11- "text" : [
12- " MONAI version: 0.0.1\n " ,
13- " Python version: 3.7.3 (default, Mar 27 2019, 22:11:17) [GCC 7.3.0]\n " ,
14- " Numpy version: 1.16.4\n " ,
15- " Pytorch version: 1.3.1\n " ,
16- " Ignite version: 0.2.1\n "
17- ]
11+ "text" : " MONAI version: 0.0.1\n Python version: 3.8.1 (default, Jan 8 2020, 22:29:32) [GCC 7.3.0]\n Numpy version: 1.18.1\n Pytorch version: 1.4.0\n Ignite version: 0.3.0\n "
1812 }
1913 ],
2014 "source" : [
2822 " import torch\n " ,
2923 " import torch.nn as nn\n " ,
3024 " from torch.utils.data import DataLoader\n " ,
31- " import monai.data.transforms.compose as transforms\n " ,
3225 " \n " ,
3326 " import numpy as np\n " ,
3427 " import matplotlib.pyplot as plt\n " ,
3528 " import nibabel as nib\n " ,
3629 " \n " ,
37- " from ignite.engine import Events, create_supervised_trainer\n " ,
38- " from ignite.handlers import ModelCheckpoint\n " ,
30+ " from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator \n " ,
31+ " from ignite.handlers import ModelCheckpoint, EarlyStopping \n " ,
3932 " \n " ,
4033 " # assumes the framework is found here, change as necessary\n " ,
4134 " sys.path.append(\" ..\" )\n " ,
4235 " \n " ,
36+ " \n " ,
37+ " import monai.data.transforms.compose as transforms\n " ,
4338 " from monai import application, data, networks, utils\n " ,
4439 " from monai.data.readers import NiftiDataset\n " ,
4540 " from monai.data.transforms import AddChannel, Transpose, Rescale, ToTensor, UniformRandomPatch, GridPatchDataset\n " ,
41+ " from monai.networks.metrics.mean_dice import MeanDice\n " ,
42+ " from monai.utils.stopperutils import stopping_fn_from_metric\n " ,
4643 " \n " ,
4744 " \n " ,
4845 " application.config.print_config()"
4946 ]
5047 },
5148 {
5249 "cell_type" : " code" ,
53- "execution_count" : 2 ,
50+ "execution_count" : 10 ,
5451 "metadata" : {},
5552 "outputs" : [],
5653 "source" : [
8178 },
8279 {
8380 "cell_type" : " code" ,
84- "execution_count" : 3 ,
81+ "execution_count" : 11 ,
8582 "metadata" : {},
8683 "outputs" : [],
8784 "source" : [
9996 },
10097 {
10198 "cell_type" : " code" ,
102- "execution_count" : 4 ,
99+ "execution_count" : 12 ,
103100 "metadata" : {},
104101 "outputs" : [
105102 {
106103 "name" : " stdout" ,
107104 "output_type" : " stream" ,
108- "text" : [
109- " torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n "
110- ]
105+ "text" : " torch.Size([10, 1, 64, 64, 64]) torch.Size([10, 1, 64, 64, 64])\n "
111106 }
112107 ],
113108 "source" : [
136131 },
137132 {
138133 "cell_type" : " code" ,
139- "execution_count" : 5 ,
134+ "execution_count" : 13 ,
140135 "metadata" : {},
141136 "outputs" : [],
142137 "source" : [
157152 },
158153 {
159154 "cell_type" : " code" ,
160- "execution_count" : 7 ,
155+ "execution_count" : 14 ,
161156 "metadata" : {},
162- "outputs" : [
163- {
164- "name" : " stdout" ,
165- "output_type" : " stream" ,
166- "text" : [
167- " Epoch 1 Loss: 0.8619852662086487\n " ,
168- " Epoch 2 Loss: 0.8307779431343079\n " ,
169- " Epoch 3 Loss: 0.8064168691635132\n " ,
170- " Epoch 4 Loss: 0.7981672883033752\n " ,
171- " Epoch 5 Loss: 0.7950631976127625\n " ,
172- " Epoch 6 Loss: 0.7949732542037964\n " ,
173- " Epoch 7 Loss: 0.7963427901268005\n " ,
174- " Epoch 8 Loss: 0.7939450144767761\n " ,
175- " Epoch 9 Loss: 0.7926643490791321\n " ,
176- " Epoch 10 Loss: 0.7911991477012634\n " ,
177- " Epoch 11 Loss: 0.7886414527893066\n " ,
178- " Epoch 12 Loss: 0.7867528796195984\n " ,
179- " Epoch 13 Loss: 0.7857398390769958\n " ,
180- " Epoch 14 Loss: 0.7833380699157715\n " ,
181- " Epoch 15 Loss: 0.7791398763656616\n " ,
182- " Epoch 16 Loss: 0.7720394730567932\n " ,
183- " Epoch 17 Loss: 0.7671006917953491\n " ,
184- " Epoch 18 Loss: 0.7646064758300781\n " ,
185- " Epoch 19 Loss: 0.7672612071037292\n " ,
186- " Epoch 20 Loss: 0.7600041627883911\n " ,
187- " Epoch 21 Loss: 0.7583478689193726\n " ,
188- " Epoch 22 Loss: 0.7571365833282471\n " ,
189- " Epoch 23 Loss: 0.7545363306999207\n " ,
190- " Epoch 24 Loss: 0.7499511241912842\n " ,
191- " Epoch 25 Loss: 0.7481640577316284\n " ,
192- " Epoch 26 Loss: 0.7469437122344971\n " ,
193- " Epoch 27 Loss: 0.7460543513298035\n " ,
194- " Epoch 28 Loss: 0.74577796459198\n " ,
195- " Epoch 29 Loss: 0.7429620027542114\n " ,
196- " Epoch 30 Loss: 0.7424858808517456\n "
197- ]
198- }
199- ],
157+ "outputs" : [],
200158 "source" : [
201159 " trainEpochs = 30\n " ,
202160 " \n " ,
218176 " \n " ,
219177 " \n " ,
220178 " loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n " ,
221- " \n " ,
179+ " val_loader = DataLoader(ds, batch_size=20, num_workers=8, pin_memory=torch.cuda.is_available())\n " ,
180+ " \n "
181+ ]
182+ },
183+ {
184+ "cell_type" : " code" ,
185+ "execution_count" : 15 ,
186+ "metadata" : {},
187+ "outputs" : [],
188+ "source" : [
189+ " validation_every_n_epochs = 1\n " ,
190+ " \n " ,
191+ " val_metrics = {'Mean Dice': MeanDice(add_sigmoid=True)}\n " ,
192+ " evaluator = create_supervised_evaluator(net, val_metrics, device, True,\n " ,
193+ " output_transform=lambda x, y, y_pred: (y_pred[0], y))\n " ,
194+ " \n " ,
195+ " \n " ,
196+ " early_stopper = EarlyStopping(patience=4, \n " ,
197+ " score_function=stopping_fn_from_metric('Mean Dice'),\n " ,
198+ " trainer=trainer)\n " ,
199+ " evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)\n " ,
200+ " \n " ,
201+ " @evaluator.on(Events.EPOCH_COMPLETED)\n " ,
202+ " def log_validation_metrics(engine):\n " ,
203+ " for name, value in engine.state.metrics.items():\n " ,
204+ " print(\" Validation --\" , name, \" :\" , value)\n " ,
205+ " \n " ,
206+ " @trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))\n " ,
207+ " def run_validation(engine):\n " ,
208+ " evaluator.run(val_loader)\n " ,
209+ " \n "
210+ ]
211+ },
212+ {
213+ "cell_type" : " code" ,
214+ "execution_count" : 16 ,
215+ "metadata" : {},
216+ "outputs" : [
217+ {
218+ "name" : " stdout" ,
219+ "output_type" : " stream" ,
220+ "text": "Epoch 1 Loss: 0.8975554704666138\nValidation -- Mean Dice : 0.11846490800380707\nEpoch 2 Loss: 0.8451039791107178\nValidation -- Mean Dice : 0.12091563045978546\nEpoch 3 Loss: 0.9355515241622925\nValidation -- Mean Dice : 0.12139833569526673\nEpoch 4 Loss: 0.843208909034729\nValidation -- Mean Dice : 0.12108306288719177\nEpoch 5 Loss: 0.8225834965705872\nValidation -- Mean Dice : 0.12179622799158096\nEpoch 6 Loss: 0.957372784614563\nValidation -- Mean Dice : 0.12193384170532226\nEpoch 7 Loss: 0.9011092782020569\nValidation -- Mean Dice : 0.1230143740773201\nEpoch 8 Loss: 0.8651387691497803\nValidation -- Mean Dice : 0.1254110112786293\nEpoch 9 Loss: 0.8767974972724915\nValidation -- Mean Dice : 0.12633273899555206\nEpoch 10 Loss: 0.8193061947822571\nValidation -- Mean Dice : 0.12657881826162337\nEpoch 11 Loss: 0.9466649293899536\nValidation -- Mean Dice : 0.12699378579854964\nEpoch 12 Loss: 0.8258659243583679\nValidation -- Mean Dice : 0.12790720015764237\nEpoch 13 Loss: 0.8661612868309021\nValidation -- Mean Dice : 0.12980296313762665\nEpoch 14 Loss: 0.8039132356643677\nValidation -- Mean Dice : 0.1311295285820961\nEpoch 15 Loss: 0.8050084114074707\nValidation -- Mean Dice : 0.13225494623184203\nEpoch 16 Loss: 0.9048625230789185\nValidation -- Mean Dice : 0.1330576255917549\nEpoch 17 Loss: 0.9179995656013489\nValidation -- Mean Dice : 0.13361359685659407\nEpoch 18 Loss: 0.8956605195999146\nValidation -- Mean Dice : 0.13432369381189346\nEpoch 19 Loss: 0.8029189705848694\nValidation -- Mean Dice : 0.13532216250896453\nEpoch 20 Loss: 0.8359838128089905\nValidation -- Mean Dice : 0.13622953295707702\nEpoch 21 Loss: 0.9225850105285645\nValidation -- Mean Dice : 0.13677610754966735\nEpoch 22 Loss: 0.7023072242736816\nValidation -- Mean Dice : 0.13693425357341765\nEpoch 23 Loss: 0.8776397705078125\nValidation -- Mean Dice : 0.13710424304008484\nEpoch 24 Loss: 0.9571539163589478\nValidation -- Mean Dice : 0.1370883911848068\nEpoch 25 Loss: 0.8877002596855164\nValidation -- Mean Dice : 0.13701471388339997\nEpoch 26 Loss: 0.817417562007904\nValidation -- Mean Dice : 0.13696834743022918\nEpoch 27 Loss: 0.8971314430236816\nValidation -- Mean Dice : 0.1371448516845703\nEpoch 28 Loss: 0.9443905353546143\nValidation -- Mean Dice : 0.13739778995513915\nEpoch 29 Loss: 0.7578094005584717\nValidation -- Mean Dice : 0.137495020031929\nEpoch 30 Loss: 0.7037953734397888\nValidation -- Mean Dice : 0.13759489357471466\n"
221+ }
222+ ],
223+ "source" : [
222224 " state = trainer.run(loader, trainEpochs)"
223225 ]
224226 }
225227 ],
226228 "metadata" : {
227229 "kernelspec" : {
228- "display_name" : " Python 3" ,
230+ "display_name" : " Python 3.7.5 64-bit ('pytorch': conda) " ,
229231 "language" : " python" ,
230- "name" : " python3 "
232+ "name" : " python37564bitpytorchconda9e7dd2186ac2430b947ee08d8eff35b4 "
231233 },
232234 "language_info" : {
233235 "codemirror_mode" : {
239241 "name" : " python" ,
240242 "nbconvert_exporter" : " python" ,
241243 "pygments_lexer" : " ipython3" ,
242- "version" : " 3.7.3 "
244+ "version" : " 3.8.1-final "
243245 }
244246 },
245247 "nbformat" : 4 ,
246248 "nbformat_minor" : 4
247- }
249+ }
0 commit comments