首页
学习
活动
专区
圈层
工具
发布
首页
学习
活动
专区
圈层
工具
社区首页 >问答首页 >多元正态分布的Numpy向量化

多元正态分布的Numpy向量化
EN

Stack Overflow用户
提问于 2018-02-08 21:42:38
回答 1查看 1.7K关注 0票数 3

我有两个二维numpy数组A,B。我想使用scipy.stats.multivariate_normal来计算A中每行的联合logpdf,使用B中的每一行作为协方差矩阵。有没有什么方法可以在不显式循环行的情况下做到这一点?将scipy.stats.multivariate_normal直接应用于A和B确实会计算A中每一行的logpdf (这是我想要的),但使用整个2D数组A作为协方差矩阵,这不是我想要的(我需要B的每一行创建一个不同的协方差矩阵)。我正在寻找一种使用numpy向量化并避免在两个数组上显式循环的解决方案。

EN

回答 1

Stack Overflow用户

发布于 2018-06-12 11:57:39

我也在尝试完成类似的事情。下面是我的代码,它接受三个NxD矩阵。X的每一行是一个数据点,means的每一行是一个均值向量,covariances的每一行是一个对角协方差矩阵的对角向量。结果是对数概率的长度为N的向量。

代码语言:javascript
运行
复制
def vectorized_gaussian_logpdf(X, means, covariances):
    """
    Compute log N(x_i; mu_i, sigma_i) for each x_i, mu_i, sigma_i
    Args:
        X : shape (n, d)
            Data points
        means : shape (n, d)
            Mean vectors
        covariances : shape (n, d)
            Diagonal covariance matrices
    Returns:
        logpdfs : shape (n,)
            Log probabilities
    """
    _, d = X.shape
    constant = d * np.log(2 * np.pi)
    log_determinants = np.log(np.prod(covariances, axis=1))
    deviations = X - means
    inverses = 1 / covariances
    return -0.5 * (constant + log_determinants +
        np.sum(deviations * inverses * deviations, axis=1))

请注意,此代码仅适用于对角线协方差矩阵。在这种特殊情况下,下面的数学定义被简化了:行列式变成元素的乘积,逆变成元素的倒数,矩阵乘法变成元素的乘法。

快速测试正确性和运行时间:

代码语言:javascript
运行
复制
def test_vectorized_gaussian_logpdf():
    n = 128**2
    d = 64

    means = np.random.uniform(-1, 1, (n, d))
    covariances = np.random.uniform(0, 2, (n, d))
    X = np.random.uniform(-1, 1, (n, d))

    refs = []

    ref_start = time.time()
    for x, mean, covariance in zip(X, means, covariances):
        refs.append(scipy.stats.multivariate_normal.logpdf(x, mean, covariance))
    ref_time = time.time() - ref_start

    fast_start = time.time()
    results = vectorized_gaussian_logpdf(X, means, covariances)
    fast_time = time.time() - fast_start

    print("Reference time:", ref_time)
    print("Vectorized time:", fast_time)
    print("Speedup:", ref_time / fast_time)

    assert np.allclose(results, refs)

我得到了250倍的加速。(是的,我的应用程序要求我计算16384种不同的高斯分布。)

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48686934

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档