AiChat.tsx 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. import React, { useState, useRef, useEffect, useCallback } from "react";
  2. import {
  3. Input,
  4. Button,
  5. Dropdown,
  6. Tooltip,
  7. Space,
  8. type MenuProps,
  9. Card,
  10. Affix,
  11. } from "antd";
  12. import {
  13. SendOutlined,
  14. DownOutlined,
  15. PaperClipOutlined,
  16. } from "@ant-design/icons";
  17. import type { IAiModel, IAiModelListResponse } from "../../api/ai"; // eslint-disable-line
  18. import { get } from "../../request";
  19. import MsgUser from "./MsgUser";
  20. import MsgAssistant from "./MsgAssistant";
  21. import MsgTyping from "./MsgTyping";
  22. import MsgLoading from "./MsgLoading";
  23. import MsgSystem from "./MsgSystem";
  24. import MsgError from "./MsgError";
  25. import PromptButtonGroup from "./PromptButtonGroup";
  26. import { useAppSelector } from "../../hooks";
  27. import { currentUser } from "../../reducers/current-user";
  28. import type { IFtsResponse } from "../fts/FullTextSearchResult";
  29. import { siteInfo } from "../../reducers/layout";
  30. const { TextArea } = Input;
  31. type AIRole = "system" | "user" | "assistant" | "function";
  32. // 类型定义
  33. export interface MessageVersion {
  34. id: number;
  35. content: string;
  36. model: string;
  37. role: AIRole;
  38. timestamp: string;
  39. }
  40. export interface Message {
  41. id: number;
  42. type: "user" | "ai" | "error";
  43. versions: MessageVersion[];
  44. }
  45. interface OpenAIMessage {
  46. role: AIRole;
  47. content: string;
  48. name?: string;
  49. }
  50. interface StreamTypeController {
  51. addToken: (token: string) => void;
  52. complete: () => void;
  53. }
  54. interface OpenAIStreamResponse {
  55. // eslint-disable-line
  56. choices?: Array<{
  57. delta?: {
  58. content?: string;
  59. };
  60. }>;
  61. }
  62. const endOfMsg = (msg: Message) => {
  63. return msg.versions[msg.versions.length - 1];
  64. };
  65. interface IWidget {
  66. initMessage?: string;
  67. systemPrompt?: string;
  68. onChat?: () => void;
  69. }
  70. const AIChatComponent = ({
  71. initMessage,
  72. systemPrompt = `你是一个巴利语专家和佛教术语解释助手。当用户询问佛教术语时,你可以调用 searchTerm 函数来查询详细信息。
  73. 使用方法:
  74. - 当用户输入类似"术语:dhamma"、"查询:karma"、"什么是 buddha"等时,调用函数查询
  75. - 查询结果会以结构化的方式展示,包含定义、词源、分类、详细说明等信息`,
  76. onChat,
  77. }: IWidget) => {
  78. const [messages, setMessages] = useState<Message[]>([]);
  79. const [inputValue, setInputValue] = useState<string>("");
  80. const [isLoading, setIsLoading] = useState<boolean>(false);
  81. const [selectedModel, setSelectedModel] = useState<string>("");
  82. const [fetchModel, setFetchModel] = useState<string>("");
  83. const [refreshingMessageId, setRefreshingMessageId] = useState<number | null>(
  84. null
  85. );
  86. const messagesEndRef = useRef<HTMLDivElement>(null);
  87. const [isTyping, setIsTyping] = useState<boolean>(false);
  88. const [currentTypingMessage, setCurrentTypingMessage] = useState<string>("");
  89. const [models, setModels] = useState<IAiModel[]>();
  90. const [error, setError] = useState<string>();
  91. const user = useAppSelector(currentUser);
  92. const site = useAppSelector(siteInfo);
  93. const scrollToBottom = useCallback(() => {
  94. messagesEndRef.current?.scrollIntoView({
  95. behavior: "smooth",
  96. block: "center",
  97. });
  98. }, []);
  99. useEffect(() => {
  100. setModels(site?.settings?.models?.chat ?? []);
  101. if (
  102. site?.settings?.models?.chat &&
  103. site?.settings?.models?.chat.length > 0
  104. ) {
  105. setSelectedModel(site?.settings?.models?.chat[0].uid);
  106. }
  107. }, [site?.settings?.models?.chat]);
  108. useEffect(() => {
  109. scrollToBottom();
  110. }, [messages, currentTypingMessage, scrollToBottom]);
  111. useEffect(() => {
  112. if (initMessage) {
  113. setMessages([]);
  114. setInputValue(initMessage);
  115. }
  116. }, [initMessage]);
  117. const streamTypeWriter = useCallback(
  118. (
  119. onToken?: (content: string) => void,
  120. onComplete?: (finalContent: string) => void
  121. ): StreamTypeController => {
  122. setIsTyping(true);
  123. setCurrentTypingMessage("");
  124. return {
  125. addToken: (token: string) => {
  126. setCurrentTypingMessage((prev) => {
  127. const newContent = prev + token;
  128. onToken && onToken(newContent);
  129. return newContent;
  130. });
  131. },
  132. complete: () => {
  133. setIsTyping(false);
  134. setCurrentTypingMessage((prev) => {
  135. const finalContent = prev;
  136. setCurrentTypingMessage("");
  137. onComplete && onComplete(finalContent);
  138. return "";
  139. });
  140. },
  141. };
  142. },
  143. []
  144. );
  145. const searchTerm = async (term: string) => {
  146. // 示例:请求你后端的百科 API
  147. const url = `/v2/search-pali-wbw?view=pali&key=${term}&limit=20&offset=0`;
  148. console.info("search api request", url);
  149. const res = await get<IFtsResponse>(url);
  150. if (res.ok) {
  151. console.info("search 搜索结果", res.data.rows);
  152. return res.data.rows;
  153. }
  154. return { error: "没有找到相关术语" };
  155. };
  156. const callOpenAI = useCallback(
  157. async (
  158. messages: OpenAIMessage[],
  159. modelId: string,
  160. isRegenerate: boolean = false,
  161. messageIndex?: number,
  162. _depth: number = 0
  163. ): Promise<{ success: boolean; content?: string; error?: string }> => {
  164. setError(undefined);
  165. if (typeof import.meta.env.VITE_REACT_APP_OPENAI_PROXY === "undefined") {
  166. console.error("no REACT_APP_OPENAI_PROXY");
  167. return { success: false, error: "API配置错误" };
  168. }
  169. const functions = [
  170. {
  171. name: "searchTerm",
  172. description: "查询佛教术语,返回百科词条",
  173. parameters: {
  174. type: "object",
  175. properties: {
  176. term: {
  177. type: "string",
  178. description: "要查询的巴利语或佛学术语",
  179. },
  180. },
  181. required: ["term"],
  182. },
  183. },
  184. ];
  185. try {
  186. setFetchModel(modelId);
  187. const payload: any = {
  188. model: models?.find((value) => value.uid === modelId)?.model,
  189. messages,
  190. stream: true,
  191. temperature: 0.5,
  192. max_tokens: 3000,
  193. functions,
  194. function_call: "auto", // 让模型决定是否调用函数
  195. };
  196. const url = import.meta.env.VITE_REACT_APP_OPENAI_PROXY;
  197. const data = {
  198. model_id: modelId,
  199. payload,
  200. };
  201. console.info("api request", url, data);
  202. setIsLoading(true);
  203. const response = await fetch(url, {
  204. method: "POST",
  205. headers: {
  206. "Content-Type": "application/json",
  207. Authorization: `Bearer ${import.meta.env.VITE_REACT_APP_OPENAI_KEY}`,
  208. },
  209. body: JSON.stringify(data),
  210. });
  211. if (!response.ok) {
  212. throw new Error(`HTTP error! status: ${response.status}`);
  213. }
  214. const reader = response.body?.getReader();
  215. if (!reader) {
  216. throw new Error("无法获取响应流");
  217. }
  218. const decoder = new TextDecoder();
  219. let buffer = "";
  220. // 🔑 新增 function_call 拼接缓冲
  221. let functionCallName: string | null = null;
  222. let functionCallArgsBuffer = "";
  223. let functionCallInProgress = false;
  224. const typeController = streamTypeWriter(
  225. (_content: string) => {},
  226. (finalContent: string) => {
  227. console.log("newData in callOpenAI", finalContent);
  228. const newData: MessageVersion = {
  229. id: Date.now(),
  230. content: finalContent,
  231. model: modelId,
  232. role: "assistant",
  233. timestamp: new Date().toLocaleTimeString(),
  234. };
  235. if (isRegenerate && messageIndex !== undefined) {
  236. setMessages((prev) => {
  237. const newMessages = [...prev];
  238. const targetMessage = newMessages[messageIndex];
  239. if (targetMessage) {
  240. if (!targetMessage.versions) {
  241. targetMessage.versions = [];
  242. }
  243. targetMessage.versions.push(newData);
  244. }
  245. setRefreshingMessageId(null);
  246. return newMessages;
  247. });
  248. } else {
  249. const aiMessage: Message = {
  250. id: Date.now(),
  251. type: "ai",
  252. versions: [newData],
  253. };
  254. setMessages((prev) => [...prev, aiMessage]);
  255. setRefreshingMessageId(null);
  256. }
  257. }
  258. );
  259. // ✅ 安全 parse
  260. const safeParseArgs = (s: string) => {
  261. try {
  262. return JSON.parse(s);
  263. } catch {
  264. return {};
  265. }
  266. };
  267. // 🔑 处理 function_call 完成时执行函数 + 再次请求
  268. const handleFunctionCallAndReask = async () => {
  269. const argsObj = safeParseArgs(functionCallArgsBuffer);
  270. console.log("完整 arguments:", functionCallArgsBuffer, argsObj);
  271. if (functionCallName === "searchTerm") {
  272. const result = await searchTerm(
  273. argsObj.term || argsObj.query || ""
  274. );
  275. const followUp: OpenAIMessage[] = [
  276. ...messages,
  277. {
  278. role: "function",
  279. name: "searchTerm",
  280. content: JSON.stringify(result),
  281. },
  282. ];
  283. console.log("search 再次请求", followUp);
  284. return await callOpenAI(
  285. followUp,
  286. modelId,
  287. isRegenerate,
  288. messageIndex,
  289. _depth + 1
  290. );
  291. }
  292. return { success: false, error: "未知函数: " + functionCallName };
  293. };
  294. try {
  295. while (true) {
  296. const { done, value } = await reader.read();
  297. if (done) {
  298. if (functionCallInProgress && functionCallName) {
  299. const res = await handleFunctionCallAndReask();
  300. setIsLoading(false);
  301. return res;
  302. }
  303. typeController.complete();
  304. setIsLoading(false);
  305. return { success: true, content: "" };
  306. }
  307. buffer += decoder.decode(value, { stream: true });
  308. const lines = buffer.split("\n");
  309. buffer = lines.pop() || "";
  310. for (const line of lines) {
  311. if (!line.trim() || !line.startsWith("data: ")) continue;
  312. const data = line.slice(6);
  313. if (data === "[DONE]") {
  314. if (functionCallInProgress && functionCallName) {
  315. const res = await handleFunctionCallAndReask();
  316. setIsLoading(false);
  317. return res;
  318. }
  319. typeController.complete();
  320. setIsLoading(false);
  321. return { success: true, content: "" };
  322. }
  323. let parsed: any = null;
  324. try {
  325. parsed = JSON.parse(data);
  326. } catch {
  327. continue;
  328. }
  329. const delta = parsed.choices?.[0]?.delta;
  330. const finish_reason = parsed.choices?.[0]?.finish_reason;
  331. // 🔑 拼接 function_call
  332. if (delta?.function_call) {
  333. if (delta.function_call.name) {
  334. functionCallName = delta.function_call.name;
  335. }
  336. if (typeof delta.function_call.arguments === "string") {
  337. functionCallInProgress = true;
  338. functionCallArgsBuffer += delta.function_call.arguments;
  339. }
  340. }
  341. // 正常文本输出
  342. if (delta?.content && !functionCallInProgress) {
  343. typeController.addToken(delta.content);
  344. }
  345. // function_call 完成
  346. if (finish_reason === "function_call") {
  347. const res = await handleFunctionCallAndReask();
  348. setIsLoading(false);
  349. return res;
  350. }
  351. }
  352. }
  353. } catch (error) {
  354. console.error("读取流数据失败:", error);
  355. typeController.complete();
  356. return { success: false, error: "读取响应流失败" };
  357. }
  358. } catch (error) {
  359. console.error("API调用失败:", error);
  360. return { success: false, error: "API调用失败,请重试" };
  361. }
  362. },
  363. [models, streamTypeWriter, currentTypingMessage]
  364. );
  365. const sendMessage = useCallback(
  366. async (messageText: string = inputValue): Promise<void> => {
  367. if (!messageText.trim()) return;
  368. const newData: MessageVersion = {
  369. id: Date.now(),
  370. content: messageText,
  371. model: "",
  372. role: "user",
  373. timestamp: new Date().toLocaleTimeString(),
  374. };
  375. const userMessage: Message = {
  376. id: Date.now(),
  377. type: "user",
  378. versions: [newData],
  379. };
  380. setMessages((prev) => [...prev, userMessage]);
  381. setInputValue("");
  382. setIsLoading(true);
  383. // Scroll to the new user message
  384. scrollToBottom();
  385. try {
  386. const conversationHistory: OpenAIMessage[] = [
  387. { role: "system", content: systemPrompt },
  388. ...messages.map((msg) => {
  389. const data: OpenAIMessage = {
  390. role: msg.type === "user" ? "user" : "assistant",
  391. content: msg.versions[msg.versions.length - 1].content,
  392. };
  393. return data;
  394. }),
  395. { role: "user", content: messageText },
  396. ];
  397. const result = await callOpenAI(conversationHistory, selectedModel);
  398. setIsLoading(false);
  399. if (!result.success) {
  400. setError("请求失败,请重试");
  401. }
  402. } catch (error) {
  403. console.error("发送消息失败:", error);
  404. setError("请求失败,请重试");
  405. setIsLoading(false);
  406. }
  407. },
  408. [
  409. inputValue,
  410. scrollToBottom,
  411. systemPrompt,
  412. messages,
  413. callOpenAI,
  414. selectedModel,
  415. ]
  416. );
  417. const refreshAIResponse = useCallback(
  418. async (messageIndex: number, modelId: string): Promise<void> => {
  419. console.debug("refresh", messageIndex);
  420. const userMessage = messages[messageIndex - 1];
  421. if (userMessage && userMessage.type === "user") {
  422. setRefreshingMessageId(messages[messageIndex].id);
  423. const conversationHistory: OpenAIMessage[] = [
  424. { role: "system", content: systemPrompt },
  425. ...messages.slice(0, messageIndex - 1).map((msg) => {
  426. const data: OpenAIMessage = {
  427. role: msg.type === "user" ? "user" : "assistant",
  428. content: endOfMsg(msg).content,
  429. };
  430. return data;
  431. }),
  432. { role: "user", content: endOfMsg(userMessage).content },
  433. ];
  434. try {
  435. const result = await callOpenAI(
  436. conversationHistory,
  437. modelId,
  438. true,
  439. messageIndex
  440. );
  441. setIsLoading(false);
  442. if (!result.success) {
  443. setError("重新生成失败,请重试");
  444. setRefreshingMessageId(null);
  445. } else {
  446. /*
  447. console.log("newData refreshAIResponse", result);
  448. setMessages((prev) => {
  449. const newMessages = [...prev];
  450. const targetMessage = newMessages[messageIndex];
  451. if (targetMessage) {
  452. const newData: MessageVersion = {
  453. id: Date.now(),
  454. content: result.content || "",
  455. model: modelId,
  456. role: "assistant",
  457. timestamp: new Date().toLocaleTimeString(),
  458. };
  459. targetMessage.type = "ai"; // Update type to "ai"
  460. if (!targetMessage.versions) {
  461. targetMessage.versions = [];
  462. }
  463. targetMessage.versions.push(newData);
  464. }
  465. setRefreshingMessageId(null);
  466. return newMessages;
  467. });
  468. */
  469. }
  470. } catch (error) {
  471. console.error("刷新回答失败:", error);
  472. setIsLoading(false);
  473. setError("请求失败,请重试");
  474. setRefreshingMessageId(null);
  475. }
  476. }
  477. },
  478. [messages, systemPrompt, callOpenAI]
  479. );
  480. const confirmEdit = useCallback((id: number, text: string): void => {
  481. setMessages((prev) => {
  482. const newMessages = [...prev];
  483. const messageIndex = newMessages.findIndex((m) => m.id === id);
  484. if (messageIndex !== -1) {
  485. const message = newMessages[messageIndex];
  486. if (!message.versions) {
  487. message.versions = [];
  488. }
  489. const newData: MessageVersion = {
  490. id: Date.now(),
  491. content: text,
  492. model: "",
  493. role: "user",
  494. timestamp: new Date().toLocaleTimeString(),
  495. };
  496. message.versions.push(newData);
  497. }
  498. return newMessages;
  499. });
  500. }, []);
  501. const handleKeyPress = useCallback(
  502. (e: React.KeyboardEvent<HTMLTextAreaElement>): void => {
  503. if (e.key === "Enter" && !e.shiftKey) {
  504. e.preventDefault();
  505. sendMessage();
  506. }
  507. },
  508. [sendMessage]
  509. );
  510. const modelMenu: MenuProps = {
  511. selectedKeys: [selectedModel],
  512. onClick: ({ key }) => {
  513. console.log("setSelectedModel", key);
  514. setSelectedModel(key);
  515. },
  516. items: models?.map((model) => ({
  517. key: model.uid,
  518. label: model.name,
  519. })),
  520. };
  521. return user ? (
  522. <div
  523. style={{
  524. display: "flex",
  525. flexDirection: "column",
  526. width: "100%",
  527. }}
  528. >
  529. <div style={{ flex: 1, overflowY: "auto", padding: "16px" }}>
  530. <Space orientation="vertical" size="middle" style={{ width: "100%" }}>
  531. <MsgSystem value={systemPrompt} />
  532. {messages.map((msg, index) => {
  533. if (msg.id === refreshingMessageId) {
  534. return <></>;
  535. } else {
  536. if (msg.type === "user") {
  537. return (
  538. <MsgUser
  539. key={index}
  540. msg={msg}
  541. onChange={(value: string) => confirmEdit(index, value)}
  542. />
  543. );
  544. } else if (msg.type === "ai") {
  545. return (
  546. <MsgAssistant
  547. key={index}
  548. msg={msg}
  549. models={models}
  550. onRefresh={(modelId: string) => {
  551. refreshAIResponse(index, modelId);
  552. }}
  553. />
  554. );
  555. } else {
  556. return <>unknown</>;
  557. }
  558. }
  559. })}
  560. {error ? (
  561. <MsgError
  562. message={error}
  563. onRefresh={() =>
  564. refreshAIResponse(messages.length - 1, fetchModel)
  565. }
  566. />
  567. ) : (
  568. <></>
  569. )}
  570. {isTyping && (
  571. <MsgTyping
  572. text={currentTypingMessage}
  573. model={models?.find((m) => m.uid === fetchModel)}
  574. />
  575. )}
  576. {isLoading && !isTyping && (
  577. <MsgLoading model={models?.find((m) => m.uid === fetchModel)} />
  578. )}
  579. </Space>
  580. <div ref={messagesEndRef} />
  581. </div>
  582. <Affix offsetBottom={10}>
  583. <Card style={{ borderRadius: "10px", borderColor: "#d9d9d9" }}>
  584. <div style={{ maxWidth: "1200px", margin: "0 auto" }}>
  585. <div style={{ display: "flex", marginBottom: "8px" }}>
  586. <TextArea
  587. value={inputValue}
  588. onChange={(e) => setInputValue(e.target.value)}
  589. onKeyPress={handleKeyPress}
  590. placeholder="提出你的问题,如:总结下面的内容..."
  591. autoSize={{ minRows: 1, maxRows: 6 }}
  592. style={{ resize: "none", paddingRight: "48px" }}
  593. />
  594. </div>
  595. <div
  596. style={{
  597. display: "flex",
  598. justifyContent: "space-between",
  599. alignItems: "center",
  600. }}
  601. >
  602. <Space>
  603. <Tooltip title="附加文件">
  604. <Button
  605. size="small"
  606. type="text"
  607. icon={<PaperClipOutlined />}
  608. />
  609. </Tooltip>
  610. <PromptButtonGroup onText={setInputValue} />
  611. </Space>
  612. <Space>
  613. <Dropdown menu={modelMenu} trigger={["click"]}>
  614. <Button size="small" type="text">
  615. {models?.find((m) => m.uid === selectedModel)?.name}
  616. <DownOutlined />
  617. </Button>
  618. </Dropdown>
  619. <Button
  620. type="primary"
  621. icon={<SendOutlined />}
  622. onClick={() => {
  623. sendMessage();
  624. onChat && onChat();
  625. }}
  626. disabled={!inputValue.trim() || isLoading}
  627. />
  628. </Space>
  629. </div>
  630. </div>
  631. </Card>
  632. </Affix>
  633. </div>
  634. ) : (
  635. <></>
  636. );
  637. };
  638. export default AIChatComponent;