diff --git a/TODO.md b/TODO.md index 4842dcf..b728937 100644 --- a/TODO.md +++ b/TODO.md @@ -21,7 +21,7 @@ - [x] Bernoulli - [x] Beta - [x] Binomial - - [ ] Dirichlet + - [x] Dirichlet - [x] Gamma - [x] Student's t - [x] Uniform diff --git a/src/statistics/dist.rs b/src/statistics/dist.rs index f236f1d..88b8186 100644 --- a/src/statistics/dist.rs +++ b/src/statistics/dist.rs @@ -13,12 +13,14 @@ //! * Uniform //! * Weighted Uniform //! * Log Normal -//! * There are two enums to represent probability distribution +//! * There are three enums to represent probability distribution //! * `OPDist` : One parameter distribution (Bernoulli) //! * `TPDist` : Two parameter distribution (Uniform, Normal, Beta, Gamma) +//! * `MVDist` : Multivariate distribution (Dirichlet) //! * `T: PartialOrd + SampleUniform + Copy + Into` //! * There are some traits for pdf -//! * `RNG` trait - extract sample & calculate pdf +//! * `RNG` trait - extract sample & calculate pdf for 1D distributions +//! * `MVRNG` trait - extract sample & calculate pdf for multivariate distributions //! * `Statistics` trait - already shown above //! //! ### `RNG` trait @@ -239,6 +241,55 @@ //! * Mean: $e^{\mu + \frac{\sigma^2}{2}}$ //! * Var: $(e^{\sigma^2} - 1)e^{2\mu + \sigma^2}$ //! * To generate log-normal random samples, Peroxide uses the `rand_distr::LogNormal` distribution from the `rand_distr` crate. +//! ### `MVRNG` trait +//! +//! * `MVRNG` trait is composed of four fields +//! * `sample`: Extract samples +//! * `sample_with_rng`: Extract samples with specific rng +//! * `pdf` : Calculate pdf value at specific point +//! * `ln_pdf` : Calculate log-pdf value at specific point +//! ```no_run +//! use rand::Rng; +//! pub trait MVRNG { +//! /// Extract samples of multivariate distributions +//! fn sample(&self, n: usize) -> Matrix; +//! +//! /// Extract samples of distributions with specific rng +//! fn sample_with_rng(&self, rng: &mut R, n: usize) -> Matrix; +//! +//! /// Probability Density Function +//! fn pdf(&self, x: &[f64]) -> f64; +//! +//! /// Log Probability Density Function +//! fn ln_pdf(&self, x: &[f64]) -> f64; +//! } +//! ``` +//! +//! ### Dirichlet Distribution +//! +//! * Definition +//! $$ \text{Dir}(\mathbf{x} | \boldsymbol{\alpha}) = \frac{1}{\text{B}(\boldsymbol{\alpha})} \prod_{i=1}^K x_i^{\alpha_i - 1} $$ +//! where $\text{B}(\boldsymbol{\alpha}) = \frac{\prod_{i=1}^K \Gamma(\alpha_i)}{\Gamma(\sum_{i=1}^K \alpha_i)}$ +//! * Representative value +//! * Mean: $\frac{\alpha_i}{\alpha_0}$ +//! * Var : $\frac{\alpha_i(\alpha_0 - \alpha_i)}{\alpha_0^2(\alpha_0 + 1)}$ +//! * To generate Dirichlet random samples, Peroxide generates $K$ independent Gamma samples and normalizes them. +//! * **Caution**: `MVDist` utilizes the existing `Statistics` trait but outputs vectors and matrices. +//! +//! ```rust +//! use peroxide::fuga::*; +//! +//! fn main() { +//! let mut rng = smallrng_from_seed(42); +//! let a = Dirichlet(vec![1.0, 2.0, 3.0]); // Dir(x | 1.0, 2.0, 3.0) +//! a.sample(100).print(); // Generate 100 samples +//! a.sample_with_rng(&mut rng, 100).print(); // Generate 100 samples with specific rng +//! a.pdf(&[0.16, 0.33, 0.51]).print(); // Probability density +//! a.mean().print(); // Mean vector +//! a.var().print(); // Variance vector +//! a.cov().print(); // Covariance matrix +//! } +//! ``` extern crate rand; extern crate rand_distr; @@ -255,6 +306,7 @@ use self::WeightedUniformError::*; use crate::statistics::{ops::C, stat::Statistics}; use crate::util::non_macro::{linspace, seq}; use crate::util::useful::{auto_zip, find_interval}; +use crate::structure::matrix::{matrix, Matrix, Row}; use anyhow::{bail, Result}; use std::f64::consts::E; @@ -283,6 +335,15 @@ pub enum TPDist> { LogNormal(T, T), } +/// Multivariate Distribution +/// +/// # Distributions +/// * `Dirichlet(alpha)`: Dirichlet distribution +#[derive(Debug, Clone)] +pub enum MVDist> { + Dirichlet(Vec), +} + pub struct WeightedUniform> { weights: Vec, sum: T, @@ -1000,3 +1061,150 @@ impl Statistics for WeightedUniform { vec![1f64] } } + +/// Multivariate Random Number Generator Trait +pub trait MVRNG { + /// Extract samples of multivariate distributions (Returns an n x k Matrix) + fn sample(&self, n: usize) -> Matrix { + let mut rng = rand::rng(); + self.sample_with_rng(&mut rng, n) + } + + /// Extract samples of distributions with specific rng + fn sample_with_rng(&self, rng: &mut R, n: usize) -> Matrix; + + /// Probability Density Function + fn pdf(&self, x: &[f64]) -> f64 { + self.ln_pdf(x).exp() + } + + /// Log Probability Density Function + fn ln_pdf(&self, x: &[f64]) -> f64; +} + +impl> Statistics for MVDist { + type Array = Matrix; + type Value = Vec; + + fn mean(&self) -> Self::Value { + match self { + MVDist::Dirichlet(alpha_t) => { + let alpha: Vec = alpha_t.iter().map(|&a| a.into()).collect(); + let alpha0: f64 = alpha.iter().sum(); + alpha.iter().map(|&a| a / alpha0).collect() + } + } + } + + fn var(&self) -> Self::Value { + match self { + MVDist::Dirichlet(alpha_t) => { + let alpha: Vec = alpha_t.iter().map(|&a| a.into()).collect(); + let alpha0: f64 = alpha.iter().sum(); + let norm = alpha0.powi(2) * (alpha0 + 1.0); + alpha.iter().map(|&a| a * (alpha0 - a) / norm).collect() + } + } + } + + fn sd(&self) -> Self::Value { + self.var().into_iter().map(|v| v.sqrt()).collect() + } + + fn cov(&self) -> Self::Array { + match self { + MVDist::Dirichlet(alpha_t) => { + let alpha: Vec = alpha_t.iter().map(|&a| a.into()).collect(); + let alpha0: f64 = alpha.iter().sum(); + let k = alpha.len(); + let norm = alpha0.powi(2) * (alpha0 + 1.0); + let mut cov_data = vec![0f64; k * k]; + + for i in 0..k { + for j in 0..k { + let idx = i * k + j; + if i == j { + cov_data[idx] = alpha[i] * (alpha0 - alpha[i]) / norm; + } else { + cov_data[idx] = -alpha[i] * alpha[j] / norm; + } + } + } + + matrix(cov_data, k, k, Row) + } + } + } + + fn cor(&self) -> Self::Array { + let cov_matrix = self.cov(); + let sd_vec = self.sd(); + let k = sd_vec.len(); + + let mut cor_data = vec![0f64; k * k]; + + for i in 0..k { + for j in 0..k { + let idx = i * k + j; + cor_data[idx] = cov_matrix[(i, j)] / (sd_vec[i] * sd_vec[j]); + } + } + matrix(cor_data, k, k, Row) + } +} + +impl> MVRNG for MVDist { + fn sample_with_rng(&self, rng: &mut R, n: usize) -> Matrix { + match self { + MVDist::Dirichlet(alpha_t) => { + let alpha: Vec = alpha_t.iter().map(|&a| a.into()).collect(); + let k = alpha.len(); + let mut sample_data = vec![0f64; n * k]; + + for i in 0..n { + let mut sum = 0f64; + let mut y = vec![0f64; k]; + + for j in 0..k { + let gamma_dist = rand_distr::Gamma::new(alpha[j], 1.0).unwrap(); + y[j] = gamma_dist.sample(rng); + sum += y[j]; + } + + for j in 0..k { + sample_data[i * k + j] = y[j] / sum; + } + } + + matrix(sample_data, n, k, Row) + } + } + } + + fn ln_pdf(&self, x: &[f64]) -> f64 { + match self { + MVDist::Dirichlet(alpha_t) => { + let alpha: Vec = alpha_t.iter().map(|&a| a.into()).collect(); + assert_eq!(alpha.len(), x.len(), "Arguments must have correct dimensions."); + + let mut term = 0f64; + let mut sum_x = 0f64; + let mut sum_alpha_ln_gamma = 0f64; + let mut alpha0 = 0f64; + + for (&x_i, &alpha_i) in x.iter().zip(alpha.iter()) { + assert!(x_i > 0f64 && x_i < 1f64, "Arguments must be in (0, 1)"); + + term += (alpha_i - 1.0) * x_i.ln(); + sum_alpha_ln_gamma += gamma(alpha_i).ln(); + sum_x += x_i; + alpha0 += alpha_i; + } + + assert!((sum_x - 1.0).abs() < 1e-4, "Arguments must sum up to 1"); + + term + gamma(alpha0).ln() - sum_alpha_ln_gamma + } + } + } +} diff --git a/src/util/print.rs b/src/util/print.rs index c9dbc5e..928462d 100644 --- a/src/util/print.rs +++ b/src/util/print.rs @@ -367,6 +367,12 @@ impl> Printable for TPD } } +impl> Printable for MVDist { + fn print(&self) { + println!("{:?}", self); + } +} + //impl Printable for Number { // fn print(&self) { // println!("{:?}", self) diff --git a/tests/dist.rs b/tests/dist.rs index 883a1d3..1b27c3e 100644 --- a/tests/dist.rs +++ b/tests/dist.rs @@ -8,3 +8,22 @@ fn test_binomial() { assert!(nearly_eq(b.mean(), 80f64)); assert!(nearly_eq(b.var(), 16f64)); } + +#[test] +fn test_dirichlet() { + let dir = MVDist::Dirichlet(vec![1.0, 2.0, 3.0]); + dir.sample(10).print(); + + let m = dir.mean(); + assert!(nearly_eq(m[0], 1.0 / 6.0)); + assert!(nearly_eq(m[1], 1.0 / 3.0)); + assert!(nearly_eq(m[2], 0.5)); + + let v = dir.var(); + assert!(nearly_eq(v[0], 5.0 / 252.0)); // 1 * 5 / (36 * 7) + assert!(nearly_eq(v[1], 8.0 / 252.0)); // 2 * 4 / (36 * 7) + assert!(nearly_eq(v[2], 9.0 / 252.0)); // 3 * 3 / (36 * 7) + + let pdf_val = dir.pdf(&[0.33333, 0.33333, 0.33333]); + assert!(nearly_eq(pdf_val, 2.222155556222205)); +}