Source code for maxent_disaggregation.shares

"""
This module provides functions for sampling from Dirichlet and generalized Dirichlet distributions,
as well as hybrid approaches, given specified means (shares) and standard deviations (sds) for the shares.
It supports maximum entropy Dirichlet sampling, bias correction, and robust handling of edge cases such as
missing or partially specified parameters.

Functions
---------
- generalized_dirichlet(n, shares, sds):
    Generate random samples from a Generalised Dirichlet distribution with given shares and standard deviations.

- dirichlet_max_ent(n, shares, **kwargs):
    Generate samples from a Dirichlet distribution with maximum entropy given input shares.

- sample_shares(n, shares, sds=None, grad_based=False, threshold_shares=0.1, threshold_sd=0.2, **kwargs):
    This is the main function which handles all the different cases and samples from a distribution of
    shares based on given means and standard deviations, using the appropriate distribution or a hybrid
    approach depending on the completeness of the input information.

- hybrid_dirichlet(shares, size=None, sds=None, max_rel_bias=0.10, max_iter_bias_fix=20, max_iter_beta_sampling=1e3, **kwargs):
    Sample shares in the case of partial mean and sd information using a hybrid Dirichlet distribution with
    iterative bias correction.

- sample_dirichlet(shares, size=None, gamma_par=None, threshold_dirichlet=0.01, force_nonzero_samples=True, **kwargs):
    Wrapper to sample from a Dirichlet distribution with given shares and gamma concentration parameter,
    with pragmatic handling of small shape parameters to avoid numerical issues.

- check_sample_means_and_sds(sample, shares, sds, threshold_shares=0.1, threshold_sd=0.2, suppress_warnings=False):
    Check if the sample means and standard deviations deviate more than the specified thresholds from the
    specified shares and standard deviations, raising warnings if so.

- sample_from_beta(n, shares, sds, fix=True, max_iter=1e3):
    Generate random samples from independent Beta distributions with specified means and standard deviations,
    ensuring that the sum of samples across columns does not exceed 1 for each row.

- The module is robust to missing or partially specified input parameters, using uniform priors or hybrid
  approaches as needed.
- Warnings are raised if input parameters are inconsistent or if generated samples deviate significantly
  from specified means or standard deviations.
- The module is intended for probabilistic modeling of compositional data, such as branching ratios or shares
  that sum to one.

"""

import warnings
import numpy as np
from scipy.stats import gamma
import scipy.stats as stats
from .maxent_direchlet import find_gamma_maxent, dirichlet_entropy

# set the warning filter to always show warnings
warnings.simplefilter("always", UserWarning)


