训练模型

This commit is contained in:
2025-10-28 17:08:51 +08:00
parent 2f93f88df1
commit a7e0a1907a
3 changed files with 56 additions and 21 deletions

View File

@@ -111,6 +111,15 @@ public class ModelTraningController {
return modelService.submitTraining(algorithmParamReq);
}
/**
* 获取算法参数设置详情
*/
@GetMapping("/getAlgorithmParam")
@Operation(summary = "获取算法参数设置详情", description = "获取算法参数设置详情")
public SdmResponse getAlgorithmParam(@RequestParam Integer modelId) {
return modelService.getAlgorithmParam(modelId);
}
/**
* 停止模型训练
*/

View File

@@ -58,6 +58,11 @@ public interface IModelService {
*/
SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq);
/**
* 获取算法参数设置详情
*/
SdmResponse getAlgorithmParam(Integer modelId);
/**
* 停止模型训练
* @param modelId 模型ID

View File

@@ -157,7 +157,7 @@ public class ModelServiceImpl implements IModelService {
PageHelper.startPage(baseReq.getCurrent(), baseReq.getSize());
List<TrainingModel> models = trainingModelService.list();
PageInfo<TrainingModel> page = new PageInfo<>(models);
return PageUtils.getJsonObjectSdmResponse(models,page);
return PageUtils.getJsonObjectSdmResponse(models, page);
}
@Override
@@ -207,7 +207,7 @@ public class ModelServiceImpl implements IModelService {
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "处理中").update();
// 异步调用Python脚本处理数据
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(),handleLoadDataReq.getTrainingModelId());
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(), handleLoadDataReq.getTrainingModelId());
return SdmResponse.success("数据处理中");
} catch (Exception e) {
log.error("处理上传数据失败", e);
@@ -268,8 +268,8 @@ public class ModelServiceImpl implements IModelService {
*
* @param pythonScriptPath Python脚本路径
* @param paramJsonPath 参数文件路径
* @param userId 用户ID
* @param modelId 模型ID
* @param userId 用户ID
* @param modelId 模型ID
*/
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, Integer userId, Integer modelId) {
new Thread(() -> {
@@ -350,12 +350,12 @@ public class ModelServiceImpl implements IModelService {
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
// 通过WebSocket通知前端处理完成
webSocketService.sendDataProcessingNotification(userId, modelId,true, "数据处理完成");
webSocketService.sendDataProcessingNotification(userId, modelId, true, "数据处理完成");
} catch (Exception e) {
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "失败").update();
log.error("异步处理数据失败", e);
// 通过WebSocket通知前端处理失败
webSocketService.sendDataProcessingNotification(userId, modelId,false, "数据处理失败: " + e.getMessage());
webSocketService.sendDataProcessingNotification(userId, modelId, false, "数据处理失败: " + e.getMessage());
}
}).start();
}
@@ -543,7 +543,7 @@ public class ModelServiceImpl implements IModelService {
// 使用Runtime执行命令
Process process = Runtime.getRuntime().exec(command);
// 异步读取输出流,避免阻塞
Thread outputThread = new Thread(() -> {
try (BufferedReader reader = new BufferedReader(
@@ -557,7 +557,7 @@ public class ModelServiceImpl implements IModelService {
}
});
outputThread.start();
// 异步读取错误流,避免阻塞
Thread errorThread = new Thread(() -> {
try (BufferedReader errorReader = new BufferedReader(
@@ -652,16 +652,16 @@ public class ModelServiceImpl implements IModelService {
log.info("开始执行Python脚本, 训练模型ID: {}", modelId);
// 调用python脚本进行训练
process = callPythonScript(TRAINING_PYTHON_SCRIPTPATH, paramJsonPath);
// 将进程添加到运行进程映射中
runningProcesses.put(modelId, process);
// 等待进程执行完成
int exitCode = process.waitFor();
long endTime = System.currentTimeMillis();
long duration = endTime - startTime;
// 根据退出码判断具体执行结果
if (exitCode == 0) {
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
@@ -679,7 +679,7 @@ public class ModelServiceImpl implements IModelService {
Integer trainingDataResultFileId = null;
Integer trainingDataLogFileId = null;
Integer trainingDataExportModelFileId = null;
if (exitCode == 0) {
// 将训练生成的生成 training.json、training.log、模型文件上传到minio
trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
@@ -698,7 +698,7 @@ public class ModelServiceImpl implements IModelService {
} else {
trainingStatus = "失败";
}
// 更新模型状态为训练完成
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
@@ -714,7 +714,7 @@ public class ModelServiceImpl implements IModelService {
} catch (Exception e) {
// 从运行进程映射中移除失败的进程
runningProcesses.remove(modelId);
log.error("模型训练失败模型ID: {}", modelId, e);
// 更新模型状态为训练失败
trainingModelService.lambdaUpdate()
@@ -797,8 +797,12 @@ public class ModelServiceImpl implements IModelService {
if (trainingModel == null) {
return SdmResponse.failed("模型不存在");
}
if (trainingModel.getTrainingStatus().equals("失败") || trainingModel.getTrainingStatus().equals("已终止") ) {
return SdmResponse.failed("模型训练失败或已终止,请查看日志详情");
if (trainingModel.getTrainingStatus().equals("已终止")) {
return SdmResponse.failed("模型训练已终止,请重新提交训练");
}
if (trainingModel.getTrainingStatus().equals("失败")) {
return SdmResponse.failed("模型训练失败,请查看日志详情");
}
if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) {
@@ -1027,9 +1031,26 @@ public class ModelServiceImpl implements IModelService {
}
}
/**
* 获取算法参数设置详情
*
* @param modelId 模型ID
* @return
*/
@Override
public SdmResponse getAlgorithmParam(Integer modelId) {
TrainingModelAlgorithmParam modelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery()
.eq(TrainingModelAlgorithmParam::getModelId, modelId)
.one();
if (modelAlgorithmParam == null) {
return SdmResponse.failed("未找到模型算法参数设置详情");
}
return SdmResponse.success(modelAlgorithmParam);
}
/**
* 终止指定模型的训练进程
*
*
* @param modelId 模型ID
* @return 是否成功终止
*/
@@ -1038,16 +1059,16 @@ public class ModelServiceImpl implements IModelService {
if (process != null) {
// 强制终止进程
process.destroyForcibly();
// 从映射中移除
runningProcesses.remove(modelId);
// 更新数据库状态
trainingModelService.lambdaUpdate()
.eq(TrainingModel::getId, modelId)
.set(TrainingModel::getTrainingStatus, "已终止")
.update();
log.info("模型训练已终止模型ID: {}", modelId);
} else {
log.warn("未找到正在运行的模型训练进程模型ID: {}", modelId);