训练模型

This commit is contained in:
2025-10-21 20:55:25 +08:00
parent 1c38be1aef
commit 85acefc8af
5 changed files with 151 additions and 5 deletions

View File

@@ -4,8 +4,10 @@ import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import com.sdm.data.service.IModelService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
@@ -53,7 +55,26 @@ public class ModelTraningController {
*/
@PostMapping("/submitTraining")
@Operation(summary = "算法参数设置并提交训练", description = "算法参数设置并提交训练")
public void submitTraining(@RequestBody AlgorithmParamReq algorithmParamReq) {
modelService.submitTraining(algorithmParamReq);
public SdmResponse submitTraining(@RequestBody AlgorithmParamReq algorithmParamReq) {
return modelService.submitTraining(algorithmParamReq);
}
/**
* 获取训练曲线和训练日志
* @param modelId
*/
@GetMapping("/getTrainingResult")
@Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志")
public SdmResponse getTrainingResult(@RequestParam String modelId){
return modelService.getTrainingResult(modelId);
}
/**
* 数据预测
*/
@GetMapping("/predict")
@Operation(summary = "数据预测", description = "数据预测")
public SdmResponse predict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) throws IOException {
return modelService.predict(modelPredictReq);
}
}

View File

@@ -70,7 +70,7 @@ public class TrainingModelAlgorithmParam implements Serializable {
@ApiModelProperty(value = "学习率0.001")
@TableField("studyPercent")
private BigDecimal studyPercent;
private Double studyPercent;
@ApiModelProperty(value = "步数3")
@TableField("stepCounts")

View File

@@ -0,0 +1,13 @@
package com.sdm.data.model.req;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
@Data
public class ModelPredictReq {
@Schema(description = "选择的预测模型ID")
String modelId;
@Schema(description = "json数据")
String jsonData;
}

View File

@@ -4,6 +4,7 @@ import com.sdm.common.common.SdmResponse;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import java.io.IOException;
@@ -30,5 +31,19 @@ public interface IModelService {
* 模型参数设置并提交训练
* @param algorithmParamReq 模型参数
*/
void submitTraining(AlgorithmParamReq algorithmParamReq);
SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq);
/**
* 模型训练结果查询
* @param modelId 模型ID
* @return 模型训练结果
*/
SdmResponse getTrainingResult(String modelId);
/**
* 数据预测
* @param modelPredictReq 模型预测请求对象
* @return 模型预测结果
*/
SdmResponse predict(ModelPredictReq modelPredictReq);
}

View File

@@ -9,6 +9,7 @@ import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
import com.sdm.data.model.req.AddModelReq;
import com.sdm.data.model.req.AlgorithmParamReq;
import com.sdm.data.model.req.HandleLoadDataReq;
import com.sdm.data.model.req.ModelPredictReq;
import com.sdm.data.service.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeanUtils;
@@ -442,7 +443,7 @@ public class ModelServiceImpl implements IModelService {
}
@Override
public void submitTraining(AlgorithmParamReq algorithmParamReq) {
public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) {
// 保存或更新算法参数
TrainingModelAlgorithmParam queryTrainingModelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery()
.eq(TrainingModelAlgorithmParam::getModelId, algorithmParamReq.getModelId()).one();
@@ -469,6 +470,8 @@ public class ModelServiceImpl implements IModelService {
// 异步调用python脚本进行训练并上传结果文件到minio
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId());
return SdmResponse.success("提交训练成功");
}
/**
@@ -578,4 +581,98 @@ public class ModelServiceImpl implements IModelService {
throw new RuntimeException("创建训练param.json文件失败: " + e.getMessage(), e);
}
}
@Override
public SdmResponse getTrainingResult(String modelId) {
// 根据modelId 获取训练模型获取训练模型结果文件id日志文件id
TrainingModel trainingModel = trainingModelService.getById(modelId);
if (trainingModel == null) {
return SdmResponse.failed("模型不存在");
}
if(trainingModel.getTrainingStatus().equals("失败")){
return SdmResponse.failed("模型训练失败,请查看日志详情");
}
if(trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")){
return SdmResponse.failed("模型正在训练中,请稍后再查询");
}
Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId();
Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId();
if (trainingDataResultFileId == null || trainingDataLogFileId == null) {
return SdmResponse.failed("训练结果文件或日志文件不存在");
}
// 从minio获取结果json文件转换成对象
InputStream resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
// 从minio获取日志文件转换成对象
InputStream logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
if (resultInputStream == null || logInputStream == null) {
return SdmResponse.failed("无法从MinIO获取文件流");
}
try {
// 处理结果JSON文件
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
// 处理日志文件
String logContent = parseStringFromStream(logInputStream);
// 构造返回对象
JSONObject response = new JSONObject();
response.put("result", resultJsonObject);
response.put("log", logContent);
return SdmResponse.success(response);
} catch (Exception e) {
log.error("处理训练结果文件失败", e);
return SdmResponse.failed("处理训练结果文件失败: " + e.getMessage());
} finally {
try {
if (resultInputStream != null) resultInputStream.close();
if (logInputStream != null) logInputStream.close();
} catch (IOException e) {
log.error("关闭输入流失败", e);
}
}
}
/**
* 从输入流中解析JSON对象
* @param inputStream 输入流
* @return 解析后的JSON对象
* @throws IOException IO异常
*/
private JSONObject parseJsonFromStream(InputStream inputStream) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
StringBuilder jsonString = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
jsonString.append(line);
}
return JSONObject.parseObject(jsonString.toString());
}
/**
* 从输入流中读取字符串内容
* @param inputStream 输入流
* @return 字符串内容
* @throws IOException IO异常
*/
private String parseStringFromStream(InputStream inputStream) throws IOException {
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
StringBuilder content = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
content.append(line).append("\n");
}
return content.toString();
}
@Override
public SdmResponse predict(ModelPredictReq modelPredictReq) {
return null;
}
}