切换git用户重新进行项目首次归档
This commit is contained in:
43
FC_ML_Loss_Function/Loss_Function_Selector.py
Normal file
43
FC_ML_Loss_Function/Loss_Function_Selector.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LossFunctionSelector:
|
||||
def __init__(self):
|
||||
self.available_losses = {
|
||||
'mse': '均方误差',
|
||||
'l1': '平均绝对误差',
|
||||
'cross_entropy': '交叉熵',
|
||||
'bce': '二分类交叉熵',
|
||||
'smooth_l1': '平滑L1',
|
||||
'kl_div': 'KL散度',
|
||||
'hinge': '合页损失',
|
||||
'triplet': '三元组损失'
|
||||
}
|
||||
|
||||
def get_loss(self, name, **kwargs):
|
||||
"""获取配置好的损失函数实例"""
|
||||
if name == 'mse':
|
||||
return nn.MSELoss(**kwargs)
|
||||
elif name == 'l1':
|
||||
return nn.L1Loss(**kwargs)
|
||||
elif name == 'cross_entropy':
|
||||
return nn.CrossEntropyLoss(**kwargs)
|
||||
elif name == 'bce':
|
||||
return nn.BCELoss(**kwargs)
|
||||
elif name == 'smooth_l1':
|
||||
return nn.SmoothL1Loss(**kwargs)
|
||||
elif name == 'kl_div':
|
||||
return nn.KLDivLoss(**kwargs)
|
||||
elif name == 'hinge':
|
||||
return nn.HingeEmbeddingLoss(**kwargs)
|
||||
elif name == 'triplet':
|
||||
return nn.TripletMarginLoss(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"不支持的损失函数类型,可选: {list(self.available_losses.keys())}")
|
||||
|
||||
def print_available(self):
|
||||
"""打印支持的损失函数列表"""
|
||||
print("可用损失函数:")
|
||||
for k, v in self.available_losses.items():
|
||||
print(f"{k.ljust(15)} -> {v}")
|
||||
Reference in New Issue
Block a user