在运行一个图像处理代码时候出现了上述问题,我的代码是
def print_metric(y_true,y_pred,weighted_error=False):
cz = pycm.ConfusionMatrix(actual_vector=y_true.argmax(y_true.logits, axis=1), predict_vector=y_pred.argmax(axis=1))
# Accuracy
acc = cz.Overall_ACC
print("Average Accuracy : "+str(acc*100)+'%')
# Specificity
specificity = cz.TNR
totalprecision = 0
for key, value in specificity.items():
totalprecision = totalprecision + value
print('Average Specificity : '+str(totalprecision*100/4.0)+'%')
报错内容
File "/root/autodl-tmp/OpticNet-71/src/metrics.py", line 30, in print_metric
cz = pycm.ConfusionMatrix(actual_vector=y_true.argmax(axis=1), predict_vector=y_pred.argmax(axis=1))
AttributeError: 'list' object has no attribute 'argmax'
相似问题