Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions tensordict/_torch_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,16 +928,27 @@ def _grad(
"torch.autograd.grad for TensorDict only supports TensorDictBase as grad_output"
)

if grad_outputs is not None:
tup_grad_outputs = tuple(
grad_outputs._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
using_unstacked_outputs = (
isinstance(outputs, LazyStackedTensorDict) and outputs._has_exclusive_keys
)
try:
if not using_unstacked_outputs:
tup_outputs = tuple(outputs[k] for k in outputs.keys(True, True))
except RuntimeError:
# Cannot stack outputs
using_unstacked_outputs = True

if using_unstacked_outputs:
tup_outputs = outputs.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)

if grad_outputs is None:
tup_grad_outputs = None
elif using_unstacked_outputs:
tup_grad_outputs = grad_outputs.values(
True, True, is_leaf=_NESTED_TENSORS_AS_LISTS
)
else:
tup_grad_outputs = None

tup_outputs = tuple(
outputs._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)
)
tup_grad_outputs = tuple(grad_outputs[k] for k in outputs.keys(True, True))

keys, all_inputs = inputs._items_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS)

Expand Down
83 changes: 83 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -10608,6 +10608,89 @@ def test_update_batch_size(self, source_is_lazy):
assert td.batch_size == (2, 4)
assert td.batch_size == td2.batch_size

@pytest.mark.parametrize(
"inputs,grad_outputs",
[
(
TensorDict(a=torch.randn(2, 3, requires_grad=True)),
LazyStackedTensorDict(
TensorDict(a=torch.ones(3)),
TensorDict(a=torch.ones(3)),
stack_dim=0,
),
),
(
LazyStackedTensorDict(
TensorDict(a=torch.randn(3, requires_grad=True)),
TensorDict(a=torch.randn(3, requires_grad=True)),
stack_dim=0,
),
TensorDict(a=torch.ones(2, 3)),
),
(
LazyStackedTensorDict(
TensorDict(a=torch.randn(3, requires_grad=True)),
TensorDict(a=torch.randn(3, requires_grad=True)),
stack_dim=0,
),
LazyStackedTensorDict(
TensorDict(a=torch.ones(2)),
TensorDict(a=torch.ones(2)),
TensorDict(a=torch.ones(2)),
stack_dim=1,
),
),
],
)
def test_autograd_grad_mixed_types(self, inputs, grad_outputs):
outputs = inputs + 1
grads = torch.autograd.grad(outputs, inputs, grad_outputs)
assert (grads == 1).all()

def test_autograd_grad_non_stackable(self):
inputs = LazyStackedTensorDict(
TensorDict(a=torch.randn(3, requires_grad=True)),
TensorDict(a=torch.randn(2, requires_grad=True)),
stack_dim=0,
)
grad_outputs = LazyStackedTensorDict(
TensorDict(a=torch.ones(3)), TensorDict(a=torch.ones(2)), stack_dim=0
)
outputs = inputs + 1
grads = torch.autograd.grad(outputs, inputs, grad_outputs)
assert (grads == 1).all()

def test_autograd_grad_key_order(self):
inputs = LazyStackedTensorDict(
TensorDict(a=torch.randn(3, requires_grad=True)),
TensorDict(a=torch.randn(2, requires_grad=True)),
stack_dim=0,
)
inputs["b"] = [
torch.randn(3, requires_grad=True),
torch.randn(2, requires_grad=True),
]
grad_outputs = LazyStackedTensorDict(
TensorDict(b=torch.ones(3)), TensorDict(b=torch.ones(2)), stack_dim=0
)
grad_outputs["a"] = [torch.ones(3), torch.ones(2)]
outputs = inputs + 1
grads = torch.autograd.grad(outputs, inputs, grad_outputs)
assert (grads == 1).all()

def test_autograd_grad_exclusive_keys(self):
inputs = LazyStackedTensorDict(
TensorDict(a=torch.randn(3, requires_grad=True)),
TensorDict(b=torch.randn(2, requires_grad=True)),
stack_dim=0,
)
grad_outputs = LazyStackedTensorDict(
TensorDict(a=torch.ones(3)), TensorDict(b=torch.ones(2)), stack_dim=0
)
outputs = inputs + 1
grads = torch.autograd.grad(outputs, inputs, grad_outputs)
assert (grads == 1).all()


@pytest.mark.skipif(
not _has_torchsnapshot, reason=f"torchsnapshot not found: err={TORCHSNAPSHOT_ERR}"
Expand Down