jittor.distributions
这里是Jittor的随机分布模块的API文档,您可以通过from jittor import distributions
来获取该模块。
- class jittor.distributions.Categorical(probs=None, logits=None)[源代码]
处理分类问题的概率分布。可以根据给定的概率或者
logits
初始化类别分布、根据初始化的类别分布进行采样、计算某个类别的对数概率、计算该类别分布的熵。类别分布的熵的计算形式为\[-\sum p \log p\]- 参数:
probs (Var 或 None): 概率分布向量,为
None
时,根据 logits 参数计算得出。默认值为None
。 probs 的正规化后得到
\[p_i^{\prime} = \log\left(\frac{p_i}{1-\sum p_i}\right)\]logits (Var 或 None): 逻辑值向量,默认值为
None
。logits 先作用 Sigmoid 函数转换为概率分布,然后再使用 probs 的正规化方法。
probs 和 logits 应有至少有一者不为 None。
- class jittor.distributions.GammaDistribution(concentration, rate)[源代码]
实现了 Gamma 分布类,支持随机采样、计算分布函数值、对数概率、熵、平均值和方差。Gamma 分布是一种主要用于连续概率分布的两参数家族,即集中度和比率参数,广泛应用于阵列模型等。其概率密度函数可以表示为:
\[p(x) = \frac{x^{\alpha-1}e^{-\frac{x}{\theta}}}{\theta^\alpha\Gamma(\alpha)}\]其中, \(\alpha > 0\) 是形状(shape)参数, 具有构成分布形状的主导作用; \(\theta > 0\) 是比率(rate)参数; \(\Gamma(\alpha)\) 是 Gamma 函数。
- 参数:
concentration (Var): Gamma 分布的形状参数。
rate (Var): Gamma 分布的比率参数。
- 代码示例:
>>> gamma_dist = GammaDistribution(concentration=2, rate=1.5) >>> gamma_dist.mean() # 计算平均值 >>> 1.3333333333333333
- class jittor.distributions.Geometric(p=None, logits=None)[源代码]
这是一个处理几何概率分布的类,几何概率分布是离散概率分布的一种,描述了在进行一系列独立的、具有相同成功概率的伯努利试验中,首次成功需要的试验次数。类 Geometric 初始化时,需要提供成功概率或者对数几率(logit)其中之一,且成功概率 \(p\) 的取值范围在 0 和 1 之间。
- 参数:
p (float): 几何概率分布成功概率。默认值:None
logits (float): 几何概率分布的对数几率。默认值:None
- 代码示例:
>>> from jittor.distributions import Geometric >>> geom = Geometric(p=0.5) >>> geom.entropy() jt.Var([1.3862944], dtype=float32)
- class jittor.distributions.Normal(mu, sigma)[源代码]
定义了一个正态分布对象,提供了样本抽样、计算对数概率以及熵计算等 。
- 参数:
mu (Var): 正态分布的均值。没有默认值,必须明确给出。
sigma (Var): 正态分布的标准差。没有默认值,必须明确给出。
- 形状:
mu 和 sigma 为形状相同的任意维数 Var,表示一系列服从正态分布的变量。
- 代码示例:
>>> n = Normal(mu=jt.array([1.1, 4.]), sigma=jt.array([5., 1.4])) >>> n.sample() jt.Var([8.488743 4.159627], dtype=float32)
- class jittor.distributions.OneHotCategorical(probs=None, logits=None)[源代码]
初始化一个“独热分类”的模型, 它是Categorical的一个子类。
- 参数:
probs (Var): 是一个概率列表, 默认值为
None
。表示类别出现的概率。logits (Var): 是一个实数列表, 默认值为
None
。表示属于不同类别的 logit 值(未标准化的对数概率)。
logits 参数和 probs 参数应至少有一个不为
None
。- 代码示例:
>>> from jittor.distributions import OneHotCategorical >>> onehot = OneHotCategorical(jt.array([0.3, 0.7])) >>> onehot.sample() jt.Var([0. 1.], dtype=float32)
- class jittor.distributions.Uniform(low, high)[源代码]
生成指定范围内的均匀分布的随机数。如果需要产生一定形状(shape)的均匀分布样本,或者计算在指定位置的对数概率和熵,可以分别使用
sample
和log_prob
,entropy
三个方法。- 参数:
low (float): 均匀分布的下界。
high (float): 均匀分布的上界,必须大于参数
low
。
- 代码示例:
>>> unif = Uniform(1, 3) >>> unif.entropy() jt.Var([0.6931472], dtype=float32)
- jittor.distributions.kl_divergence(cur_dist, old_dist)[源代码]
计算两个概率分布之间的
KL
散度。KL
散度,也称为Kullback-Leibler
散度,是一种测量两个概率分布之间差异的方法。如果两个分布完全相同,则KL
散度为0。请注意,KL
散度不是对称的,也就是说,KL
散度(cur_dist
,old_dist
)并不等于KL
散度(old_dist
,cur_dist
)。数学公式:对于
Normal
分布,KL
散度为\[$0.5 * \left(\frac{\sigma_{cur}^2}{\sigma_{old}^2} + \left(\frac{\mu_{cur} - \mu_{old}}{\sigma_{old}}\right)^2 - 1 - \log\left(\frac{\sigma_{cur}^2}{\sigma_{old}^2}\right)\right)$\]对于
Categorical
或OneHotCategorical
分布,KL
散度是\[$prob_{cur}*(logits_{cur}-logits_{old})$\]对于
Uniform
分布,如果满足条件old_dist.low
>cur_dist.low
orold_dist.high
<cur_dist.high
,则结果为正无穷,否则为\[$log\left(\frac{old\_high - old\_low}{cur\_high - cur_low}\right)$\]对于
Geometric
分布,KL
散度为\[$-entropy_{cur} - \log(1-prob_{old})/prob_{cur} - logits_{old}\]- 参数:
cur_dist (Distribution类型): 当前的概率分布。必须是
Normal
,Categorical
,OneHotCategorical
,Uniform
或Geometric
的一个实例。old_dist(Distribution类型): 用于比较的旧概率分布。必须是与
cur_dist
同类型的实例。
- 返回值:
根据提供的分布类型返回两分布间的
KL
散度。- 代码示例:
>>> from jittor.distributions import kl_divergence >>> from jittor.distributions import OneHotCategorical >>> cur_onehot = OneHotCategorical(jt.array([0.3, 0.7])) >>> old_onehot = OneHotCategorical(jt.array([0.5, 0.5])) >>> kl_divergence(cur_onehot, old_onehot) jt.Var([0.08228284], dtype=float32)
- jittor.distributions.simple_presum(x)[源代码]
对输入的
data
进行简单的前缀和计算。公式如下:\[y[i_0, ..., i_{n-1}, j] = \sum_{k=0}^{j} x[i_0, ..., i_{n-1}, k]\]其中, \(n\) 为 \(x\) 的维度数。
- 参数:
x (Var): 输入的
data
,可以是任意维度。
- 代码示例:
>>> jt.flags.use_cuda = 1 >>> x = jt.array([1.0, 2.0, 3.0, 4.0]) >>> print(simple_presum(x)) jt.Var([0. 1. 3. 6. 10.])
- 返回值:
\(x\) 的
shape
向后扩展一位的结果,数据类型与 \(x\) 相同