[docs] def generalized_dirichlet(n, shares, sds, seed=None): """ Generate random samples from a Generalised Dirichlet distribution with given shares and standard deviations. Reference: ---------------- Plessis, Sylvain, Nathalie Carrasco, and Pascal Pernot. “Knowledge-Based Probabilistic Representations of Branching Ratios in Chemical Networks: The Case of Dissociative Recombinations.” The Journal of Chemical Physics 133, no. 13 (October 7, 2010): 134110. https://doi.org/10.1063/1.3479907. Parameters: ------------ n (int): Number of samples to generate. shares (array-like): best-guess (mean) values for the shares. Must sum to 1!y. sds (array-like): Array of standard deviations for the shares. Returns: ------------ tuple: A tuple containing: - sample (ndarray): An array of shape (n, lentgh(shares)) containing the generated samples. - None: Placeholder for compatibility with other functions (always returns None). """ shares = np.asarray(shares) sds = np.asarray(sds) if not np.isclose(shares.sum(), 1): raise ValueError("The shares must sum to 1. Please check your input values.") if not np.all(np.isfinite(sds)): raise ValueError( "The standard deviations must be finite. Please check your input values." ) if shares.shape != sds.shape: raise ValueError( "The shares and standard deviations must have the same shape. Please check your input values." ) if np.any(sds < 0): raise ValueError( "The standard deviations must be non-negative. Please check your input values." ) if np.any(shares < 0): raise ValueError( "The shares must be non-negative. Please check your input values." ) if np.any(shares > 1): raise ValueError( "The shares must be less than or equal to 1. Please check your input values." ) alpha2 = (shares / sds) ** 2 beta2 = shares / (sds) ** 2 k = len(alpha2) rng = np.random.default_rng(seed) x = np.zeros((n, k)) for i in range(k): x[:, i] = gamma.rvs(alpha2[i], scale=1 / beta2[i], size=n, random_state=rng) sample = x / x.sum(axis=1, keepdims=True) return sample
[docs] def dirichlet_max_ent(n: int, shares: np.ndarray | list, seed=None, **kwargs): """ Generate samples from a Dirichlet distribution with maximum entropy. This function computes the gamma parameter that maximizes the entropy of the Dirichlet distribution given the input shares. It then generates `n` samples from the resulting Dirichlet distribution. Parameters: n (int): The number of samples to generate. shares (array-like): The input shares (probabilities) that define the Dirichlet distribution. **kwargs: Additional keyword arguments passed to the `find_gamma_maxent` function. Returns: tuple: A tuple containing: - sample (ndarray): An array of shape (n, len(shares)) containing the generated samples. - gamma_par (float): The computed gamma parameter that maximizes the entropy of the Dirichlet distribution. """ gamma_par = find_gamma_maxent(shares, eval_f=dirichlet_entropy, **kwargs) sample = sample_dirichlet(shares * gamma_par, size=n, seed=seed, **kwargs) return sample, gamma_par
[docs] def sample_shares( n: int, shares: np.ndarray | list, sds: np.ndarray | list = None, grad_based: bool = False, threshold_shares: float = 0.1, threshold_sd: float = 0.2, suppress_warnings: bool = False, seed: int = None, **kwargs, ): """ Samples from a distribution of shares based on given means and standard deviations. This function generates samples of shares using either a generalized Dirichlet distribution, a maximum entropy Dirichlet distribution, or a combination of both, depending on the availability of mean and standard deviation inputs. Parameters: ---------- n : int Number of samples to generate. shares : np.ndarray | list Array or list of mean values for the shares. These should sum to 1 if fully specified. sds : np.ndarray | list, optional Array or list of standard deviations for the shares. If not provided, defaults to NaN. grad_based : bool, optional Whether to use gradient-based optimization for maximum entropy Dirichlet sampling. Default is False. threshold_shares : float, optional Threshold for the relative difference between the sample mean and the specified shares. If the difference exceeds this threshold, a warning is raised. Default is 0.1 (10%). threshold_sd : float, optional Threshold for the relative difference between the sample standard deviation and the specified sds. If the difference exceeds this threshold, a warning is raised. Default is 0.2 (20%). suppress_warnings : bool, optional If True, suppress warnings about sample means and standard deviations deviating from the specified values. Default is False. seed : int, optional Random seed for reproducibility. Default is None. **kwargs : dict Additional keyword arguments passed to the underlying sampling functions. Returns: ------- sample : np.ndarray A 2D array of shape (n, K), where K is the number of shares, containing the sampled values. gamma_par : np.ndarray Parameters of the Dirichlet or generalized Dirichlet distribution used for sampling. Notes: ----- - If both means and standard deviations are provided for all shares, the generalized Dirichlet distribution is used. - If only means are provided, the maximum entropy Dirichlet distribution is used. - If no means are provided, a uniform Dirichlet distribution is used. - If a mix of known and unknown means/standard deviations is provided, a hierarchical approach is used to sample the shares (function called `hybrid_dirichlet`). - The function raises warnings if standard deviations are provided without corresponding mean values, as this is not recommended. """ gamma_par = None # set default value for gamma_par if sds is None: sds = np.full_like(shares, np.nan) shares = np.asarray(shares) sds = np.asarray(sds) K = len(shares) have_mean = np.isfinite(shares) have_sd = np.isfinite(sds) have_mean_only = np.isfinite(shares) & ~np.isfinite(sds) have_sd_only = np.isfinite(sds) & ~np.isfinite(shares) if np.sum(have_sd_only) > 0: warnings.warn( "You have standard deviations for shares without a best guess estimate. This is not recommended, please check your inputs. These will be treated as missing values and ignored for the calculation." ) have_both = np.isfinite(shares) & np.isfinite(sds) if np.all(have_both): # use generalized dirichlet sample = generalized_dirichlet(n, shares, sds, seed=seed) elif np.all(have_mean_only): # maximize entropy for dirichlet sample, gamma_par = dirichlet_max_ent( n, shares, grad_based=grad_based, seed=seed, **kwargs ) elif np.isfinite(shares).sum() == 0: # no information on the shares, use uniform dirichlet shares = np.asarray([1 / len(shares)] * len(shares)) # The maximum entropy concentration parameter for a uniform Dirichlet distribution is gammapar = K sample = sample_dirichlet(shares=shares, gamma_par=K, size=n, seed=seed) # break out because it does not need to check the means and sd's return sample, None else: # If we have a mix of known and unknown shares, we handle this case using the Hybrid Dirichlet logic. sample = hybrid_dirichlet(shares=shares, size=n, sds=sds, seed=seed) # check if sample means and standard deviations deviate more than the threshold values: check_sample_means_and_sds( sample, shares, sds, threshold_shares=threshold_shares, threshold_sd=threshold_sd, suppress_warnings=suppress_warnings, ) return sample, gamma_par
[docs] def hybrid_dirichlet( shares, size=None, sds=None, max_rel_bias=0.10, max_iter_bias_fix=20, max_iter_beta_sampling=1e3, seed=None, **kwargs, ): """ Function to sample in the case of partial mean and sd information using a hybrid Dirichlet distribution with iterative bias correction. Samples are generated from a combination of beta distributions for shares with both mean and sd, and a maximum-entropy Dirichlet distribution for shares with only mean values. This function iteratively adjusts the standard deviations of shares that exceed a specified relative bias threshold, ensuring that the final samples meet the desired accuracy in terms of relative bias. Parameters: ---------- shares : array-like Array of mean values for the shares. These should sum to 1 if fully specified. size : int Number of samples to generate. sds : array-like, optional Array of standard deviations for the shares. If not provided, defaults to NaN. max_rel_bias : float, optional Maximum relative bias allowed for the generated samples. Default is 0.10 (10%). max_iter_bias_fix : int, optional Maximum number of iterations for bias correction. Default is 20. max_iter_rbeta3 : int, optional Maximum number of iterations for beta sampling. Default is 1e3. **kwargs : dict Additional keyword arguments passed to the underlying sampling functions. Returns: ------- sample : np.ndarray A 2D array of shape (size, len(shares)) containing the sampled values. """ # make sure that shares and sds are numpy arrays and exist shares = np.asarray(shares) if sds is None: sds = np.full_like(shares, np.nan) sds = np.asarray(sds) K = len(shares) if np.isnan(shares).sum() > 0: # fill the unkown means with a uniform prior shares[np.isnan(shares)] = (1 - np.nansum(shares)) / np.isnan(shares).sum() rng = np.random.default_rng(seed) iter_count = 0 while iter_count < max_iter_bias_fix: iter_count += 1 sample = np.zeros((size, K)) have_both = np.isfinite(shares) & np.isfinite(sds) have_mean_only = np.isfinite(shares) & ~np.isfinite(sds) # Derive child seeds for sub-function calls seed_beta = int(rng.integers(0, 2**31)) seed_dirichlet = int(rng.integers(0, 2**31)) # 1) Components with mean *and* SD --> Beta-truncated sampling if np.sum(have_both) > 0: sample[:, have_both] = sample_from_beta( size, shares[have_both], sds[have_both], fix=True, max_iter=max_iter_beta_sampling, seed=seed_beta, ) # 2) Components with mean only --> MaxEnt Dirichlet (rescaled afterwards) if np.sum(have_mean_only) > 0: alpha2 = shares[have_mean_only] / np.sum(shares[have_mean_only]) sample_temp, _ = dirichlet_max_ent(size, alpha2, seed=seed_dirichlet, **kwargs) sample[:, have_mean_only] = sample_temp * ( 1 - sample[:, have_both].sum(axis=1, keepdims=True) ) # Calculate the relative bias for each share rel_bias = np.abs(sample.mean(axis=0) - shares) / shares # Check if the relative bias is within the allowed threshold if np.all(rel_bias <= max_rel_bias): return sample ## 3) Bias check on the “have_both” block if np.any(have_both): if np.all(rel_bias[have_both] <= max_rel_bias): return sample # Success – exit the loop # → At least one component is too far off: mark its SD as NaN sds[(have_both) & (rel_bias > max_rel_bias)] = np.nan warnings.warn( f"Relative bias exceeded {max_rel_bias} for component(s): " f"{np.where(rel_bias > max_rel_bias)[0]}. Their standard deviations have been set to NaN " "and will be handled with a Maximum-Entropy Dirichlet on the next iteration." ) else: # If no components with both mean and SD left, we can exit the loop return sample
[docs] def sample_dirichlet( shares, size=None, gamma_par=None, threshold_dirichlet=0.01, force_nonzero_samples=True, seed=None, **kwargs, ): """ A wrapper function to sample from a Dirichlet distribution with a given set of shares and gamma concentration parameter. It differs from the default Dirichlet distroibution in that when the For each variable i whose mean value (alpha_i = gamma_par * share_i) that is below a `threshold`, a fallback parametrization of the Gamma distribution (which is used for sampling from the Dirichlet distribution) is applied to avoid zero or near-zero sampling. This is especially useful for very small shape parameters, which can cause numerical issues in in the dirichlet sampling. The following pragmatic workaround is used that sets: - alpha_i = 1 (shape) for shares below `threshold` - rate = 1 / alpha_i ensuring less extreme values. For more details, see the discussion in [rgamma()] under "small shape values" and the references there. This approach helps mitigate issues where numeric precision can push small Gamma-distributed values to zero (see https://stat.ethz.ch/R-manual/R-devel/library/stats/html/GammaDist.html). Note however that fix changes the expectation values (means) of the sampled parameters such that they can deviate from the inputed shares. If this is undesired set force_nonzero_samples=False. Parameters: ----------- size : int The number of samples to generate. shares : array-like The input shares (probabilities) that define the Dirichlet distribution. gamma_par : float The gamma parameter that scales the shares for the Dirichlet distribution. threshold : float The threshold below which the shares are adjusted to avoid zero sampling. force_nonzero_samples : bool If True, forces non-zero samples by adjusting alphas and rate/scale within the gamma distribution. This may lead to biased means of the samples. If False, uses the original scipy implementation of the Dirichlet distribution. Note that in the case of very small alphas, this may lead to a large number zeros in the samples due to numerical issues. The means are unbiased though. Methods: -------- sample(): Generates samples from the Dirichlet distribution. """ if gamma_par is None: alpha = np.asarray(shares) else: alpha = np.asarray(shares) * gamma_par if not force_nonzero_samples: print("Using scipy dirichlet!!!!!") rng = np.random.default_rng(seed) return stats.dirichlet.rvs(alpha, size=size, random_state=rng) else: l = len(alpha) rate = np.ones(l) rate[alpha < threshold_dirichlet] = 1 / alpha[alpha < threshold_dirichlet] alpha[alpha < threshold_dirichlet] = 1 rng = np.random.default_rng(seed) x = gamma.rvs(alpha, scale=1 / rate, size=(size, l), random_state=rng) sample = x / x.sum(axis=1, keepdims=True) return sample
[docs] def check_sample_means_and_sds( sample, shares, sds, threshold_shares=0.1, threshold_sd=0.2, suppress_warnings=False, ): """ Check if the sample means and standard deviations deviate more than the specified thresholds from the specified shares and standard deviations. If they do, a warning is raised. Parameters: ---------- sample : np.ndarray The generated samples from the Dirichlet distribution. shares : np.ndarray The specified shares (mean values) for the Dirichlet distribution. sds : np.ndarray The specified standard deviations for the shares. threshold_shares : float, optional The threshold for the relative difference between the sample mean and the specified shares. If the difference exceeds this threshold, a warning is raised. Default is 0.1 (10%). threshold_sd : float, optional The threshold for the relative difference between the sample standard deviation and the specified sds. If the difference exceeds this threshold, a warning is raised. Default is 0.2 (20%). suppress_warnings : bool, optional If True, suppress all warnings from this function. Default is False. """ if not isinstance(sample, np.ndarray): raise TypeError("Sample must be a numpy array.") if not isinstance(shares, np.ndarray): shares = np.asarray(shares) if not isinstance(sds, np.ndarray): sds = np.asarray(sds) sample_mean = np.mean(sample, axis=0) sample_sd = np.std(sample, axis=0) diff_mean = np.abs(sample_mean - shares) / shares diff_sd = np.abs(sample_sd - sds) / sds means_above_threshold = diff_mean > threshold_shares indices_above_threshold = np.where(means_above_threshold)[0] if np.any(means_above_threshold) and not suppress_warnings: warnings.warn( f"The generated samples for the shares have a mean that is more than {threshold_shares*100}% different from the specified shares. " f"Please check your inputs. Reasons for this could be large relative uncertainties for the shares, or a small number of samples. " f"To suppress this warning you can set suppress_warnings=True or set a higher threshold_shares.\n" f"Shares above threshold: {diff_mean[means_above_threshold]}\n" f"Shares: {shares[means_above_threshold]}\n" f"Sample mean: {sample_mean[means_above_threshold]}\n" f"Indices: {indices_above_threshold}" ) sds_above_threshold = diff_sd > threshold_sd indices_above_threshold = np.where(sds_above_threshold)[0] if np.any(sds_above_threshold) and not suppress_warnings: warnings.warn( f"The generated samples for the shares have a standard deviation that is more than {threshold_sd*100}% different from the specified sd's. " f"Please note that the specified sd's might be incompatible with the other constraints. " f"Please check your inputs. To suppress this warning you can set suppress_warnings=True or set a higher threshold_sd.\n" f"Sds above threshold: {diff_sd[sds_above_threshold]}\n" f"Sds: {sds[sds_above_threshold]}\n" f"Sample sd: {sample_sd[sds_above_threshold]}\n" f"Indices: {indices_above_threshold}" )
[docs] def sample_from_beta(n, shares, sds, fix=True, max_iter=1e3, seed=None): """ Generate random samples from independent Beta distributions with specified means (shares) and standard deviations (sds), ensuring that the sum of samples across columns does not exceed 1 for each row. Parameters ---------- n : int Number of samples to generate. shares : array-like Array of mean values (between 0 and 1) for each Beta distribution. sds : array-like Array of standard deviations for each Beta distribution. fix : bool, optional (default=True) If True, automatically adjust invalid variance values to the maximum allowed for the given mean. If False, raise a ValueError when invalid parameter combinations are detected. max_iter : int, optional (default=1e3) Maximum number of iterations to attempt resampling rows where the sum exceeds 1. Returns ------- x : ndarray An (n, k) array of samples, where k is the length of `shares`, such that each row sums to less than or equal to 1. Raises ------ ValueError If the provided standard deviation is too large for the given mean (unless `fix=True`), or if a valid sample cannot be generated within `max_iter` iterations. Notes ----- The function ensures that for each sample (row), the sum across all Beta-distributed variables does not exceed 1 by resampling as needed. """ var = sds**2 undef_comb = (shares * (1 - shares)) < var if not np.all(~undef_comb): if fix: var[undef_comb] = shares[undef_comb] ** 2 else: raise ValueError( "The beta distribution is not defined for the parameter combination you provided! sd must be smaller or equal sqrt(shares*(1-shares))" ) alpha = shares * (((shares * (1 - shares)) / var) - 1) beta = (1 - shares) * (((shares * (1 - shares)) / var) - 1) rng = np.random.default_rng(seed) k = len(shares) x = np.zeros((n, k)) for i in range(k): x[:, i] = rng.beta(alpha[i], beta[i], size=n) larger_one = x.sum(axis=1) > 1 count = 0 while np.sum(larger_one) > 0: for i in range(k): x[larger_one, i] = rng.beta( alpha[i], beta[i], size=np.sum(larger_one) ) larger_one = x.sum(axis=1) > 1 count += 1 if count > max_iter: raise ValueError( "max_iter is reached. the combinations of shares and sds you provided does allow to generate `n` random samples that are not larger than 1. Either increase max_iter, or change parameter combination." ) return x