@@ -6,15 +6,11 @@ import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.entity.FileMetadataInfo ;
import com.sdm.data.model.entity.TrainingModel ;
import com.sdm.data.model.entity.TrainingModelAlgorithmParam ;
import com.sdm.data.model.req.AddModelReq ;
import com.sdm.data.model.req.AlgorithmParamReq ;
import com.sdm.data.model.req.HandleLoadDataReq ;
import com.sdm.data.model.req.ModelPredictReq ;
import com.sdm.data.model.req.* ;
import com.sdm.data.service.* ;
import lombok.extern.slf4j.Slf4j ;
import org.springframework.beans.BeanUtils ;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.beans.factory.annotation.Value ;
import org.springframework.mock.web.MockMultipartFile ;
import org.springframework.stereotype.Service ;
import org.springframework.transaction.annotation.Transactional ;
@@ -26,45 +22,84 @@ import java.io.OutputStream;
import java.nio.file.Files ;
import java.util.Date ;
import java.util.List ;
import java.util.Random ;
@Service
@Slf4j
public class ModelServiceImpl implements IModelService {
/**
* 数据处理文件存放路径
* 训练模型基础文件夹路径(脚本、训练数据、处理结果、训练模型)
*/
private static final String TANING_MODEL_DATA_FILE _PATH = " /home/app/model/dataHandler/ " ;
private static final String TANING_MODEL_BASE_DIR _PATH = " /home/app/model/ " ;
/**
* 模型id前缀
*/
private static final String MODEL_ID = " /ModelId_ " ;
/**
* 数据处理脚本存放路径
*/
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = " /home/app/model /ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py" ;
private static final String DATA_HANDLER_PYTHON_SCRIPT_PATH = TANING_MODEL_BASE_DIR_PATH + " /ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler/Data_Load.py " ;
/**
* 数据处理结果文件名
*/
private static final String SOURCE_JSON_FILE_NAME = " source.json " ;
/**
* 训练结果文件存放路径
*/
private static final String TRAIN_RESULT_FILE_PATH = " /home/app/model/trainResult/ " ;
/**
* 训练脚本存放路径
*/
private static final String TRAINING_PYTHON_SCRIPTPATH = " /home/app/model /ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Data_Handler /Train_Proxy_Model.py" ;
private static final String TRAINING_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + " /ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Train /Train_Proxy_Model.py " ;
/**
* 训练结果文件名
* 预测脚本存放路径
*/
private static final String PRED_PYTHON_SCRIPTPATH = TANING_MODEL_BASE_DIR_PATH + " /ModelTrainingPython/FC_ML_Baseline/FC_ML_Baseline_Predict/Model_Pred.py " ;
/**
* 数据处理文件夹放路径
*/
private static final String DATA_HANDLE_DIR_PATH = " /dataHandler/ " ;
/**
* 训练结果文件夹路径
*/
private static final String TRAIN_RESULT_DIR_PATH = " /trainResult/ " ;
/**
* 预测结果文件夹路径
*/
private static final String PRED_RESULT_DIR_PATH = " /predResult/ " ;
/**
* dataHandler 数据处理结果文件名
*/
private static final String SOURCE_JSON_FILE_NAME = " source.json " ;
/**
* trainResult 训练结果文件名
*/
private static final String TRAINING_JSON_FILE_NAME = " training.json " ;
/**
* 训练日志文件
* trainResult 训练日志文件
*/
private static final String TRAINING_LOG_FILE_NAME = " training.log " ;
/**
* trainResult 训练模型文件名前缀
*/
private static final String TRAINING_MODEL_PERFIX_NAME = " model. " ;
/**
* 预测结果文件名
*/
private static final String PRED_RESULT_FILE_NAME = " forecast.json " ;
/**
* 脚本 入参文件名
*/
private static final String PYTHON_SCRIPT_PARAM_FILE_NAME = " param.json " ;
@Autowired
IDataFileService dataFileService ;
@@ -115,11 +150,12 @@ public class ModelServiceImpl implements IModelService {
log . info ( " 训练模型关联训练数据文件成功, 模型ID: {} " , trainingModelId ) ;
// 直接将MultipartFile 训练数据文件 写入Linux服务器
String t aningModelDataFilePath = writeToLinuxServer ( handleLoadDataReq . getFile ( ) , TANING_MODEL_DATA_FILE _PATH) ;
// 直接将MultipartFile 训练数据文件 写入Linux服务器的对应模型目录下:/home/app/model/ModelId_1/dataHandler/,并返回文件路径
String h andleLoadModelDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + " / " + DATA_HANDLE_DIR _PATH;
String taningModelDataFilePath = writeToLinuxServer ( handleLoadDataReq . getFile ( ) , handleLoadModelDirPath ) ;
// 根据生成的文件路径创建python脚本 的param.json参数文件
String paramJsonPath = createParamJsonFile ( taningModelDataFilePath ) ;
String paramJsonPath = createDataHadle ParamJsonFile ( taningModelDataFilePath ) ;
// 更新训练模型 训练状态
@@ -136,7 +172,7 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse getHandleLoadDataResult ( String modelId ) {
public SdmResponse getHandleLoadDataResult ( Integer modelId ) {
// 根据模型ID查询处理结果, 再根据结果获取文件ID, 再从MinIO中下载json文件, 处理json文件返回一个对象
TrainingModel trainingModel = trainingModelService . getById ( modelId ) ;
if ( trainingModel = = null ) {
@@ -202,13 +238,69 @@ public class ModelServiceImpl implements IModelService {
log . info ( " Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms " , trainingModelId , duration ) ;
// 上传处理完的python数据到minio
String resultFilePath = TANING_MODEL_DATA_FILE _PATH + SOURCE_JSON_FILE_NAME ;
String resultFilePath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModelId + " / " + DATA_HANDLE_DIR _PATH + SOURCE_JSON_FILE_NAME ;
Integer fileId = uploadResultFileToMinio ( resultFilePath ) ;
InputStream minioInputStream = dataFileService . getMinioInputStream ( fileId ) ;
// 直接在内存中读取并解析JSON
BufferedReader reader = new BufferedReader ( new InputStreamReader ( minioInputStream ) ) ;
StringBuilder jsonString = new StringBuilder ( ) ;
String line ;
while ( ( line = reader . readLine ( ) ) ! = null ) {
jsonString . append ( line ) ;
}
// 将JSON字符串转换为JSONObject或其他Java对象, 获取表头总数量
JSONObject jsonObject = JSONObject . parseObject ( jsonString . toString ( ) ) ;
JSONArray averageDataArray = jsonObject . getJSONArray ( " average_data " ) ;
// 循环处理averageDataArray
JSONArray maxValues = null ;
JSONArray minValues = null ;
for ( int i = 0 ; i < averageDataArray . size ( ) ; i + + ) {
JSONObject averageData = averageDataArray . getJSONObject ( i ) ;
if ( averageData . getString ( " property " ) . equals ( " 最大值 " ) ) {
// 提取最大值数组, 排除property字段
maxValues = new JSONArray ( ) ;
// 获取所有字段名
java . util . Set < String > keySet = averageData . keySet ( ) ;
for ( String key : keySet ) {
if ( ! " property " . equals ( key ) ) {
maxValues . add ( averageData . getDoubleValue ( key ) ) ;
}
}
} else if ( averageData . getString ( " property " ) . equals ( " 最小值 " ) ) {
// 提取最小值数组, 排除property字段
minValues = new JSONArray ( ) ;
// 获取所有字段名
java . util . Set < String > keySet = averageData . keySet ( ) ;
for ( String key : keySet ) {
if ( ! " property " . equals ( key ) ) {
minValues . add ( averageData . getDoubleValue ( key ) ) ;
}
}
}
}
// 将数组转换为JSON字符串
String normalizerMaxJson = null ;
String normalizerMinJson = null ;
if ( maxValues ! = null ) {
normalizerMaxJson = maxValues . toJSONString ( ) ;
}
if ( minValues ! = null ) {
normalizerMinJson = minValues . toJSONString ( ) ;
}
// 更新训练模型 训练状态
trainingModelService . lambdaUpdate ( ) . eq ( TrainingModel : : getId , trainingModelId )
. set ( TrainingModel : : getHandleStatus , " 成功 " )
. set ( TrainingModel : : getTrainingDataHandleFileId , fileId )
. set ( TrainingModel : : getNormalizerMax , normalizerMaxJson )
. set ( TrainingModel : : getNormalizerMin , normalizerMinJson )
. update ( ) ;
log . info ( " 训练模型ID: {}数据处理完成,处理结果文件ID:{},总耗时: {} ms " , trainingModelId , fileId , duration ) ;
@@ -300,12 +392,12 @@ public class ModelServiceImpl implements IModelService {
}
/**
* 根据数据文件路径创建param.json配置文件
* 根据数据文件路径创建数据处理的 param.json配置文件
*
* @param dataFilePath 数据文件的完整路径
* @return param.json的完整路径
*/
private String createParamJsonFile ( String dataFilePath ) {
private String createDataHadle ParamJsonFile ( String dataFilePath ) {
try {
// 提取文件名和目录路径
File dataFile = new File ( dataFilePath ) ;
@@ -313,7 +405,7 @@ public class ModelServiceImpl implements IModelService {
String directoryPath = dataFile . getParent ( ) ;
// 构造param.json文件路径
String paramJsonPath = directoryPath + " /param.json " ;
String paramJsonPath = directoryPath + " / " + PYTHON_SCRIPT_PARAM_FILE_NAME ;
// 创建JSON对象
JSONObject paramJson = new JSONObject ( ) ;
@@ -339,9 +431,10 @@ public class ModelServiceImpl implements IModelService {
/**
* 下载文件到服务器本地
* @param fileId 文件ID
*
* @param fileId 文件ID
* @param localDirPath 本地文件夹路径
* @return 本地文件的完整路径
* @return 本地文件的完整路径 localDirPath / fileName
*/
private String downloadToLocalFromMinIO ( Integer fileId , String localDirPath ) {
InputStream in = null ;
@@ -360,13 +453,13 @@ public class ModelServiceImpl implements IModelService {
}
String localFilePath = localDirPath + " / " + fileMetadataInfo . getOriginalName ( ) ;
// 确保目标目录存在
File localFile = new File ( localDir Path ) ;
File localFile = new File ( localFile Path ) ;
File parentDir = localFile . getParentFile ( ) ;
if ( parentDir ! = null & & ! parentDir . exists ( ) ) {
parentDir . mkdirs ( ) ;
}
out = new FileOutputStream ( localDir Path ) ;
out = new FileOutputStream ( localFile Path ) ;
// 将输入流写入本地文件
byte [ ] buffer = new byte [ 1024 ] ;
int len ;
@@ -442,6 +535,37 @@ public class ModelServiceImpl implements IModelService {
}
}
@Override
public SdmResponse setTrainingDataInPutOutPutColumn ( SetTrainingDataInPutOutPutColumnReq req ) {
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , req . getModelId ( ) )
. set ( TrainingModel : : getInputSize , req . getInputColumns ( ) . size ( ) )
. set ( TrainingModel : : getInputLabel , JSONArray . toJSONString ( req . getInputColumns ( ) ) )
. set ( TrainingModel : : getOutputSize , req . getOutputColumns ( ) . size ( ) )
. set ( TrainingModel : : getOutputLabel , JSONArray . toJSONString ( req . getOutputColumns ( ) ) )
. update ( ) ;
return SdmResponse . success ( ) ;
}
@Override
public SdmResponse getTrainingDataInPutOutPutColumn ( Integer modelId ) {
TrainingModel trainingModel = trainingModelService . getById ( modelId ) ;
// 从数据库获取inputLabel字段值( JSON字符串)
String inputLabelJson = trainingModel . getInputLabel ( ) ;
// 解析JSON字符串为列表
List < String > inputLabels = JSONArray . parseArray ( inputLabelJson , String . class ) ;
// 从数据库获取outputLable字段值( JSON字符串)
String outputLabelJson = trainingModel . getOutputLabel ( ) ;
// 解析JSON字符串为列表
List < String > outputLabels = JSONArray . parseArray ( outputLabelJson , String . class ) ;
JSONObject result = new JSONObject ( ) ;
result . put ( " inputLabels " , inputLabels ) ;
result . put ( " outputLabels " , outputLabels ) ;
return SdmResponse . success ( result ) ;
}
@Override
public SdmResponse submitTraining ( AlgorithmParamReq algorithmParamReq ) {
// 保存或更新算法参数
@@ -469,7 +593,7 @@ public class ModelServiceImpl implements IModelService {
String paramJsonPath = createTraningParamJsonFile ( algorithmParamReq ) ;
// 异步调用python脚本进行训练, 并上传结果文件到minio
trainModelAsync ( paramJsonPath , algorithmParamReq . getModelId ( ) ) ;
trainModelAsync ( paramJsonPath , algorithmParamReq . getModelId ( ) , algorithmParamReq . getExportFormat ( ) );
return SdmResponse . success ( " 提交训练成功 " ) ;
}
@@ -480,7 +604,7 @@ public class ModelServiceImpl implements IModelService {
* @param paramJsonPath 参数文件路径
* @param modelId 模型ID
*/
private void trainModelAsync ( String paramJsonPath , Integer modelId ) {
private void trainModelAsync ( String paramJsonPath , Integer modelId , String exportFormat ) {
new Thread ( ( ) - > {
try {
long startTime = System . currentTimeMillis ( ) ;
@@ -491,10 +615,13 @@ public class ModelServiceImpl implements IModelService {
long duration = endTime - startTime ;
log . info ( " Python脚本执行完成, 训练模型ID: {}, 耗时: {} ms " , modelId , duration ) ;
// 将训练生成的生成training.json、training.log文件, 上传到minio
Integer trainingDataResultFileId = uploadResultFileToMinio ( TRAIN_RESULT_FILE_PATH + TRAINING_JSON_FILE_NAME ) ;
Integer trainingDataLog FileId = uploadResultFileToMinio ( TRAIN_RESULT_FILE _PATH + TRAINING_LOG _FILE_NAME ) ;
log . info ( " 模型训练完成,结果文件: training.json、training.log文件上传成功, 文件ID: {},{} " , trainingDataResultFileId , trainingDataLogFileId ) ;
// 将训练生成的生成 training.json、training.log、模型文件, 上传到minio
Integer trainingDataResult FileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR _PATH + TRAINING_JSON _FILE_NAME ) ;
Integer trainingDataLogFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_LOG_FILE_NAME ) ;
Integer trainingDataExportModelFileId = uploadResultFileToMinio ( TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR_PATH + TRAINING_MODEL_PERFIX_NAME + exportFormat ) ;
log . info ( " 模型训练完成,结果文件: training.json文件ID: {}、training.log文件ID: {}、导出model文件ID: {}文件上传成功 " , trainingDataResultFileId , trainingDataLogFileId , trainingDataExportModelFileId ) ;
// 更新模型状态为训练完成
trainingModelService . lambdaUpdate ( )
@@ -504,8 +631,9 @@ public class ModelServiceImpl implements IModelService {
. set ( TrainingModel : : getTrainingDuration , duration )
. set ( TrainingModel : : getTrainingDataResultFileId , trainingDataResultFileId )
. set ( TrainingModel : : getTrainingDataLogFileId , trainingDataLogFileId )
. set ( TrainingModel : : getTrainingDataExportModelFileId , trainingDataExportModelFileId )
. update ( ) ;
log . info ( " 模型训练完成, 模型ID: {} " , modelId ) ;
} catch ( Exception e ) {
log . error ( " 模型训练失败, 模型ID: {} " , modelId , e ) ;
@@ -520,15 +648,16 @@ public class ModelServiceImpl implements IModelService {
/**
* 创建训练参数param.json文件
*
* @param algorithmParamReq 算法参数请求
* @return
*/
private String createTraningParamJsonFile ( AlgorithmParamReq algorithmParamReq ) {
// 根据训练模型id 获取训练模型的训练数据文件id, 并下载到本地 trainingDataFilePath = /home/app/model/trainResult/sample.CSV
Integer modelId = algorithmParamReq . getModelId ( ) ;
TrainingModel trainingModel = trainingModelService . getById ( modelId ) ;
TrainingModel trainingModel = trainingModelService . getById ( modelId ) ;
Integer trainingDataFileId = trainingModel . getTrainingDataFileId ( ) ;
String trainingDataFilePath = downloadToLocalFromMinIO ( trainingDataFileId , TRAIN_RESULT_FILE _PATH ) ;
String trainingDataFilePath = downloadToLocalFromMinIO ( trainingDataFileId , TANING_MODEL_BASE_DIR_PATH + MODEL_ID + modelId + " / " + TRAIN_RESULT_DIR _PATH) ;
try {
// 提取文件名和目录路径
@@ -537,7 +666,7 @@ public class ModelServiceImpl implements IModelService {
String directoryPath = dataFile . getParent ( ) ;
// 构造param.json文件路径
String paramJsonPath = directoryPath + " /param.json " ;
String paramJsonPath = directoryPath + " / " + PYTHON_SCRIPT_PARAM_FILE_NAME ;
// 创建JSON对象
JSONObject paramJson = new JSONObject ( ) ;
@@ -549,8 +678,8 @@ public class ModelServiceImpl implements IModelService {
// 构造algorithmParam对象
JSONObject algorithmParam = new JSONObject ( ) ;
algorithmParam . put ( " inputSize " , algorithmParamReq . getInputSize ( ) ) ;
algorithmParam . put ( " outputSize " , algorithmParamReq . getOutputSize ( ) ) ;
algorithmParam . put ( " inputSize " , trainingModel . getInputSize ( ) ) ;
algorithmParam . put ( " outputSize " , trainingModel . getOutputSize ( ) ) ;
algorithmParam . put ( " algorithm " , algorithmParamReq . getAlgorithm ( ) ) ;
algorithmParam . put ( " activateFun " , algorithmParamReq . getActivateFun ( ) ) ;
algorithmParam . put ( " lossFun " , algorithmParamReq . getLossFun ( ) ) ;
@@ -583,47 +712,58 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse getTrainingResult ( String modelId ) {
public SdmResponse getTrainingResult ( Integer modelId ) {
// 根据modelId 获取训练模型, 获取训练模型结果文件id, 日志文件id,
TrainingModel trainingModel = trainingModelService . getById ( modelId ) ;
if ( trainingModel = = null ) {
return SdmResponse . failed ( " 模型不存在 " ) ;
}
if ( trainingModel . getTrainingStatus ( ) . equals ( " 失败 " ) ) {
if ( trainingModel . getTrainingStatus ( ) . equals ( " 失败 " ) ) {
return SdmResponse . failed ( " 模型训练失败,请查看日志详情 " ) ;
}
if ( trainingModel . getTrainingStatus ( ) . equals ( " 待开始 " ) | | trainingModel . getTrainingStatus ( ) . equals ( " 训练中 " ) ) {
if ( trainingModel . getTrainingStatus ( ) . equals ( " 待开始 " ) | | trainingModel . getTrainingStatus ( ) . equals ( " 训练中 " ) ) {
return SdmResponse . failed ( " 模型正在训练中,请稍后再查询 " ) ;
}
Integer trainingDataResultFileId = trainingModel . getTrainingDataResultFileId ( ) ;
Integer trainingDataLogFileId = trainingModel . getTrainingDataLogFileId ( ) ;
if ( trainingDataResultFileId = = null | | trainingDataLogFileId = = null ) {
return SdmResponse . failed ( " 训练结果文件或 日志文件不存在 " ) ;
if ( trainingDataResultFileId = = null & & trainingDataLogFileId = = null ) {
return SdmResponse . failed ( " 训练结果文件和 日志文件都 不存在 " ) ;
}
// 从minio获取结果json文件转换成对象
InputStream result InputStream = dataFileService . getMinioInputStream ( trainingDataResultFileId ) ;
// 从minio获取日志文件转换成对象
InputStream logInputStream = dataFileService . getMinioInputStream ( trainingDataLogFileId ) ;
InputStream resultInputStream = null ;
InputStream log InputStream = null ;
if ( resultInputStream = = null | | logInputStream = = null ) {
// 从minio获取结果json文件转换成对象
if ( trainingDataResultFileId ! = null ) {
resultInputStream = dataFileService . getMinioInputStream ( trainingDataResultFileId ) ;
}
// 从minio获取日志文件转换成对象
if ( trainingDataLogFileId ! = null ) {
logInputStream = dataFileService . getMinioInputStream ( trainingDataLogFileId ) ;
}
if ( resultInputStream = = null & & logInputStream = = null ) {
return SdmResponse . failed ( " 无法从MinIO获取文件流 " ) ;
}
try {
// 处理结果JSON文件
JSONObject resultJsonObject = parseJsonFromStream ( resultInputStream ) ;
// 处理日志文件
String logContent = parseStringFromStream ( logInputStream ) ;
// 构造返回对象
JSONObject response = new JSONObject ( ) ;
response . put ( " result " , resultJsonObject ) ;
response . put ( " log " , logContent ) ;
// 处理结果JSON文件( 如果存在)
if ( resultInputStream ! = null ) {
JSONObject resultJsonObject = parseJsonFromStream ( resultInputStream ) ;
response . put ( " result " , resultJsonObject ) ;
}
// 处理日志文件(如果存在)
if ( logInputStream ! = null ) {
String logContent = parseStringFromStream ( logInputStream ) ;
response . put ( " log " , logContent ) ;
}
return SdmResponse . success ( response ) ;
} catch ( Exception e ) {
@@ -641,6 +781,7 @@ public class ModelServiceImpl implements IModelService {
/**
* 从输入流中解析JSON对象
*
* @param inputStream 输入流
* @return 解析后的JSON对象
* @throws IOException IO异常
@@ -657,6 +798,7 @@ public class ModelServiceImpl implements IModelService {
/**
* 从输入流中读取字符串内容
*
* @param inputStream 输入流
* @return 字符串内容
* @throws IOException IO异常
@@ -672,7 +814,158 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public SdmResponse p redict( ModelPredictReq modelPredictReq ) {
return null ;
public SdmResponse startP redict( ModelPredictReq modelPredictReq ) {
JSONObject predResultObject = new JSONObject ( ) ;
Integer modelId = modelPredictReq . getModelId ( ) ;
// 将jsonData存储到与modelId关联的模型表的inputLabel字段中
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , modelPredictReq . getModelId ( ) )
. set ( TrainingModel : : getInputPredLabelValue , modelPredictReq . getJsonData ( ) )
. update ( ) ;
TrainingModel trainingModel = trainingModelService . lambdaQuery ( ) . eq ( TrainingModel : : getId , modelId ) . one ( ) ;
TrainingModelAlgorithmParam trainingModelAlgorithmParam = trainingModelAlgorithmParamService . lambdaQuery ( ) . eq ( TrainingModelAlgorithmParam : : getModelId , modelId ) . one ( ) ;
String predDirPath = TANING_MODEL_BASE_DIR_PATH + MODEL_ID + trainingModel . getId ( ) + " / " + PRED_RESULT_DIR_PATH ;
// 创建预测用的param.json文件
String predParamJsonFile = createPredParamJsonFile ( modelPredictReq , trainingModel , trainingModelAlgorithmParam , predDirPath ) ;
// 执行预测
try {
long startTime = System . currentTimeMillis ( ) ;
log . info ( " 开始执行Python预测脚本, 模型ID: {} " , modelId ) ;
// 调用python脚本进行预测
callPythonScript ( PRED_PYTHON_SCRIPTPATH , predParamJsonFile ) ;
long endTime = System . currentTimeMillis ( ) ;
long duration = endTime - startTime ;
log . info ( " Python预测脚本执行完成, 模型ID: {}, 耗时: {} ms " , modelId , duration ) ;
// 读取预测生成的结果文件并保存到数据库
String predResultPath = predDirPath + " / " + PRED_RESULT_FILE_NAME ;
File predResultFile = new File ( predResultPath ) ;
if ( predResultFile . exists ( ) ) {
try {
// 读取预测结果文件内容
String predResultJson = new String ( Files . readAllBytes ( predResultFile . toPath ( ) ) ) ;
predResultObject = JSONObject . parseObject ( predResultJson ) ;
// 将预测结果保存到数据库
trainingModelService . lambdaUpdate ( )
. eq ( TrainingModel : : getId , modelId )
. set ( TrainingModel : : getOutputPredLabelValue , predResultObject . toJSONString ( ) )
. update ( ) ;
log . info ( " 模型预测结果已保存到数据库, 模型ID: {} " , modelId ) ;
} catch ( Exception e ) {
log . error ( " 解析或保存预测结果失败, 模型ID: {} " , modelId , e ) ;
}
} else {
log . warn ( " 预测结果文件不存在: {} " , predResultPath ) ;
}
log . info ( " 模型预测完成, 模型ID: {} " , modelId ) ;
} catch ( Exception e ) {
log . error ( " 模型预测失败, 模型ID: {} " , modelId , e ) ;
}
return SdmResponse . success ( predResultObject ) ;
}
/**
* 创建预测参数param.json文件
*
* @param modelPredictReq 预测请求参数
* @param trainingModel 训练模型
* @param trainingModelAlgorithmParam 训练模型算法参数
* @param predDirPath 预测目录路径
* @return param.json文件路径
*/
private String createPredParamJsonFile ( ModelPredictReq modelPredictReq , TrainingModel trainingModel , TrainingModelAlgorithmParam trainingModelAlgorithmParam , String predDirPath ) {
try {
// 确保目录存在
File predDir = new File ( predDirPath ) ;
if ( ! predDir . exists ( ) ) {
predDir . mkdirs ( ) ;
}
// 构造param.json文件路径
String paramJsonPath = predDirPath + " / " + PYTHON_SCRIPT_PARAM_FILE_NAME ;
// 创建JSON对象
JSONObject paramJson = new JSONObject ( ) ;
// 下载model模型文件到本地
downloadToLocalFromMinIO ( trainingModel . getTrainingDataExportModelFileId ( ) , predDirPath ) ;
paramJson . put ( " modelFile " , TRAINING_MODEL_PERFIX_NAME + trainingModelAlgorithmParam . getExportFormat ( ) ) ;
// 设置路径
paramJson . put ( " path " , predDirPath ) ;
// 添加模型参数
JSONObject modelParams = new JSONObject ( ) ;
modelParams . put ( " inputSize " , trainingModel . getInputSize ( ) ) ;
modelParams . put ( " outputSize " , trainingModel . getOutputSize ( ) ) ;
modelParams . put ( " normalizerType " , trainingModelAlgorithmParam . getDisposeMethod ( ) ) ;
// 添加归一化参数
if ( trainingModel . getNormalizerMax ( ) ! = null ) {
JSONArray normalizerMaxArray = JSONArray . parseArray ( trainingModel . getNormalizerMax ( ) ) ;
modelParams . put ( " normalizerMax " , normalizerMaxArray ) ;
}
if ( trainingModel . getNormalizerMin ( ) ! = null ) {
JSONArray normalizerMinArray = JSONArray . parseArray ( trainingModel . getNormalizerMin ( ) ) ;
modelParams . put ( " normalizerMin " , normalizerMinArray ) ;
}
paramJson . put ( " modelParams " , modelParams ) ;
// 添加输入数据
// 直接使用传入的JSON数组作为input值
JSONArray inputArray = JSONArray . parseArray ( modelPredictReq . getJsonData ( ) ) ;
paramJson . put ( " input " , inputArray ) ;
// 添加输出标签
JSONObject output = new JSONObject ( ) ;
if ( trainingModel . getOutputLabel ( ) ! = null ) {
JSONArray outputLabels = JSONArray . parseArray ( trainingModel . getOutputLabel ( ) ) ;
output . put ( " names " , outputLabels ) ;
}
paramJson . put ( " output " , output ) ;
// 写入param.json文件
try ( FileWriter fileWriter = new FileWriter ( paramJsonPath ) ) {
fileWriter . write ( paramJson . toJSONString ( ) ) ;
}
log . info ( " 成功创建预测param.json文件: {} " , paramJsonPath ) ;
return paramJsonPath ;
} catch ( IOException e ) {
log . error ( " 创建预测param.json文件失败 " , e ) ;
throw new RuntimeException ( " 创建预测param.json文件失败: " + e . getMessage ( ) , e ) ;
}
}
@Override
public SdmResponse getModelPredictResult ( Integer modelId ) {
JSONObject result = new JSONObject ( ) ;
TrainingModel model = trainingModelService . lambdaQuery ( ) . eq ( TrainingModel : : getId , modelId ) . one ( ) ;
if ( model = = null ) {
return SdmResponse . failed ( " 模型不存在 " ) ;
}
String inputPredLabelValue = model . getInputPredLabelValue ( ) ;
if ( inputPredLabelValue ! = null ) {
JSONObject inputPredLabelValueJson = JSONObject . parseObject ( inputPredLabelValue ) ;
result . put ( " inputPredLabelValue " , inputPredLabelValueJson ) ;
}
String outputPredLabelValue = model . getOutputPredLabelValue ( ) ;
if ( outputPredLabelValue ! = null ) {
JSONObject outputPredLabelValueJson = JSONObject . parseObject ( outputPredLabelValue ) ;
result . put ( " inputPredLabelValue " , outputPredLabelValueJson ) ;
}
return SdmResponse . success ( result ) ;
}
}