Source code for choix.ep

import functools
from collections.abc import Callable, Sequence
from math import exp, log, pi, sqrt  # Faster than numpy equivalents.
from typing import Literal

import numpy as np
import numpy.random as nprand
from numpy.linalg import norm
from numpy.typing import NDArray
from scipy.special import logsumexp

from .typing import PairwiseData
from .utils import SQRT2, SQRT2PI, inv_posdef, normal_cdf

# EP-related settings.
THRESHOLD = 1e-4

MAT_ONE = np.array([[1.0, -1.0], [-1.0, 1.0]])
MAT_ONE_FLAT = MAT_ONE.ravel()


# Some magic constants for a stable computation of _log_phi(z).
CS = [
    0.00048204,
    -0.00142906,
    0.0013200243174,
    0.0009461589032,
    -0.0045563339802,
    0.00556964649138,
    0.00125993961762116,
    -0.01621575378835404,
    0.02629651521057465,
    -0.001829764677455021,
    2 * (1 - pi / 3),
    (4 - pi) / 3,
    1,
    1,
]
RS = [
    1.2753666447299659525,
    5.019049726784267463450,
    6.1602098531096305441,
    7.409740605964741794425,
    2.9788656263939928886,
]
QS = [
    2.260528520767326969592,
    9.3960340162350541504,
    12.048951927855129036034,
    17.081440747466004316,
    9.608965327192787870698,
    3.3690752069827527677,
]


