Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
creating custom backends).
"""
import itertools as itl
import logging

import numpy as np
import warnings
import theano.tensor as tt

from ..model import modelcontext
from .report import SamplerReport, merge_reports

logger = logging.getLogger('pymc3')


class BackendError(Exception):
Expand Down Expand Up @@ -61,6 +65,10 @@ def __init__(self, name, model=None, vars=None, test_point=None):
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
self._warnings = []

def _add_warnings(self, warnings):
self._warnings.extend(warnings)

# Sampling methods

Expand Down Expand Up @@ -174,7 +182,7 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
return self._get_sampler_stats(varname, sampler_idx, burn, thin)

sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
if varname in s]
if varname in s]
if not sampler_idxs:
raise KeyError("Unknown sampler stat %s" % varname)

Expand All @@ -185,20 +193,19 @@ def get_sampler_stats(self, varname, sampler_idx=None, burn=0, thin=1):
else:
return vals


def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
"""Get sampler statistics."""
raise NotImplementedError()

def _slice(self, idx):
"""Slice trace object."""
raise NotImplementedError
raise NotImplementedError()

def point(self, idx):
"""Return dictionary of point values at `idx` for current chain
with variables names as keys.
"""
raise NotImplementedError
raise NotImplementedError()

@property
def stat_names(self):
Expand Down Expand Up @@ -258,6 +265,11 @@ def __init__(self, straces):
raise ValueError("Chains are not unique.")
self._straces[strace.chain] = strace

self._report = SamplerReport()
for strace in straces:
if hasattr(strace, '_warnings'):
self._report._add_warnings(strace._warnings, strace.chain)

def __repr__(self):
template = '<{}: {} chains, {} iterations, {} variables>'
return template.format(self.__class__.__name__,
Expand All @@ -271,6 +283,10 @@ def nchains(self):
def chains(self):
return list(sorted(self._straces.keys()))

@property
def report(self):
return self._report

def __getitem__(self, idx):
if isinstance(idx, slice):
return self._slice(idx)
Expand Down Expand Up @@ -303,7 +319,7 @@ def __getitem__(self, idx):
raise KeyError("Unknown variable %s" % var)

_attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats'])
'supports_sampler_stats', '_report'])

def __getattr__(self, name):
# Avoid infinite recursion when called before __init__
Expand Down Expand Up @@ -447,10 +463,13 @@ def get_sampler_stats(self, varname, burn=0, thin=1, combine=True,
for chain in chains]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, idx):
"""Return a new MultiTrace object sliced according to `idx`."""
new_traces = [trace._slice(idx) for trace in self._straces.values()]
return MultiTrace(new_traces)
def _slice(self, slice):
"""Return a new MultiTrace object sliced according to `slice`."""
new_traces = [trace._slice(slice) for trace in self._straces.values()]
trace = MultiTrace(new_traces)
idxs = slice.indices(len(self))
trace._report = self._report._slice(*idxs)
return trace

def point(self, idx, chain=None):
"""Return a dictionary of point values at `idx`.
Expand Down Expand Up @@ -502,6 +521,7 @@ def merge_traces(mtraces):
if new_chain in base_mtrace._straces:
raise ValueError("Chains are not unique.")
base_mtrace._straces[new_chain] = strace
base_mtrace.report = merge_reports([trace.report for trace in mtraces])
return base_mtrace


Expand Down
5 changes: 3 additions & 2 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,9 @@ def close(self):
self.samples = {var: vtrace[:self.draw_idx]
for var, vtrace in self.samples.items()}
if self._stats is not None:
self._stats = [{var: trace[:self.draw_idx] for var, trace in stats.items()}
for stats in self._stats]
self._stats = [
{var: trace[:self.draw_idx] for var, trace in stats.items()}
for stats in self._stats]

# Selection methods

Expand Down
162 changes: 162 additions & 0 deletions pymc3/backends/report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from collections import namedtuple
import logging
import enum


logger = logging.getLogger('pymc3')


@enum.unique
class WarningType(enum.Enum):
# For HMC and NUTS
DIVERGENCE = 1
TUNING_DIVERGENCE = 2
DIVERGENCES = 3
TREEDEPTH = 4
# Problematic sampler parameters
BAD_PARAMS = 5
# Indications that chains did not converge, eg Rhat
CONVERGENCE = 6
BAD_ACCEPTANCE = 7


SamplerWarning = namedtuple(
'SamplerWarning',
"kind, message, level, step, exec_info, extra")


_LEVELS = {
'info': logging.INFO,
'error': logging.ERROR,
'warn': logging.WARN,
'debug': logging.DEBUG,
'critical': logging.CRITICAL,
}


class SamplerReport(object):
def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
self._effective_n = None
self._gelman_rubin = None

@property
def _warnings(self):
chains = sum(self._chain_warnings.values(), [])
return chains + self._global_warnings

@property
def ok(self):
"""Whether the automatic convergence checks found serious problems."""
return all(_LEVELS[warn.level] < _LEVELS['warn']
for warn in self._warnings)

