class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
扫码关注腾讯云开发者
领取腾讯云代金券
Copyright © 2013 - 2025 Tencent Cloud. All Rights Reserved. 腾讯云 版权所有
深圳市腾讯计算机系统有限公司 ICP备案/许可证号:粤B2-20090059 深公网安备号 44030502008569
腾讯云计算(北京)有限责任公司 京ICP证150476号 | 京ICP备11018762号 | 京公网安备号11010802020287
Copyright © 2013 - 2025 Tencent Cloud.
All Rights Reserved. 腾讯云 版权所有