首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用占位符从MultivariateNormalDiag采样

使用占位符从MultivariateNormalDiag采样
EN

Stack Overflow用户
提问于 2020-02-08 08:59:21
回答 1查看 178关注 0票数 0

我运行的是tensorflow 2.1和tensorflow_probability 0.9。我想用占位符参数化多变量正态分布,并从中抽取样本。以下是我尝试过的方法

代码语言:javascript
复制
import tensorflow as tf
import tensorflow_probability as tfp

@tf.function()
def sample_vae(dist):
    return dist.sample()

vae_mu = tf.keras.layers.Input(shape=(5), dtype=tf.float16)
vae_logvar = tf.keras.layers.Input(shape=(5), dtype=tf.float16)
dist = tfp.distributions.MultivariateNormalDiag(loc=vae_mu, scale_diag=tf.exp(vae_logvar))
z = sample_vae(dist)

上面的代码给出了以下错误

代码语言:javascript
复制
TypeError                                 Traceback (most recent call last)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     60                                                op_name, inputs, attrs,
---> 61                                                num_outputs)
     62   except core._NotOkStatusException as e:

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
For example, the following function will fail:
  @tf.function
  def has_init_scope():
    my_constant = tf.constant(1.)
    with tf.init_scope():
      added = my_constant * 2
The graph tensor has name: BroadcastArgs_3:0


During handling of the above exception, another exception occurred:

_SymbolicException                        Traceback (most recent call last)

6 frames

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     73       raise core._SymbolicException(
     74           "Inputs to eager execution function cannot be Keras symbolic "
---> 75           "tensors, but found {}".format(keras_symbolic_tensors))
     76     raise e
     77   # pylint: enable=protected-access

_SymbolicException: Inputs to eager execution function cannot be Keras symbolic tensors, but found [<tf.Tensor 'BroadcastArgs_3:0' shape=(1,) dtype=int32>, <tf.Tensor 'Exp_3:0' shape=(None, 5) dtype=float16>, <tf.Tensor 'input_7:0' shape=(None, 5) dtype=float16>]
EN

回答 1

Stack Overflow用户

发布于 2020-02-08 11:26:15

我可以使用DistributionLambda解决这个问题

代码语言:javascript
复制
vae_mu = tf.keras.layers.Input(shape=(1, 5), dtype=tf.float16)
vae_logvar = tf.keras.layers.Input(shape=(1, 5), dtype=tf.float16)
mu_logvar = tf.concat([vae_mu, vae_logvar], axis=1)
l = vae_mu.shape[1]
dist = tfp.layers.DistributionLambda(lambda theta: tfp.distributions.MultivariateNormalDiag(loc=theta[:, :l], scale_diag=theta[:, l:]))
z = dist(mu_logvar)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60123001

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档