百度文心一言 java 支持流式输出,Springboot+ sse的demo
和其他的接口不一样 需要 CompletionsResponse.data封装下 ,不然前端页面需要兼容非json的格式。3、WenxinEventSourceListener事件监听器。2、配置apikey和secretkey。4、返回的json格式。
·
参考:GitHub - mmciel/wenxin-api-java: 百度文心一言Java库,支持问答和对话,支持流式输出和同步输出。提供SpringBoot调用样例。提供拓展能力。
1、依赖
<dependency>
<groupId>com.baidu.aip</groupId>
<artifactId>java-sdk</artifactId>
<version>4.16.18</version>
</dependency>
2、配置apikey和secretkey
3、主要使用的接口
4、返回的json格式
3、WenxinEventSourceListener 事件监听器
和其他的接口不一样 需要 CompletionsResponse.data 封装下 ,不然前端页面需要兼容非json的格式
@Slf4j
public class WenxinEventSourceListener extends EventSourceListener {
private long tokens;
private SseEmitter sseEmitter;
public WenxinEventSourceListener(SseEmitter sseEmitter) {
this.sseEmitter = sseEmitter;
}
@Override
public void onOpen(EventSource eventSource, Response response) {
log.info("建立sse连接...");
}
@SneakyThrows
@Override
@JsonIgnoreProperties(ignoreUnknown = true)
public void onEvent(EventSource eventSource, String id, String type, String data) {
ChatResponse bean = JSONUtil.parseObj(data).toBean(ChatResponse.class);
log.info("返回数据:{}", data);
if (bean.getIs_end()) {
log.info("返回数据结束了");
sseEmitter.send(SseEmitter.event()
.id("[TOKENS]")
.data("<br/><br/>tokens:" + tokens())
.reconnectTime(3000));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
// 传输完成后自动关闭sse
sseEmitter.complete();
return;
}
log.info("OpenAI返回数据:{}", data);
tokens += 1;
if (data.equals("[DONE]")) {
log.info("OpenAI返回数据结束了");
sseEmitter.send(SseEmitter.event()
.id("[TOKENS]")
.data("<br/><br/>tokens:" + tokens())
.reconnectTime(3000));
sseEmitter.send(SseEmitter.event()
.id("[DONE]")
.data("[DONE]")
.reconnectTime(3000));
// 传输完成后自动关闭sse
sseEmitter.complete();
return;
}
CompletionsResponse completionResponse = new CompletionsResponse();
CompletionsResponse.Data dataResult = new CompletionsResponse.Data();
dataResult.setText(bean.getResult());
completionResponse.setData(dataResult);
try {
sseEmitter.send(SseEmitter.event()
.id(bean.getId())
.data(completionResponse.getData())
.reconnectTime(3000));
} catch (Exception e) {
log.error("sse信息推送失败!");
eventSource.cancel();
e.printStackTrace();
}
}
@Override
public void onClosed(EventSource eventSource) {
log.info("关闭sse连接...");
}
@SneakyThrows
@Override
public void onFailure(EventSource eventSource, Throwable t, Response response) {
if(Objects.isNull(response)){
log.error("sse连接异常:{}", t);
eventSource.cancel();
return;
}
ResponseBody body = response.body();
if (Objects.nonNull(body)) {
// 错误处理 {"error_code":110,"error_msg":"Access token invalid or no longer valid"},异常:{}
log.error("sse连接异常data:{},异常:{}", body.string(), t);
} else {
log.error("sse连接异常data:{},异常:{}", response, t);
}
eventSource.cancel();
}
/**
* tokens
* @return
*/
public long tokens() {
return tokens;
}
}
public class CompletionsResponse implements Serializable {
/**
* 请求处理是否成功
*/
@JSONField(name = "Success")
private Boolean success = true;
/**
* 请求失败code
*/
@JSONField(name = "Code")
private String code;
/**
* 请求失败描述
*/
@JSONField(name = "Message")
private String message;
/**
* 请求Id
*/
@JSONField(name = "RequestId")
private String requestId;
/**
* 请求处理结果
*/
@JSONField(name = "Data")
private Data data;
public static class Data implements Serializable {
private static final long serialVersionUID = -2717404558710025579L;
/**
* 大模型请求id
*/
@JSONField(name = "responseId")
private String responseId;
/**
* 上下文sessionId
*/
@JSONField(name = "SessionId")
private String sessionId;
/**
* 文本生成的内容
*/
@JSONField(name = "Text")
private String text;
/**
* 应用级别的token消耗
*/
@JSONField(name = "Usage")
private List<Usage> usage;
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("Data{");
sb.append("responseId='").append(responseId).append('\'');
sb.append(", sessionId='").append(sessionId).append('\'');
sb.append(", text='").append(text).append('\'');
sb.append(", usage=").append(usage);
sb.append('}');
return sb.toString();
}
}
public Boolean isSuccess() {
return success;
}
public void setSuccess(Boolean success) {
this.success = success;
}
public String getCode() {
return code;
}
public void setCode(String code) {
this.code = code;
}
public String getMessage() {
return message;
}
public void setMessage(String message) {
this.message = message;
}
public String getRequestId() {
return requestId;
}
public void setRequestId(String requestId) {
this.requestId = requestId;
}
public Data getData() {
return data;
}
public void setData(Data data) {
this.data = data;
}
@Override
public String toString() {
return "CompletionsResponse{" + "success=" + success +
", code='" + code + '\'' +
", message='" + message + '\'' +
", requestId='" + requestId + '\'' +
", data=" + data +
'}';
}
public static class Usage implements Serializable {
private static final long serialVersionUID = 1169681422970782335L;
/**
* input消耗的token
*/
@JSONField(name = "InputTokens")
private Integer inputTokens;
/**
* output消耗的token
*/
@JSONField(name = "OutputTokens")
private Integer outputTokens;
/**
* 模型id
*/
@JSONField(name = "ModelId")
private String modelId;
@Override
public String toString() {
return "Usage{" + "inputTokens=" + inputTokens +
", outputTokens=" + outputTokens +
", modelId='" + modelId + '\'' +
'}';
}
}
}
4、WenXinClient 流式主要看下 streamChat 方式,之前从千帆上找到流式例子 返回type是json的,所以之前自己手写的demo总报异常。
public void streamChat(ChatBody chatBody, EventSourceListener eventSourceListener, ModelE modelE) {
if (Objects.isNull(eventSourceListener)) {
throw new WenXinException("参数异常:EventSourceListener不能为空");
}
chatBody.setStream(true);
try {
EventSource.Factory factory = EventSources.createFactory(this.okHttpClient);
Request request = new Request.Builder().url(assembleUrl(modelE))
.post(RequestBody.create(MediaType.parse(ContentType.JSON.getValue()),
new ObjectMapper().writeValueAsString(chatBody))).build();
factory.newEventSource(request, eventSourceListener);
} catch (Exception e) {
log.error("请求参数解析异常:", e);
e.printStackTrace();
}
}
private String assembleUrl(ModelE modelE) {
accessToken = WenXinConfig.refreshAccessToken();
return modelE.getApiHost() + "?access_token=" + accessToken;
}
5、定义Sse的接口是实现方法
public interface SseService {
/**
* 创建SSE
* @param uid
* @return
*/
SseEmitter createSse(String uid);
/**
* 关闭SSE
* @param uid
*/
void closeSse(String uid);
/**
* 客户端发送消息到服务端
* @param uid
* @param chatRequest
*/
ChatResponse sseChat(String uid, ChatRequest chatRequest);
}
public class WenXinSseServiceImpl implements SseService {
@Value("${chat.accessKeyId}")
private String accessKeyId;
@Value("${chat.accessKeySecret}")
private String accessKeySecret;
@Value("${chat.agentKey}")
private String agentKey;
@Value("${chat.appId}")
private String appId;
@Autowired
WenXinClient wenXinClient;
@Override
public SseEmitter createSse(String uid) {
//默认30秒超时,设置为0L则永不超时
SseEmitter sseEmitter = new SseEmitter(0l);
//完成后回调
sseEmitter.onCompletion(() -> {
log.info("[{}]结束连接...................", uid);
LocalCache.CACHE.remove(uid);
});
//超时回调
sseEmitter.onTimeout(() -> {
log.info("[{}]连接超时...................", uid);
});
//异常回调
sseEmitter.onError(
throwable -> {
try {
log.info("[{}]连接异常,{}", uid, throwable.toString());
sseEmitter.send(SseEmitter.event()
.id(uid)
.name("发生异常!")
.data(Message.builder().content("发生异常请重试!").build())
.reconnectTime(3000));
LocalCache.CACHE.put(uid, sseEmitter);
} catch (IOException e) {
e.printStackTrace();
}
}
);
try {
sseEmitter.send(SseEmitter.event().reconnectTime(5000));
} catch (IOException e) {
e.printStackTrace();
}
LocalCache.CACHE.put(uid, sseEmitter);
log.info("[{}]创建sse连接成功!", uid);
return sseEmitter;
}
@Override
public void closeSse(String uid) {
SseEmitter sse = (SseEmitter) LocalCache.CACHE.get(uid);
if (sse != null) {
sse.complete();
//移除
LocalCache.CACHE.remove(uid);
}
}
@Override
public ChatResponse sseChat(String uid, ChatRequest chatRequest) {
if (StringUtils.isBlank(chatRequest.getMsg())) {
log.error("参数异常,msg为null", uid);
throw new BaseException("参数异常,msg不能为空~");
}
SseEmitter sseEmitter = (SseEmitter) LocalCache.CACHE.get(uid);
if (sseEmitter == null) {
log.info("聊天消息推送失败uid:[{}],没有创建连接,请重试。", uid);
throw new BaseException("聊天消息推送失败uid:[{}],没有创建连接,请重试。~");
}
WenxinEventSourceListener openAIEventSourceListener = new WenxinEventSourceListener(sseEmitter);
List<MessageItem> messages = new ArrayList<>();
messages.add(MessageItem.builder().role(MessageItem.Role.USER).content(chatRequest.getMsg()).build());
wenXinClient.streamChat(messages, openAIEventSourceListener, ModelE.ERNIE_Bot);
LocalCache.CACHE.put("msg" + uid, JSONUtil.toJsonStr(messages), LocalCache.TIMEOUT);
ChatResponse response = new ChatResponse();
response.setQuestionTokens(1);
return response;
}
}
6、主要的controller接口
/**
* 创建sse连接
*
* @param headers
* @return
*/
@CrossOrigin
@GetMapping("/createSse")
public SseEmitter createConnect(@RequestHeader Map<String, String> headers) {
String uid = getUid(headers);
return sseService.createSse(uid);
}
/**
* 聊天接口
*
* @param chatRequest
* @param headers
*/
@CrossOrigin
@PostMapping("/chat")
@ResponseBody
public ChatResponse sseChat(@RequestBody ChatRequest chatRequest, @RequestHeader Map<String, String> headers, HttpServletResponse response) {
String uid = getUid(headers);
return sseService.sseChat(uid, chatRequest);
}
/**
* 关闭连接
*
* @param headers
*/
@CrossOrigin
@GetMapping("/closeSse")
public void closeConnect(@RequestHeader Map<String, String> headers) {
String uid = getUid(headers);
sseService.closeSse(uid);
}
7、主要的页面代码
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>智能问答</title>
<link rel="stylesheet" href="styles.css"> <!-- 引入外部CSS -->
<script src="HZRecorder.js"></script>
<script src="https://cdn.bootcdn.net/ajax/libs/jquery/3.6.0/jquery.min.js"></script>
<script src="js/markdown.min.js"></script>
<script src="js/eventsource.min.js"></script>
<script>
function setText(text, uuid_str) {
let content = document.getElementById(uuid_str);
content.innerHTML = marked(text);
}
function uuid() {
var s = [];
var hexDigits = "0123456789abcdef";
for (var i = 0; i < 36; i++) {
s[i] = hexDigits.substr(Math.floor(Math.random() * 0x10), 1);
}
s[14] = "4"; // bits 12-15 of the time_hi_and_version field to 0010
s[19] = hexDigits.substr((s[19] & 0x3) | 0x8, 1); // bits 6-7 of the clock_seq_hi_and_reserved to 01
s[8] = s[13] = s[18] = s[23] = "-";
var uuid = s.join("");
console.log(uuid)
return uuid;
}
window.onload = function () {
/*let disconnectBtn = document.getElementById("disconnectSSE");*/
let messageElement = document.getElementById("messageInput");
let chat = document.getElementById("chat-messages");
let sse;
let uid = window.localStorage.getItem("uid");
if (uid == null || uid == "" || uid == "null") {
uid = uuid();
}
let text = "";
let uuid_str;
// 设置本地存储
window.localStorage.setItem("uid", uid);
// 发送消息按钮点击事件
document.getElementById('sendTextButton').addEventListener('click', async function () {
try {
const userInput = document.getElementById('messageInput').value.trim();
if (userInput) {
await sseOneTurn(userInput)
userInput.value = ''; // 清空输入框
} else {
alert('请输入文字消息!');
}
} catch (error) {
alert('发送消息时发生错误: ' + error.message);
}
});
// 回车事件
messageElement.onkeydown = function () {
if (window.event.keyCode === 13) {
if (!messageElement.value) {
return;
}
sseOneTurn(messageElement.value);
}
};
function sseOneTurn(InputText) {
uuid_str = uuid();
//创建sse
const eventSource = new EventSourcePolyfill("/createSse", {
headers: {
uid: uid,
},
});
eventSource.onopen = (event) => {
console.log("开始输出后端返回值");
sse = event.target;
};
eventSource.onmessage = (event) => {
debugger;
if (event.lastEventId == "[TOKENS]") {
text = text + event.data;
setText(text, uuid_str);
text = "";
return;
}
if (event.data == "[DONE]") {
text = "";
if (sse) {
sse.close();
}
return;
}
let json_data = JSON.parse(event.data);
console.log(json_data);
if (json_data.text == null || json_data.text == "null") {
return;
}
text = text + json_data.text;
setText(text, uuid_str);
};
eventSource.onerror = (event) => {
console.log("onerror", event);
alert("服务异常请重试并联系开发者!");
if (event.readyState === EventSource.CLOSED) {
console.log("connection is closed");
} else {
console.log("Error occured", event);
}
event.target.close();
};
eventSource.addEventListener("customEventName", (event) => {
console.log("Message id is " + event.lastEventId);
});
eventSource.addEventListener("customEventName", (event) => {
console.log("Message id is " + event.lastEventId);
});
$.ajax({
type: "post",
url: "/chat",
data: JSON.stringify({
msg: InputText,
}),
contentType: "application/json;charset=UTF-8",
dataType: "json",
headers: {
uid: uid,
},
beforeSend: function (request) {},
success: function (result) {
//新增问题框
debugger;
chat.innerHTML +=
'<tr><td style="height: 30px;">' +
InputText +
"<br/><br/> tokens:" +
result.question_tokens +
"</td></tr>";
InputText = null;
//新增答案框
chat.innerHTML +=
'<tr><td><article id="' +
uuid_str +
'" class="markdown-body"></article></td></tr>';
},
complete: function () {},
error: function () {
console.info("发送问题失败!");
},
});
}
/*disconnectBtn.onclick = function () {
if (sse) {
sse.close();
}
};*/
};
</script>
</head>
<body>
<div class="chat-container">
<div class="chat-header">
<h1>智能问答</h1>
</div>
<div class="chat-messages" id="chat-messages">
<!-- 聊天消息将会在这里显示 -->
</div>
<form class="message-form" onsubmit="return false;">
<input type="text" id="messageInput" placeholder="输入消息..." autocomplete="off">
<button type="button" id="sendTextButton">发送文字</button>
<button type="button" id="recordAndUploadButton">按住录音</button>
<progress id="uploadProgress" value="0" max="100" style="display:none;"></progress>
</form>
</div>
</body>
</html>
最后的呈现效果如下:
更多推荐
已为社区贡献6条内容
所有评论(0)