-
Notifications
You must be signed in to change notification settings - Fork 98
Open
Description
Command
python train.py runoja.txt --n_epochs 5000 --model lstm --cuda
gives
Traceback (most recent call last):
File "train_.py", line 95, in <module>
loss = train(*random_training_set(args.chunk_len, args.batch_size))
File "train_.py", line 55, in train
hidden = hidden.cuda()
AttributeError: 'tuple' object has no attribute 'cuda'
Here https://github.com/spro/char-rnn.pytorch/blob/master/model.py#L38-L39 it looks like the hidden layer is indeed a tuple. Not sure about the correct way to make this work. Anyhow, I tried this
if args.cuda:
if args.model == "gru":
hidden = hidden.cuda()
else:
hidden = (hidden[0].cuda(), hidden[1].cuda())
and it appears to work. Except that the same problem seems to exist also in generate.py, too.
This is what I used there:
if isinstance(hidden, tuple):
hidden = (hidden[0].cuda(), hidden[1].cuda())
else:
hidden = hidden.cuda()
filmo, nyonyanzheng, bluesea0, olivatooo and malhaar2002
Metadata
Metadata
Assignees
Labels
No labels