Files
ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Predict/Model_Pred.py
2025-10-31 14:40:46 +08:00

72 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import argparse
import json
import torch
import sys
import os
# 获取当前脚本Data_Load.py所在的目录
current_script_dir = os.path.dirname(__file__) # 结果:/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler
# 从当前目录回退 2 级,得到项目根目录 ModelTrainingPython
root_path = os.path.abspath(os.path.join(current_script_dir, "..", ".."))
# 将根目录添加到 Python 搜索路径
sys.path.append(root_path)
from FC_ML_Data.FC_ML_Data_Process.Data_Process_Normalization import Normalizer
from FC_ML_NN_Model.Poly_Model import PolyModel
from FC_ML_Tool.Serialization import parse_json_file
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='代理模型训练参数输入')
parser.add_argument('--param', default='D:\liyong\project\ModelTrainingPython\FC_ML_Baseline\FC_ML_Baseline_Test\pred\param.json',
help='配置参数文件绝对路径')
args = parser.parse_args()
params = parse_json_file(args.param)
print(params)
source_dir = params["path"] + "/"
model_file = source_dir + params["modelFile"]
inputs = []
names = params["output"]["names"]
#获取输入特征
for input_value in params["input"]:
inputs.append(input_value["value"])
# names.append(input_value["name"])
#记载模型进行预测
input_size = params["modelParams"]["inputSize"]
output_size = params["modelParams"]["outputSize"]
model_path = params["path"] + "/" + params["modelFile"]
device = torch.device('cpu')
model = PolyModel(input_size,output_size).to(device)
model.load_state_dict(torch.load(model_file))
model.eval()
#加载数据处理器
normalization_type = params["modelParams"]["normalizerType"]
normalization_max = params["modelParams"]["normalizerMax"]
normalization_min = params["modelParams"]["normalizerMin"]
normalizer = Normalizer(method=normalization_type)
normalizer.load_params(normalization_type,normalization_min[0:input_size],normalization_max[0:input_size])
input_data = normalizer.transform(torch.tensor(inputs))
#执行模型预测
with torch.no_grad():
output_data = model(input_data)
# print(f"Prediction result: {output_data.item().tolist():.4f}")
normalizer.load_params(normalization_type, normalization_min[-output_size:], normalization_max[-output_size:])
output_data_ori = normalizer.inverse_transform(output_data)
# print(f"Prediction real result: {output_data_ori.item().tolist():.4f}")
#输出预测结果到文件中
output_datas = output_data_ori.tolist()
json_str = {}
if len(output_datas) == len(names):
for i in range(len(names)):
json_str[names[i]] = output_datas[i]
with open(source_dir + "forecast.json","w") as f:
f.write(json.dumps(json_str, indent=None, ensure_ascii=False))