def raise_ok(self, level='error'):
errors = [warn for warn in self._warnings
if _LEVELS[warn.level] >= _LEVELS[level]]
if errors:
raise ValueError('Serious convergence issues during sampling.')

def _run_convergence_checks(self, trace):
if trace.nchains == 1:
msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks")
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
None, None, None)
self._add_warnings([warn])
return

from pymc3 import diagnostics

self._effective_n = effective_n = diagnostics.effective_n(trace)
self._gelman_rubin = gelman_rubin = diagnostics.gelman_rubin(trace)

warnings = []
rhat_max = max(val.max() for val in gelman_rubin.values())
if rhat_max > 1.4:
msg = ("The gelman-rubin statistic is larger than 1.4 for some "
"parameters. The sampler did not converge.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, gelman_rubin)
warnings.append(warn)
elif rhat_max > 1.2:
msg = ("The gelman-rubin statistic is larger than 1.2 for some "
"parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, gelman_rubin)
warnings.append(warn)
elif rhat_max > 1.05:
msg = ("The gelman-rubin statistic is larger than 1.05 for some "
"parameters. This indicates slight problems during "
"sampling.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'info', None, None, gelman_rubin)
warnings.append(warn)

eff_min = min(val.min() for val in effective_n.values())
n_samples = len(trace) * trace.nchains
if eff_min < 200 and n_samples >= 500:
msg = ("The estimated number of effective samples is smaller than "
"200 for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, effective_n)
warnings.append(warn)
elif eff_min / n_samples < 0.25:
msg = ("The number of effective samples is smaller than "
"25% for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, effective_n)
warnings.append(warn)

self._add_warnings(warnings)

def _add_warnings(self, warnings, chain=None):
if chain is None:
warn_list = self._global_warnings
else:
warn_list = self._chain_warnings.setdefault(chain, [])
warn_list.extend(warnings)

def _log_summary(self):

def log_warning(warn):
level = _LEVELS[warn.level]
logger.log(level, warn.message)

for chain, warns in self._chain_warnings.items():
for warn in warns:
log_warning(warn)
for warn in self._global_warnings:
log_warning(warn)

def _slice(self, start, stop, step):
report = SamplerReport()

def filter_warns(warnings):
filtered = []
for warn in warnings:
if warn.step is None:
filtered.append(warn)
elif (start <= warn.step < stop and
(warn.step - start) % step == 0):
warn = warn._replace(step=warn.step - start)
filtered.append(warn)
return filtered

report._add_warnings(filter_warns(self._global_warnings))
for chain in self._chain_warnings:
report._add_warnings(
filter_warns(self._chain_warnings[chain]),
chain)

return report


def merge_reports(reports):
report = SamplerReport()
for rep in reports:
report._add_warnings(rep._global_warnings)
for chain in rep._chain_warnings:
report._add_warnings(rep._chain_warnings[chain], chain)
return report
4 changes: 2 additions & 2 deletions pymc3/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ def get_neff(x, Vhat):
if t % 2:
t -= 1

return min(num_chains * num_samples,
int(num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())))
neff = num_chains * num_samples / (1. + 2 * rho[1:t-1].sum())
return min(num_chains * num_samples, np.floor(neff))

def generate_neff(trace_values):
x = np.array(trace_values)
Expand Down
14 changes: 7 additions & 7 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from theano import theano, tensor as tt
from theano.tensor.var import TensorVariable

from pymc3.theanof import set_theano_conf
from pymc3.theanof import set_theano_conf, floatX
import pymc3 as pm
from pymc3.math import flatten_list
from .memoize import memoize, WithMemoization
Expand Down Expand Up @@ -1061,13 +1061,13 @@ def _get_scaling(total_size, shape, ndim):
scalar
"""
if total_size is None:
coef = pm.floatX(1)
coef = floatX(1)
elif isinstance(total_size, int):
if ndim >= 1:
denom = shape[0]
else:
denom = 1
coef = pm.floatX(total_size) / pm.floatX(denom)
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError('Unrecognized `total_size` type, expected '
Expand All @@ -1085,20 +1085,20 @@ def _get_scaling(total_size, shape, ndim):
raise ValueError('Length of `total_size` is too big, '
'number of scalings is bigger that ndim, got %r' % total_size)
elif (len(begin) + len(end)) == 0:
return pm.floatX(1)
return floatX(1)
if len(end) > 0:
shp_end = shape[-len(end):]
else:
shp_end = np.asarray([])
shp_begin = shape[:len(begin)]
begin_coef = [pm.floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [pm.floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = tt.prod(coefs)
else:
raise TypeError('Unrecognized `total_size` type, expected '
'int or list of ints, got %r' % total_size)
return tt.as_tensor(pm.floatX(coef))
return tt.as_tensor(floatX(coef))


class FreeRV(Factor, TensorVariable):
Expand Down
Loading