训练模型

This commit is contained in:
2025-10-24 16:50:40 +08:00
parent cfa09a1178
commit 431870d579
24 changed files with 475 additions and 104 deletions

View File

@@ -5,6 +5,7 @@ import com.sdm.data.model.req.*;
import com.sdm.data.service.IModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -14,7 +15,8 @@ import java.io.IOException;
* 模型训练控制器
*/
@RestController
@RequestMapping("/modelTraning")
@RequestMapping("/modelTraining")
@Tag(name = "模型训练")
public class ModelTraningController {
@Autowired
private IModelService modelService;
@@ -28,6 +30,41 @@ public class ModelTraningController {
return modelService.addModel(addModelReq);
}
/**
* 删除模型
*/
@GetMapping("/deleteModel")
@Operation(summary = "删除模型", description = "删除模型")
public SdmResponse deleteModel(@RequestParam Integer modelId) {
return modelService.deleteModel(modelId);
}
/**
* 获取模型列表
*/
@PostMapping("/getModelList")
@Operation(summary = "获取模型列表", description = "获取模型列表")
public SdmResponse getModelList(@RequestBody BaseReq baseReq) {
return modelService.getModelList(baseReq);
}
/**
* 获取模型详情
*/
@GetMapping("/getModelDetail")
@Operation(summary = "获取模型详情", description = "获取模型详情")
public SdmResponse getModelDetail(@RequestParam Integer modelId) {
return modelService.getModelDetail(modelId);
}
/**
* 修改模型
*/
@PostMapping("/updateModel")
@Operation(summary = "修改模型", description = "修改模型")
public SdmResponse updateModel(@RequestBody AddModelReq addModelReq) {
return modelService.updateModel(addModelReq);
}
/**
* 调用python脚本处理导入数据
*
@@ -74,6 +111,15 @@ public class ModelTraningController {
return modelService.submitTraining(algorithmParamReq);
}
/**
* 停止模型训练
*/
@GetMapping("/stopTraining")
@Operation(summary = "停止模型训练", description = "停止模型训练")
public SdmResponse stopTraining(@RequestParam Integer modelId) {
return modelService.stopTraining(modelId);
}
/**
* 获取训练曲线和训练日志
*

View File

@@ -8,6 +8,9 @@ import java.time.LocalDateTime;
@Data
public class AddModelReq {
@Schema(description = "模型ID")
private Integer modelId;
@Schema(description = "模型名称")
private String modelName;

View File

@@ -12,6 +12,26 @@ public interface IModelService {
*/
SdmResponse addModel(AddModelReq addModelReq);
/**
* 删除模型
*/
SdmResponse deleteModel(Integer modelId);
/**
* 获取模型列表
*/
SdmResponse getModelList(BaseReq baseReq);
/**
* 获取模型详情
*/
SdmResponse getModelDetail(Integer modelId);
/**
* 修改模型
*/
SdmResponse updateModel(AddModelReq addModelReq);
/**
* 调用python脚本处理导入数据
*/
@@ -38,6 +58,13 @@ public interface IModelService {
*/
SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq);
/**
* 停止模型训练
* @param modelId 模型ID
* @return 模型训练结果
*/
SdmResponse stopTraining(Integer modelId);
/**
* 模型训练结果查询
* @param modelId 模型ID

View File

@@ -2,7 +2,10 @@ package com.sdm.data.service.impl;
import com.alibaba.fastjson2.JSONArray;
import com.alibaba.fastjson2.JSONObject;
import com.github.pagehelper.PageHelper;
import com.github.pagehelper.PageInfo;
import com.sdm.common.common.SdmResponse;
import com.sdm.common.utils.PageUtils;
import com.sdm.data.model.entity.FileMetadataInfo;
import com.sdm.data.model.entity.TrainingModel;
import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
@@ -22,7 +25,9 @@ import java.io.OutputStream;
import java.nio.file.Files;
import java.util.Date;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
@Service
@Slf4j
@@ -100,6 +105,8 @@ public class ModelServiceImpl implements IModelService {
*/
private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = "param.json";
// 用于存储正在运行的训练进程
private final Map<Integer, Process> runningProcesses = new ConcurrentHashMap<>();
@Autowired
IDataFileService dataFileService;
@@ -133,6 +140,41 @@ public class ModelServiceImpl implements IModelService {
}
}
@Override
@Transactional(rollbackFor = Exception.class)
public SdmResponse deleteModel(Integer modelId) {
// 删除 training_model 和 training_model_algorithm_param 表数据
trainingModelService.removeById(modelId);
trainingModelAlgorithmParamService.lambdaUpdate().eq(TrainingModelAlgorithmParam::getModelId, modelId).remove();
return SdmResponse.success("删除训练模型成功");
}
@Override
public SdmResponse getModelList(BaseReq baseReq) {
PageHelper.startPage(baseReq.getCurrent(), baseReq.getSize());
List<TrainingModel> models = trainingModelService.list();
PageInfo<TrainingModel> page = new PageInfo<>(models);
return PageUtils.getJsonObjectSdmResponse(models,page);
}
@Override
public SdmResponse getModelDetail(Integer modelId) {
TrainingModel model = trainingModelService.getById(modelId);
if (model != null) {
return SdmResponse.success(model);
}
return SdmResponse.success();
}
@Override
public SdmResponse updateModel(AddModelReq addModelReq) {
trainingModelService.lambdaUpdate()
.set(TrainingModel::getModelName, addModelReq.getModelName())
.set(TrainingModel::getDescription, addModelReq.getDescription())
.eq(TrainingModel::getId, addModelReq.getModelId()).update();
return SdmResponse.success("更新训练模型成功");
}
@Override
public SdmResponse<String> handleLoadData(HandleLoadDataReq handleLoadDataReq) {
try {
@@ -489,50 +531,49 @@ public class ModelServiceImpl implements IModelService {
}
}
private void callPythonScript(String pythonScriptPath, String paramJsonPath) {
try {
// 直接拼接完整命令(注意空格分隔)
String command = String.format(
"python %s --param %s",
pythonScriptPath,
paramJsonPath
);
private Process callPythonScript(String pythonScriptPath, String paramJsonPath) throws Exception {
// 直接拼接完整命令(注意空格分隔)
String command = String.format(
"python %s --param %s",
pythonScriptPath,
paramJsonPath
);
// 打印执行的命令(便于调试)
log.info("执行的Python命令: {}", command);
// 打印执行的命令(便于调试)
log.info("执行的Python命令: {}", command);
// 使用Runtime执行命令
Process process = Runtime.getRuntime().exec(command);
// 读取输出流
// 使用Runtime执行命令
Process process = Runtime.getRuntime().exec(command);
// 异步读取输出流,避免阻塞
Thread outputThread = new Thread(() -> {
try (BufferedReader reader = new BufferedReader(
new InputStreamReader(process.getInputStream()))) {
String line;
while ((line = reader.readLine()) != null) {
log.info("Python脚本输出: {}", line);
}
} catch (IOException e) {
log.error("读取Python脚本输出异常", e);
}
// 读取错误流(单独处理错误输出,避免阻塞)
});
outputThread.start();
// 异步读取错误流,避免阻塞
Thread errorThread = new Thread(() -> {
try (BufferedReader errorReader = new BufferedReader(
new InputStreamReader(process.getErrorStream()))) {
String errorLine;
while ((errorLine = errorReader.readLine()) != null) {
log.error("Python脚本错误输出: {}", errorLine);
}
} catch (IOException e) {
log.error("读取Python脚本错误输出异常", e);
}
});
errorThread.start();
// 等待执行完成
int exitCode = process.waitFor();
if (exitCode == 0) {
log.info("Python 脚本执行成功");
} else {
log.error("Python 脚本执行失败,退出码: {}", exitCode);
}
} catch (Exception e) {
log.error("调用Python脚本失败", e);
throw new RuntimeException(e);
}
return process;
}
@Override
@@ -606,27 +647,63 @@ public class ModelServiceImpl implements IModelService {
*/
private void trainModelAsync(String paramJsonPath, Integer modelId, String exportFormat) {
new Thread(() -> {
Process process = null;
try {
long startTime = System.currentTimeMillis();
log.info("开始执行Python脚本, 训练模型ID: {}", modelId);
// 调用python脚本进行训练
callPythonScript(TRAINING_PYTHON_SCRIPTPATH, paramJsonPath);
process = callPythonScript(TRAINING_PYTHON_SCRIPTPATH, paramJsonPath);
// 将进程添加到运行进程映射中
runningProcesses.put(modelId, process);
// 等待进程执行完成
int exitCode = process.waitFor();
long endTime = System.currentTimeMillis();
long duration = endTime - startTime;
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
// 根据退出码判断具体执行结果
if (exitCode == 0) {
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
} else if (exitCode == 137) {
// 退出码137表示进程被SIGKILL信号强制终止如调用destroyForcibly()
log.warn("Python脚本执行被强制终止退出码: {}, 训练模型ID: {}, 耗时: {} ms", exitCode, modelId, duration);
} else {
log.error("Python脚本执行失败退出码: {}, 训练模型ID: {}, 耗时: {} ms", exitCode, modelId, duration);
}
// 从运行进程映射中移除已完成的进程
runningProcesses.remove(modelId);
// 将训练生成的生成 training.json、training.log、模型文件上传到minio
Integer trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
Integer trainingDataLogFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME);
Integer trainingDataExportModelFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat);
// 只有在训练成功时才上传结果文件到minio
Integer trainingDataResultFileId = null;
Integer trainingDataLogFileId = null;
Integer trainingDataExportModelFileId = null;
if (exitCode == 0) {
// 将训练生成的生成 training.json、training.log、模型文件上传到minio
trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
trainingDataLogFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME);
trainingDataExportModelFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat);
log.info("模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功", trainingDataResultFileId, trainingDataLogFileId, trainingDataExportModelFileId);
log.info("模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功", trainingDataResultFileId, trainingDataLogFileId, trainingDataExportModelFileId);
}
// 根据退出码设置训练状态
String trainingStatus;
if (exitCode == 0) {
trainingStatus = "成功";
} else if (exitCode == 137) {
trainingStatus = "已终止";
} else {
trainingStatus = "失败";
}
// 更新模型状态为训练完成
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
.set(TrainingModel::getTrainingStatus, "成功")
.set(TrainingModel::getTrainingStatus, trainingStatus)
.set(TrainingModel::getTrainingTime, new Date())
.set(TrainingModel::getTrainingDuration, duration)
.set(TrainingModel::getTrainingDataResultFileId, trainingDataResultFileId)
@@ -636,6 +713,9 @@ public class ModelServiceImpl implements IModelService {
log.info("模型训练完成模型ID: {}", modelId);
} catch (Exception e) {
// 从运行进程映射中移除失败的进程
runningProcesses.remove(modelId);
log.error("模型训练失败模型ID: {}", modelId, e);
// 更新模型状态为训练失败
trainingModelService.lambdaUpdate()
@@ -718,8 +798,8 @@ public class ModelServiceImpl implements IModelService {
if (trainingModel == null) {
return SdmResponse.failed("模型不存在");
}
if (trainingModel.getTrainingStatus().equals("失败")) {
return SdmResponse.failed("模型训练失败,请查看日志详情");
if (trainingModel.getTrainingStatus().equals("失败") || trainingModel.getTrainingStatus().equals("已终止") ) {
return SdmResponse.failed("模型训练失败或已终止,请查看日志详情");
}
if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) {
@@ -828,8 +908,8 @@ public class ModelServiceImpl implements IModelService {
String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel.getId() + "/" + PRED_RESULT_DIR_PATH;
// 创建预测用的param.json文件
String predParamJsonFile = createPredParamJsonFile(modelPredictReq, trainingModel, trainingModelAlgorithmParam,predDirPath);
String predParamJsonFile = createPredParamJsonFile(modelPredictReq, trainingModel, trainingModelAlgorithmParam, predDirPath);
// 执行预测
try {
long startTime = System.currentTimeMillis();
@@ -845,13 +925,13 @@ public class ModelServiceImpl implements IModelService {
// 读取预测生成的结果文件并保存到数据库
String predResultPath = predDirPath + "/" + PRED_RESULT_FILE_NAME;
File predResultFile = new File(predResultPath);
if (predResultFile.exists()) {
try {
// 读取预测结果文件内容
String predResultJson = new String(Files.readAllBytes(predResultFile.toPath()));
predResultObject = JSONObject.parseObject(predResultJson);
// 将预测结果保存到数据库
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
@@ -869,20 +949,20 @@ public class ModelServiceImpl implements IModelService {
} catch (Exception e) {
log.error("模型预测失败模型ID: {}", modelId, e);
}
return SdmResponse.success(predResultObject);
}
/**
* 创建预测参数param.json文件
*
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param trainingModelAlgorithmParam 训练模型算法参数
* @param predDirPath 预测目录路径
* @param predDirPath 预测目录路径
* @return param.json文件路径
*/
private String createPredParamJsonFile(ModelPredictReq modelPredictReq, TrainingModel trainingModel, TrainingModelAlgorithmParam trainingModelAlgorithmParam,String predDirPath) {
private String createPredParamJsonFile(ModelPredictReq modelPredictReq, TrainingModel trainingModel, TrainingModelAlgorithmParam trainingModelAlgorithmParam, String predDirPath) {
try {
// 确保目录存在
File predDir = new File(predDirPath);
@@ -908,18 +988,18 @@ public class ModelServiceImpl implements IModelService {
modelParams.put("inputSize", trainingModel.getInputSize());
modelParams.put("outputSize", trainingModel.getOutputSize());
modelParams.put("normalizerType", trainingModelAlgorithmParam.getDisposeMethod());
// 添加归一化参数
if (trainingModel.getNormalizerMax() != null) {
JSONArray normalizerMaxArray = JSONArray.parseArray(trainingModel.getNormalizerMax());
modelParams.put("normalizerMax", normalizerMaxArray);
}
if (trainingModel.getNormalizerMin() != null) {
JSONArray normalizerMinArray = JSONArray.parseArray(trainingModel.getNormalizerMin());
modelParams.put("normalizerMin", normalizerMinArray);
}
paramJson.put("modelParams", modelParams);
// 添加输入数据
@@ -948,6 +1028,34 @@ public class ModelServiceImpl implements IModelService {
}
}
/**
* 终止指定模型的训练进程
*
* @param modelId 模型ID
* @return 是否成功终止
*/
public SdmResponse stopTraining(Integer modelId) {
Process process = runningProcesses.get(modelId);
if (process != null) {
// 强制终止进程
process.destroyForcibly();
// 从映射中移除
runningProcesses.remove(modelId);
// 更新数据库状态
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
.set(TrainingModel::getTrainingStatus, "已终止")
.update();
log.info("模型训练已终止模型ID: {}", modelId);
} else {
log.warn("未找到正在运行的模型训练进程模型ID: {}", modelId);
}
return SdmResponse.success();
}
@Override
public SdmResponse getModelPredictResult(Integer modelId) {
JSONObject result = new JSONObject();