forked from toolchaintechnologycenter/spdm-backend
训练模型
This commit is contained in:
@@ -4,8 +4,10 @@ import com.sdm.common.common.SdmResponse;
|
||||
import com.sdm.data.model.req.AddModelReq;
|
||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
||||
import com.sdm.data.model.req.ModelPredictReq;
|
||||
import com.sdm.data.service.IModelService;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.Parameter;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.*;
|
||||
|
||||
@@ -53,7 +55,26 @@ public class ModelTraningController {
|
||||
*/
|
||||
@PostMapping("/submitTraining")
|
||||
@Operation(summary = "算法参数设置并提交训练", description = "算法参数设置并提交训练")
|
||||
public void submitTraining(@RequestBody AlgorithmParamReq algorithmParamReq) {
|
||||
modelService.submitTraining(algorithmParamReq);
|
||||
public SdmResponse submitTraining(@RequestBody AlgorithmParamReq algorithmParamReq) {
|
||||
return modelService.submitTraining(algorithmParamReq);
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取训练曲线和训练日志
|
||||
* @param modelId
|
||||
*/
|
||||
@GetMapping("/getTrainingResult")
|
||||
@Operation(summary = "获取训练曲线和训练日志", description = "获取训练曲线和训练日志")
|
||||
public SdmResponse getTrainingResult(@RequestParam String modelId){
|
||||
return modelService.getTrainingResult(modelId);
|
||||
}
|
||||
|
||||
/**
|
||||
* 数据预测
|
||||
*/
|
||||
@GetMapping("/predict")
|
||||
@Operation(summary = "数据预测", description = "数据预测")
|
||||
public SdmResponse predict(@Parameter(description = "数据预测请求对象") @RequestBody ModelPredictReq modelPredictReq) throws IOException {
|
||||
return modelService.predict(modelPredictReq);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,7 +70,7 @@ public class TrainingModelAlgorithmParam implements Serializable {
|
||||
|
||||
@ApiModelProperty(value = "学习率(如:0.001)")
|
||||
@TableField("studyPercent")
|
||||
private BigDecimal studyPercent;
|
||||
private Double studyPercent;
|
||||
|
||||
@ApiModelProperty(value = "步数(如:3)")
|
||||
@TableField("stepCounts")
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.sdm.data.model.req;
|
||||
|
||||
import io.swagger.v3.oas.annotations.media.Schema;
|
||||
import lombok.Data;
|
||||
|
||||
@Data
|
||||
public class ModelPredictReq {
|
||||
@Schema(description = "选择的预测模型ID")
|
||||
String modelId;
|
||||
|
||||
@Schema(description = "json数据")
|
||||
String jsonData;
|
||||
}
|
||||
@@ -4,6 +4,7 @@ import com.sdm.common.common.SdmResponse;
|
||||
import com.sdm.data.model.req.AddModelReq;
|
||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
||||
import com.sdm.data.model.req.ModelPredictReq;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@@ -30,5 +31,19 @@ public interface IModelService {
|
||||
* 模型参数设置并提交训练
|
||||
* @param algorithmParamReq 模型参数
|
||||
*/
|
||||
void submitTraining(AlgorithmParamReq algorithmParamReq);
|
||||
SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq);
|
||||
|
||||
/**
|
||||
* 模型训练结果查询
|
||||
* @param modelId 模型ID
|
||||
* @return 模型训练结果
|
||||
*/
|
||||
SdmResponse getTrainingResult(String modelId);
|
||||
|
||||
/**
|
||||
* 数据预测
|
||||
* @param modelPredictReq 模型预测请求对象
|
||||
* @return 模型预测结果
|
||||
*/
|
||||
SdmResponse predict(ModelPredictReq modelPredictReq);
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import com.sdm.data.model.entity.TrainingModelAlgorithmParam;
|
||||
import com.sdm.data.model.req.AddModelReq;
|
||||
import com.sdm.data.model.req.AlgorithmParamReq;
|
||||
import com.sdm.data.model.req.HandleLoadDataReq;
|
||||
import com.sdm.data.model.req.ModelPredictReq;
|
||||
import com.sdm.data.service.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.beans.BeanUtils;
|
||||
@@ -442,7 +443,7 @@ public class ModelServiceImpl implements IModelService {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void submitTraining(AlgorithmParamReq algorithmParamReq) {
|
||||
public SdmResponse submitTraining(AlgorithmParamReq algorithmParamReq) {
|
||||
// 保存或更新算法参数
|
||||
TrainingModelAlgorithmParam queryTrainingModelAlgorithmParam = trainingModelAlgorithmParamService.lambdaQuery()
|
||||
.eq(TrainingModelAlgorithmParam::getModelId, algorithmParamReq.getModelId()).one();
|
||||
@@ -469,6 +470,8 @@ public class ModelServiceImpl implements IModelService {
|
||||
|
||||
// 异步调用python脚本进行训练,并上传结果文件到minio
|
||||
trainModelAsync(paramJsonPath, algorithmParamReq.getModelId());
|
||||
|
||||
return SdmResponse.success("提交训练成功");
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -578,4 +581,98 @@ public class ModelServiceImpl implements IModelService {
|
||||
throw new RuntimeException("创建训练param.json文件失败: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public SdmResponse getTrainingResult(String modelId) {
|
||||
// 根据modelId 获取训练模型,获取训练模型结果文件id,日志文件id,
|
||||
TrainingModel trainingModel = trainingModelService.getById(modelId);
|
||||
if (trainingModel == null) {
|
||||
return SdmResponse.failed("模型不存在");
|
||||
}
|
||||
if(trainingModel.getTrainingStatus().equals("失败")){
|
||||
return SdmResponse.failed("模型训练失败,请查看日志详情");
|
||||
}
|
||||
|
||||
if(trainingModel.getTrainingStatus().equals("待开始") || trainingModel.getTrainingStatus().equals("训练中")){
|
||||
return SdmResponse.failed("模型正在训练中,请稍后再查询");
|
||||
}
|
||||
|
||||
|
||||
Integer trainingDataResultFileId = trainingModel.getTrainingDataResultFileId();
|
||||
Integer trainingDataLogFileId = trainingModel.getTrainingDataLogFileId();
|
||||
|
||||
if (trainingDataResultFileId == null || trainingDataLogFileId == null) {
|
||||
return SdmResponse.failed("训练结果文件或日志文件不存在");
|
||||
}
|
||||
|
||||
// 从minio获取结果json文件转换成对象
|
||||
InputStream resultInputStream = dataFileService.getMinioInputStream(trainingDataResultFileId);
|
||||
// 从minio获取日志文件转换成对象
|
||||
InputStream logInputStream = dataFileService.getMinioInputStream(trainingDataLogFileId);
|
||||
|
||||
if (resultInputStream == null || logInputStream == null) {
|
||||
return SdmResponse.failed("无法从MinIO获取文件流");
|
||||
}
|
||||
|
||||
try {
|
||||
// 处理结果JSON文件
|
||||
JSONObject resultJsonObject = parseJsonFromStream(resultInputStream);
|
||||
// 处理日志文件
|
||||
String logContent = parseStringFromStream(logInputStream);
|
||||
|
||||
// 构造返回对象
|
||||
JSONObject response = new JSONObject();
|
||||
response.put("result", resultJsonObject);
|
||||
response.put("log", logContent);
|
||||
|
||||
return SdmResponse.success(response);
|
||||
} catch (Exception e) {
|
||||
log.error("处理训练结果文件失败", e);
|
||||
return SdmResponse.failed("处理训练结果文件失败: " + e.getMessage());
|
||||
} finally {
|
||||
try {
|
||||
if (resultInputStream != null) resultInputStream.close();
|
||||
if (logInputStream != null) logInputStream.close();
|
||||
} catch (IOException e) {
|
||||
log.error("关闭输入流失败", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从输入流中解析JSON对象
|
||||
* @param inputStream 输入流
|
||||
* @return 解析后的JSON对象
|
||||
* @throws IOException IO异常
|
||||
*/
|
||||
private JSONObject parseJsonFromStream(InputStream inputStream) throws IOException {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
StringBuilder jsonString = new StringBuilder();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
jsonString.append(line);
|
||||
}
|
||||
return JSONObject.parseObject(jsonString.toString());
|
||||
}
|
||||
|
||||
/**
|
||||
* 从输入流中读取字符串内容
|
||||
* @param inputStream 输入流
|
||||
* @return 字符串内容
|
||||
* @throws IOException IO异常
|
||||
*/
|
||||
private String parseStringFromStream(InputStream inputStream) throws IOException {
|
||||
BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream));
|
||||
StringBuilder content = new StringBuilder();
|
||||
String line;
|
||||
while ((line = reader.readLine()) != null) {
|
||||
content.append(line).append("\n");
|
||||
}
|
||||
return content.toString();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SdmResponse predict(ModelPredictReq modelPredictReq) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user