# ***************************************************************
# Copyright (c) 2023 Jittor. All Rights Reserved.
# Maintainers:
# Haoyang Peng <2247838039@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import math
import os
import numpy as np
import jittor as jt
from jittor import nn
from jittor.nn import binary_cross_entropy_with_logits
from jittor import lgamma, igamma
from jittor.math_util.gamma import gamma_grad, sample_gamma
[文档]
def simple_presum(x):
"""
对输入的 ``data`` 进行简单的前缀和计算。公式如下:
.. math::
y[i_0, ..., i_{n-1}, j] = \\sum_{k=0}^{j} x[i_0, ..., i_{n-1}, k]
其中, :math:`n` 为 :math:`x` 的维度数。
参数:
- x (Var): 输入的 ``data`` ,可以是任意维度。
代码示例:
>>> jt.flags.use_cuda = 1
>>> x = jt.array([1.0, 2.0, 3.0, 4.0])
>>> print(simple_presum(x))
jt.Var([0. 1. 3. 6. 10.])
返回值:
:math:`x` 的 ``shape`` 向后扩展一位的结果,数据类型与 :math:`x` 相同
"""
src = '''
__inline_static__
@python.jittor.auto_parallel(1)
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
out[i0*(nl+1)] = 0;
for (int i=0; i<nl; i++)
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
}
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
'''
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
cpu_src=src, cuda_src=src)
[文档]
class OneHotCategorical:
"""
初始化一个“独热分类”的模型, 它是Categorical的一个子类。
参数:
- probs (Var): 是一个概率列表, 默认值为 ``None``。表示类别出现的概率。
- logits (Var): 是一个实数列表, 默认值为 ``None``。表示属于不同类别的 logit 值(未标准化的对数概率)。
logits 参数和 probs 参数应至少有一个不为 ``None``。
代码示例:
>>> from jittor.distributions import OneHotCategorical
>>> onehot = OneHotCategorical(jt.array([0.3, 0.7]))
>>> onehot.sample()
jt.Var([0. 1.], dtype=float32)
"""
def __init__(self, probs=None, logits=None):
Categorical.__init__(self, probs, logits)
def sample(self, sample_shape=[]):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r).float()
return one_hot
def log_prob(self, x):
x = jt.argmax(x, dim=-1)[0]
return Categorical.log_prob(self, x)
def entropy(self):
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
[文档]
class Categorical:
"""
处理分类问题的概率分布。可以根据给定的概率或者 ``logits`` 初始化类别分布、根据初始化的类别分布进行采样、计算某个类别的对数概率、计算该类别分布的熵。类别分布的熵的计算形式为
.. math::
-\\sum p \\log p
参数:
- probs (Var 或 None): 概率分布向量,为 ``None`` 时,根据 logits 参数计算得出。默认值为 ``None`` 。 probs 的正规化后得到
.. math::
p_i^{\\prime} = \\log\\left(\\frac{p_i}{1-\\sum p_i}\\right)
- logits (Var 或 None): 逻辑值向量,默认值为 ``None`` 。logits 先作用 Sigmoid 函数转换为概率分布,然后再使用 probs 的正规化方法。
probs 和 logits 应有至少有一者不为 `None`。
"""
def __init__(self, probs=None, logits=None):
assert not (probs is None and logits is None)
if probs is None:
# cannot align to pytorch
probs = jt.sigmoid(logits)
probs = probs / probs.sum(-1, True)
if logits is None:
logits = jt.safe_log(probs)
with jt.no_grad():
self.probs = probs
self.logits = logits
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]
def sample(self, sample_shape=()):
shape = sample_shape + self.probs.shape[:-1] + (1,)
rand = jt.rand(shape)
one_hot = jt.logical_and(self.cum_probs_l < rand, rand <= self.cum_probs_r)
index = one_hot.index(one_hot.ndim - 1)
return (one_hot * index).sum(-1)
def log_prob(self, x):
a = self.probs.ndim
b = x.ndim
indexes = tuple( f'i{i}' for i in range(b-a+1, b) )
indexes = indexes + (x,)
return jt.safe_log(self.probs).getitem(indexes)
def entropy(self):
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)
[文档]
class Normal:
"""
定义了一个正态分布对象,提供了样本抽样、计算对数概率以及熵计算等 。
参数:
- mu (Var): 正态分布的均值。没有默认值,必须明确给出。
- sigma (Var): 正态分布的标准差。没有默认值,必须明确给出。
形状:
- mu 和 sigma 为形状相同的任意维数 Var,表示一系列服从正态分布的变量。
代码示例:
>>> n = Normal(mu=jt.array([1.1, 4.]), sigma=jt.array([5., 1.4]))
>>> n.sample()
jt.Var([8.488743 4.159627], dtype=float32)
"""
def __init__(self, mu, sigma):
self.mu = mu
self.sigma = sigma
def sample(self, sample_shape=None):
return jt.normal(jt.array(self.mu), jt.array(self.sigma),size=sample_shape)
def log_prob(self, x):
var = self.sigma**2
log_scale = jt.safe_log(self.sigma)
return -((x-self.mu)**2) / (2*var) - log_scale-np.log(np.sqrt(2*np.pi))
def entropy(self):
return 0.5+0.5*np.log(2*np.pi)+jt.safe_log(self.sigma)
[文档]
class Geometric:
"""
这是一个处理几何概率分布的类,几何概率分布是离散概率分布的一种,描述了在进行一系列独立的、具有相同成功概率的伯努利试验中,首次成功需要的试验次数。类 Geometric 初始化时,需要提供成功概率或者对数几率(logit)其中之一,且成功概率 :math:`p` 的取值范围在 0 和 1 之间。
参数:
- p (float): 几何概率分布成功概率。默认值:None
- logits (float): 几何概率分布的对数几率。默认值:None
代码示例:
>>> from jittor.distributions import Geometric
>>> geom = Geometric(p=0.5)
>>> geom.entropy()
jt.Var([1.3862944], dtype=float32)
"""
def __init__(self,p=None,logits=None):
assert (p is not None) or (logits is not None)
assert 0 < p and p < 1
if p is None:
self.prob = jt.sigmoid(logits)
self.logits = logits
elif logits is None:
self.prob = p
self.logits = -jt.safe_log(1. / p - 1)
def sample(self, sample_shape):
u = jt.rand(sample_shape)
return (jt.safe_log(u) / (jt.safe_log(-self.probs+1))).floor_int()
def log_prob(self, x):
return x*jt.safe_log(-self.prob+1)+jt.safe_log(self.prob)
def entropy(self):
return binary_cross_entropy_with_logits(jt.array(self.logits),jt.array(self.prob)) / self.prob
[文档]
class GammaDistribution:
"""
实现了 Gamma 分布类,支持随机采样、计算分布函数值、对数概率、熵、平均值和方差。Gamma 分布是一种主要用于连续概率分布的两参数家族,即集中度和比率参数,广泛应用于阵列模型等。其概率密度函数可以表示为:
.. math::
p(x) = \\frac{x^{\\alpha-1}e^{-\\frac{x}{\\theta}}}{\\theta^\\alpha\\Gamma(\\alpha)}
其中, :math:`\\alpha > 0` 是形状(shape)参数, 具有构成分布形状的主导作用; :math:`\\theta > 0` 是比率(rate)参数; :math:`\\Gamma(\\alpha)` 是 Gamma 函数。
参数:
- concentration (Var): Gamma 分布的形状参数。
- rate (Var): Gamma 分布的比率参数。
代码示例:
>>> gamma_dist = GammaDistribution(concentration=2, rate=1.5)
>>> gamma_dist.mean() # 计算平均值
>>> 1.3333333333333333
"""
def __init__(self, concentration, rate):
self.concentration = concentration
self.rate = rate
self.lgamma_alpha = lgamma.apply(jt.array([concentration,]))
def sample(self, shape):
return sample_gamma(self.concentration, shape)
def cdf(self, value):
return igamma(self.concentration, value)
def log_prob(self, value):
return (self.concentration * jt.log(self.rate) +
(self.concentration - 1) * jt.log(value) -
self.rate * value - self.lgamma_alpha)
def mean(self):
return self.concentration / self.rate
def mode(self):
return np.minimum((self.concentration - 1) / self.rate, 1)
def variance(self):
return self.concentration / (self.rate * self.rate)
[文档]
def kl_divergence(cur_dist, old_dist):
"""
计算两个概率分布之间的 ``KL`` 散度。 ``KL`` 散度,也称为 ``Kullback-Leibler`` 散度,是一种测量两个概率分布之间差异的方法。如果两个分布完全相同,则 ``KL`` 散度为0。请注意, ``KL`` 散度不是对称的,也就是说, ``KL`` 散度( ``cur_dist`` , ``old_dist`` )并不等于 ``KL`` 散度( ``old_dist`` , ``cur_dist`` )。数学公式:
对于 ``Normal`` 分布, ``KL`` 散度为
.. math::
$0.5 * \\left(\\frac{\\sigma_{cur}^2}{\\sigma_{old}^2} + \\left(\\frac{\\mu_{cur} - \\mu_{old}}{\\sigma_{old}}\\right)^2 - 1 - \\log\\left(\\frac{\\sigma_{cur}^2}{\\sigma_{old}^2}\\right)\\right)$
对于 ``Categorical`` 或 ``OneHotCategorical`` 分布, ``KL`` 散度是
.. math::
$prob_{cur}*(logits_{cur}-logits_{old})$
对于 ``Uniform`` 分布,如果满足条件 ``old_dist.low`` > ``cur_dist.low`` or ``old_dist.high`` < ``cur_dist.high`` ,则结果为正无穷,否则为
.. math::
$log\\left(\\frac{old\\_high - old\\_low}{cur\\_high - cur_low}\\right)$
对于 ``Geometric`` 分布, ``KL`` 散度为
.. math::
$-entropy_{cur} - \\log(1-prob_{old})/prob_{cur} - logits_{old}
参数:
- cur_dist (Distribution类型): 当前的概率分布。必须是 ``Normal`` , ``Categorical`` , ``OneHotCategorical`` , ``Uniform`` 或 ``Geometric`` 的一个实例。
- old_dist(Distribution类型): 用于比较的旧概率分布。必须是与 ``cur_dist`` 同类型的实例。
返回值:
根据提供的分布类型返回两分布间的 ``KL`` 散度。
代码示例:
>>> from jittor.distributions import kl_divergence
>>> from jittor.distributions import OneHotCategorical
>>> cur_onehot = OneHotCategorical(jt.array([0.3, 0.7]))
>>> old_onehot = OneHotCategorical(jt.array([0.5, 0.5]))
>>> kl_divergence(cur_onehot, old_onehot)
jt.Var([0.08228284], dtype=float32)
"""
assert isinstance(cur_dist, type(old_dist))
if isinstance(cur_dist, Normal):
vr = (cur_dist.sigma / old_dist.sigma)**2
t1 = ((cur_dist.mu - old_dist.mu) / old_dist.sigma)**2
return 0.5*(vr+t1-1-jt.safe_log(vr))
if isinstance(cur_dist, Categorical) or isinstance(cur_dist,OneHotCategorical):
t = cur_dist.probs * (cur_dist.logits-old_dist.logits)
return t.sum(-1)
if isinstance(cur_dist, Uniform):
res = jt.safe_log((old_dist.high - old_dist.low) / (cur_dist.high - cur_dist.low))
if old_dist.low > cur_dist.low or old_dist.high < cur_dist.high:
res = math.inf
return res
if isinstance(cur_dist, Geometric):
return -cur_dist.entropy() - jt.safe_log(-old_dist.prob+1) / cur_dist.prob - old_dist.logits