Files
ModelTrainingPython/FC_ML_Loss_Function/Loss_Function_Selector.py

43 lines
1.4 KiB
Python

import torch
import torch.nn as nn
class LossFunctionSelector:
def __init__(self):
self.available_losses = {
'mse': '均方误差',
'l1': '平均绝对误差',
'cross_entropy': '交叉熵',
'bce': '二分类交叉熵',
'smooth_l1': '平滑L1',
'kl_div': 'KL散度',
'hinge': '合页损失',
'triplet': '三元组损失'
}
def get_loss(self, name, **kwargs):
"""获取配置好的损失函数实例"""
if name == 'mse':
return nn.MSELoss(**kwargs)
elif name == 'l1':
return nn.L1Loss(**kwargs)
elif name == 'cross_entropy':
return nn.CrossEntropyLoss(**kwargs)
elif name == 'bce':
return nn.BCELoss(**kwargs)
elif name == 'smooth_l1':
return nn.SmoothL1Loss(**kwargs)
elif name == 'kl_div':
return nn.KLDivLoss(**kwargs)
elif name == 'hinge':
return nn.HingeEmbeddingLoss(**kwargs)
elif name == 'triplet':
return nn.TripletMarginLoss(**kwargs)
else:
raise ValueError(f"不支持的损失函数类型,可选: {list(self.available_losses.keys())}")
def print_available(self):
"""打印支持的损失函数列表"""
print("可用损失函数:")
for k, v in self.available_losses.items():
print(f"{k.ljust(15)} -> {v}")