Source code for probeye.inference.emcee.solver

# standard library imports
from typing import TYPE_CHECKING, Optional
import time
import random
import contextlib

# third party imports
import numpy as np
import emcee
import arviz as az
from loguru import logger
from tabulate import tabulate

# local imports
from probeye.subroutines import pretty_time_delta
from probeye.subroutines import check_for_uninformative_priors
from probeye.inference.scipy.solver import ScipySolver
from probeye.subroutines import stream_to_logger
from probeye.subroutines import print_dict_in_rows
from probeye.subroutines import extract_true_values

# imports only needed for type hints
if TYPE_CHECKING:  # pragma: no cover
    from probeye.definition.inverse_problem import InverseProblem


[docs]class EmceeSolver(ScipySolver): """ Provides emcee-sampler which is a pure-Python implementation of Goodman & Weare’s Affine Invariant Markov chain Monte Carlo (MCMC) Ensemble sampler. For more information, check out https://emcee.readthedocs.io/en/stable/. Parameters ---------- problem Describes the inverse problem including e.g. parameters and data. seed Random state used for random number generation. show_progress When True, the progress of a solver routine will be shown (for example as a progress-bar) if such a feature is available. Otherwise, the progress will not shown. """ def __init__( self, problem: "InverseProblem", seed: Optional[int] = None, show_progress: bool = True, ): logger.debug(f"Initializing {self.__class__.__name__}") # check that the problem does not contain a uninformative prior check_for_uninformative_priors(problem) # initialize the scipy-based solver (ScipySolver) super().__init__(problem, seed=seed, show_progress=show_progress)
[docs] def emcee_summary( self, posterior_samples: np.ndarray, true_values: Optional[dict] = None ) -> dict: """ Computes and prints a summary of the posterior samples containing mean, median, standard deviation, 5th percentile and 95th percentile. Note, that this method was based on code from the taralli package: https://gitlab.com/tno-bim/taralli. Parameters ---------- posterior_samples The generated samples in an array with as many columns as there are latent parameters, and n rows, where n = n_chains * n_steps. true_values True parameter values, if known. Returns ------- Keys are the different statistics 'mean', 'median', 'sd' (standard deviation), 'q05' and 'q95' (0.05- and 0.95-quantile). The values are dictionaries with the parameter names as keys and the respective statistics as values. """ # used for the names in the first column var_names = self.problem.get_theta_names(tex=False, components=True) # compute some stats for each column (i.e., each parameter) mean = np.mean(posterior_samples, axis=0) quantiles = np.quantile(posterior_samples, [0.50, 0.05, 0.95], axis=0) median = quantiles[0, :] quantile_05 = quantiles[1, :] quantile_95 = quantiles[2, :] # compute the sample standard deviations for each parameter cov_matrix = np.atleast_2d(np.cov(posterior_samples.T)) sd = np.sqrt(np.diag(cov_matrix)) # assemble the summary array if true_values: col_names = ["", "true", "mean", "median", "sd", "5%", "95%"] true = extract_true_values(true_values, var_names) row_names = np.array(var_names) tab = np.hstack( ( row_names.reshape(-1, 1), true.reshape(-1, 1), mean.reshape(-1, 1), median.reshape(-1, 1), sd.reshape(-1, 1), quantile_05.reshape(-1, 1), quantile_95.reshape(-1, 1), ) ) # print the generated table, and return a summary dict for later use print(tabulate(tab, headers=col_names, floatfmt=".2f")) return { "true": {name: val for name, val in zip(row_names, true)}, "mean": {name: val for name, val in zip(row_names, mean)}, "median": {name: val for name, val in zip(row_names, median)}, "sd": {name: val for name, val in zip(row_names, sd)}, "q05": {name: val for name, val in zip(row_names, quantile_05)}, "q95": {name: val for name, val in zip(row_names, quantile_95)}, } else: col_names = ["", "mean", "median", "sd", "5%", "95%"] row_names = np.array(var_names) tab = np.hstack( ( row_names.reshape(-1, 1), mean.reshape(-1, 1), median.reshape(-1, 1), sd.reshape(-1, 1), quantile_05.reshape(-1, 1), quantile_95.reshape(-1, 1), ) ) # print the generated table, and return a summary dict for later use print(tabulate(tab, headers=col_names, floatfmt=".2f")) return { "mean": {name: val for name, val in zip(row_names, mean)}, "median": {name: val for name, val in zip(row_names, median)}, "sd": {name: val for name, val in zip(row_names, sd)}, "q05": {name: val for name, val in zip(row_names, quantile_05)}, "q95": {name: val for name, val in zip(row_names, quantile_95)}, }
[docs] def run( self, n_walkers: int = 20, n_steps: int = 1000, n_initial_steps: int = 100, true_values: Optional[dict] = None, **kwargs, ) -> az.data.inference_data.InferenceData: """ Runs the emcee-sampler for the InverseProblem the EmceeSolver was initialized with and returns the results as an arviz InferenceData obj. Parameters ---------- n_walkers Number of walkers used by the estimator. n_steps Number of steps to run. n_initial_steps Number of steps for initial (burn-in) sampling. true_values True parameter values, if known. kwargs Additional key-word arguments channeled to emcee.EnsembleSampler. Returns ------- inference_data Contains the results of the sampling procedure. """ # log which solver is used logger.info( f"Solving problem using emcee sampler with {n_initial_steps} + {n_steps} " f"samples and {n_walkers} walkers" ) if kwargs: logger.info("Additional options:") print_dict_in_rows(kwargs, printer=logger.info) else: logger.info("No additional options specified") # draw initial samples from the parameter's priors logger.debug("Drawing initial samples") if self.seed is not None: np.random.seed(self.seed) sampling_initial_positions = np.zeros( (n_walkers, self.problem.n_latent_prms_dim) ) theta_names = self.problem.get_theta_names(tex=False, components=False) for parameter_name in theta_names: idx = self.problem.parameters[parameter_name].index idx_end = self.problem.parameters[parameter_name].index_end samples = self.sample_from_prior(parameter_name, n_walkers) if (idx_end - idx) == 1: sampling_initial_positions[:, idx] = samples else: sampling_initial_positions[:, idx:idx_end] = samples # The following code is based on taralli and merely adjusted to the variables # in the probeye setup; see https://gitlab.com/tno-bim/taralli # ............................................................................ # # Pre-process # # ............................................................................ # def logprob(x): # Skip loglikelihood evaluation if logprior is equal # to negative infinity logprior = self.logprior(x) if logprior == -np.inf: return logprior # Otherwise return logprior + loglikelihood return logprior + self.loglike(x) logger.debug("Setting up EnsembleSampler") sampler = emcee.EnsembleSampler( nwalkers=n_walkers, ndim=self.problem.n_latent_prms_dim, log_prob_fn=logprob, **kwargs, ) if self.seed is not None: random.seed(self.seed) sampler.random_state = np.random.mtrand.RandomState(self.seed) # ............................................................................ # # Initial sampling, burn-in: used to avoid a poor starting point # # ............................................................................ # logger.debug("Starting sampling (initial + main)") start = time.time() state = sampler.run_mcmc( initial_state=sampling_initial_positions, nsteps=n_initial_steps, progress=self.show_progress, ) sampler.reset() # ............................................................................ # # Sampling of the posterior # # ............................................................................ # sampler.run_mcmc( initial_state=state, nsteps=n_steps, progress=self.show_progress ) end = time.time() runtime_str = pretty_time_delta(end - start) logger.info( f"Sampling of the posterior distribution completed: {n_steps} steps and " f"{n_walkers} walkers." ) logger.info(f"Total run-time (including initial sampling): {runtime_str}.") logger.info("") logger.info("Summary of sampling results (emcee)") posterior_samples = sampler.get_chain(flat=True) with contextlib.redirect_stdout(stream_to_logger("INFO")): # type: ignore self.summary = self.emcee_summary( posterior_samples, true_values=true_values ) logger.info("") # empty line for visual buffer self.raw_results = sampler # translate the results to a common data structure and return it var_names = self.problem.get_theta_names(tex=True, components=True) inference_data = az.from_emcee(sampler, var_names=var_names) return inference_data