# standard library imports
from typing import TYPE_CHECKING, Optional
import time
import contextlib
# third party imports
import numpy as np
import dynesty
import arviz as az
from loguru import logger
from tabulate import tabulate
from scipy.stats import norm
# local imports
from probeye.subroutines import pretty_time_delta
from probeye.subroutines import check_for_uninformative_priors
from probeye.inference.scipy.solver import ScipySolver
from probeye.subroutines import stream_to_logger
from probeye.subroutines import extract_true_values
# imports only needed for type hints
if TYPE_CHECKING: # pragma: no cover
from probeye.definition.inverse_problem import InverseProblem
[docs]class DynestySolver(ScipySolver):
"""
A static and dynamic nested parameter estimator. It facilitates the use of the
python package `dynesty`. The default is set to a static parameter estimator.
_Note:_ For full details on the `dynesty` library see:
https://dynesty.readthedocs.io/en/latest/index.html.
Parameters
----------
problem
Describes the inverse problem including e.g. parameters and data.
seed
Random state used for random number generation.
show_progress
When True, the progress of a solver routine will be shown (for example as a
progress-bar) if such a feature is available. Otherwise, the progress will
not shown.
"""
def __init__(
self,
problem: "InverseProblem",
seed: Optional[int] = None,
show_progress: bool = True,
):
logger.debug(f"Initializing {self.__class__.__name__}")
# check that the problem does not contain a uninformative prior
check_for_uninformative_priors(problem)
# initialize the scipy-based solver (ScipySolver)
super().__init__(problem, seed=seed, show_progress=show_progress)
[docs] def get_summary(
self, posterior_samples: np.ndarray, true_values: Optional[dict] = None
) -> dict:
"""
Computes and prints a summary of the posterior samples containing mean, median,
standard deviation, 5th percentile and 95th percentile. Note, that this method
was based on code from the taralli package: https://gitlab.com/tno-bim/taralli.
Parameters
----------
posterior_samples
The generated samples in an array with as many columns as there are latent
parameters, and n rows, where n = n_chains * n_steps.
true_values
True parameter values, if known.
Returns
-------
Keys are the different statistics 'mean', 'median', 'sd' (standard
deviation), 'q05' and 'q95' (0.05- and 0.95-quantile). The values are
dictionaries with the parameter names as keys and the respective statistics
as values.
"""
# used for the names in the first column
var_names = self.problem.get_theta_names(tex=False, components=True)
# compute some stats for each column (i.e., each parameter)
mean = np.mean(posterior_samples, axis=0)
quantiles = np.quantile(posterior_samples, [0.50, 0.05, 0.95], axis=0)
median = quantiles[0, :]
quantile_05 = quantiles[1, :]
quantile_95 = quantiles[2, :]
# compute the sample standard deviations for each parameter
cov_matrix = np.atleast_2d(np.cov(posterior_samples.T))
sd = np.sqrt(np.diag(cov_matrix))
# assemble the summary array
if true_values:
col_names = ["", "true", "mean", "median", "sd", "5%", "95%"]
true = extract_true_values(true_values, var_names)
row_names = np.array(var_names)
tab = np.hstack(
(
row_names.reshape(-1, 1),
true.reshape(-1, 1),
mean.reshape(-1, 1),
median.reshape(-1, 1),
sd.reshape(-1, 1),
quantile_05.reshape(-1, 1),
quantile_95.reshape(-1, 1),
)
)
# print the generated table, and return a summary dict for later use
print(tabulate(tab, headers=col_names, floatfmt=".2f"))
return {
"true": {name: val for name, val in zip(row_names, true)},
"mean": {name: val for name, val in zip(row_names, mean)},
"median": {name: val for name, val in zip(row_names, median)},
"sd": {name: val for name, val in zip(row_names, sd)},
"q05": {name: val for name, val in zip(row_names, quantile_05)},
"q95": {name: val for name, val in zip(row_names, quantile_95)},
}
else:
col_names = ["", "mean", "median", "sd", "5%", "95%"]
row_names = np.array(var_names)
tab = np.hstack(
(
row_names.reshape(-1, 1),
mean.reshape(-1, 1),
median.reshape(-1, 1),
sd.reshape(-1, 1),
quantile_05.reshape(-1, 1),
quantile_95.reshape(-1, 1),
)
)
# print the generated table, and return a summary dict for later use
print(tabulate(tab, headers=col_names, floatfmt=".2f"))
return {
"mean": {name: val for name, val in zip(row_names, mean)},
"median": {name: val for name, val in zip(row_names, median)},
"sd": {name: val for name, val in zip(row_names, sd)},
"q05": {name: val for name, val in zip(row_names, quantile_05)},
"q95": {name: val for name, val in zip(row_names, quantile_95)},
}
[docs] def run(
self,
estimation_method: str = "dynamic",
nlive: int = 250,
true_values: Optional[dict] = None,
**kwargs,
) -> az.data.inference_data.InferenceData:
"""
Runs the dynesty-sampler for the InverseProblem the DynestySolver was
initialized with and returns the results as an arviz InferenceData obj.
Parameters
----------
estimation_method
"dynamic" or "static"
nlive
number of live points
true_values
True parameter values, if known.
kwargs
Additional key-word arguments channeled to emcee.EnsembleSampler.
Returns
-------
inference_data or dynesty sampler
Contains the results of the sampling procedure.
"""
start = time.time()
logger.info(
f"Solving problem using dynesty sampler with keyword arguments: {kwargs}"
)
# The following code is based on taralli and merely adjusted to the variables
# in the probeye setup; see https://gitlab.com/tno-bim/taralli
# ............................................................................ #
# Pre-process #
# ............................................................................ #
rstate = None
if self.seed is not None:
rstate = np.random.default_rng(self.seed)
if estimation_method == "dynamic":
sampler = dynesty.DynamicNestedSampler(
loglikelihood=self.loglike,
prior_transform=self.prior_transform,
ndim=self.problem.n_latent_prms_dim,
rstate=rstate,
nlive=nlive,
)
sampler.run_nested(print_progress=self.show_progress, **kwargs)
elif estimation_method == "static":
sampler = dynesty.NestedSampler(
loglikelihood=self.loglike,
prior_transform=self.prior_transform,
ndim=self.problem.n_latent_prms_dim,
rstate=rstate,
nlive=nlive,
)
sampler.run_nested(print_progress=self.show_progress, **kwargs)
else:
raise RuntimeError(
"Choose 'dynamic' or 'static' as the estimation_method parameter!"
)
end = time.time()
runtime_str = pretty_time_delta(end - start)
logger.info(f"Total run-time: {runtime_str}.")
logger.info("Resample weighted samples to equal samples for post")
logger.info("processing. Access the original dynesty results via .raw_results")
weights = np.exp(sampler.results.logwt - sampler.results.logz[-1])
samples = dynesty.utils.resample_equal(sampler.results.samples, weights)
logger.info("") # empty line for visual buffer
logger.info("Summary of sampling results (dynesty)")
with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore
self.summary = self.get_summary(samples, true_values=true_values)
logger.info("") # empty line for visual buffer
logger.info(f"Posterior log evidence: {sampler.results.logz[-1]}")
logger.info("") # empty line for visual buffer
var_names = self.problem.get_theta_names(tex=True, components=True)
posterior_dict = {name: samples[:, i] for i, name in enumerate(var_names)}
inference_data = az.convert_to_inference_data(posterior_dict)
# TODO maybe add information like "source = dynesty", timestamp, etc ...
self.raw_results = sampler.results
return inference_data