diff --git a/data/src/main/java/com/sdm/data/controller/ModelTraningController.java b/data/src/main/java/com/sdm/data/controller/ModelTraningController.java index 7d911a69..dcf87f58 100644 --- a/data/src/main/java/com/sdm/data/controller/ModelTraningController.java +++ b/data/src/main/java/com/sdm/data/controller/ModelTraningController.java @@ -1,10 +1,7 @@ package com.sdm.data.controller; import com.sdm.common.common.SdmResponse; -import com.sdm.data.model.req.AddModelReq; -import com.sdm.data.model.req.AlgorithmParamReq; -import com.sdm.data.model.req.HandleLoadDataReq; -import com.sdm.data.model.req.ModelPredictReq; +import com.sdm.data.model.req.*; import com.sdm.data.service.IModelService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.Parameter; @@ -28,7 +25,7 @@ public class ModelTraningController { @PostMapping("/addModel") @Operation(summary = "新增训练模型", description = "新增训练模型") public SdmResponse addModel(@RequestBody AddModelReq addModelReq) { - return modelService.addModel(addModelReq); + return modelService.addModel(addModelReq); } /** @@ -46,10 +43,28 @@ public class ModelTraningController { */ @GetMapping("/getHandleLoadDataResult") @Operation(summary = "获取python脚本训练数据处理结果", description = "获取python脚本训练数据处理结果") - public SdmResponse getHandleLoadDataResult(@RequestParam String modelId) { + public SdmResponse getHandleLoadDataResult(@RequestParam Integer modelId) { return modelService.getHandleLoadDataResult(modelId); } + /** + * 设置训练数据输入输出特征列 + */ + @PostMapping("/setTrainingDataInPutOutPutColumn") + @Operation(summary = "设置训练数据输入输出特征列", description = "设置训练数据输入输出特征列") + public SdmResponse setTrainingDataInPutOutPutColumn(@RequestBody SetTrainingDataInPutOutPutColumnReq req) { + return modelService.setTrainingDataInPutOutPutColumn(req); + } + + /** + * 获取训练数据输入输出特征设置详情 + */ + @GetMapping("/getTrainingDataInPutOutPutColumn") + @Operation(summary = "获取训练数据输入输出特征设置详情", description = "获取训练数据输入输出特征设置详情") + public SdmResponse getTrainingDataInPutOutPutColumn(@RequestParam Integer modelId) { + return modelService.getTrainingDataInPutOutPutColumn(modelId); + } + /** * 算法参数设置并提交训练 */ @@ -61,20 +76,32 @@ public class ModelTraningController { /** * 获取训练曲线和训练日志 + * * @param modelId */ @GetMapping("/getTrainingResult") @Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志") - public SdmResponse getTrainingResult(@RequestParam String modelId){ + public SdmResponse getTrainingResult(@RequestParam Integer modelId) { return modelService.getTrainingResult(modelId); } /** - * 数据预测 + * 开始数据预测 */ - @GetMapping("/predict") + @PostMapping("/startPredict") @Operation(summary = "数据预测", description = "数据预测") - public SdmResponse predict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) throws IOException { - return modelService.predict(modelPredictReq); + public SdmResponse startPredict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) { + return modelService.startPredict(modelPredictReq); } + + /** + * 进入模型预测页面,获取历史模型预测结果 + */ + @GetMapping("/getModelPredictResult") + @Operation(summary = "进入模型预测页面,获取历史模型预测结果", description = "进入模型预测页面,获取历史模型预测结果") + public SdmResponse getModelPredictResult(@RequestParam Integer modelId) { + + return modelService.getModelPredictResult(modelId); + } + } diff --git a/data/src/main/java/com/sdm/data/model/entity/TrainingModel.java b/data/src/main/java/com/sdm/data/model/entity/TrainingModel.java index 417e0d21..aa85bf23 100644 --- a/data/src/main/java/com/sdm/data/model/entity/TrainingModel.java +++ b/data/src/main/java/com/sdm/data/model/entity/TrainingModel.java @@ -73,6 +73,39 @@ public class TrainingModel implements Serializable { @TableField("trainingDataHandleFileId") private Integer trainingDataHandleFileId; + @ApiModelProperty(value = " 归一化最大值json格式") + @TableField("normalizerMax") + private String normalizerMax; + + @ApiModelProperty(value = " 归一化最小值json格式") + @TableField("normalizerMin") + private String normalizerMin; + + + @ApiModelProperty(value = "特征输入列数量") + @TableField("inputSize") + Integer inputSize; + + @ApiModelProperty(value = "输入列的json格式") + @TableField("inputLabel") + private String inputLabel; + + @ApiModelProperty(value = "输入的列-值的json格式") + @TableField("inputPredLabelValue") + private String inputPredLabelValue; + + @ApiModelProperty(value = " 特征输出列数量") + @TableField("outputSize") + Integer outputSize; + + @ApiModelProperty(value = "输出列的JSON格式") + @TableField("outputLabel") + private String outputLabel; + + @ApiModelProperty(value = "输出的列-值的json格式") + @TableField("outputPredLabelValue") + private String outputPredLabelValue; + @ApiModelProperty(value = "training脚本处理数据结果文件id") @TableField("trainingDataResultFileId") private Integer trainingDataResultFileId; @@ -81,6 +114,10 @@ public class TrainingModel implements Serializable { @TableField("trainingDataLogFileId") private Integer trainingDataLogFileId; + @ApiModelProperty(value = "training脚本处理后模型文件id") + @TableField("trainingDataExportModelFileId") + private Integer trainingDataExportModelFileId; + @ApiModelProperty(value = "使用算法") @TableField("algorithmUsed") private String algorithmUsed; @@ -88,6 +125,4 @@ public class TrainingModel implements Serializable { @ApiModelProperty(value = "说明") @TableField("description") private String description; - - } diff --git a/data/src/main/java/com/sdm/data/model/req/AlgorithmParamReq.java b/data/src/main/java/com/sdm/data/model/req/AlgorithmParamReq.java index fe7072a5..cbe62a0b 100644 --- a/data/src/main/java/com/sdm/data/model/req/AlgorithmParamReq.java +++ b/data/src/main/java/com/sdm/data/model/req/AlgorithmParamReq.java @@ -13,12 +13,6 @@ public class AlgorithmParamReq { @Schema(description = "模型id") private Integer modelId; - @Schema(description = "输入大小", example = "10") - Integer inputSize; - - @Schema(description = "输出大小", example = "1") - Integer outputSize; - @Schema(description = "算法类型", example = "多项式拟合") private String algorithm; diff --git a/data/src/main/java/com/sdm/data/model/req/ModelPredictReq.java b/data/src/main/java/com/sdm/data/model/req/ModelPredictReq.java index 41c88fde..55d3e5c0 100644 --- a/data/src/main/java/com/sdm/data/model/req/ModelPredictReq.java +++ b/data/src/main/java/com/sdm/data/model/req/ModelPredictReq.java @@ -6,8 +6,14 @@ import lombok.Data; @Data public class ModelPredictReq { @Schema(description = "选择的预测模型ID") - String modelId; + Integer modelId; - @Schema(description = "json数据") + @Schema(description = "json数据: [{\n" + + " \"name\": \"param1\",\n" + + " \"value\": 0.1\n" + + " }, {\n" + + " \"name\": \"param1\",\n" + + " \"value\": 371.6669936\n" + + " }]") String jsonData; } diff --git a/data/src/main/java/com/sdm/data/model/req/SetTrainingDataInPutOutPutColumnReq.java b/data/src/main/java/com/sdm/data/model/req/SetTrainingDataInPutOutPutColumnReq.java new file mode 100644 index 00000000..ea0e6b5b --- /dev/null +++ b/data/src/main/java/com/sdm/data/model/req/SetTrainingDataInPutOutPutColumnReq.java @@ -0,0 +1,19 @@ +package com.sdm.data.model.req; + +import io.swagger.v3.oas.annotations.media.Schema; +import lombok.Data; + +import java.util.List; + + +@Data +public class SetTrainingDataInPutOutPutColumnReq { + @Schema(description = "模型ID") + Integer modelId; + + @Schema(description = "输入特征列") + List inputColumns; + + @Schema(description = "输出特征列") + List outputColumns; +} diff --git a/data/src/main/java/com/sdm/data/service/IModelService.java b/data/src/main/java/com/sdm/data/service/IModelService.java index 6e6258bd..4a5f128a 100644 --- a/data/src/main/java/com/sdm/data/service/IModelService.java +++ b/data/src/main/java/com/sdm/data/service/IModelService.java @@ -1,12 +1,7 @@ package com.sdm.data.service; import com.sdm.common.common.SdmResponse; -import com.sdm.data.model.req.AddModelReq; -import com.sdm.data.model.req.AlgorithmParamReq; -import com.sdm.data.model.req.HandleLoadDataReq; -import com.sdm.data.model.req.ModelPredictReq; - -import java.io.IOException; +import com.sdm.data.model.req.*; /** * 模型训练服务接口 @@ -25,7 +20,17 @@ public interface IModelService { /** * 获取python脚本训练数据处理结果 */ - SdmResponse getHandleLoadDataResult(String modelId); + SdmResponse getHandleLoadDataResult(Integer modelId); + + /** + * 设置训练数据输入输出特征列 + */ + SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req); + + /** + * 获取训练数据输入输出特征设置详情 + */ + SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId); /** * 模型参数设置并提交训练 @@ -38,12 +43,19 @@ public interface IModelService { * @param modelId 模型ID * @return 模型训练结果 */ - SdmResponse getTrainingResult(String modelId); + SdmResponse getTrainingResult(Integer modelId); /** * 数据预测 * @param modelPredictReq 模型预测请求对象 * @return 模型预测结果 */ - SdmResponse predict(ModelPredictReq modelPredictReq); + SdmResponse startPredict(ModelPredictReq modelPredictReq); + + /** + * 获取模型预测结果 + * @param modelId + * @return + */ + SdmResponse getModelPredictResult(Integer modelId); } diff --git a/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java b/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java index c4035ac8..736cbf99 100644 --- a/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java +++ b/data/src/main/java/com/sdm/data/service/impl/ModelServiceImpl.java @@ -6,15 +6,11 @@ import com.sdm.common.common.SdmResponse; import com.sdm.data.model.entity.FileMetadataInfo; import com.sdm.data.model.entity.TrainingModel; import com.sdm.data.model.entity.TrainingModelAlgorithmParam; -import com.sdm.data.model.req.AddModelReq; -import com.sdm.data.model.req.AlgorithmParamReq; -import com.sdm.data.model.req.HandleLoadDataReq; -import com.sdm.data.model.req.ModelPredictReq; +import com.sdm.data.model.req.*; import com.sdm.data.service.*; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.BeanUtils; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Value; import org.springframework.mock.web.MockMultipartFile; import org.springframework.stereotype.Service; import org.springframework.transaction.annotation.Transactional; @@ -26,45 +22,84 @@ import java.io.OutputStream; import java.nio.file.Files; import java.util.Date; import java.util.List; +import java.util.Random; @Service @Slf4j public class ModelServiceImpl implements IModelService { /** - * 数据处理文件存放路径 + * 训练模型基础文件夹路径(脚本、训练数据、处理结果、训练模型) */ - private static final String TANING_MODEL_DATA_FILE_PATH = "/home/app/model/dataHandler/"; + private static final String TANING_MODEL_BASE_DIR_PATH = "/home/app/model/"; + + /** + * 模型id前缀 + */ + private static final String MODEL_ID = "/ModelId_"; /** * 数据处理脚本存放路径 */ - private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py"; + private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py"; - /** - * 数据处理结果文件名 - */ - private static final String SOURCE_JSON_FILE_NAME = "source.json"; - - /** - * 训练结果文件存放路径 - */ - private static final String TRAIN_RESULT_FILE_PATH = "/home/app/model/trainResult/"; /** * 训练脚本存放路径 */ - private static final String TRAINING_PYTHON_SCRIPTPATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Train_Proxy_Model.py"; + private static final String TRAINING_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Train/Train_Proxy_Model.py"; /** - * 训练结果文件名 + * 预测脚本存放路径 + */ + private static final String PRED_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Predict/Model_Pred.py"; + + + /** + * 数据处理文件夹放路径 + */ + private static final String DATA_HANDLE_DIR_PATH = "/dataHandler/"; + + /** + * 训练结果文件夹路径 + */ + private static final String TRAIN_RESULT_DIR_PATH = "/trainResult/"; + + /** + * 预测结果文件夹路径 + */ + private static final String PRED_RESULT_DIR_PATH = "/predResult/"; + + + /** + * dataHandler 数据处理结果文件名 + */ + private static final String SOURCE_JSON_FILE_NAME = "source.json"; + + /** + * trainResult 训练结果文件名 */ private static final String TRAINING_JSON_FILE_NAME = "training.json"; /** - * 训练日志文件 + * trainResult 训练日志文件 */ private static final String TRAINING_LOG_FILE_NAME = "training.log"; + /** + * trainResult 训练模型文件名前缀 + */ + private static final String TRAINING_MODEL_PERFIX_NAME = "model."; + + /** + * 预测结果文件名 + */ + private static final String PRED_RESULT_FILE_NAME = "forecast.json"; + + /** + * 脚本 入参文件名 + */ + private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = "param.json"; + @Autowired IDataFileService dataFileService; @@ -115,11 +150,12 @@ public class ModelServiceImpl implements IModelService { log.info("训练模型关联训练数据文件成功,模型ID: {}", trainingModelId); - // 直接将MultipartFile 训练数据文件 写入Linux服务器 - String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), TANING_MODEL_DATA_FILE_PATH); + // 直接将MultipartFile 训练数据文件 写入Linux服务器的对应模型目录下:/home/app/model/ModelId_1/dataHandler/,并返回文件路径 + String handleLoadModelDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH; + String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), handleLoadModelDirPath); // 根据生成的文件路径创建python脚本 的param.json参数文件 - String paramJsonPath = createParamJsonFile(taningModelDataFilePath); + String paramJsonPath = createDataHadleParamJsonFile(taningModelDataFilePath); // 更新训练模型 训练状态 @@ -136,7 +172,7 @@ public class ModelServiceImpl implements IModelService { } @Override - public SdmResponse getHandleLoadDataResult(String modelId) { + public SdmResponse getHandleLoadDataResult(Integer modelId) { // 根据模型ID查询处理结果,再根据结果获取文件ID,再从MinIO中下载json文件,处理json文件返回一个对象 TrainingModel trainingModel = trainingModelService.getById(modelId); if (trainingModel == null) { @@ -202,13 +238,69 @@ public class ModelServiceImpl implements IModelService { log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", trainingModelId, duration); // 上传处理完的python数据到minio - String resultFilePath = TANING_MODEL_DATA_FILE_PATH + SOURCE_JSON_FILE_NAME; + String resultFilePath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH + SOURCE_JSON_FILE_NAME; Integer fileId = uploadResultFileToMinio(resultFilePath); + InputStream minioInputStream = dataFileService.getMinioInputStream(fileId); + + // 直接在内存中读取并解析JSON + BufferedReader reader = new BufferedReader(new InputStreamReader(minioInputStream)); + StringBuilder jsonString = new StringBuilder(); + String line; + while ((line = reader.readLine()) != null) { + jsonString.append(line); + } + + // 将JSON字符串转换为JSONObject或其他Java对象,获取表头总数量 + JSONObject jsonObject = JSONObject.parseObject(jsonString.toString()); + + JSONArray averageDataArray = jsonObject.getJSONArray("average_data"); + // 循环处理averageDataArray + + JSONArray maxValues = null; + JSONArray minValues = null; + + for (int i = 0; i < averageDataArray.size(); i++) { + JSONObject averageData = averageDataArray.getJSONObject(i); + if (averageData.getString("property").equals("最大值")) { + // 提取最大值数组,排除property字段 + maxValues = new JSONArray(); + // 获取所有字段名 + java.util.Set keySet = averageData.keySet(); + for (String key : keySet) { + if (!"property".equals(key)) { + maxValues.add(averageData.getDoubleValue(key)); + } + } + } else if (averageData.getString("property").equals("最小值")) { + // 提取最小值数组,排除property字段 + minValues = new JSONArray(); + // 获取所有字段名 + java.util.Set keySet = averageData.keySet(); + for (String key : keySet) { + if (!"property".equals(key)) { + minValues.add(averageData.getDoubleValue(key)); + } + } + } + } + + // 将数组转换为JSON字符串 + String normalizerMaxJson = null; + String normalizerMinJson = null; + if (maxValues != null) { + normalizerMaxJson = maxValues.toJSONString(); + } + if (minValues != null) { + normalizerMinJson = minValues.toJSONString(); + } + // 更新训练模型 训练状态 trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId) .set(TrainingModel::getHandleStatus, "成功") .set(TrainingModel::getTrainingDataHandleFileId, fileId) + .set(TrainingModel::getNormalizerMax, normalizerMaxJson) + .set(TrainingModel::getNormalizerMin, normalizerMinJson) .update(); log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration); @@ -300,12 +392,12 @@ public class ModelServiceImpl implements IModelService { } /** - * 根据数据文件路径创建param.json配置文件 + * 根据数据文件路径创建数据处理的param.json配置文件 * * @param dataFilePath 数据文件的完整路径 * @return param.json的完整路径 */ - private String createParamJsonFile(String dataFilePath) { + private String createDataHadleParamJsonFile(String dataFilePath) { try { // 提取文件名和目录路径 File dataFile = new File(dataFilePath); @@ -313,7 +405,7 @@ public class ModelServiceImpl implements IModelService { String directoryPath = dataFile.getParent(); // 构造param.json文件路径 - String paramJsonPath = directoryPath + "/param.json"; + String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME; // 创建JSON对象 JSONObject paramJson = new JSONObject(); @@ -339,9 +431,10 @@ public class ModelServiceImpl implements IModelService { /** * 下载文件到服务器本地 - * @param fileId 文件ID + * + * @param fileId 文件ID * @param localDirPath 本地文件夹路径 - * @return 本地文件的完整路径 + * @return 本地文件的完整路径 localDirPath / fileName */ private String downloadToLocalFromMinIO(Integer fileId, String localDirPath) { InputStream in = null; @@ -360,13 +453,13 @@ public class ModelServiceImpl implements IModelService { } String localFilePath = localDirPath + "/" + fileMetadataInfo.getOriginalName(); // 确保目标目录存在 - File localFile = new File(localDirPath); + File localFile = new File(localFilePath); File parentDir = localFile.getParentFile(); if (parentDir != null && !parentDir.exists()) { parentDir.mkdirs(); } - out = new FileOutputStream(localDirPath); + out = new FileOutputStream(localFilePath); // 将输入流写入本地文件 byte[] buffer = new byte[1024]; int len; @@ -442,6 +535,37 @@ public class ModelServiceImpl implements IModelService { } } + @Override + public SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req) { + trainingModelService.lambdaUpdate() + .eq(TrainingModel::getId, req.getModelId()) + .set(TrainingModel::getInputSize, req.getInputColumns().size()) + .set(TrainingModel::getInputLabel, JSONArray.toJSONString(req.getInputColumns())) + .set(TrainingModel::getOutputSize, req.getOutputColumns().size()) + .set(TrainingModel::getOutputLabel, JSONArray.toJSONString(req.getOutputColumns())) + .update(); + return SdmResponse.success(); + } + + @Override + public SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId) { + TrainingModel trainingModel = trainingModelService.getById(modelId); + // 从数据库获取inputLabel字段值(JSON字符串) + String inputLabelJson = trainingModel.getInputLabel(); + // 解析JSON字符串为列表 + List inputLabels = JSONArray.parseArray(inputLabelJson, String.class); + + // 从数据库获取outputLable字段值(JSON字符串) + String outputLabelJson = trainingModel.getOutputLabel(); + // 解析JSON字符串为列表 + List outputLabels = JSONArray.parseArray(outputLabelJson, String.class); + + JSONObject result = new JSONObject(); + result.put("inputLabels", inputLabels); + result.put("outputLabels", outputLabels); + return SdmResponse.success(result); + } + @Override public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) { // 保存或更新算法参数 @@ -469,7 +593,7 @@ public class ModelServiceImpl implements IModelService { String paramJsonPath = createTraningParamJsonFile(algorithmParamReq); // 异步调用python脚本进行训练,并上传结果文件到minio - trainModelAsync(paramJsonPath, algorithmParamReq.getModelId()); + trainModelAsync(paramJsonPath, algorithmParamReq.getModelId(), algorithmParamReq.getExportFormat()); return SdmResponse.success("提交训练成功"); } @@ -480,7 +604,7 @@ public class ModelServiceImpl implements IModelService { * @param paramJsonPath 参数文件路径 * @param modelId 模型ID */ - private void trainModelAsync(String paramJsonPath, Integer modelId) { + private void trainModelAsync(String paramJsonPath, Integer modelId, String exportFormat) { new Thread(() -> { try { long startTime = System.currentTimeMillis(); @@ -491,10 +615,13 @@ public class ModelServiceImpl implements IModelService { long duration = endTime - startTime; log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration); - // 将训练生成的生成training.json、training.log文件,上传到minio - Integer trainingDataResultFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_JSON_FILE_NAME); - Integer trainingDataLogFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_LOG_FILE_NAME); - log.info("模型训练完成,结果文件: training.json、training.log文件上传成功,文件ID: {},{}", trainingDataResultFileId, trainingDataLogFileId); + + // 将训练生成的生成 training.json、training.log、模型文件,上传到minio + Integer trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME); + Integer trainingDataLogFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME); + Integer trainingDataExportModelFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat); + + log.info("模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功", trainingDataResultFileId, trainingDataLogFileId, trainingDataExportModelFileId); // 更新模型状态为训练完成 trainingModelService.lambdaUpdate() @@ -504,8 +631,9 @@ public class ModelServiceImpl implements IModelService { .set(TrainingModel::getTrainingDuration, duration) .set(TrainingModel::getTrainingDataResultFileId, trainingDataResultFileId) .set(TrainingModel::getTrainingDataLogFileId, trainingDataLogFileId) + .set(TrainingModel::getTrainingDataExportModelFileId, trainingDataExportModelFileId) .update(); - + log.info("模型训练完成,模型ID: {}", modelId); } catch (Exception e) { log.error("模型训练失败,模型ID: {}", modelId, e); @@ -520,15 +648,16 @@ public class ModelServiceImpl implements IModelService { /** * 创建训练参数param.json文件 + * * @param algorithmParamReq 算法参数请求 * @return */ private String createTraningParamJsonFile(AlgorithmParamReq algorithmParamReq) { // 根据训练模型id 获取训练模型的训练数据文件id, 并下载到本地 trainingDataFilePath = /home/app/model/trainResult/sample.CSV Integer modelId = algorithmParamReq.getModelId(); - TrainingModel trainingModel= trainingModelService.getById(modelId); + TrainingModel trainingModel = trainingModelService.getById(modelId); Integer trainingDataFileId = trainingModel.getTrainingDataFileId(); - String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TRAIN_RESULT_FILE_PATH); + String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH); try { // 提取文件名和目录路径 @@ -537,7 +666,7 @@ public class ModelServiceImpl implements IModelService { String directoryPath = dataFile.getParent(); // 构造param.json文件路径 - String paramJsonPath = directoryPath + "/param.json"; + String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME; // 创建JSON对象 JSONObject paramJson = new JSONObject(); @@ -549,8 +678,8 @@ public class ModelServiceImpl implements IModelService { // 构造algorithmParam对象 JSONObject algorithmParam = new JSONObject(); - algorithmParam.put("inputSize", algorithmParamReq.getInputSize()); - algorithmParam.put("outputSize", algorithmParamReq.getOutputSize()); + algorithmParam.put("inputSize", trainingModel.getInputSize()); + algorithmParam.put("outputSize", trainingModel.getOutputSize()); algorithmParam.put("algorithm", algorithmParamReq.getAlgorithm()); algorithmParam.put("activateFun", algorithmParamReq.getActivateFun()); algorithmParam.put("lossFun", algorithmParamReq.getLossFun()); @@ -583,47 +712,58 @@ public class ModelServiceImpl implements IModelService { } @Override - public SdmResponse getTrainingResult(String modelId) { + public SdmResponse getTrainingResult(Integer modelId) { // 根据modelId 获取训练模型,获取训练模型结果文件id,日志文件id, TrainingModel trainingModel = trainingModelService.getById(modelId); if (trainingModel == null) { return SdmResponse.failed("模型不存在"); } - if(trainingModel.getTrainingStatus().equals("失败")){ + if (trainingModel.getTrainingStatus().equals("失败")) { return SdmResponse.failed("模型训练失败,请查看日志详情"); } - if(trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")){ + if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) { return SdmResponse.failed("模型正在训练中,请稍后再查询"); } - Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId(); Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId(); - if (trainingDataResultFileId == null || trainingDataLogFileId == null) { - return SdmResponse.failed("训练结果文件或日志文件不存在"); + if (trainingDataResultFileId == null && trainingDataLogFileId == null) { + return SdmResponse.failed("训练结果文件和日志文件都不存在"); } - // 从minio获取结果json文件转换成对象 - InputStream resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId); - // 从minio获取日志文件转换成对象 - InputStream logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId); + InputStream resultInputStream = null; + InputStream logInputStream = null; - if (resultInputStream == null || logInputStream == null) { + // 从minio获取结果json文件转换成对象 + if (trainingDataResultFileId != null) { + resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId); + } + + // 从minio获取日志文件转换成对象 + if (trainingDataLogFileId != null) { + logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId); + } + + if (resultInputStream == null && logInputStream == null) { return SdmResponse.failed("无法从MinIO获取文件流"); } try { - // 处理结果JSON文件 - JSONObject resultJsonObject = parseJsonFromStream(resultInputStream); - // 处理日志文件 - String logContent = parseStringFromStream(logInputStream); - - // 构造返回对象 JSONObject response = new JSONObject(); - response.put("result", resultJsonObject); - response.put("log", logContent); + + // 处理结果JSON文件(如果存在) + if (resultInputStream != null) { + JSONObject resultJsonObject = parseJsonFromStream(resultInputStream); + response.put("result", resultJsonObject); + } + + // 处理日志文件(如果存在) + if (logInputStream != null) { + String logContent = parseStringFromStream(logInputStream); + response.put("log", logContent); + } return SdmResponse.success(response); } catch (Exception e) { @@ -641,6 +781,7 @@ public class ModelServiceImpl implements IModelService { /** * 从输入流中解析JSON对象 + * * @param inputStream 输入流 * @return 解析后的JSON对象 * @throws IOException IO异常 @@ -657,6 +798,7 @@ public class ModelServiceImpl implements IModelService { /** * 从输入流中读取字符串内容 + * * @param inputStream 输入流 * @return 字符串内容 * @throws IOException IO异常 @@ -672,7 +814,158 @@ public class ModelServiceImpl implements IModelService { } @Override - public SdmResponse predict(ModelPredictReq modelPredictReq) { - return null; + public SdmResponse startPredict(ModelPredictReq modelPredictReq) { + JSONObject predResultObject = new JSONObject(); + + Integer modelId = modelPredictReq.getModelId(); + // 将jsonData存储到与modelId关联的模型表的inputLabel字段中 + trainingModelService.lambdaUpdate() + .eq(TrainingModel::getId, modelPredictReq.getModelId()) + .set(TrainingModel::getInputPredLabelValue, modelPredictReq.getJsonData()) + .update(); + TrainingModel trainingModel = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one(); + TrainingModelAlgorithmParam trainingModelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery().eq(TrainingModelAlgorithmParam::getModelId, modelId).one(); + + String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel.getId() + "/" + PRED_RESULT_DIR_PATH; + // 创建预测用的param.json文件 + String predParamJsonFile = createPredParamJsonFile(modelPredictReq, trainingModel, trainingModelAlgorithmParam,predDirPath); + + // 执行预测 + try { + long startTime = System.currentTimeMillis(); + log.info("开始执行Python预测脚本, 模型ID: {}", modelId); + + // 调用python脚本进行预测 + callPythonScript(PRED_PYTHON_SCRIPTPATH, predParamJsonFile); + + long endTime = System.currentTimeMillis(); + long duration = endTime - startTime; + log.info("Python预测脚本执行完成, 模型ID: {}, 耗时: {} ms", modelId, duration); + + // 读取预测生成的结果文件并保存到数据库 + String predResultPath = predDirPath + "/" + PRED_RESULT_FILE_NAME; + File predResultFile = new File(predResultPath); + + if (predResultFile.exists()) { + try { + // 读取预测结果文件内容 + String predResultJson = new String(Files.readAllBytes(predResultFile.toPath())); + predResultObject = JSONObject.parseObject(predResultJson); + + // 将预测结果保存到数据库 + trainingModelService.lambdaUpdate() + .eq(TrainingModel::getId, modelId) + .set(TrainingModel::getOutputPredLabelValue, predResultObject.toJSONString()) + .update(); + log.info("模型预测结果已保存到数据库,模型ID: {}", modelId); + } catch (Exception e) { + log.error("解析或保存预测结果失败,模型ID: {}", modelId, e); + } + } else { + log.warn("预测结果文件不存在: {}", predResultPath); + } + + log.info("模型预测完成,模型ID: {}", modelId); + } catch (Exception e) { + log.error("模型预测失败,模型ID: {}", modelId, e); + } + + return SdmResponse.success(predResultObject); + } + + /** + * 创建预测参数param.json文件 + * + * @param modelPredictReq 预测请求参数 + * @param trainingModel 训练模型 + * @param trainingModelAlgorithmParam 训练模型算法参数 + * @param predDirPath 预测目录路径 + * @return param.json文件路径 + */ + private String createPredParamJsonFile(ModelPredictReq modelPredictReq, TrainingModel trainingModel, TrainingModelAlgorithmParam trainingModelAlgorithmParam,String predDirPath) { + try { + // 确保目录存在 + File predDir = new File(predDirPath); + if (!predDir.exists()) { + predDir.mkdirs(); + } + + // 构造param.json文件路径 + String paramJsonPath = predDirPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME; + + // 创建JSON对象 + JSONObject paramJson = new JSONObject(); + + // 下载model模型文件到本地 + downloadToLocalFromMinIO(trainingModel.getTrainingDataExportModelFileId(), predDirPath); + paramJson.put("modelFile", TRAINING_MODEL_PERFIX_NAME + trainingModelAlgorithmParam.getExportFormat()); + + // 设置路径 + paramJson.put("path", predDirPath); + + // 添加模型参数 + JSONObject modelParams = new JSONObject(); + modelParams.put("inputSize", trainingModel.getInputSize()); + modelParams.put("outputSize", trainingModel.getOutputSize()); + modelParams.put("normalizerType", trainingModelAlgorithmParam.getDisposeMethod()); + + // 添加归一化参数 + if (trainingModel.getNormalizerMax() != null) { + JSONArray normalizerMaxArray = JSONArray.parseArray(trainingModel.getNormalizerMax()); + modelParams.put("normalizerMax", normalizerMaxArray); + } + + if (trainingModel.getNormalizerMin() != null) { + JSONArray normalizerMinArray = JSONArray.parseArray(trainingModel.getNormalizerMin()); + modelParams.put("normalizerMin", normalizerMinArray); + } + + paramJson.put("modelParams", modelParams); + + // 添加输入数据 + // 直接使用传入的JSON数组作为input值 + JSONArray inputArray = JSONArray.parseArray(modelPredictReq.getJsonData()); + paramJson.put("input", inputArray); + + // 添加输出标签 + JSONObject output = new JSONObject(); + if (trainingModel.getOutputLabel() != null) { + JSONArray outputLabels = JSONArray.parseArray(trainingModel.getOutputLabel()); + output.put("names", outputLabels); + } + paramJson.put("output", output); + + // 写入param.json文件 + try (FileWriter fileWriter = new FileWriter(paramJsonPath)) { + fileWriter.write(paramJson.toJSONString()); + } + + log.info("成功创建预测param.json文件: {}", paramJsonPath); + return paramJsonPath; + } catch (IOException e) { + log.error("创建预测param.json文件失败", e); + throw new RuntimeException("创建预测param.json文件失败: " + e.getMessage(), e); + } + } + + @Override + public SdmResponse getModelPredictResult(Integer modelId) { + JSONObject result = new JSONObject(); + TrainingModel model = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one(); + if (model == null) { + return SdmResponse.failed("模型不存在"); + } + String inputPredLabelValue = model.getInputPredLabelValue(); + if (inputPredLabelValue != null) { + JSONObject inputPredLabelValueJson = JSONObject.parseObject(inputPredLabelValue); + result.put("inputPredLabelValue", inputPredLabelValueJson); + } + String outputPredLabelValue = model.getOutputPredLabelValue(); + if (outputPredLabelValue != null) { + JSONObject outputPredLabelValueJson = JSONObject.parseObject(outputPredLabelValue); + result.put("inputPredLabelValue", outputPredLabelValueJson); + } + + return SdmResponse.success(result); } }