我遇到了一个问题,不管我传递给TF.Metrics.Mean_Squared_Error的是什么标签/预测,它总是返回一个0值。
下面是重复问题的代码:
a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess)
%%返回0.0
发布于 2017-10-30 09:39:32
我真的不知道它为什么会这样工作,但是在考虑到您的数据之前,您实际上需要运行update
:
a = tf.constant([0,0,0,0])
b = tf.constant([1,1,1,1])
mse, update = tf.metrics.mean_squared_error(a,b)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
mse.eval(session=sess) # Gives 0.0, the initial MSE value
update.eval(session=sess) # Gives 1.0, the value of the update for the mse
mse.eval(session=sess) # Gives 1.0, which is 0.0+1.0, the updated mse value
例如,tf.metrics.mean_squared_error()
是用来计算整个数据集上的MSE的,所以如果您想要独立于批处理的结果,则不应该使用它。为此,例如使用tf.losses.mean_squared_error(a, b, loss_collection=None)
。
https://stackoverflow.com/questions/47020901
复制相似问题