我的样本有组id。有没有办法将组id传递给我的keras模型,并根据组id计算损失?
发布于 2021-03-19 15:48:51
您可以将编组标签作为虚拟列插入到目标向量y中,并将其从损失中排除。下面的代码计算每组的均方误差并返回最大值。
def worst_case_group_mse(y_true, y_pred):
"""calculate mean squared error for each group separately and return worst value
Args:
y_true, y_pred (tf.Tensor):
last column corresponds to group index,
mean squared error calculated over all other columns
Returns:
tf.Tensor: maximum grouped mean squared error
"""
groups = tf.cast(y_true[:,-1], tf.int32)
y_true, y_pred = y_true[:,:-1], y_pred[:,:-1]
square = tf.math.square(y_pred - y_true)
unique, idx, count = tf.unique_with_counts(groups)
group_losses = tf.math.unsorted_segment_mean(square, idx, tf.size(unique))
group_losses = tf.math.reduce_mean(group_losses, axis=1)
return tf.math.reduce_max(group_losses)
https://stackoverflow.com/questions/49405228
复制