[docs] def ep_pairwise( n_items: int, data: PairwiseData, alpha: float, model: Literal["logit", "probit"] = "logit", max_iter: int = 100, initial_state: tuple | None = None, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Compute a distribution of model parameters using the EP algorithm. This function computes an approximate Bayesian posterior probability distribution over model parameters, given pairwise-comparison data (see :ref:`data-pairwise`). It uses the expectation propagation algorithm, as presented, e.g., in [CG05]_. The prior distribution is assumed to be isotropic Gaussian with variance ``1 / alpha``. The posterior is approximated by a a general multivariate Gaussian distribution, described by a mean vector and a covariance matrix. Two different observation models are available. ``logit`` (default) assumes that pairwise-comparison outcomes follow from a Bradley-Terry model. ``probit`` assumes that the outcomes follow from Thurstone's model. Parameters ---------- n_items : int Number of distinct items. data : list of lists Pairwise-comparison data. alpha : float Inverse variance of the (isotropic) prior. model : str, optional Observation model. Either "logit" or "probit". max_iter : int, optional Maximum number of iterations allowed. initial_state : tuple of array_like, optional Natural parameters used to initialize the EP algorithm. Returns ------- mean : numpy.ndarray The mean vector of the approximate Gaussian posterior. cov : numpy.ndarray The covariance matrix of the approximate Gaussian posterior. Raises ------ ValueError If the observation model is not "logit" or "probit". """ if model == "logit": match_moments = _match_moments_logit elif model == "probit": match_moments = _match_moments_probit else: raise ValueError("unknown model '{}'".format(model)) return _ep_pairwise(n_items, data, alpha, match_moments, max_iter, initial_state)
def _ep_pairwise( n_items: int, comparisons: PairwiseData, alpha: float, match_moments: Callable[[float, float], tuple[float, float, float]], max_iter: int, initial_state: tuple | None, ) -> tuple[NDArray[np.float64], NDArray[np.float64]]: """Compute a distribution of model parameters using the EP algorithm. Raises ------ RuntimeError If the algorithm does not converge after ``max_iter`` iterations. """ # Static variable that allows to check the # of iterations after the call. _ep_pairwise.iterations = 0 # pyright: ignore[reportFunctionMemberAccess] m = len(comparisons) prior_inv = alpha * np.eye(n_items) if initial_state is None: # Initially, mean and covariance come from the prior. mean = np.zeros(n_items) cov = (1 / alpha) * np.eye(n_items) # Initialize the natural params in the function space. tau = np.zeros(m) nu = np.zeros(m) # Initialize the natural params in the space of thetas. prec = np.zeros((n_items, n_items)) xs = np.zeros(n_items) else: tau, nu = initial_state mean, cov, xs, prec = _init_ws(n_items, comparisons, prior_inv, tau, nu) for _ in range(max_iter): _ep_pairwise.iterations += 1 # pyright: ignore[reportFunctionMemberAccess] # Keep a copy of the old parameters for convergence testing. tau_old = np.array(tau, copy=True) nu_old = np.array(nu, copy=True) for i in nprand.permutation(m): a, b = comparisons[i] # Update mean and variance in function space. f_var = cov[a, a] + cov[b, b] - 2 * cov[a, b] f_mean = mean[a] - mean[b] # Cavity distribution. tau_tot = 1.0 / f_var nu_tot = tau_tot * f_mean tau_cav: float = tau_tot - tau[i] # pyright: ignore[reportAssignmentType] nu_cav: float = nu_tot - nu[i] # pyright: ignore[reportAssignmentType] cov_cav: float = 1.0 / tau_cav mean_cav: float = cov_cav * nu_cav # Moment matching. logpart, dlogpart, d2logpart = match_moments(mean_cav, cov_cav) # Update factor params in the function space. tau[i] = -d2logpart / (1 + d2logpart / tau_cav) delta_tau = tau[i] - tau_old[i] nu[i] = (dlogpart - (nu_cav / tau_cav) * d2logpart) / (1 + d2logpart / tau_cav) delta_nu = nu[i] - nu_old[i] # Update factor params in the weight space. prec[(a, a, b, b), (a, b, a, b)] += delta_tau * MAT_ONE_FLAT xs[a] += delta_nu xs[b] -= delta_nu # Update mean and covariance. if abs(delta_tau) > 0: phi = -1.0 / ((1.0 / delta_tau) + f_var) * MAT_ONE upd_mat = cov.take([a, b], axis=0) cov = cov + upd_mat.T.dot(phi).dot(upd_mat) mean = cov.dot(xs) # Recompute the global parameters for stability. cov = inv_posdef(prior_inv + prec) mean = cov.dot(xs) if _converged((tau, nu), (tau_old, nu_old)): return mean, cov raise RuntimeError("EP did not converge after {} iterations".format(max_iter)) def _log_phi(z: float) -> tuple[float, float]: """Stable computation of the log of the Normal CDF and its derivative.""" # Adapted from the GPML function `logphi.m`. if z * z < 0.0492: # First case: z close to zero. coef = -z / SQRT2PI val = functools.reduce(lambda acc, c: coef * (c + acc), CS, 0) res = -2 * val - log(2) dres = exp(-(z * z) / 2 - res) / SQRT2PI elif z < -11.3137: # Second case: z very small. num = functools.reduce(lambda acc, r: -z * acc / SQRT2 + r, RS, 0.5641895835477550741) den = functools.reduce(lambda acc, q: -z * acc / SQRT2 + q, QS, 1.0) res = log(num / (2 * den)) - (z * z) / 2 dres = abs(den / num) * sqrt(2.0 / pi) else: res = log(normal_cdf(z)) dres = exp(-(z * z) / 2 - res) / SQRT2PI return res, dres def _match_moments_logit(mean_cav: float, cov_cav: float) -> tuple[float, float, float]: # Adapted from the GPML function `likLogistic.m`. # First use a scale mixture. lambdas = sqrt(2) * np.array([0.44, 0.41, 0.40, 0.39, 0.36]) cs = np.array( [ 1.146480988574439e02, -1.508871030070582e03, 2.676085036831241e03, -1.356294962039222e03, 7.543285642111850e01, ] ) arr1, arr2, arr3 = np.zeros(5), np.zeros(5), np.zeros(5) for i, x in enumerate(lambdas): arr1[i], arr2[i], arr3[i] = _match_moments_probit(x * mean_cav, x * x * cov_cav) logpart1: float = logsumexp(arr1, b=cs) # pyright: ignore[reportAssignmentType] dlogpart1 = np.dot(np.exp(arr1) * arr2, cs * lambdas) / np.dot(np.exp(arr1), cs) d2logpart1 = ( np.dot(np.exp(arr1) * (arr2 * arr2 + arr3), cs * lambdas * lambdas) / np.dot(np.exp(arr1), cs) ) - (dlogpart1 * dlogpart1) # Tail decays linearly in the log domain (and not quadratically). exponent = -10.0 * (abs(mean_cav) - (196.0 / 200.0) * cov_cav - 4.0) if exponent < 500: lambd = 1.0 / (1.0 + exp(exponent)) logpart2 = min(cov_cav / 2.0 - abs(mean_cav), -0.1) dlogpart2 = 1.0 if mean_cav > 0: logpart2 = log(1 - exp(logpart2)) dlogpart2 = 0.0 d2logpart2 = 0.0 else: lambd, logpart2, dlogpart2, d2logpart2 = 0.0, 0.0, 0.0, 0.0 logpart = (1 - lambd) * logpart1 + lambd * logpart2 dlogpart = (1 - lambd) * dlogpart1 + lambd * dlogpart2 d2logpart = (1 - lambd) * d2logpart1 + lambd * d2logpart2 return logpart, dlogpart, d2logpart def _match_moments_probit(mean_cav: float, cov_cav: float) -> tuple[float, float, float]: # Adapted from the GPML function `likErf.m`. z = mean_cav / sqrt(1 + cov_cav) logpart, val = _log_phi(z) dlogpart = val / sqrt(1 + cov_cav) # 1st derivative w.r.t. mean. d2logpart = -val * (z + val) / (1 + cov_cav) return logpart, dlogpart, d2logpart def _init_ws( n_items: int, comparisons: PairwiseData, prior_inv: NDArray[np.float64], tau: NDArray[np.float64], nu: NDArray[np.float64], ) -> tuple[NDArray[np.float64], NDArray[np.float64], NDArray[np.float64], NDArray[np.float64]]: """Initialize parameters in the weight space.""" prec = np.zeros((n_items, n_items)) xs = np.zeros(n_items) for i, (a, b) in enumerate(comparisons): prec[(a, a, b, b), (a, b, a, b)] += tau[i] * MAT_ONE_FLAT xs[a] += nu[i] xs[b] -= nu[i] cov = inv_posdef(prior_inv + prec) mean = cov.dot(xs) return mean, cov, xs, prec def _converged( new: Sequence[NDArray[np.float64]], old: Sequence[NDArray[np.float64]], threshold: float = THRESHOLD, ) -> bool: for param_new, param_old in zip(new, old): if norm(param_new - param_old, ord=np.inf) > threshold: return False return True