训练模型

This commit is contained in:
2025-10-22 18:06:32 +08:00
parent 87ef6e5c90
commit cfa09a1178
7 changed files with 481 additions and 95 deletions

View File

@@ -1,10 +1,7 @@
package com.sdm.data.controller;
import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import com.sdm.data.model.req.*;
import com.sdm.data.service.IModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
@@ -28,7 +25,7 @@ public class ModelTraningController {
@PostMapping("/addModel")
@Operation(summary = "新增训练模型", description = "新增训练模型")
public SdmResponse addModel(@RequestBody AddModelReq addModelReq) {
return modelService.addModel(addModelReq);
return modelService.addModel(addModelReq);
}
/**
@@ -46,10 +43,28 @@ public class ModelTraningController {
*/
@GetMapping("/getHandleLoadDataResult")
@Operation(summary = "获取python脚本训练数据处理结果", description = "获取python脚本训练数据处理结果")
public SdmResponse getHandleLoadDataResult(@RequestParam String modelId) {
public SdmResponse getHandleLoadDataResult(@RequestParam Integer modelId) {
return modelService.getHandleLoadDataResult(modelId);
}
/**
* 设置训练数据输入输出特征列
*/
@PostMapping("/setTrainingDataInPutOutPutColumn")
@Operation(summary = "设置训练数据输入输出特征列", description = "设置训练数据输入输出特征列")
public SdmResponse setTrainingDataInPutOutPutColumn(@RequestBody SetTrainingDataInPutOutPutColumnReq req) {
return modelService.setTrainingDataInPutOutPutColumn(req);
}
/**
* 获取训练数据输入输出特征设置详情
*/
@GetMapping("/getTrainingDataInPutOutPutColumn")
@Operation(summary = "获取训练数据输入输出特征设置详情", description = "获取训练数据输入输出特征设置详情")
public SdmResponse getTrainingDataInPutOutPutColumn(@RequestParam Integer modelId) {
return modelService.getTrainingDataInPutOutPutColumn(modelId);
}
/**
* 算法参数设置并提交训练
*/
@@ -61,20 +76,32 @@ public class ModelTraningController {
/**
* 获取训练曲线和训练日志
*
* @param modelId
*/
@GetMapping("/getTrainingResult")
@Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志")
public SdmResponse getTrainingResult(@RequestParam String modelId){
public SdmResponse getTrainingResult(@RequestParam Integer modelId) {
return modelService.getTrainingResult(modelId);
}
/**
* 数据预测
* 开始数据预测
*/
@GetMapping("/predict")
@PostMapping("/startPredict")
@Operation(summary = "数据预测", description = "数据预测")
public SdmResponse predict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) throws IOException {
return modelService.predict(modelPredictReq);
public SdmResponse startPredict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) {
return modelService.startPredict(modelPredictReq);
}
/**
* 进入模型预测页面,获取历史模型预测结果
*/
@GetMapping("/getModelPredictResult")
@Operation(summary = "进入模型预测页面,获取历史模型预测结果", description = "进入模型预测页面,获取历史模型预测结果")
public SdmResponse getModelPredictResult(@RequestParam Integer modelId) {
return modelService.getModelPredictResult(modelId);
}
}

View File

@@ -73,6 +73,39 @@ public class TrainingModel implements Serializable {
@TableField("trainingDataHandleFileId")
private Integer trainingDataHandleFileId;
@ApiModelProperty(value = " 归一化最大值json格式")
@TableField("normalizerMax")
private String normalizerMax;
@ApiModelProperty(value = " 归一化最小值json格式")
@TableField("normalizerMin")
private String normalizerMin;
@ApiModelProperty(value = "特征输入列数量")
@TableField("inputSize")
Integer inputSize;
@ApiModelProperty(value = "输入列的json格式")
@TableField("inputLabel")
private String inputLabel;
@ApiModelProperty(value = "输入的列-值的json格式")
@TableField("inputPredLabelValue")
private String inputPredLabelValue;
@ApiModelProperty(value = " 特征输出列数量")
@TableField("outputSize")
Integer outputSize;
@ApiModelProperty(value = "输出列的JSON格式")
@TableField("outputLabel")
private String outputLabel;
@ApiModelProperty(value = "输出的列-值的json格式")
@TableField("outputPredLabelValue")
private String outputPredLabelValue;
@ApiModelProperty(value = "training脚本处理数据结果文件id")
@TableField("trainingDataResultFileId")
private Integer trainingDataResultFileId;
@@ -81,6 +114,10 @@ public class TrainingModel implements Serializable {
@TableField("trainingDataLogFileId")
private Integer trainingDataLogFileId;
@ApiModelProperty(value = "training脚本处理后模型文件id")
@TableField("trainingDataExportModelFileId")
private Integer trainingDataExportModelFileId;
@ApiModelProperty(value = "使用算法")
@TableField("algorithmUsed")
private String algorithmUsed;
@@ -88,6 +125,4 @@ public class TrainingModel implements Serializable {
@ApiModelProperty(value = "说明")
@TableField("description")
private String description;
}

View File

@@ -13,12 +13,6 @@ public class AlgorithmParamReq {
@Schema(description = "模型id")
private Integer modelId;
@Schema(description = "输入大小", example = "10")
Integer inputSize;
@Schema(description = "输出大小", example = "1")
Integer outputSize;
@Schema(description = "算法类型", example = "多项式拟合")
private String algorithm;

View File

@@ -6,8 +6,14 @@ import lombok.Data;
@Data
public class ModelPredictReq {
@Schema(description = "选择的预测模型ID")
String modelId;
Integer modelId;
@Schema(description = "json数据")
@Schema(description = "json数据: [{\n" +
" \"name\": \"param1\",\n" +
" \"value\": 0.1\n" +
" }, {\n" +
" \"name\": \"param1\",\n" +
" \"value\": 371.6669936\n" +
" }]")
String jsonData;
}

View File

@@ -0,0 +1,19 @@
package com.sdm.data.model.req;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.util.List;
@Data
public class SetTrainingDataInPutOutPutColumnReq {
@Schema(description = "模型ID")
Integer modelId;
@Schema(description = "输入特征列")
List<String> inputColumns;
@Schema(description = "输出特征列")
List<String> outputColumns;
}

View File

@@ -1,12 +1,7 @@
package com.sdm.data.service;
import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import java.io.IOException;
import com.sdm.data.model.req.*;
/**
* 模型训练服务接口
@@ -25,7 +20,17 @@ public interface IModelService {
/**
* 获取python脚本训练数据处理结果
*/
SdmResponse getHandleLoadDataResult(String modelId);
SdmResponse getHandleLoadDataResult(Integer modelId);
/**
* 设置训练数据输入输出特征列
*/
SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req);
/**
* 获取训练数据输入输出特征设置详情
*/
SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId);
/**
* 模型参数设置并提交训练
@@ -38,12 +43,19 @@ public interface IModelService {
* @param modelId 模型ID
* @return 模型训练结果
*/
SdmResponse getTrainingResult(String modelId);
SdmResponse getTrainingResult(Integer modelId);
/**
* 数据预测
* @param modelPredictReq 模型预测请求对象
* @return 模型预测结果
*/
SdmResponse predict(ModelPredictReq modelPredictReq);
SdmResponse startPredict(ModelPredictReq modelPredictReq);
/**
* 获取模型预测结果
* @param modelId
* @return
*/
SdmResponse getModelPredictResult(Integer modelId);
}

View File

@@ -6,15 +6,11 @@ import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.entity.FileMetadataInfo;
import com.sdm.data.model.entity.TrainingModel;
import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import com.sdm.data.model.req.*;
import com.sdm.data.service.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.mock.web.MockMultipartFile;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
@@ -26,45 +22,84 @@ import java.io.OutputStream;
import java.nio.file.Files;
import java.util.Date;
import java.util.List;
import java.util.Random;
@Service
@Slf4j
public class ModelServiceImpl implements IModelService {
/**
* 数据处理文件存放路径
* 训练模型基础文件夹路径(脚本、训练数据、处理结果、训练模型)
*/
private static final String TANING_MODEL_DATA_FILE_PATH = "/home/app/model/dataHandler/";
private static final String TANING_MODEL_BASE_DIR_PATH = "/home/app/model/";
/**
* 模型id前缀
*/
private static final String MODEL_ID = "/ModelId_";
/**
* 数据处理脚本存放路径
*/
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py";
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py";
/**
* 数据处理结果文件名
*/
private static final String SOURCE_JSON_FILE_NAME = "source.json";
/**
* 训练结果文件存放路径
*/
private static final String TRAIN_RESULT_FILE_PATH = "/home/app/model/trainResult/";
/**
* 训练脚本存放路径
*/
private static final String TRAINING_PYTHON_SCRIPTPATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Train_Proxy_Model.py";
private static final String TRAINING_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Train/Train_Proxy_Model.py";
/**
* 训练结果文件名
* 预测脚本存放路径
*/
private static final String PRED_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Predict/Model_Pred.py";
/**
* 数据处理文件夹放路径
*/
private static final String DATA_HANDLE_DIR_PATH = "/dataHandler/";
/**
* 训练结果文件夹路径
*/
private static final String TRAIN_RESULT_DIR_PATH = "/trainResult/";
/**
* 预测结果文件夹路径
*/
private static final String PRED_RESULT_DIR_PATH = "/predResult/";
/**
* dataHandler 数据处理结果文件名
*/
private static final String SOURCE_JSON_FILE_NAME = "source.json";
/**
* trainResult 训练结果文件名
*/
private static final String TRAINING_JSON_FILE_NAME = "training.json";
/**
* 训练日志文件
* trainResult 训练日志文件
*/
private static final String TRAINING_LOG_FILE_NAME = "training.log";
/**
* trainResult 训练模型文件名前缀
*/
private static final String TRAINING_MODEL_PERFIX_NAME = "model.";
/**
* 预测结果文件名
*/
private static final String PRED_RESULT_FILE_NAME = "forecast.json";
/**
* 脚本 入参文件名
*/
private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = "param.json";
@Autowired
IDataFileService dataFileService;
@@ -115,11 +150,12 @@ public class ModelServiceImpl implements IModelService {
log.info("训练模型关联训练数据文件成功模型ID: {}", trainingModelId);
// 直接将MultipartFile 训练数据文件 写入Linux服务器
String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), TANING_MODEL_DATA_FILE_PATH);
// 直接将MultipartFile 训练数据文件 写入Linux服务器的对应模型目录下:/home/app/model/ModelId_1/dataHandler/,并返回文件路径
String handleLoadModelDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH;
String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), handleLoadModelDirPath);
// 根据生成的文件路径创建python脚本 的param.json参数文件
String paramJsonPath = createParamJsonFile(taningModelDataFilePath);
String paramJsonPath = createDataHadleParamJsonFile(taningModelDataFilePath);
// 更新训练模型 训练状态
@@ -136,7 +172,7 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse getHandleLoadDataResult(String modelId) {
public SdmResponse getHandleLoadDataResult(Integer modelId) {
// 根据模型ID查询处理结果再根据结果获取文件ID再从MinIO中下载json文件处理json文件返回一个对象
TrainingModel trainingModel = trainingModelService.getById(modelId);
if (trainingModel == null) {
@@ -202,13 +238,69 @@ public class ModelServiceImpl implements IModelService {
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", trainingModelId, duration);
// 上传处理完的python数据到minio
String resultFilePath = TANING_MODEL_DATA_FILE_PATH + SOURCE_JSON_FILE_NAME;
String resultFilePath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH + SOURCE_JSON_FILE_NAME;
Integer fileId = uploadResultFileToMinio(resultFilePath);
InputStream minioInputStream = dataFileService.getMinioInputStream(fileId);
// 直接在内存中读取并解析JSON
BufferedReader reader = new BufferedReader(new InputStreamReader(minioInputStream));
StringBuilder jsonString = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
jsonString.append(line);
}
// 将JSON字符串转换为JSONObject或其他Java对象获取表头总数量
JSONObject jsonObject = JSONObject.parseObject(jsonString.toString());
JSONArray averageDataArray = jsonObject.getJSONArray("average_data");
// 循环处理averageDataArray
JSONArray maxValues = null;
JSONArray minValues = null;
for (int i = 0; i < averageDataArray.size(); i++) {
JSONObject averageData = averageDataArray.getJSONObject(i);
if (averageData.getString("property").equals("最大值")) {
// 提取最大值数组排除property字段
maxValues = new JSONArray();
// 获取所有字段名
java.util.Set<String> keySet = averageData.keySet();
for (String key : keySet) {
if (!"property".equals(key)) {
maxValues.add(averageData.getDoubleValue(key));
}
}
} else if (averageData.getString("property").equals("最小值")) {
// 提取最小值数组排除property字段
minValues = new JSONArray();
// 获取所有字段名
java.util.Set<String> keySet = averageData.keySet();
for (String key : keySet) {
if (!"property".equals(key)) {
minValues.add(averageData.getDoubleValue(key));
}
}
}
}
// 将数组转换为JSON字符串
String normalizerMaxJson = null;
String normalizerMinJson = null;
if (maxValues != null) {
normalizerMaxJson = maxValues.toJSONString();
}
if (minValues != null) {
normalizerMinJson = minValues.toJSONString();
}
// 更新训练模型 训练状态
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId)
.set(TrainingModel::getHandleStatus, "成功")
.set(TrainingModel::getTrainingDataHandleFileId, fileId)
.set(TrainingModel::getNormalizerMax, normalizerMaxJson)
.set(TrainingModel::getNormalizerMin, normalizerMinJson)
.update();
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
@@ -300,12 +392,12 @@ public class ModelServiceImpl implements IModelService {
}
/**
* 根据数据文件路径创建param.json配置文件
* 根据数据文件路径创建数据处理的param.json配置文件
*
* @param dataFilePath 数据文件的完整路径
* @return param.json的完整路径
*/
private String createParamJsonFile(String dataFilePath) {
private String createDataHadleParamJsonFile(String dataFilePath) {
try {
// 提取文件名和目录路径
File dataFile = new File(dataFilePath);
@@ -313,7 +405,7 @@ public class ModelServiceImpl implements IModelService {
String directoryPath = dataFile.getParent();
// 构造param.json文件路径
String paramJsonPath = directoryPath + "/param.json";
String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
// 创建JSON对象
JSONObject paramJson = new JSONObject();
@@ -339,9 +431,10 @@ public class ModelServiceImpl implements IModelService {
/**
* 下载文件到服务器本地
* @param fileId 文件ID
*
* @param fileId 文件ID
* @param localDirPath 本地文件夹路径
* @return 本地文件的完整路径
* @return 本地文件的完整路径 localDirPath / fileName
*/
private String downloadToLocalFromMinIO(Integer fileId, String localDirPath) {
InputStream in = null;
@@ -360,13 +453,13 @@ public class ModelServiceImpl implements IModelService {
}
String localFilePath = localDirPath + "/" + fileMetadataInfo.getOriginalName();
// 确保目标目录存在
File localFile = new File(localDirPath);
File localFile = new File(localFilePath);
File parentDir = localFile.getParentFile();
if (parentDir != null && !parentDir.exists()) {
parentDir.mkdirs();
}
out = new FileOutputStream(localDirPath);
out = new FileOutputStream(localFilePath);
// 将输入流写入本地文件
byte[] buffer = new byte[1024];
int len;
@@ -442,6 +535,37 @@ public class ModelServiceImpl implements IModelService {
}
}
@Override
public SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req) {
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, req.getModelId())
.set(TrainingModel::getInputSize, req.getInputColumns().size())
.set(TrainingModel::getInputLabel, JSONArray.toJSONString(req.getInputColumns()))
.set(TrainingModel::getOutputSize, req.getOutputColumns().size())
.set(TrainingModel::getOutputLabel, JSONArray.toJSONString(req.getOutputColumns()))
.update();
return SdmResponse.success();
}
@Override
public SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId) {
TrainingModel trainingModel = trainingModelService.getById(modelId);
// 从数据库获取inputLabel字段值JSON字符串
String inputLabelJson = trainingModel.getInputLabel();
// 解析JSON字符串为列表
List<String> inputLabels = JSONArray.parseArray(inputLabelJson, String.class);
// 从数据库获取outputLable字段值JSON字符串
String outputLabelJson = trainingModel.getOutputLabel();
// 解析JSON字符串为列表
List<String> outputLabels = JSONArray.parseArray(outputLabelJson, String.class);
JSONObject result = new JSONObject();
result.put("inputLabels", inputLabels);
result.put("outputLabels", outputLabels);
return SdmResponse.success(result);
}
@Override
public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) {
// 保存或更新算法参数
@@ -469,7 +593,7 @@ public class ModelServiceImpl implements IModelService {
String paramJsonPath = createTraningParamJsonFile(algorithmParamReq);
// 异步调用python脚本进行训练并上传结果文件到minio
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId());
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId(), algorithmParamReq.getExportFormat());
return SdmResponse.success("提交训练成功");
}
@@ -480,7 +604,7 @@ public class ModelServiceImpl implements IModelService {
* @param paramJsonPath 参数文件路径
* @param modelId 模型ID
*/
private void trainModelAsync(String paramJsonPath, Integer modelId) {
private void trainModelAsync(String paramJsonPath, Integer modelId, String exportFormat) {
new Thread(() -> {
try {
long startTime = System.currentTimeMillis();
@@ -491,10 +615,13 @@ public class ModelServiceImpl implements IModelService {
long duration = endTime - startTime;
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
// 将训练生成的生成training.json、training.log文件上传到minio
Integer trainingDataResultFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_JSON_FILE_NAME);
Integer trainingDataLogFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_LOG_FILE_NAME);
log.info("模型训练完成,结果文件: training.json、training.log文件上传成功文件ID: {},{}", trainingDataResultFileId, trainingDataLogFileId);
// 将训练生成的生成 training.json、training.log、模型文件上传到minio
Integer trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
Integer trainingDataLogFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME);
Integer trainingDataExportModelFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat);
log.info("模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功", trainingDataResultFileId, trainingDataLogFileId, trainingDataExportModelFileId);
// 更新模型状态为训练完成
trainingModelService.lambdaUpdate()
@@ -504,8 +631,9 @@ public class ModelServiceImpl implements IModelService {
.set(TrainingModel::getTrainingDuration, duration)
.set(TrainingModel::getTrainingDataResultFileId, trainingDataResultFileId)
.set(TrainingModel::getTrainingDataLogFileId, trainingDataLogFileId)
.set(TrainingModel::getTrainingDataExportModelFileId, trainingDataExportModelFileId)
.update();
log.info("模型训练完成模型ID: {}", modelId);
} catch (Exception e) {
log.error("模型训练失败模型ID: {}", modelId, e);
@@ -520,15 +648,16 @@ public class ModelServiceImpl implements IModelService {
/**
* 创建训练参数param.json文件
*
* @param algorithmParamReq 算法参数请求
* @return
*/
private String createTraningParamJsonFile(AlgorithmParamReq algorithmParamReq) {
// 根据训练模型id 获取训练模型的训练数据文件id, 并下载到本地 trainingDataFilePath = /home/app/model/trainResult/sample.CSV
Integer modelId = algorithmParamReq.getModelId();
TrainingModel trainingModel= trainingModelService.getById(modelId);
TrainingModel trainingModel = trainingModelService.getById(modelId);
Integer trainingDataFileId = trainingModel.getTrainingDataFileId();
String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TRAIN_RESULT_FILE_PATH);
String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH);
try {
// 提取文件名和目录路径
@@ -537,7 +666,7 @@ public class ModelServiceImpl implements IModelService {
String directoryPath = dataFile.getParent();
// 构造param.json文件路径
String paramJsonPath = directoryPath + "/param.json";
String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
// 创建JSON对象
JSONObject paramJson = new JSONObject();
@@ -549,8 +678,8 @@ public class ModelServiceImpl implements IModelService {
// 构造algorithmParam对象
JSONObject algorithmParam = new JSONObject();
algorithmParam.put("inputSize", algorithmParamReq.getInputSize());
algorithmParam.put("outputSize", algorithmParamReq.getOutputSize());
algorithmParam.put("inputSize", trainingModel.getInputSize());
algorithmParam.put("outputSize", trainingModel.getOutputSize());
algorithmParam.put("algorithm", algorithmParamReq.getAlgorithm());
algorithmParam.put("activateFun", algorithmParamReq.getActivateFun());
algorithmParam.put("lossFun", algorithmParamReq.getLossFun());
@@ -583,47 +712,58 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse getTrainingResult(String modelId) {
public SdmResponse getTrainingResult(Integer modelId) {
// 根据modelId 获取训练模型获取训练模型结果文件id日志文件id
TrainingModel trainingModel = trainingModelService.getById(modelId);
if (trainingModel == null) {
return SdmResponse.failed("模型不存在");
}
if(trainingModel.getTrainingStatus().equals("失败")){
if (trainingModel.getTrainingStatus().equals("失败")) {
return SdmResponse.failed("模型训练失败,请查看日志详情");
}
if(trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")){
if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) {
return SdmResponse.failed("模型正在训练中,请稍后再查询");
}
Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId();
Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId();
if (trainingDataResultFileId == null || trainingDataLogFileId == null) {
return SdmResponse.failed("训练结果文件日志文件不存在");
if (trainingDataResultFileId == null && trainingDataLogFileId == null) {
return SdmResponse.failed("训练结果文件日志文件不存在");
}
// 从minio获取结果json文件转换成对象
InputStream resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
// 从minio获取日志文件转换成对象
InputStream logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
InputStream resultInputStream = null;
InputStream logInputStream = null;
if (resultInputStream == null || logInputStream == null) {
// 从minio获取结果json文件转换成对象
if (trainingDataResultFileId != null) {
resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
}
// 从minio获取日志文件转换成对象
if (trainingDataLogFileId != null) {
logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
}
if (resultInputStream == null && logInputStream == null) {
return SdmResponse.failed("无法从MinIO获取文件流");
}
try {
// 处理结果JSON文件
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
// 处理日志文件
String logContent = parseStringFromStream(logInputStream);
// 构造返回对象
JSONObject response = new JSONObject();
response.put("result", resultJsonObject);
response.put("log", logContent);
// 处理结果JSON文件如果存在
if (resultInputStream != null) {
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
response.put("result", resultJsonObject);
}
// 处理日志文件(如果存在)
if (logInputStream != null) {
String logContent = parseStringFromStream(logInputStream);
response.put("log", logContent);
}
return SdmResponse.success(response);
} catch (Exception e) {
@@ -641,6 +781,7 @@ public class ModelServiceImpl implements IModelService {
/**
* 从输入流中解析JSON对象
*
* @param inputStream 输入流
* @return 解析后的JSON对象
* @throws IOException IO异常
@@ -657,6 +798,7 @@ public class ModelServiceImpl implements IModelService {
/**
* 从输入流中读取字符串内容
*
* @param inputStream 输入流
* @return 字符串内容
* @throws IOException IO异常
@@ -672,7 +814,158 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse predict(ModelPredictReq modelPredictReq) {
return null;
public SdmResponse startPredict(ModelPredictReq modelPredictReq) {
JSONObject predResultObject = new JSONObject();
Integer modelId = modelPredictReq.getModelId();
// 将jsonData存储到与modelId关联的模型表的inputLabel字段中
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelPredictReq.getModelId())
.set(TrainingModel::getInputPredLabelValue, modelPredictReq.getJsonData())
.update();
TrainingModel trainingModel = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one();
TrainingModelAlgorithmParam trainingModelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery().eq(TrainingModelAlgorithmParam::getModelId, modelId).one();
String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel.getId() + "/" + PRED_RESULT_DIR_PATH;
// 创建预测用的param.json文件
String predParamJsonFile = createPredParamJsonFile(modelPredictReq, trainingModel, trainingModelAlgorithmParam,predDirPath);
// 执行预测
try {
long startTime = System.currentTimeMillis();
log.info("开始执行Python预测脚本, 模型ID: {}", modelId);
// 调用python脚本进行预测
callPythonScript(PRED_PYTHON_SCRIPTPATH, predParamJsonFile);
long endTime = System.currentTimeMillis();
long duration = endTime - startTime;
log.info("Python预测脚本执行完成, 模型ID: {}, 耗时: {} ms", modelId, duration);
// 读取预测生成的结果文件并保存到数据库
String predResultPath = predDirPath + "/" + PRED_RESULT_FILE_NAME;
File predResultFile = new File(predResultPath);
if (predResultFile.exists()) {
try {
// 读取预测结果文件内容
String predResultJson = new String(Files.readAllBytes(predResultFile.toPath()));
predResultObject = JSONObject.parseObject(predResultJson);
// 将预测结果保存到数据库
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
.set(TrainingModel::getOutputPredLabelValue, predResultObject.toJSONString())
.update();
log.info("模型预测结果已保存到数据库模型ID: {}", modelId);
} catch (Exception e) {
log.error("解析或保存预测结果失败模型ID: {}", modelId, e);
}
} else {
log.warn("预测结果文件不存在: {}", predResultPath);
}
log.info("模型预测完成模型ID: {}", modelId);
} catch (Exception e) {
log.error("模型预测失败模型ID: {}", modelId, e);
}
return SdmResponse.success(predResultObject);
}
/**
* 创建预测参数param.json文件
*
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param trainingModelAlgorithmParam 训练模型算法参数
* @param predDirPath 预测目录路径
* @return param.json文件路径
*/
private String createPredParamJsonFile(ModelPredictReq modelPredictReq, TrainingModel trainingModel, TrainingModelAlgorithmParam trainingModelAlgorithmParam,String predDirPath) {
try {
// 确保目录存在
File predDir = new File(predDirPath);
if (!predDir.exists()) {
predDir.mkdirs();
}
// 构造param.json文件路径
String paramJsonPath = predDirPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
// 创建JSON对象
JSONObject paramJson = new JSONObject();
// 下载model模型文件到本地
downloadToLocalFromMinIO(trainingModel.getTrainingDataExportModelFileId(), predDirPath);
paramJson.put("modelFile", TRAINING_MODEL_PERFIX_NAME + trainingModelAlgorithmParam.getExportFormat());
// 设置路径
paramJson.put("path", predDirPath);
// 添加模型参数
JSONObject modelParams = new JSONObject();
modelParams.put("inputSize", trainingModel.getInputSize());
modelParams.put("outputSize", trainingModel.getOutputSize());
modelParams.put("normalizerType", trainingModelAlgorithmParam.getDisposeMethod());
// 添加归一化参数
if (trainingModel.getNormalizerMax() != null) {
JSONArray normalizerMaxArray = JSONArray.parseArray(trainingModel.getNormalizerMax());
modelParams.put("normalizerMax", normalizerMaxArray);
}
if (trainingModel.getNormalizerMin() != null) {
JSONArray normalizerMinArray = JSONArray.parseArray(trainingModel.getNormalizerMin());
modelParams.put("normalizerMin", normalizerMinArray);
}
paramJson.put("modelParams", modelParams);
// 添加输入数据
// 直接使用传入的JSON数组作为input值
JSONArray inputArray = JSONArray.parseArray(modelPredictReq.getJsonData());
paramJson.put("input", inputArray);
// 添加输出标签
JSONObject output = new JSONObject();
if (trainingModel.getOutputLabel() != null) {
JSONArray outputLabels = JSONArray.parseArray(trainingModel.getOutputLabel());
output.put("names", outputLabels);
}
paramJson.put("output", output);
// 写入param.json文件
try (FileWriter fileWriter = new FileWriter(paramJsonPath)) {
fileWriter.write(paramJson.toJSONString());
}
log.info("成功创建预测param.json文件: {}", paramJsonPath);
return paramJsonPath;
} catch (IOException e) {
log.error("创建预测param.json文件失败", e);
throw new RuntimeException("创建预测param.json文件失败: " + e.getMessage(), e);
}
}
@Override
public SdmResponse getModelPredictResult(Integer modelId) {
JSONObject result = new JSONObject();
TrainingModel model = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one();
if (model == null) {
return SdmResponse.failed("模型不存在");
}
String inputPredLabelValue = model.getInputPredLabelValue();
if (inputPredLabelValue != null) {
JSONObject inputPredLabelValueJson = JSONObject.parseObject(inputPredLabelValue);
result.put("inputPredLabelValue", inputPredLabelValueJson);
}
String outputPredLabelValue = model.getOutputPredLabelValue();
if (outputPredLabelValue != null) {
JSONObject outputPredLabelValueJson = JSONObject.parseObject(outputPredLabelValue);
result.put("inputPredLabelValue", outputPredLabelValueJson);
}
return SdmResponse.success(result);
}
}