训练模型
This commit is contained in:
@@ -1,10 +1,7 @@
|
|||||||
package com.sdm.data.controller;
|
package com.sdm.data.controller;
|
||||||
|
|
||||||
import com.sdm.common.common.SdmResponse;
|
import com.sdm.common.common.SdmResponse;
|
||||||
import com.sdm.data.model.req.AddModelReq;
|
import com.sdm.data.model.req.*;
|
||||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
|
||||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
|
||||||
import com.sdm.data.model.req.ModelPredictReq;
|
|
||||||
import com.sdm.data.service.IModelService;
|
import com.sdm.data.service.IModelService;
|
||||||
import io.swagger.v3.oas.annotations.Operation;
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
import io.swagger.v3.oas.annotations.Parameter;
|
import io.swagger.v3.oas.annotations.Parameter;
|
||||||
@@ -28,7 +25,7 @@ public class ModelTraningController {
|
|||||||
@PostMapping("/addModel")
|
@PostMapping("/addModel")
|
||||||
@Operation(summary = "新增训练模型", description = "新增训练模型")
|
@Operation(summary = "新增训练模型", description = "新增训练模型")
|
||||||
public SdmResponse addModel(@RequestBody AddModelReq addModelReq) {
|
public SdmResponse addModel(@RequestBody AddModelReq addModelReq) {
|
||||||
return modelService.addModel(addModelReq);
|
return modelService.addModel(addModelReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -46,10 +43,28 @@ public class ModelTraningController {
|
|||||||
*/
|
*/
|
||||||
@GetMapping("/getHandleLoadDataResult")
|
@GetMapping("/getHandleLoadDataResult")
|
||||||
@Operation(summary = "获取python脚本训练数据处理结果", description = "获取python脚本训练数据处理结果")
|
@Operation(summary = "获取python脚本训练数据处理结果", description = "获取python脚本训练数据处理结果")
|
||||||
public SdmResponse getHandleLoadDataResult(@RequestParam String modelId) {
|
public SdmResponse getHandleLoadDataResult(@RequestParam Integer modelId) {
|
||||||
return modelService.getHandleLoadDataResult(modelId);
|
return modelService.getHandleLoadDataResult(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置训练数据输入输出特征列
|
||||||
|
*/
|
||||||
|
@PostMapping("/setTrainingDataInPutOutPutColumn")
|
||||||
|
@Operation(summary = "设置训练数据输入输出特征列", description = "设置训练数据输入输出特征列")
|
||||||
|
public SdmResponse setTrainingDataInPutOutPutColumn(@RequestBody SetTrainingDataInPutOutPutColumnReq req) {
|
||||||
|
return modelService.setTrainingDataInPutOutPutColumn(req);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取训练数据输入输出特征设置详情
|
||||||
|
*/
|
||||||
|
@GetMapping("/getTrainingDataInPutOutPutColumn")
|
||||||
|
@Operation(summary = "获取训练数据输入输出特征设置详情", description = "获取训练数据输入输出特征设置详情")
|
||||||
|
public SdmResponse getTrainingDataInPutOutPutColumn(@RequestParam Integer modelId) {
|
||||||
|
return modelService.getTrainingDataInPutOutPutColumn(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 算法参数设置并提交训练
|
* 算法参数设置并提交训练
|
||||||
*/
|
*/
|
||||||
@@ -61,20 +76,32 @@ public class ModelTraningController {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取训练曲线和训练日志
|
* 获取训练曲线和训练日志
|
||||||
|
*
|
||||||
* @param modelId
|
* @param modelId
|
||||||
*/
|
*/
|
||||||
@GetMapping("/getTrainingResult")
|
@GetMapping("/getTrainingResult")
|
||||||
@Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志")
|
@Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志")
|
||||||
public SdmResponse getTrainingResult(@RequestParam String modelId){
|
public SdmResponse getTrainingResult(@RequestParam Integer modelId) {
|
||||||
return modelService.getTrainingResult(modelId);
|
return modelService.getTrainingResult(modelId);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 数据预测
|
* 开始数据预测
|
||||||
*/
|
*/
|
||||||
@GetMapping("/predict")
|
@PostMapping("/startPredict")
|
||||||
@Operation(summary = "数据预测", description = "数据预测")
|
@Operation(summary = "数据预测", description = "数据预测")
|
||||||
public SdmResponse predict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) throws IOException {
|
public SdmResponse startPredict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) {
|
||||||
return modelService.predict(modelPredictReq);
|
return modelService.startPredict(modelPredictReq);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 进入模型预测页面,获取历史模型预测结果
|
||||||
|
*/
|
||||||
|
@GetMapping("/getModelPredictResult")
|
||||||
|
@Operation(summary = "进入模型预测页面,获取历史模型预测结果", description = "进入模型预测页面,获取历史模型预测结果")
|
||||||
|
public SdmResponse getModelPredictResult(@RequestParam Integer modelId) {
|
||||||
|
|
||||||
|
return modelService.getModelPredictResult(modelId);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -73,6 +73,39 @@ public class TrainingModel implements Serializable {
|
|||||||
@TableField("trainingDataHandleFileId")
|
@TableField("trainingDataHandleFileId")
|
||||||
private Integer trainingDataHandleFileId;
|
private Integer trainingDataHandleFileId;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = " 归一化最大值json格式")
|
||||||
|
@TableField("normalizerMax")
|
||||||
|
private String normalizerMax;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = " 归一化最小值json格式")
|
||||||
|
@TableField("normalizerMin")
|
||||||
|
private String normalizerMin;
|
||||||
|
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "特征输入列数量")
|
||||||
|
@TableField("inputSize")
|
||||||
|
Integer inputSize;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "输入列的json格式")
|
||||||
|
@TableField("inputLabel")
|
||||||
|
private String inputLabel;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "输入的列-值的json格式")
|
||||||
|
@TableField("inputPredLabelValue")
|
||||||
|
private String inputPredLabelValue;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = " 特征输出列数量")
|
||||||
|
@TableField("outputSize")
|
||||||
|
Integer outputSize;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "输出列的JSON格式")
|
||||||
|
@TableField("outputLabel")
|
||||||
|
private String outputLabel;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "输出的列-值的json格式")
|
||||||
|
@TableField("outputPredLabelValue")
|
||||||
|
private String outputPredLabelValue;
|
||||||
|
|
||||||
@ApiModelProperty(value = "training脚本处理数据结果文件id")
|
@ApiModelProperty(value = "training脚本处理数据结果文件id")
|
||||||
@TableField("trainingDataResultFileId")
|
@TableField("trainingDataResultFileId")
|
||||||
private Integer trainingDataResultFileId;
|
private Integer trainingDataResultFileId;
|
||||||
@@ -81,6 +114,10 @@ public class TrainingModel implements Serializable {
|
|||||||
@TableField("trainingDataLogFileId")
|
@TableField("trainingDataLogFileId")
|
||||||
private Integer trainingDataLogFileId;
|
private Integer trainingDataLogFileId;
|
||||||
|
|
||||||
|
@ApiModelProperty(value = "training脚本处理后模型文件id")
|
||||||
|
@TableField("trainingDataExportModelFileId")
|
||||||
|
private Integer trainingDataExportModelFileId;
|
||||||
|
|
||||||
@ApiModelProperty(value = "使用算法")
|
@ApiModelProperty(value = "使用算法")
|
||||||
@TableField("algorithmUsed")
|
@TableField("algorithmUsed")
|
||||||
private String algorithmUsed;
|
private String algorithmUsed;
|
||||||
@@ -88,6 +125,4 @@ public class TrainingModel implements Serializable {
|
|||||||
@ApiModelProperty(value = "说明")
|
@ApiModelProperty(value = "说明")
|
||||||
@TableField("description")
|
@TableField("description")
|
||||||
private String description;
|
private String description;
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,12 +13,6 @@ public class AlgorithmParamReq {
|
|||||||
@Schema(description = "模型id")
|
@Schema(description = "模型id")
|
||||||
private Integer modelId;
|
private Integer modelId;
|
||||||
|
|
||||||
@Schema(description = "输入大小", example = "10")
|
|
||||||
Integer inputSize;
|
|
||||||
|
|
||||||
@Schema(description = "输出大小", example = "1")
|
|
||||||
Integer outputSize;
|
|
||||||
|
|
||||||
@Schema(description = "算法类型", example = "多项式拟合")
|
@Schema(description = "算法类型", example = "多项式拟合")
|
||||||
private String algorithm;
|
private String algorithm;
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,14 @@ import lombok.Data;
|
|||||||
@Data
|
@Data
|
||||||
public class ModelPredictReq {
|
public class ModelPredictReq {
|
||||||
@Schema(description = "选择的预测模型ID")
|
@Schema(description = "选择的预测模型ID")
|
||||||
String modelId;
|
Integer modelId;
|
||||||
|
|
||||||
@Schema(description = "json数据")
|
@Schema(description = "json数据: [{\n" +
|
||||||
|
" \"name\": \"param1\",\n" +
|
||||||
|
" \"value\": 0.1\n" +
|
||||||
|
" }, {\n" +
|
||||||
|
" \"name\": \"param1\",\n" +
|
||||||
|
" \"value\": 371.6669936\n" +
|
||||||
|
" }]")
|
||||||
String jsonData;
|
String jsonData;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,19 @@
|
|||||||
|
package com.sdm.data.model.req;
|
||||||
|
|
||||||
|
import io.swagger.v3.oas.annotations.media.Schema;
|
||||||
|
import lombok.Data;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class SetTrainingDataInPutOutPutColumnReq {
|
||||||
|
@Schema(description = "模型ID")
|
||||||
|
Integer modelId;
|
||||||
|
|
||||||
|
@Schema(description = "输入特征列")
|
||||||
|
List<String> inputColumns;
|
||||||
|
|
||||||
|
@Schema(description = "输出特征列")
|
||||||
|
List<String> outputColumns;
|
||||||
|
}
|
||||||
@@ -1,12 +1,7 @@
|
|||||||
package com.sdm.data.service;
|
package com.sdm.data.service;
|
||||||
|
|
||||||
import com.sdm.common.common.SdmResponse;
|
import com.sdm.common.common.SdmResponse;
|
||||||
import com.sdm.data.model.req.AddModelReq;
|
import com.sdm.data.model.req.*;
|
||||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
|
||||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
|
||||||
import com.sdm.data.model.req.ModelPredictReq;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模型训练服务接口
|
* 模型训练服务接口
|
||||||
@@ -25,7 +20,17 @@ public interface IModelService {
|
|||||||
/**
|
/**
|
||||||
* 获取python脚本训练数据处理结果
|
* 获取python脚本训练数据处理结果
|
||||||
*/
|
*/
|
||||||
SdmResponse getHandleLoadDataResult(String modelId);
|
SdmResponse getHandleLoadDataResult(Integer modelId);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置训练数据输入输出特征列
|
||||||
|
*/
|
||||||
|
SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取训练数据输入输出特征设置详情
|
||||||
|
*/
|
||||||
|
SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 模型参数设置并提交训练
|
* 模型参数设置并提交训练
|
||||||
@@ -38,12 +43,19 @@ public interface IModelService {
|
|||||||
* @param modelId 模型ID
|
* @param modelId 模型ID
|
||||||
* @return 模型训练结果
|
* @return 模型训练结果
|
||||||
*/
|
*/
|
||||||
SdmResponse getTrainingResult(String modelId);
|
SdmResponse getTrainingResult(Integer modelId);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 数据预测
|
* 数据预测
|
||||||
* @param modelPredictReq 模型预测请求对象
|
* @param modelPredictReq 模型预测请求对象
|
||||||
* @return 模型预测结果
|
* @return 模型预测结果
|
||||||
*/
|
*/
|
||||||
SdmResponse predict(ModelPredictReq modelPredictReq);
|
SdmResponse startPredict(ModelPredictReq modelPredictReq);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取模型预测结果
|
||||||
|
* @param modelId
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
SdmResponse getModelPredictResult(Integer modelId);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,15 +6,11 @@ import com.sdm.common.common.SdmResponse;
|
|||||||
import com.sdm.data.model.entity.FileMetadataInfo;
|
import com.sdm.data.model.entity.FileMetadataInfo;
|
||||||
import com.sdm.data.model.entity.TrainingModel;
|
import com.sdm.data.model.entity.TrainingModel;
|
||||||
import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
|
import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
|
||||||
import com.sdm.data.model.req.AddModelReq;
|
import com.sdm.data.model.req.*;
|
||||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
|
||||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
|
||||||
import com.sdm.data.model.req.ModelPredictReq;
|
|
||||||
import com.sdm.data.service.*;
|
import com.sdm.data.service.*;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import org.springframework.beans.BeanUtils;
|
import org.springframework.beans.BeanUtils;
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
import org.springframework.beans.factory.annotation.Autowired;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
|
||||||
import org.springframework.mock.web.MockMultipartFile;
|
import org.springframework.mock.web.MockMultipartFile;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
import org.springframework.transaction.annotation.Transactional;
|
import org.springframework.transaction.annotation.Transactional;
|
||||||
@@ -26,45 +22,84 @@ import java.io.OutputStream;
|
|||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.util.Date;
|
import java.util.Date;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Random;
|
||||||
|
|
||||||
@Service
|
@Service
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class ModelServiceImpl implements IModelService {
|
public class ModelServiceImpl implements IModelService {
|
||||||
/**
|
/**
|
||||||
* 数据处理文件存放路径
|
* 训练模型基础文件夹路径(脚本、训练数据、处理结果、训练模型)
|
||||||
*/
|
*/
|
||||||
private static final String TANING_MODEL_DATA_FILE_PATH = "/home/app/model/dataHandler/";
|
private static final String TANING_MODEL_BASE_DIR_PATH = "/home/app/model/";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 模型id前缀
|
||||||
|
*/
|
||||||
|
private static final String MODEL_ID = "/ModelId_";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 数据处理脚本存放路径
|
* 数据处理脚本存放路径
|
||||||
*/
|
*/
|
||||||
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py";
|
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py";
|
||||||
|
|
||||||
/**
|
|
||||||
* 数据处理结果文件名
|
|
||||||
*/
|
|
||||||
private static final String SOURCE_JSON_FILE_NAME = "source.json";
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 训练结果文件存放路径
|
|
||||||
*/
|
|
||||||
private static final String TRAIN_RESULT_FILE_PATH = "/home/app/model/trainResult/";
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 训练脚本存放路径
|
* 训练脚本存放路径
|
||||||
*/
|
*/
|
||||||
private static final String TRAINING_PYTHON_SCRIPTPATH = "/home/app/model/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Train_Proxy_Model.py";
|
private static final String TRAINING_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Train/Train_Proxy_Model.py";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 训练结果文件名
|
* 预测脚本存放路径
|
||||||
|
*/
|
||||||
|
private static final String PRED_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + "/ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Predict/Model_Pred.py";
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 数据处理文件夹放路径
|
||||||
|
*/
|
||||||
|
private static final String DATA_HANDLE_DIR_PATH = "/dataHandler/";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 训练结果文件夹路径
|
||||||
|
*/
|
||||||
|
private static final String TRAIN_RESULT_DIR_PATH = "/trainResult/";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 预测结果文件夹路径
|
||||||
|
*/
|
||||||
|
private static final String PRED_RESULT_DIR_PATH = "/predResult/";
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* dataHandler 数据处理结果文件名
|
||||||
|
*/
|
||||||
|
private static final String SOURCE_JSON_FILE_NAME = "source.json";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* trainResult 训练结果文件名
|
||||||
*/
|
*/
|
||||||
private static final String TRAINING_JSON_FILE_NAME = "training.json";
|
private static final String TRAINING_JSON_FILE_NAME = "training.json";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 训练日志文件
|
* trainResult 训练日志文件
|
||||||
*/
|
*/
|
||||||
private static final String TRAINING_LOG_FILE_NAME = "training.log";
|
private static final String TRAINING_LOG_FILE_NAME = "training.log";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* trainResult 训练模型文件名前缀
|
||||||
|
*/
|
||||||
|
private static final String TRAINING_MODEL_PERFIX_NAME = "model.";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 预测结果文件名
|
||||||
|
*/
|
||||||
|
private static final String PRED_RESULT_FILE_NAME = "forecast.json";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 脚本 入参文件名
|
||||||
|
*/
|
||||||
|
private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = "param.json";
|
||||||
|
|
||||||
|
|
||||||
@Autowired
|
@Autowired
|
||||||
IDataFileService dataFileService;
|
IDataFileService dataFileService;
|
||||||
@@ -115,11 +150,12 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
log.info("训练模型关联训练数据文件成功,模型ID: {}", trainingModelId);
|
log.info("训练模型关联训练数据文件成功,模型ID: {}", trainingModelId);
|
||||||
|
|
||||||
|
|
||||||
// 直接将MultipartFile 训练数据文件 写入Linux服务器
|
// 直接将MultipartFile 训练数据文件 写入Linux服务器的对应模型目录下:/home/app/model/ModelId_1/dataHandler/,并返回文件路径
|
||||||
String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), TANING_MODEL_DATA_FILE_PATH);
|
String handleLoadModelDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH;
|
||||||
|
String taningModelDataFilePath = writeToLinuxServer(handleLoadDataReq.getFile(), handleLoadModelDirPath);
|
||||||
|
|
||||||
// 根据生成的文件路径创建python脚本 的param.json参数文件
|
// 根据生成的文件路径创建python脚本 的param.json参数文件
|
||||||
String paramJsonPath = createParamJsonFile(taningModelDataFilePath);
|
String paramJsonPath = createDataHadleParamJsonFile(taningModelDataFilePath);
|
||||||
|
|
||||||
|
|
||||||
// 更新训练模型 训练状态
|
// 更新训练模型 训练状态
|
||||||
@@ -136,7 +172,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SdmResponse getHandleLoadDataResult(String modelId) {
|
public SdmResponse getHandleLoadDataResult(Integer modelId) {
|
||||||
// 根据模型ID查询处理结果,再根据结果获取文件ID,再从MinIO中下载json文件,处理json文件返回一个对象
|
// 根据模型ID查询处理结果,再根据结果获取文件ID,再从MinIO中下载json文件,处理json文件返回一个对象
|
||||||
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
||||||
if (trainingModel == null) {
|
if (trainingModel == null) {
|
||||||
@@ -202,13 +238,69 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", trainingModelId, duration);
|
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", trainingModelId, duration);
|
||||||
|
|
||||||
// 上传处理完的python数据到minio
|
// 上传处理完的python数据到minio
|
||||||
String resultFilePath = TANING_MODEL_DATA_FILE_PATH + SOURCE_JSON_FILE_NAME;
|
String resultFilePath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + "/" + DATA_HANDLE_DIR_PATH + SOURCE_JSON_FILE_NAME;
|
||||||
Integer fileId = uploadResultFileToMinio(resultFilePath);
|
Integer fileId = uploadResultFileToMinio(resultFilePath);
|
||||||
|
|
||||||
|
InputStream minioInputStream = dataFileService.getMinioInputStream(fileId);
|
||||||
|
|
||||||
|
// 直接在内存中读取并解析JSON
|
||||||
|
BufferedReader reader = new BufferedReader(new InputStreamReader(minioInputStream));
|
||||||
|
StringBuilder jsonString = new StringBuilder();
|
||||||
|
String line;
|
||||||
|
while ((line = reader.readLine()) != null) {
|
||||||
|
jsonString.append(line);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将JSON字符串转换为JSONObject或其他Java对象,获取表头总数量
|
||||||
|
JSONObject jsonObject = JSONObject.parseObject(jsonString.toString());
|
||||||
|
|
||||||
|
JSONArray averageDataArray = jsonObject.getJSONArray("average_data");
|
||||||
|
// 循环处理averageDataArray
|
||||||
|
|
||||||
|
JSONArray maxValues = null;
|
||||||
|
JSONArray minValues = null;
|
||||||
|
|
||||||
|
for (int i = 0; i < averageDataArray.size(); i++) {
|
||||||
|
JSONObject averageData = averageDataArray.getJSONObject(i);
|
||||||
|
if (averageData.getString("property").equals("最大值")) {
|
||||||
|
// 提取最大值数组,排除property字段
|
||||||
|
maxValues = new JSONArray();
|
||||||
|
// 获取所有字段名
|
||||||
|
java.util.Set<String> keySet = averageData.keySet();
|
||||||
|
for (String key : keySet) {
|
||||||
|
if (!"property".equals(key)) {
|
||||||
|
maxValues.add(averageData.getDoubleValue(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (averageData.getString("property").equals("最小值")) {
|
||||||
|
// 提取最小值数组,排除property字段
|
||||||
|
minValues = new JSONArray();
|
||||||
|
// 获取所有字段名
|
||||||
|
java.util.Set<String> keySet = averageData.keySet();
|
||||||
|
for (String key : keySet) {
|
||||||
|
if (!"property".equals(key)) {
|
||||||
|
minValues.add(averageData.getDoubleValue(key));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将数组转换为JSON字符串
|
||||||
|
String normalizerMaxJson = null;
|
||||||
|
String normalizerMinJson = null;
|
||||||
|
if (maxValues != null) {
|
||||||
|
normalizerMaxJson = maxValues.toJSONString();
|
||||||
|
}
|
||||||
|
if (minValues != null) {
|
||||||
|
normalizerMinJson = minValues.toJSONString();
|
||||||
|
}
|
||||||
|
|
||||||
// 更新训练模型 训练状态
|
// 更新训练模型 训练状态
|
||||||
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId)
|
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId)
|
||||||
.set(TrainingModel::getHandleStatus, "成功")
|
.set(TrainingModel::getHandleStatus, "成功")
|
||||||
.set(TrainingModel::getTrainingDataHandleFileId, fileId)
|
.set(TrainingModel::getTrainingDataHandleFileId, fileId)
|
||||||
|
.set(TrainingModel::getNormalizerMax, normalizerMaxJson)
|
||||||
|
.set(TrainingModel::getNormalizerMin, normalizerMinJson)
|
||||||
.update();
|
.update();
|
||||||
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
|
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
|
||||||
|
|
||||||
@@ -300,12 +392,12 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 根据数据文件路径创建param.json配置文件
|
* 根据数据文件路径创建数据处理的param.json配置文件
|
||||||
*
|
*
|
||||||
* @param dataFilePath 数据文件的完整路径
|
* @param dataFilePath 数据文件的完整路径
|
||||||
* @return param.json的完整路径
|
* @return param.json的完整路径
|
||||||
*/
|
*/
|
||||||
private String createParamJsonFile(String dataFilePath) {
|
private String createDataHadleParamJsonFile(String dataFilePath) {
|
||||||
try {
|
try {
|
||||||
// 提取文件名和目录路径
|
// 提取文件名和目录路径
|
||||||
File dataFile = new File(dataFilePath);
|
File dataFile = new File(dataFilePath);
|
||||||
@@ -313,7 +405,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
String directoryPath = dataFile.getParent();
|
String directoryPath = dataFile.getParent();
|
||||||
|
|
||||||
// 构造param.json文件路径
|
// 构造param.json文件路径
|
||||||
String paramJsonPath = directoryPath + "/param.json";
|
String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
|
||||||
|
|
||||||
// 创建JSON对象
|
// 创建JSON对象
|
||||||
JSONObject paramJson = new JSONObject();
|
JSONObject paramJson = new JSONObject();
|
||||||
@@ -339,9 +431,10 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 下载文件到服务器本地
|
* 下载文件到服务器本地
|
||||||
* @param fileId 文件ID
|
*
|
||||||
|
* @param fileId 文件ID
|
||||||
* @param localDirPath 本地文件夹路径
|
* @param localDirPath 本地文件夹路径
|
||||||
* @return 本地文件的完整路径
|
* @return 本地文件的完整路径 localDirPath / fileName
|
||||||
*/
|
*/
|
||||||
private String downloadToLocalFromMinIO(Integer fileId, String localDirPath) {
|
private String downloadToLocalFromMinIO(Integer fileId, String localDirPath) {
|
||||||
InputStream in = null;
|
InputStream in = null;
|
||||||
@@ -360,13 +453,13 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
String localFilePath = localDirPath + "/" + fileMetadataInfo.getOriginalName();
|
String localFilePath = localDirPath + "/" + fileMetadataInfo.getOriginalName();
|
||||||
// 确保目标目录存在
|
// 确保目标目录存在
|
||||||
File localFile = new File(localDirPath);
|
File localFile = new File(localFilePath);
|
||||||
File parentDir = localFile.getParentFile();
|
File parentDir = localFile.getParentFile();
|
||||||
if (parentDir != null && !parentDir.exists()) {
|
if (parentDir != null && !parentDir.exists()) {
|
||||||
parentDir.mkdirs();
|
parentDir.mkdirs();
|
||||||
}
|
}
|
||||||
|
|
||||||
out = new FileOutputStream(localDirPath);
|
out = new FileOutputStream(localFilePath);
|
||||||
// 将输入流写入本地文件
|
// 将输入流写入本地文件
|
||||||
byte[] buffer = new byte[1024];
|
byte[] buffer = new byte[1024];
|
||||||
int len;
|
int len;
|
||||||
@@ -442,6 +535,37 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SdmResponse setTrainingDataInPutOutPutColumn(SetTrainingDataInPutOutPutColumnReq req) {
|
||||||
|
trainingModelService.lambdaUpdate()
|
||||||
|
.eq(TrainingModel::getId, req.getModelId())
|
||||||
|
.set(TrainingModel::getInputSize, req.getInputColumns().size())
|
||||||
|
.set(TrainingModel::getInputLabel, JSONArray.toJSONString(req.getInputColumns()))
|
||||||
|
.set(TrainingModel::getOutputSize, req.getOutputColumns().size())
|
||||||
|
.set(TrainingModel::getOutputLabel, JSONArray.toJSONString(req.getOutputColumns()))
|
||||||
|
.update();
|
||||||
|
return SdmResponse.success();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SdmResponse getTrainingDataInPutOutPutColumn(Integer modelId) {
|
||||||
|
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
||||||
|
// 从数据库获取inputLabel字段值(JSON字符串)
|
||||||
|
String inputLabelJson = trainingModel.getInputLabel();
|
||||||
|
// 解析JSON字符串为列表
|
||||||
|
List<String> inputLabels = JSONArray.parseArray(inputLabelJson, String.class);
|
||||||
|
|
||||||
|
// 从数据库获取outputLable字段值(JSON字符串)
|
||||||
|
String outputLabelJson = trainingModel.getOutputLabel();
|
||||||
|
// 解析JSON字符串为列表
|
||||||
|
List<String> outputLabels = JSONArray.parseArray(outputLabelJson, String.class);
|
||||||
|
|
||||||
|
JSONObject result = new JSONObject();
|
||||||
|
result.put("inputLabels", inputLabels);
|
||||||
|
result.put("outputLabels", outputLabels);
|
||||||
|
return SdmResponse.success(result);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) {
|
public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) {
|
||||||
// 保存或更新算法参数
|
// 保存或更新算法参数
|
||||||
@@ -469,7 +593,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
String paramJsonPath = createTraningParamJsonFile(algorithmParamReq);
|
String paramJsonPath = createTraningParamJsonFile(algorithmParamReq);
|
||||||
|
|
||||||
// 异步调用python脚本进行训练,并上传结果文件到minio
|
// 异步调用python脚本进行训练,并上传结果文件到minio
|
||||||
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId());
|
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId(), algorithmParamReq.getExportFormat());
|
||||||
|
|
||||||
return SdmResponse.success("提交训练成功");
|
return SdmResponse.success("提交训练成功");
|
||||||
}
|
}
|
||||||
@@ -480,7 +604,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
* @param paramJsonPath 参数文件路径
|
* @param paramJsonPath 参数文件路径
|
||||||
* @param modelId 模型ID
|
* @param modelId 模型ID
|
||||||
*/
|
*/
|
||||||
private void trainModelAsync(String paramJsonPath, Integer modelId) {
|
private void trainModelAsync(String paramJsonPath, Integer modelId, String exportFormat) {
|
||||||
new Thread(() -> {
|
new Thread(() -> {
|
||||||
try {
|
try {
|
||||||
long startTime = System.currentTimeMillis();
|
long startTime = System.currentTimeMillis();
|
||||||
@@ -491,10 +615,13 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
long duration = endTime - startTime;
|
long duration = endTime - startTime;
|
||||||
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
|
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
|
||||||
|
|
||||||
// 将训练生成的生成training.json、training.log文件,上传到minio
|
|
||||||
Integer trainingDataResultFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_JSON_FILE_NAME);
|
// 将训练生成的生成 training.json、training.log、模型文件,上传到minio
|
||||||
Integer trainingDataLogFileId = uploadResultFileToMinio(TRAIN_RESULT_FILE_PATH + TRAINING_LOG_FILE_NAME);
|
Integer trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
|
||||||
log.info("模型训练完成,结果文件: training.json、training.log文件上传成功,文件ID: {},{}", trainingDataResultFileId, trainingDataLogFileId);
|
Integer trainingDataLogFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME);
|
||||||
|
Integer trainingDataExportModelFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat);
|
||||||
|
|
||||||
|
log.info("模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功", trainingDataResultFileId, trainingDataLogFileId, trainingDataExportModelFileId);
|
||||||
|
|
||||||
// 更新模型状态为训练完成
|
// 更新模型状态为训练完成
|
||||||
trainingModelService.lambdaUpdate()
|
trainingModelService.lambdaUpdate()
|
||||||
@@ -504,8 +631,9 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
.set(TrainingModel::getTrainingDuration, duration)
|
.set(TrainingModel::getTrainingDuration, duration)
|
||||||
.set(TrainingModel::getTrainingDataResultFileId, trainingDataResultFileId)
|
.set(TrainingModel::getTrainingDataResultFileId, trainingDataResultFileId)
|
||||||
.set(TrainingModel::getTrainingDataLogFileId, trainingDataLogFileId)
|
.set(TrainingModel::getTrainingDataLogFileId, trainingDataLogFileId)
|
||||||
|
.set(TrainingModel::getTrainingDataExportModelFileId, trainingDataExportModelFileId)
|
||||||
.update();
|
.update();
|
||||||
|
|
||||||
log.info("模型训练完成,模型ID: {}", modelId);
|
log.info("模型训练完成,模型ID: {}", modelId);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.error("模型训练失败,模型ID: {}", modelId, e);
|
log.error("模型训练失败,模型ID: {}", modelId, e);
|
||||||
@@ -520,15 +648,16 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 创建训练参数param.json文件
|
* 创建训练参数param.json文件
|
||||||
|
*
|
||||||
* @param algorithmParamReq 算法参数请求
|
* @param algorithmParamReq 算法参数请求
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
private String createTraningParamJsonFile(AlgorithmParamReq algorithmParamReq) {
|
private String createTraningParamJsonFile(AlgorithmParamReq algorithmParamReq) {
|
||||||
// 根据训练模型id 获取训练模型的训练数据文件id, 并下载到本地 trainingDataFilePath = /home/app/model/trainResult/sample.CSV
|
// 根据训练模型id 获取训练模型的训练数据文件id, 并下载到本地 trainingDataFilePath = /home/app/model/trainResult/sample.CSV
|
||||||
Integer modelId = algorithmParamReq.getModelId();
|
Integer modelId = algorithmParamReq.getModelId();
|
||||||
TrainingModel trainingModel= trainingModelService.getById(modelId);
|
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
||||||
Integer trainingDataFileId = trainingModel.getTrainingDataFileId();
|
Integer trainingDataFileId = trainingModel.getTrainingDataFileId();
|
||||||
String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TRAIN_RESULT_FILE_PATH);
|
String trainingDataFilePath = downloadToLocalFromMinIO(trainingDataFileId, TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 提取文件名和目录路径
|
// 提取文件名和目录路径
|
||||||
@@ -537,7 +666,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
String directoryPath = dataFile.getParent();
|
String directoryPath = dataFile.getParent();
|
||||||
|
|
||||||
// 构造param.json文件路径
|
// 构造param.json文件路径
|
||||||
String paramJsonPath = directoryPath + "/param.json";
|
String paramJsonPath = directoryPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
|
||||||
|
|
||||||
// 创建JSON对象
|
// 创建JSON对象
|
||||||
JSONObject paramJson = new JSONObject();
|
JSONObject paramJson = new JSONObject();
|
||||||
@@ -549,8 +678,8 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
|
|
||||||
// 构造algorithmParam对象
|
// 构造algorithmParam对象
|
||||||
JSONObject algorithmParam = new JSONObject();
|
JSONObject algorithmParam = new JSONObject();
|
||||||
algorithmParam.put("inputSize", algorithmParamReq.getInputSize());
|
algorithmParam.put("inputSize", trainingModel.getInputSize());
|
||||||
algorithmParam.put("outputSize", algorithmParamReq.getOutputSize());
|
algorithmParam.put("outputSize", trainingModel.getOutputSize());
|
||||||
algorithmParam.put("algorithm", algorithmParamReq.getAlgorithm());
|
algorithmParam.put("algorithm", algorithmParamReq.getAlgorithm());
|
||||||
algorithmParam.put("activateFun", algorithmParamReq.getActivateFun());
|
algorithmParam.put("activateFun", algorithmParamReq.getActivateFun());
|
||||||
algorithmParam.put("lossFun", algorithmParamReq.getLossFun());
|
algorithmParam.put("lossFun", algorithmParamReq.getLossFun());
|
||||||
@@ -583,47 +712,58 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SdmResponse getTrainingResult(String modelId) {
|
public SdmResponse getTrainingResult(Integer modelId) {
|
||||||
// 根据modelId 获取训练模型,获取训练模型结果文件id,日志文件id,
|
// 根据modelId 获取训练模型,获取训练模型结果文件id,日志文件id,
|
||||||
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
||||||
if (trainingModel == null) {
|
if (trainingModel == null) {
|
||||||
return SdmResponse.failed("模型不存在");
|
return SdmResponse.failed("模型不存在");
|
||||||
}
|
}
|
||||||
if(trainingModel.getTrainingStatus().equals("失败")){
|
if (trainingModel.getTrainingStatus().equals("失败")) {
|
||||||
return SdmResponse.failed("模型训练失败,请查看日志详情");
|
return SdmResponse.failed("模型训练失败,请查看日志详情");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")){
|
if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) {
|
||||||
return SdmResponse.failed("模型正在训练中,请稍后再查询");
|
return SdmResponse.failed("模型正在训练中,请稍后再查询");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId();
|
Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId();
|
||||||
Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId();
|
Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId();
|
||||||
|
|
||||||
if (trainingDataResultFileId == null || trainingDataLogFileId == null) {
|
if (trainingDataResultFileId == null && trainingDataLogFileId == null) {
|
||||||
return SdmResponse.failed("训练结果文件或日志文件不存在");
|
return SdmResponse.failed("训练结果文件和日志文件都不存在");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 从minio获取结果json文件转换成对象
|
InputStream resultInputStream = null;
|
||||||
InputStream resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
|
InputStream logInputStream = null;
|
||||||
// 从minio获取日志文件转换成对象
|
|
||||||
InputStream logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
|
|
||||||
|
|
||||||
if (resultInputStream == null || logInputStream == null) {
|
// 从minio获取结果json文件转换成对象
|
||||||
|
if (trainingDataResultFileId != null) {
|
||||||
|
resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从minio获取日志文件转换成对象
|
||||||
|
if (trainingDataLogFileId != null) {
|
||||||
|
logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (resultInputStream == null && logInputStream == null) {
|
||||||
return SdmResponse.failed("无法从MinIO获取文件流");
|
return SdmResponse.failed("无法从MinIO获取文件流");
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 处理结果JSON文件
|
|
||||||
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
|
|
||||||
// 处理日志文件
|
|
||||||
String logContent = parseStringFromStream(logInputStream);
|
|
||||||
|
|
||||||
// 构造返回对象
|
|
||||||
JSONObject response = new JSONObject();
|
JSONObject response = new JSONObject();
|
||||||
response.put("result", resultJsonObject);
|
|
||||||
response.put("log", logContent);
|
// 处理结果JSON文件(如果存在)
|
||||||
|
if (resultInputStream != null) {
|
||||||
|
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
|
||||||
|
response.put("result", resultJsonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理日志文件(如果存在)
|
||||||
|
if (logInputStream != null) {
|
||||||
|
String logContent = parseStringFromStream(logInputStream);
|
||||||
|
response.put("log", logContent);
|
||||||
|
}
|
||||||
|
|
||||||
return SdmResponse.success(response);
|
return SdmResponse.success(response);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
@@ -641,6 +781,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 从输入流中解析JSON对象
|
* 从输入流中解析JSON对象
|
||||||
|
*
|
||||||
* @param inputStream 输入流
|
* @param inputStream 输入流
|
||||||
* @return 解析后的JSON对象
|
* @return 解析后的JSON对象
|
||||||
* @throws IOException IO异常
|
* @throws IOException IO异常
|
||||||
@@ -657,6 +798,7 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 从输入流中读取字符串内容
|
* 从输入流中读取字符串内容
|
||||||
|
*
|
||||||
* @param inputStream 输入流
|
* @param inputStream 输入流
|
||||||
* @return 字符串内容
|
* @return 字符串内容
|
||||||
* @throws IOException IO异常
|
* @throws IOException IO异常
|
||||||
@@ -672,7 +814,158 @@ public class ModelServiceImpl implements IModelService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public SdmResponse predict(ModelPredictReq modelPredictReq) {
|
public SdmResponse startPredict(ModelPredictReq modelPredictReq) {
|
||||||
return null;
|
JSONObject predResultObject = new JSONObject();
|
||||||
|
|
||||||
|
Integer modelId = modelPredictReq.getModelId();
|
||||||
|
// 将jsonData存储到与modelId关联的模型表的inputLabel字段中
|
||||||
|
trainingModelService.lambdaUpdate()
|
||||||
|
.eq(TrainingModel::getId, modelPredictReq.getModelId())
|
||||||
|
.set(TrainingModel::getInputPredLabelValue, modelPredictReq.getJsonData())
|
||||||
|
.update();
|
||||||
|
TrainingModel trainingModel = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one();
|
||||||
|
TrainingModelAlgorithmParam trainingModelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery().eq(TrainingModelAlgorithmParam::getModelId, modelId).one();
|
||||||
|
|
||||||
|
String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel.getId() + "/" + PRED_RESULT_DIR_PATH;
|
||||||
|
// 创建预测用的param.json文件
|
||||||
|
String predParamJsonFile = createPredParamJsonFile(modelPredictReq, trainingModel, trainingModelAlgorithmParam,predDirPath);
|
||||||
|
|
||||||
|
// 执行预测
|
||||||
|
try {
|
||||||
|
long startTime = System.currentTimeMillis();
|
||||||
|
log.info("开始执行Python预测脚本, 模型ID: {}", modelId);
|
||||||
|
|
||||||
|
// 调用python脚本进行预测
|
||||||
|
callPythonScript(PRED_PYTHON_SCRIPTPATH, predParamJsonFile);
|
||||||
|
|
||||||
|
long endTime = System.currentTimeMillis();
|
||||||
|
long duration = endTime - startTime;
|
||||||
|
log.info("Python预测脚本执行完成, 模型ID: {}, 耗时: {} ms", modelId, duration);
|
||||||
|
|
||||||
|
// 读取预测生成的结果文件并保存到数据库
|
||||||
|
String predResultPath = predDirPath + "/" + PRED_RESULT_FILE_NAME;
|
||||||
|
File predResultFile = new File(predResultPath);
|
||||||
|
|
||||||
|
if (predResultFile.exists()) {
|
||||||
|
try {
|
||||||
|
// 读取预测结果文件内容
|
||||||
|
String predResultJson = new String(Files.readAllBytes(predResultFile.toPath()));
|
||||||
|
predResultObject = JSONObject.parseObject(predResultJson);
|
||||||
|
|
||||||
|
// 将预测结果保存到数据库
|
||||||
|
trainingModelService.lambdaUpdate()
|
||||||
|
.eq(TrainingModel::getId, modelId)
|
||||||
|
.set(TrainingModel::getOutputPredLabelValue, predResultObject.toJSONString())
|
||||||
|
.update();
|
||||||
|
log.info("模型预测结果已保存到数据库,模型ID: {}", modelId);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("解析或保存预测结果失败,模型ID: {}", modelId, e);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.warn("预测结果文件不存在: {}", predResultPath);
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("模型预测完成,模型ID: {}", modelId);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("模型预测失败,模型ID: {}", modelId, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
return SdmResponse.success(predResultObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建预测参数param.json文件
|
||||||
|
*
|
||||||
|
* @param modelPredictReq 预测请求参数
|
||||||
|
* @param trainingModel 训练模型
|
||||||
|
* @param trainingModelAlgorithmParam 训练模型算法参数
|
||||||
|
* @param predDirPath 预测目录路径
|
||||||
|
* @return param.json文件路径
|
||||||
|
*/
|
||||||
|
private String createPredParamJsonFile(ModelPredictReq modelPredictReq, TrainingModel trainingModel, TrainingModelAlgorithmParam trainingModelAlgorithmParam,String predDirPath) {
|
||||||
|
try {
|
||||||
|
// 确保目录存在
|
||||||
|
File predDir = new File(predDirPath);
|
||||||
|
if (!predDir.exists()) {
|
||||||
|
predDir.mkdirs();
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造param.json文件路径
|
||||||
|
String paramJsonPath = predDirPath + "/" + PYTHON_SCRIPT_PARAM_FILE_NAME;
|
||||||
|
|
||||||
|
// 创建JSON对象
|
||||||
|
JSONObject paramJson = new JSONObject();
|
||||||
|
|
||||||
|
// 下载model模型文件到本地
|
||||||
|
downloadToLocalFromMinIO(trainingModel.getTrainingDataExportModelFileId(), predDirPath);
|
||||||
|
paramJson.put("modelFile", TRAINING_MODEL_PERFIX_NAME + trainingModelAlgorithmParam.getExportFormat());
|
||||||
|
|
||||||
|
// 设置路径
|
||||||
|
paramJson.put("path", predDirPath);
|
||||||
|
|
||||||
|
// 添加模型参数
|
||||||
|
JSONObject modelParams = new JSONObject();
|
||||||
|
modelParams.put("inputSize", trainingModel.getInputSize());
|
||||||
|
modelParams.put("outputSize", trainingModel.getOutputSize());
|
||||||
|
modelParams.put("normalizerType", trainingModelAlgorithmParam.getDisposeMethod());
|
||||||
|
|
||||||
|
// 添加归一化参数
|
||||||
|
if (trainingModel.getNormalizerMax() != null) {
|
||||||
|
JSONArray normalizerMaxArray = JSONArray.parseArray(trainingModel.getNormalizerMax());
|
||||||
|
modelParams.put("normalizerMax", normalizerMaxArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (trainingModel.getNormalizerMin() != null) {
|
||||||
|
JSONArray normalizerMinArray = JSONArray.parseArray(trainingModel.getNormalizerMin());
|
||||||
|
modelParams.put("normalizerMin", normalizerMinArray);
|
||||||
|
}
|
||||||
|
|
||||||
|
paramJson.put("modelParams", modelParams);
|
||||||
|
|
||||||
|
// 添加输入数据
|
||||||
|
// 直接使用传入的JSON数组作为input值
|
||||||
|
JSONArray inputArray = JSONArray.parseArray(modelPredictReq.getJsonData());
|
||||||
|
paramJson.put("input", inputArray);
|
||||||
|
|
||||||
|
// 添加输出标签
|
||||||
|
JSONObject output = new JSONObject();
|
||||||
|
if (trainingModel.getOutputLabel() != null) {
|
||||||
|
JSONArray outputLabels = JSONArray.parseArray(trainingModel.getOutputLabel());
|
||||||
|
output.put("names", outputLabels);
|
||||||
|
}
|
||||||
|
paramJson.put("output", output);
|
||||||
|
|
||||||
|
// 写入param.json文件
|
||||||
|
try (FileWriter fileWriter = new FileWriter(paramJsonPath)) {
|
||||||
|
fileWriter.write(paramJson.toJSONString());
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("成功创建预测param.json文件: {}", paramJsonPath);
|
||||||
|
return paramJsonPath;
|
||||||
|
} catch (IOException e) {
|
||||||
|
log.error("创建预测param.json文件失败", e);
|
||||||
|
throw new RuntimeException("创建预测param.json文件失败: " + e.getMessage(), e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public SdmResponse getModelPredictResult(Integer modelId) {
|
||||||
|
JSONObject result = new JSONObject();
|
||||||
|
TrainingModel model = trainingModelService.lambdaQuery().eq(TrainingModel::getId, modelId).one();
|
||||||
|
if (model == null) {
|
||||||
|
return SdmResponse.failed("模型不存在");
|
||||||
|
}
|
||||||
|
String inputPredLabelValue = model.getInputPredLabelValue();
|
||||||
|
if (inputPredLabelValue != null) {
|
||||||
|
JSONObject inputPredLabelValueJson = JSONObject.parseObject(inputPredLabelValue);
|
||||||
|
result.put("inputPredLabelValue", inputPredLabelValueJson);
|
||||||
|
}
|
||||||
|
String outputPredLabelValue = model.getOutputPredLabelValue();
|
||||||
|
if (outputPredLabelValue != null) {
|
||||||
|
JSONObject outputPredLabelValueJson = JSONObject.parseObject(outputPredLabelValue);
|
||||||
|
result.put("inputPredLabelValue", outputPredLabelValueJson);
|
||||||
|
}
|
||||||
|
|
||||||
|
return SdmResponse.success(result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user