Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pymc/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,11 +777,13 @@ def logp(
norm = 0.0

logp = _logprob(normal, (value,), None, None, None, mu, sigma) - norm
logp = at.switch(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually to be equivalent to what we had before we should do something like this:

if is_lower_bounded:
logp = at.switch(value < lower, -np.inf, logp)
if is_upper_bounded:
logp = at.switch(value <= upper, logp, -np.inf)

Copy link
Contributor Author

@adrn adrn Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that equivalent to what is implemented here because of the default values for lower and upper?

lower = at.as_tensor_variable(floatX(lower)) if lower is not None else at.constant(-np.inf)
upper = at.as_tensor_variable(floatX(upper)) if upper is not None else at.constant(np.inf)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We retrieve the None case here:

unbounded_lower = isinstance(lower, TensorConstant) and np.all(lower.value == -np.inf)
unbounded_upper = isinstance(upper, TensorConstant) and np.all(upper.value == np.inf)

So in those cases we avoid introducing the useless switch. It's a small optimization but I don't see any reason yo modify it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see. OK

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last commit makes the implementation here more analogous to the general truncated case you linked above - thanks for that pointer!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ricardoV94 for taking a look and helping out already! Let me know if the implementation in the latest few commits looks ok. Also, it looks like this PR is waiting for approval to run the full test suite.

at.or_(at.le(value, lower), at.ge(value, upper)),
-np.inf,
logp,
)

bounds = []
if not unbounded_lower:
bounds.append(value >= lower)
if not unbounded_upper:
bounds.append(value <= upper)
if not unbounded_lower and not unbounded_upper:
bounds.append(lower <= upper)
return check_parameters(logp, *bounds)
Expand Down
12 changes: 12 additions & 0 deletions pymc/tests/distributions/test_continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,18 @@ def test_lower_bounded(self):
assert lower_interval.value == -2
assert upper_interval is None

def test_bounded_value(self):
with pm.Model() as model:
dist = pm.TruncatedNormal(
'bounded_both', mu=1, sigma=2, lower=0, upper=3
)
value = pm.Deterministic('value', pm.logp(dist, [-2., 1., 4.]))

values = model.compile_fn(value)({})
assert np.isinf(values[0])
assert np.isfinite(values[1])
assert np.isinf(values[2])

def test_lower_bounded_vector(self):
bounded_rv_name = "upper_bounded"
with pm.Model() as model:
Expand Down