修复预测脚本和训练脚本的执行bug

This commit is contained in:
2025-10-21 19:49:21 +08:00
parent b9cce1d733
commit 4fb2da1366
10 changed files with 87 additions and 45 deletions

View File

@@ -2,7 +2,6 @@ import argparse
import json
import torch
from openpyxl.styles.builtins import output
from FC_ML_Data.FC_ML_Data_Process.Data_Process_Normalization import Normalizer
from FC_ML_NN_Model.Poly_Model import PolyModel
@@ -10,18 +9,19 @@ from FC_ML_Tool.Serialization import parse_json_file
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='代理模型训练参数输入')
parser.add_argument('--param', default='D:\liyong\project\TVS_ML\FC_ML_Baseline\FC_ML_Baseline_Test\Train\param.json',
parser.add_argument('--param', default='D:\liyong\project\ModelTrainingPython\FC_ML_Baseline\FC_ML_Baseline_Test\pred\param.json',
help='配置参数文件绝对路径')
args = parser.parse_args()
params = parse_json_file(args.param)
print(params)
source_dir = params["path"] + "/"
model_file = source_dir + params["modelFile"]
inputs = []
names = []
names = params["output"]["names"]
#获取输入特征
for input_value in params["input"]:
inputs.append(input_value["value"])
names.append(input_value["name"])
# names.append(input_value["name"])
#记载模型进行预测
input_size = params["modelParams"]["inputSize"]
output_size = params["modelParams"]["outputSize"]
@@ -36,21 +36,23 @@ if __name__ == "__main__":
normalization_max = params["modelParams"]["normalizerMax"]
normalization_min = params["modelParams"]["normalizerMin"]
normalizer = Normalizer(method=normalization_type)
normalizer.load_params(normalization_type,normalization_min,normalization_max)
normalizer.load_params(normalization_type,normalization_min[0:input_size],normalization_max[0:input_size])
input_data = normalizer.transform(torch.tensor(inputs))
#执行模型预测
with torch.no_grad():
output_data = model(input_data)
print(f"Prediction result: {output_data.item():.4f}")
# print(f"Prediction result: {output_data.item().tolist():.4f}")
normalizer.load_params(normalization_type, normalization_min[-output_size:], normalization_max[-output_size:])
output_data_ori = normalizer.inverse_transform(output_data)
print(f"Prediction real result: {output_data_ori.item():.4f}")
# print(f"Prediction real result: {output_data_ori.item().tolist():.4f}")
#输出预测结果到文件中
output_datas = output_data_ori.tolist()
json_str = {}
if len(output_datas) == len(names):
for i in range(len(names)):
json_str[names[i]] = output_datas[i]
with open(source_dir + "forecast.json", ) as f:
with open(source_dir + "forecast.json","w") as f:
f.write(json.dumps(json_str, indent=None, ensure_ascii=False))