@@ -2,7 +2,10 @@ package com.sdm.data.service.impl;
import com.alibaba.fastjson2.JSONArray ;
import com.alibaba.fastjson2.JSONObject ;
import com.github.pagehelper.PageHelper ;
import com.github.pagehelper.PageInfo ;
import com.sdm.common.common.SdmResponse ;
import com.sdm.common.utils.PageUtils ;
import com.sdm.data.model.entity.FileMetadataInfo ;
import com.sdm.data.model.entity.TrainingModel ;
import com.sdm.data.model.entity.TrainingModelAlgorithmParam ;
@@ -22,7 +25,9 @@ import java.io.OutputStream;
import java.nio.file.Files ;
import java.util.Date ;
import java.util.List ;
import java.util.Map ;
import java.util.Random ;
import java.util.concurrent.ConcurrentHashMap ;
@Service
@Slf4j
@@ -100,6 +105,8 @@ public class ModelServiceImpl implements IModelService {
*/
private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = " param.json " ;
// 用于存储正在运行的训练进程
private final Map < Integer , Process > runningProcesses = new ConcurrentHashMap < > ( ) ;
@Autowired
IDataFileService dataFileService ;
@@ -133,6 +140,41 @@ public class ModelServiceImpl implements IModelService {
}
}
@Override
@Transactional ( rollbackFor = Exception . class )
public SdmResponse deleteModel ( Integer modelId ) {
// 删除 training_model 和 training_model_algorithm_param 表数据
trainingModelService . removeById ( modelId ) ;
trainingModelAlgorithmParamService . lambdaUpdate ( ) . eq ( TrainingModelAlgorithmParam : : getModelId , modelId ) . remove ( ) ;
return SdmResponse . success ( " 删除训练模型成功 " ) ;
}
@Override
public SdmResponse getModelList ( BaseReq baseReq ) {
PageHelper . startPage ( baseReq . getCurrent ( ) , baseReq . getSize ( ) ) ;
List < TrainingModel > models = trainingModelService . list ( ) ;
PageInfo < TrainingModel > page = new PageInfo < > ( models ) ;
return PageUtils . getJsonObjectSdmResponse ( models , page ) ;
}
@Override
public SdmResponse getModelDetail ( Integer modelId ) {
TrainingModel model = trainingModelService . getById ( modelId ) ;
if ( model ! = null ) {
return SdmResponse . success ( model ) ;
}
return SdmResponse . success ( ) ;
}
@Override
public SdmResponse updateModel ( AddModelReq addModelReq ) {
trainingModelService . lambdaUpdate ( )
. set ( TrainingModel : : getModelName , addModelReq . getModelName ( ) )
. set ( TrainingModel : : getDescription , addModelReq . getDescription ( ) )
. eq ( TrainingModel : : getId , addModelReq . getModelId ( ) ) . update ( ) ;
return SdmResponse . success ( " 更新训练模型成功 " ) ;
}
@Override
public SdmResponse < String > handleLoadData ( HandleLoadDataReq handleLoadDataReq ) {
try {
@@ -489,50 +531,49 @@ public class ModelServiceImpl implements IModelService {
}
}
private void callPythonScript ( String pythonScriptPath , String paramJsonPath ) {
try {
// 直接拼接完整命令(注意空格分隔)
String command = String . format (
" python %s --param %s " ,
pythonScript Path ,
paramJsonPath
) ;
private Process callPythonScript ( String pythonScriptPath , String paramJsonPath ) throws Exception {
// 直接拼接完整命令(注意空格分隔)
String command = String . format (
" python %s --param %s " ,
pythonScriptPath ,
paramJson Path
) ;
// 打印执行的命令(便于调试)
log . info ( " 执行的Python命令: {} " , command ) ;
// 打印执行的命令(便于调试)
log . info ( " 执行的Python命令: {} " , command ) ;
// 使用Runtime执行命令
Process process = Runtime . getRuntime ( ) . exec ( command ) ;
// 读取输出流
// 使用Runtime执行命令
Process process = Runtime . getRuntime ( ) . exec ( command ) ;
// 异步 读取输出流,避免阻塞
Thread outputThread = new Thread ( ( ) - > {
try ( BufferedReader reader = new BufferedReader (
new InputStreamReader ( process . getInputStream ( ) ) ) ) {
String line ;
while ( ( line = reader . readLine ( ) ) ! = null ) {
log . info ( " Python脚本输出: {} " , line ) ;
}
} catch ( IOException e ) {
log . error ( " 读取Python脚本输出异常 " , e ) ;
}
// 读取错误流(单独处理错误输出,避免阻塞)
} ) ;
outputThread . start ( ) ;
// 异步读取错误流,避免阻塞
Thread errorThread = new Thread ( ( ) - > {
try ( BufferedReader errorReader = new BufferedReader (
new InputStreamReader ( process . getErrorStream ( ) ) ) ) {
String errorLine ;
while ( ( errorLine = errorReader . readLine ( ) ) ! = null ) {
log . error ( " Python脚本错误输出: {} " , errorLine ) ;
}
} catch ( IOException e ) {
log . error ( " 读取Python脚本错误输出异常 " , e ) ;
}
} ) ;
errorThread . start ( ) ;
// 等待执行完成
int exitCode = process . waitFor ( ) ;
if ( exitCode = = 0 ) {
log . info ( " Python 脚本执行成功 " ) ;
} else {
log . error ( " Python 脚本执行失败,退出码: {} " , exitCode ) ;
}
} catch ( Exception e ) {
log . error ( " 调用Python脚本失败 " , e ) ;
throw new RuntimeException ( e ) ;
}
return process ;
}
@Override
@@ -606,27 +647,63 @@ public class ModelServiceImpl implements IModelService {
*/
private void trainModelAsync ( String paramJsonPath , Integer modelId , String exportFormat ) {
new Thread ( ( ) - > {
Process process = null ;
try {
long startTime = System . currentTimeMillis ( ) ;
log . info ( " 开始执行Python脚本, 训练模型ID: {} " , modelId ) ;
// 调用python脚本进行训练
callPythonScript ( TRAINING_PYTHON_SCRIPTPATH , paramJsonPath ) ;
process = callPythonScript( TRAINING_PYTHON_SCRIPTPATH , paramJsonPath ) ;
// 将进程添加到运行进程映射中
runningProcesses . put ( modelId , process ) ;
// 等待进程执行完成
int exitCode = process . waitFor ( ) ;
long endTime = System . currentTimeMillis ( ) ;
long duration = endTime - startTime ;
log . info ( " Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms " , modelId , duration ) ;
// 根据退出码判断具体执行结果
if ( exitCode = = 0 ) {
log . info ( " Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms " , modelId , duration ) ;
} else if ( exitCode = = 137 ) {
// 退出码137表示进程被SIGKILL信号强制终止( 如调用destroyForcibly())
log . warn ( " Python脚本执行被强制终止, 退出码: {}, 训练模型ID: {}, 耗时: {} ms " , exitCode , modelId , duration ) ;
} else {
log . error ( " Python脚本执行失败, 退出码: {}, 训练模型ID: {}, 耗时: {} ms " , exitCode , modelId , duration ) ;
}
// 从运行进程映射中移除已完成的进程
runningProcesses . remove ( modelId ) ;
// 将训练生成的生成 training.json、training.log、模型文件, 上传 到minio
Integer trainingDataResultFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME ) ;
Integer trainingDataLogFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME ) ;
Integer trainingDataExportModelFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat ) ;
// 只有在训练成功时才上传结果文件 到minio
Integer trainingDataResultFileId = null ;
Integer trainingDataLogFileId = null ;
Integer trainingDataExportModelFileId = null ;
if ( exitCode = = 0 ) {
// 将训练生成的生成 training.json、training.log、模型文件, 上传到minio
trainingDataResultFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_JSON_FILE_NAME ) ;
trainingDataLogFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME ) ;
trainingDataExportModelFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat ) ;
log . info ( " 模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功 " , trainingDataResultFileId , trainingDataLogFileId , trainingDataExportModelFileId ) ;
log . info ( " 模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功 " , trainingDataResultFileId , trainingDataLogFileId , trainingDataExportModelFileId ) ;
}
// 根据退出码设置训练状态
String trainingStatus ;
if ( exitCode = = 0 ) {
trainingStatus = " 成功 " ;
} else if ( exitCode = = 137 ) {
trainingStatus = " 已终止 " ;
} else {
trainingStatus = " 失败 " ;
}
// 更新模型状态为训练完成
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , modelId )
. set ( TrainingModel : : getTrainingStatus , " 成功 " )
. set ( TrainingModel : : getTrainingStatus , trainingStatus )
. set ( TrainingModel : : getTrainingTime , new Date ( ) )
. set ( TrainingModel : : getTrainingDuration , duration )
. set ( TrainingModel : : getTrainingDataResultFileId , trainingDataResultFileId )
@@ -636,6 +713,9 @@ public class ModelServiceImpl implements IModelService {
log . info ( " 模型训练完成, 模型ID: {} " , modelId ) ;
} catch ( Exception e ) {
// 从运行进程映射中移除失败的进程
runningProcesses . remove ( modelId ) ;
log . error ( " 模型训练失败, 模型ID: {} " , modelId , e ) ;
// 更新模型状态为训练失败
trainingModelService . lambdaUpdate ( )
@@ -718,8 +798,8 @@ public class ModelServiceImpl implements IModelService {
if ( trainingModel = = null ) {
return SdmResponse . failed ( " 模型不存在 " ) ;
}
if ( trainingModel . getTrainingStatus ( ) . equals ( " 失败 " ) ) {
return SdmResponse . failed ( " 模型训练失败,请查看日志详情 " ) ;
if ( trainingModel . getTrainingStatus ( ) . equals ( " 失败 " ) | | trainingModel . getTrainingStatus ( ) . equals ( " 已终止 " ) ) {
return SdmResponse . failed ( " 模型训练失败或已终止 ,请查看日志详情 " ) ;
}
if ( trainingModel . getTrainingStatus ( ) . equals ( " 待开始 " ) | | trainingModel . getTrainingStatus ( ) . equals ( " 训练中 " ) ) {
@@ -828,8 +908,8 @@ public class ModelServiceImpl implements IModelService {
String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel . getId ( ) + " / " + PRED_RESULT_DIR_PATH ;
// 创建预测用的param.json文件
String predParamJsonFile = createPredParamJsonFile ( modelPredictReq , trainingModel , trainingModelAlgorithmParam , predDirPath ) ;
String predParamJsonFile = createPredParamJsonFile ( modelPredictReq , trainingModel , trainingModelAlgorithmParam , predDirPath ) ;
// 执行预测
try {
long startTime = System . currentTimeMillis ( ) ;
@@ -845,13 +925,13 @@ public class ModelServiceImpl implements IModelService {
// 读取预测生成的结果文件并保存到数据库
String predResultPath = predDirPath + " / " + PRED_RESULT_FILE_NAME ;
File predResultFile = new File ( predResultPath ) ;
if ( predResultFile . exists ( ) ) {
try {
// 读取预测结果文件内容
String predResultJson = new String ( Files . readAllBytes ( predResultFile . toPath ( ) ) ) ;
predResultObject = JSONObject . parseObject ( predResultJson ) ;
// 将预测结果保存到数据库
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , modelId )
@@ -869,20 +949,20 @@ public class ModelServiceImpl implements IModelService {
} catch ( Exception e ) {
log . error ( " 模型预测失败, 模型ID: {} " , modelId , e ) ;
}
return SdmResponse . success ( predResultObject ) ;
}
/**
* 创建预测参数param.json文件
*
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param trainingModelAlgorithmParam 训练模型算法参数
* @param predDirPath 预测目录路径
* @param predDirPath 预测目录路径
* @return param.json文件路径
*/
private String createPredParamJsonFile ( ModelPredictReq modelPredictReq , TrainingModel trainingModel , TrainingModelAlgorithmParam trainingModelAlgorithmParam , String predDirPath ) {
private String createPredParamJsonFile ( ModelPredictReq modelPredictReq , TrainingModel trainingModel , TrainingModelAlgorithmParam trainingModelAlgorithmParam , String predDirPath ) {
try {
// 确保目录存在
File predDir = new File ( predDirPath ) ;
@@ -908,18 +988,18 @@ public class ModelServiceImpl implements IModelService {
modelParams . put ( " inputSize " , trainingModel . getInputSize ( ) ) ;
modelParams . put ( " outputSize " , trainingModel . getOutputSize ( ) ) ;
modelParams . put ( " normalizerType " , trainingModelAlgorithmParam . getDisposeMethod ( ) ) ;
// 添加归一化参数
if ( trainingModel . getNormalizerMax ( ) ! = null ) {
JSONArray normalizerMaxArray = JSONArray . parseArray ( trainingModel . getNormalizerMax ( ) ) ;
modelParams . put ( " normalizerMax " , normalizerMaxArray ) ;
}
if ( trainingModel . getNormalizerMin ( ) ! = null ) {
JSONArray normalizerMinArray = JSONArray . parseArray ( trainingModel . getNormalizerMin ( ) ) ;
modelParams . put ( " normalizerMin " , normalizerMinArray ) ;
}
paramJson . put ( " modelParams " , modelParams ) ;
// 添加输入数据
@@ -948,6 +1028,34 @@ public class ModelServiceImpl implements IModelService {
}
}
/**
* 终止指定模型的训练进程
*
* @param modelId 模型ID
* @return 是否成功终止
*/
public SdmResponse stopTraining ( Integer modelId ) {
Process process = runningProcesses . get ( modelId ) ;
if ( process ! = null ) {
// 强制终止进程
process . destroyForcibly ( ) ;
// 从映射中移除
runningProcesses . remove ( modelId ) ;
// 更新数据库状态
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , modelId )
. set ( TrainingModel : : getTrainingStatus , " 已终止 " )
. update ( ) ;
log . info ( " 模型训练已终止, 模型ID: {} " , modelId ) ;
} else {
log . warn ( " 未找到正在运行的模型训练进程, 模型ID: {} " , modelId ) ;
}
return SdmResponse . success ( ) ;
}
@Override
public SdmResponse getModelPredictResult ( Integer modelId ) {
JSONObject result = new JSONObject ( ) ;