训练模型
This commit is contained in:
@@ -111,6 +111,15 @@ public class ModelTraningController {
|
||||
return modelService.submitTraining(algorithmParamReq);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取算法参数设置详情
|
||||
*/
|
||||
@GetMapping("/getAlgorithmParam")
|
||||
@Operation(summary = "获取算法参数设置详情", description = "获取算法参数设置详情")
|
||||
public SdmResponse getAlgorithmParam(@RequestParam Integer modelId) {
|
||||
return modelService.getAlgorithmParam(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 停止模型训练
|
||||
*/
|
||||
|
||||
@@ -58,6 +58,11 @@ public interface IModelService {
|
||||
*/
|
||||
SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq);
|
||||
|
||||
/**
|
||||
* 获取算法参数设置详情
|
||||
*/
|
||||
SdmResponse getAlgorithmParam(Integer modelId);
|
||||
|
||||
/**
|
||||
* 停止模型训练
|
||||
* @param modelId 模型ID
|
||||
|
||||
@@ -157,7 +157,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
PageHelper.startPage(baseReq.getCurrent(), baseReq.getSize());
|
||||
List<TrainingModel> models = trainingModelService.list();
|
||||
PageInfo<TrainingModel> page = new PageInfo<>(models);
|
||||
return PageUtils.getJsonObjectSdmResponse(models,page);
|
||||
return PageUtils.getJsonObjectSdmResponse(models, page);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -207,7 +207,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "处理中").update();
|
||||
|
||||
// 异步调用Python脚本处理数据
|
||||
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(),handleLoadDataReq.getTrainingModelId());
|
||||
processDataAsync(DATA_HANDLER_PYTHON_SCRIPT_PATH, paramJsonPath, trainingModelId, handleLoadDataReq.getUserId(), handleLoadDataReq.getTrainingModelId());
|
||||
return SdmResponse.success("数据处理中");
|
||||
} catch (Exception e) {
|
||||
log.error("处理上传数据失败", e);
|
||||
@@ -268,8 +268,8 @@ public class ModelServiceImpl implements IModelService {
|
||||
*
|
||||
* @param pythonScriptPath Python脚本路径
|
||||
* @param paramJsonPath 参数文件路径
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
* @param userId 用户ID
|
||||
* @param modelId 模型ID
|
||||
*/
|
||||
private void processDataAsync(String pythonScriptPath, String paramJsonPath, Integer trainingModelId, Integer userId, Integer modelId) {
|
||||
new Thread(() -> {
|
||||
@@ -350,12 +350,12 @@ public class ModelServiceImpl implements IModelService {
|
||||
log.info("训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms", trainingModelId, fileId, duration);
|
||||
|
||||
// 通过WebSocket通知前端处理完成
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId,true, "数据处理完成");
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId, true, "数据处理完成");
|
||||
} catch (Exception e) {
|
||||
trainingModelService.lambdaUpdate().eq(TrainingModel::getId, trainingModelId).set(TrainingModel::getHandleStatus, "失败").update();
|
||||
log.error("异步处理数据失败", e);
|
||||
// 通过WebSocket通知前端处理失败
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId,false, "数据处理失败: " + e.getMessage());
|
||||
webSocketService.sendDataProcessingNotification(userId, modelId, false, "数据处理失败: " + e.getMessage());
|
||||
}
|
||||
}).start();
|
||||
}
|
||||
@@ -543,7 +543,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
|
||||
// 使用Runtime执行命令
|
||||
Process process = Runtime.getRuntime().exec(command);
|
||||
|
||||
|
||||
// 异步读取输出流,避免阻塞
|
||||
Thread outputThread = new Thread(() -> {
|
||||
try (BufferedReader reader = new BufferedReader(
|
||||
@@ -557,7 +557,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
}
|
||||
});
|
||||
outputThread.start();
|
||||
|
||||
|
||||
// 异步读取错误流,避免阻塞
|
||||
Thread errorThread = new Thread(() -> {
|
||||
try (BufferedReader errorReader = new BufferedReader(
|
||||
@@ -652,16 +652,16 @@ public class ModelServiceImpl implements IModelService {
|
||||
log.info("开始执行Python脚本, 训练模型ID: {}", modelId);
|
||||
// 调用python脚本进行训练
|
||||
process = callPythonScript(TRAINING_PYTHON_SCRIPTPATH, paramJsonPath);
|
||||
|
||||
|
||||
// 将进程添加到运行进程映射中
|
||||
runningProcesses.put(modelId, process);
|
||||
|
||||
|
||||
// 等待进程执行完成
|
||||
int exitCode = process.waitFor();
|
||||
|
||||
|
||||
long endTime = System.currentTimeMillis();
|
||||
long duration = endTime - startTime;
|
||||
|
||||
|
||||
// 根据退出码判断具体执行结果
|
||||
if (exitCode == 0) {
|
||||
log.info("Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms", modelId, duration);
|
||||
@@ -679,7 +679,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
Integer trainingDataResultFileId = null;
|
||||
Integer trainingDataLogFileId = null;
|
||||
Integer trainingDataExportModelFileId = null;
|
||||
|
||||
|
||||
if (exitCode == 0) {
|
||||
// 将训练生成的生成 training.json、training.log、模型文件,上传到minio
|
||||
trainingDataResultFileId = uploadResultFileToMinio(TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + "/" + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME);
|
||||
@@ -698,7 +698,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
} else {
|
||||
trainingStatus = "失败";
|
||||
}
|
||||
|
||||
|
||||
// 更新模型状态为训练完成
|
||||
trainingModelService.lambdaUpdate()
|
||||
.eq(TrainingModel::getId, modelId)
|
||||
@@ -714,7 +714,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
} catch (Exception e) {
|
||||
// 从运行进程映射中移除失败的进程
|
||||
runningProcesses.remove(modelId);
|
||||
|
||||
|
||||
log.error("模型训练失败,模型ID: {}", modelId, e);
|
||||
// 更新模型状态为训练失败
|
||||
trainingModelService.lambdaUpdate()
|
||||
@@ -797,8 +797,12 @@ public class ModelServiceImpl implements IModelService {
|
||||
if (trainingModel == null) {
|
||||
return SdmResponse.failed("模型不存在");
|
||||
}
|
||||
if (trainingModel.getTrainingStatus().equals("失败") || trainingModel.getTrainingStatus().equals("已终止") ) {
|
||||
return SdmResponse.failed("模型训练失败或已终止,请查看日志详情");
|
||||
if (trainingModel.getTrainingStatus().equals("已终止")) {
|
||||
return SdmResponse.failed("模型训练已终止,请重新提交训练");
|
||||
}
|
||||
|
||||
if (trainingModel.getTrainingStatus().equals("失败")) {
|
||||
return SdmResponse.failed("模型训练失败,请查看日志详情");
|
||||
}
|
||||
|
||||
if (trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")) {
|
||||
@@ -1027,9 +1031,26 @@ public class ModelServiceImpl implements IModelService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取算法参数设置详情
|
||||
*
|
||||
* @param modelId 模型ID
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SdmResponse getAlgorithmParam(Integer modelId) {
|
||||
TrainingModelAlgorithmParam modelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery()
|
||||
.eq(TrainingModelAlgorithmParam::getModelId, modelId)
|
||||
.one();
|
||||
if (modelAlgorithmParam == null) {
|
||||
return SdmResponse.failed("未找到模型算法参数设置详情");
|
||||
}
|
||||
return SdmResponse.success(modelAlgorithmParam);
|
||||
}
|
||||
|
||||
/**
|
||||
* 终止指定模型的训练进程
|
||||
*
|
||||
*
|
||||
* @param modelId 模型ID
|
||||
* @return 是否成功终止
|
||||
*/
|
||||
@@ -1038,16 +1059,16 @@ public class ModelServiceImpl implements IModelService {
|
||||
if (process != null) {
|
||||
// 强制终止进程
|
||||
process.destroyForcibly();
|
||||
|
||||
|
||||
// 从映射中移除
|
||||
runningProcesses.remove(modelId);
|
||||
|
||||
|
||||
// 更新数据库状态
|
||||
trainingModelService.lambdaUpdate()
|
||||
.eq(TrainingModel::getId, modelId)
|
||||
.set(TrainingModel::getTrainingStatus, "已终止")
|
||||
.update();
|
||||
|
||||
|
||||
log.info("模型训练已终止,模型ID: {}", modelId);
|
||||
} else {
|
||||
log.warn("未找到正在运行的模型训练进程,模型ID: {}", modelId);
|
||||
|
||||
Reference in New Issue
Block a user