useChatData.ts 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631
  1. // dashboard-v4/dashboard/src/hooks/useChatData.ts
  2. import { useState, useCallback, useMemo, useEffect } from "react";
  3. import type {
  4. MessageNode,
  5. ChatState,
  6. ChatActions,
  7. PendingMessage,
  8. OpenAIMessage,
  9. TOpenAIRole,
  10. ParsedChunk,
  11. ToolCall,
  12. CreateMessageRequest,
  13. } from "../types/chat";
  14. import { useActivePath } from "./useActivePath";
  15. import { useSessionGroups } from "./useSessionGroups";
  16. //import { messageApi } from "../services/messageApi";
  17. import { mockMessageApi as messageApi } from "../services/mockMessageApi";
  18. import { getModelAdapter } from "../services/modelAdapters";
  19. import type { IAiModel } from "../api/ai";
  20. export function useChatData(chatId: string): {
  21. chatState: ChatState;
  22. actions: ChatActions;
  23. } {
  24. // Mock模式:直接使用mock数据
  25. const [rawMessages, setRawMessages] = useState<MessageNode[]>([]);
  26. const [pendingMessages, setPendingMessages] = useState<PendingMessage[]>([]);
  27. const [isLoading, setIsLoading] = useState(false);
  28. const [isInitialized, setIsInitialized] = useState(false); // 新增:标记是否已初始化
  29. const [streamingMessage, setStreamingMessage] = useState<string>();
  30. const [streamingSessionId, setStreamingSessionId] = useState<string>();
  31. const [error, setError] = useState<string>();
  32. const [activePoint, setActivePoint] = useState<string>();
  33. const [currModel, setCurrModel] = useState<IAiModel>();
  34. // 合并已保存和待保存的消息用于显示
  35. const allMessages = useMemo(() => {
  36. const pending = pendingMessages.flatMap((p) => p.messages);
  37. return [...rawMessages, ...pending];
  38. }, [rawMessages, pendingMessages]);
  39. const activePath = useActivePath(allMessages, activePoint);
  40. const sessionGroups = useSessionGroups(activePath, allMessages);
  41. // 加载消息列表
  42. const loadMessages = useCallback(async () => {
  43. // 如果 chatId 为空或无效,不执行加载
  44. /*
  45. if (!chatId || chatId.trim() === "") {
  46. return;
  47. }
  48. */
  49. try {
  50. setIsLoading(true);
  51. setError(undefined); // 清除之前的错误
  52. const response = await messageApi.getMessages(chatId);
  53. setRawMessages(response.data.rows);
  54. setIsInitialized(true);
  55. } catch (err) {
  56. const errorMessage = err instanceof Error ? err.message : "加载消息失败";
  57. setError(errorMessage);
  58. console.error("加载消息失败:", err);
  59. } finally {
  60. setIsLoading(false);
  61. }
  62. }, [chatId]);
  63. // 当 chatId 变化时自动加载消息
  64. useEffect(() => {
  65. // 重置状态
  66. setIsInitialized(false);
  67. setRawMessages([]);
  68. setPendingMessages([]);
  69. setError(undefined);
  70. setActivePoint(undefined);
  71. // 加载新的消息
  72. loadMessages();
  73. }, [chatId, loadMessages]);
  74. // 手动刷新方法(供外部调用)
  75. const refreshMessages = useCallback(async () => {
  76. setIsInitialized(false);
  77. await loadMessages();
  78. }, [loadMessages]);
  79. // 构建对话历史(用于AI API调用)
  80. const buildConversationHistory = useCallback(
  81. (_baseMessages: MessageNode[], newUserMessage?: MessageNode) => {
  82. const history: OpenAIMessage[] = activePath.map((m) => ({
  83. role: m.role,
  84. content: m.content || "",
  85. tool_calls: m.tool_calls,
  86. tool_call_id: m.tool_call_id,
  87. }));
  88. if (newUserMessage) {
  89. history.push({
  90. role: "user",
  91. content: newUserMessage.content || "",
  92. });
  93. }
  94. return history;
  95. },
  96. [activePath]
  97. );
  98. // 保存消息组到数据库
  99. const saveMessageGroup = useCallback(
  100. async (tempId: string, messages: MessageNode[]) => {
  101. try {
  102. const data: CreateMessageRequest = {
  103. messages: messages.map((m) => ({
  104. uid: m.uid,
  105. parent_id: m.parent_id,
  106. role: m.role,
  107. content: m.content,
  108. session_id: m.session_id,
  109. model_id: m.model_id,
  110. tool_calls: m.tool_calls,
  111. tool_call_id: m.tool_call_id,
  112. metadata: m.metadata,
  113. })),
  114. };
  115. const savedMessages = await messageApi.createMessages(chatId, data);
  116. // 更新本地状态:移除pending,添加到已保存消息
  117. setPendingMessages((prev) => prev.filter((p) => p.temp_id !== tempId));
  118. setRawMessages((prev) => [...prev, ...savedMessages.data]);
  119. } catch (err) {
  120. console.error("保存消息组失败:", err);
  121. setPendingMessages((prev) =>
  122. prev.map((p) =>
  123. p.temp_id === tempId
  124. ? {
  125. ...p,
  126. error: err instanceof Error ? err.message : "保存失败",
  127. messages: p.messages.map((m) => ({
  128. ...m,
  129. save_status: "failed" as const,
  130. })),
  131. }
  132. : p
  133. )
  134. );
  135. }
  136. },
  137. [chatId]
  138. );
  139. // 一个安全 JSON 解析函数
  140. function ___safeJsonParse(str: string): any {
  141. try {
  142. return JSON.parse(str);
  143. } catch {
  144. return str; // 返回原始字符串,避免崩溃
  145. }
  146. }
  147. // 发送消息给AI并处理响应
  148. const sendMessageToAI = useCallback(
  149. async (userMessage: MessageNode, pendingGroup: PendingMessage) => {
  150. try {
  151. if (!currModel) {
  152. console.error("no model selected");
  153. return;
  154. }
  155. console.debug("ai chat send message current model", currModel);
  156. setIsLoading(true);
  157. setStreamingSessionId(pendingGroup.session_id);
  158. const conversationHistory = buildConversationHistory(
  159. rawMessages,
  160. userMessage
  161. );
  162. const adapter = getModelAdapter(currModel);
  163. // 处理Function Call的循环逻辑
  164. const currentMessages = conversationHistory;
  165. let maxIterations = 10;
  166. const allAiMessages: MessageNode[] = [];
  167. while (maxIterations-- > 0) {
  168. // 流式处理AI响应
  169. let responseContent = "";
  170. let functionCalls: ToolCall[] = [];
  171. const metadata: any = {};
  172. const streamResponse = await adapter.sendMessage(currentMessages, {
  173. temperature: 0.7,
  174. max_tokens: 2048,
  175. });
  176. // 用于存储流式拼接的 tool_calls
  177. const toolCallBuffer = new Map<number, ToolCall>();
  178. // 流式输出处理
  179. await new Promise((resolve, reject) => {
  180. const processStream = async () => {
  181. try {
  182. for await (const chunk of streamResponse) {
  183. const parsed: ParsedChunk | null =
  184. adapter.parseStreamChunk(chunk);
  185. // 处理内容流
  186. if (parsed?.content) {
  187. responseContent += parsed.content;
  188. setStreamingMessage(responseContent);
  189. }
  190. // 处理 tool_calls
  191. if (parsed?.tool_calls) {
  192. for (const call of parsed.tool_calls) {
  193. //console.info("ai chat call", call);
  194. const existing = toolCallBuffer.get(call.index);
  195. if (existing) {
  196. // 拼接 arguments
  197. existing.function.arguments +=
  198. call.function.arguments || "";
  199. //console.debug("ai chat 拼接 arguments", existing);
  200. } else {
  201. // 初始化新 tool_call
  202. //console.debug("ai chat 初始化新 tool_call", call);
  203. toolCallBuffer.set(call.index, {
  204. ...call,
  205. function: {
  206. ...call.function,
  207. arguments: call.function.arguments || "",
  208. },
  209. });
  210. }
  211. }
  212. console.info(
  213. "ai chat function call (buffer)",
  214. toolCallBuffer
  215. );
  216. }
  217. // 如果模型说明调用结束
  218. if (parsed?.finish_reason === "tool_calls") {
  219. // 在这里 arguments 已经拼接完整,可以解析
  220. const toolCalls: ToolCall[] = [];
  221. toolCallBuffer.forEach((value: ToolCall, _key: number) => {
  222. toolCalls.push(value);
  223. });
  224. console.info("ai chat Final tool calls", toolCalls);
  225. // TODO: 在这里触发你实际的函数调用逻辑
  226. functionCalls = toolCalls;
  227. }
  228. }
  229. resolve(undefined);
  230. } catch (err) {
  231. reject(err);
  232. }
  233. };
  234. processStream();
  235. });
  236. // 创建AI请求消息
  237. const toolCallsMessage: MessageNode = {
  238. id: 0,
  239. uid: `temp_ai_${pendingGroup.temp_id}_${allAiMessages.length}`,
  240. temp_id: pendingGroup.temp_id,
  241. chat_id: chatId,
  242. session_id: pendingGroup.session_id,
  243. parent_id:
  244. allAiMessages.length === 0
  245. ? userMessage.uid
  246. : allAiMessages[allAiMessages.length - 1].uid,
  247. role: "assistant",
  248. content: responseContent,
  249. model_id: currModel.uid,
  250. tool_calls: functionCalls.length > 0 ? functionCalls : undefined,
  251. metadata,
  252. is_active: true,
  253. save_status: "pending",
  254. created_at: new Date().toISOString(),
  255. updated_at: new Date().toISOString(),
  256. };
  257. allAiMessages.push(toolCallsMessage);
  258. // 如果有function calls,处理它们
  259. if (functionCalls.length > 0) {
  260. const toolResults = await Promise.all(
  261. functionCalls.map((call) => adapter.handleFunctionCall(call))
  262. );
  263. //ai相应消息
  264. const toolMessages: MessageNode[] = functionCalls.map(
  265. (call, index) => ({
  266. id: 0,
  267. uid: `temp_tool_${pendingGroup.temp_id}_${index}`,
  268. temp_id: pendingGroup.temp_id,
  269. chat_id: chatId,
  270. session_id: pendingGroup.session_id,
  271. parent_id: toolCallsMessage.uid,
  272. role: "tool" as const,
  273. content: JSON.stringify(toolResults[index]),
  274. tool_call_id: call.id,
  275. is_active: true,
  276. save_status: "pending" as const,
  277. created_at: new Date().toISOString(),
  278. updated_at: new Date().toISOString(),
  279. })
  280. );
  281. allAiMessages.push(...toolMessages);
  282. console.debug("ai chat allAiMessages", allAiMessages);
  283. // 更新对话历史,继续循环
  284. currentMessages.push(
  285. {
  286. role: "assistant",
  287. content: responseContent,
  288. tool_calls: functionCalls,
  289. },
  290. ...toolMessages.map((tm) => ({
  291. role: "tool" as const,
  292. content: tm.content || "",
  293. tool_call_id: tm.tool_call_id,
  294. }))
  295. );
  296. console.debug("ai chat currentMessages", currentMessages);
  297. continue;
  298. }
  299. // 没有function call,结束循环
  300. break;
  301. }
  302. // 更新pending消息组
  303. setPendingMessages((prev) =>
  304. prev.map((p) =>
  305. p.temp_id === pendingGroup.temp_id
  306. ? { ...p, messages: [...p.messages, ...allAiMessages] }
  307. : p
  308. )
  309. );
  310. // 保存整个消息组到数据库
  311. await saveMessageGroup(pendingGroup.temp_id, [
  312. userMessage,
  313. ...allAiMessages,
  314. ]);
  315. } catch (err) {
  316. console.error("AI响应失败:", err);
  317. setPendingMessages((prev) =>
  318. prev.map((p) =>
  319. p.temp_id === pendingGroup.temp_id
  320. ? {
  321. ...p,
  322. error: err instanceof Error ? err.message : "未知错误",
  323. retry_count: p.retry_count + 1,
  324. }
  325. : p
  326. )
  327. );
  328. } finally {
  329. setIsLoading(false);
  330. setStreamingMessage(undefined);
  331. setStreamingSessionId(undefined);
  332. }
  333. },
  334. [currModel, buildConversationHistory, rawMessages, saveMessageGroup, chatId]
  335. );
  336. // 编辑消息 - 创建新版本
  337. const editMessage = useCallback(
  338. async (sessionId: string, content: string, role: TOpenAIRole = "user") => {
  339. const tempId = `temp_${Date.now()}`;
  340. try {
  341. // 找到要编辑的消息的父消息
  342. let parentId: string | undefined;
  343. if (sessionId === "new") {
  344. // 新消息,找到最后一个激活消息作为父消息
  345. const lastMessage = activePath[activePath.length - 1];
  346. parentId = lastMessage?.uid;
  347. } else {
  348. // 编辑现有session,找到该session的父消息
  349. const sessionMessages = activePath.filter(
  350. (m) => m.session_id === sessionId
  351. );
  352. const firstMessage = sessionMessages[0];
  353. parentId = firstMessage?.parent_id;
  354. }
  355. const newSessionId =
  356. sessionId === "new" ? `session_${tempId}` : `session_${tempId}`;
  357. const maxId = Math.max(...rawMessages.map((msg) => msg.id));
  358. // 创建新的用户消息
  359. const newUserMessage: MessageNode = {
  360. id: maxId + 1,
  361. uid: `temp_user_${tempId}`,
  362. temp_id: tempId,
  363. chat_id: chatId,
  364. parent_id: parentId,
  365. session_id: newSessionId,
  366. role: "user",
  367. content,
  368. is_active: true,
  369. save_status: "pending",
  370. created_at: new Date().toISOString(),
  371. updated_at: new Date().toISOString(),
  372. };
  373. // 创建待保存消息组
  374. const pendingGroup: PendingMessage = {
  375. temp_id: tempId,
  376. session_id: newSessionId,
  377. messages: [newUserMessage],
  378. retry_count: 0,
  379. created_at: new Date().toISOString(),
  380. };
  381. setPendingMessages((prev) => [...prev, pendingGroup]);
  382. console.debug("ai chat", pendingGroup);
  383. // 如果是用户消息,发送给AI
  384. if (role === "user") {
  385. await sendMessageToAI(newUserMessage, pendingGroup);
  386. }
  387. } catch (err) {
  388. console.error("编辑消息失败:", err);
  389. setPendingMessages((prev) =>
  390. prev.map((p) =>
  391. p.temp_id === tempId
  392. ? {
  393. ...p,
  394. messages: p.messages.map((m) => ({
  395. ...m,
  396. save_status: "failed" as const,
  397. })),
  398. error: err instanceof Error ? err.message : "编辑失败",
  399. }
  400. : p
  401. )
  402. );
  403. }
  404. },
  405. [rawMessages, chatId, activePath, sendMessageToAI]
  406. );
  407. // 重试失败的消息
  408. const retryMessage = useCallback(
  409. async (tempId: string) => {
  410. const pendingGroup = pendingMessages.find((p) => p.temp_id === tempId);
  411. if (!pendingGroup) return;
  412. const userMessage = pendingGroup.messages.find((m) => m.role === "user");
  413. if (!userMessage) return;
  414. // 重置状态并重试
  415. setPendingMessages((prev) =>
  416. prev.map((p) =>
  417. p.temp_id === tempId
  418. ? {
  419. ...p,
  420. messages: [{ ...userMessage, save_status: "pending" }],
  421. error: undefined,
  422. }
  423. : p
  424. )
  425. );
  426. await sendMessageToAI(userMessage, {
  427. ...pendingGroup,
  428. messages: [userMessage],
  429. });
  430. },
  431. [pendingMessages, sendMessageToAI]
  432. );
  433. // 切换版本
  434. const switchVersion = useCallback((activeMsgId: string) => {
  435. console.debug("activeMsgId", activeMsgId);
  436. setActivePoint(activeMsgId);
  437. }, []);
  438. // 刷新AI回答
  439. const refreshResponse = useCallback(
  440. async (sessionId: string, _modelId?: string) => {
  441. const session = sessionGroups.find((sg) => sg.session_id === sessionId);
  442. if (!session?.user_message) return;
  443. const tempId = `temp_refresh_${Date.now()}`;
  444. try {
  445. // 创建基于原用户消息的新AI回答
  446. const userMsg = session.user_message;
  447. const newSessionId = `session_${tempId}`;
  448. const pendingGroup: PendingMessage = {
  449. temp_id: tempId,
  450. session_id: newSessionId,
  451. messages: [
  452. {
  453. ...userMsg,
  454. temp_id: tempId,
  455. session_id: newSessionId,
  456. save_status: "pending",
  457. },
  458. ],
  459. retry_count: 0,
  460. created_at: new Date().toISOString(),
  461. };
  462. setPendingMessages((prev) => [...prev, pendingGroup]);
  463. // 发送给AI获取新回答
  464. await sendMessageToAI(pendingGroup.messages[0], pendingGroup);
  465. } catch (err) {
  466. console.error("刷新回答失败:", err);
  467. setError(err instanceof Error ? err.message : "刷新失败");
  468. }
  469. },
  470. [sessionGroups, sendMessageToAI]
  471. );
  472. // 消息操作功能
  473. const likeMessage = useCallback(async (messageId: string) => {
  474. try {
  475. await messageApi.likeMessage(messageId);
  476. // 可以添加本地状态更新
  477. } catch (err) {
  478. console.error("点赞失败:", err);
  479. }
  480. }, []);
  481. const dislikeMessage = useCallback(async (messageId: string) => {
  482. try {
  483. await messageApi.dislikeMessage(messageId);
  484. // 可以添加本地状态更新
  485. } catch (err) {
  486. console.error("点踩失败:", err);
  487. }
  488. }, []);
  489. const copyMessage = useCallback(
  490. (messageId: string) => {
  491. const message = allMessages.find((m) => m.uid === messageId);
  492. if (message?.content) {
  493. navigator.clipboard.writeText(message.content);
  494. }
  495. },
  496. [allMessages]
  497. );
  498. const shareMessage = useCallback(
  499. async (messageId: string): Promise<string> => {
  500. try {
  501. const response = await messageApi.shareMessage(messageId);
  502. return response.data.shareUrl;
  503. } catch (err) {
  504. console.error("分享失败:", err);
  505. throw err;
  506. }
  507. },
  508. []
  509. );
  510. const deleteMessage = useCallback(
  511. async (messageId: string) => {
  512. try {
  513. await messageApi.deleteMessage(messageId);
  514. await refreshMessages(); // 使用 refreshMessages 而不是 loadMessages
  515. } catch (err) {
  516. console.error("删除失败:", err);
  517. setError(err instanceof Error ? err.message : "删除失败");
  518. }
  519. },
  520. [refreshMessages]
  521. );
  522. const setModel = useCallback((model: IAiModel | undefined) => {
  523. setCurrModel(model);
  524. }, []);
  525. const actions = useMemo(
  526. () => ({
  527. switchVersion,
  528. editMessage,
  529. retryMessage,
  530. refreshResponse,
  531. loadMessages: refreshMessages, // 对外暴露的是 refreshMessages
  532. likeMessage,
  533. dislikeMessage,
  534. copyMessage,
  535. shareMessage,
  536. deleteMessage,
  537. setModel,
  538. }),
  539. [
  540. switchVersion,
  541. editMessage,
  542. retryMessage,
  543. refreshResponse,
  544. refreshMessages,
  545. likeMessage,
  546. dislikeMessage,
  547. copyMessage,
  548. shareMessage,
  549. deleteMessage,
  550. setModel,
  551. ]
  552. );
  553. return {
  554. chatState: {
  555. chat_id: chatId,
  556. title: "", // 可以从props传入或另行管理
  557. raw_messages: rawMessages,
  558. active_path: activePath,
  559. session_groups: sessionGroups,
  560. pending_messages: pendingMessages,
  561. is_loading: isLoading,
  562. is_initialized: isInitialized, // 新增:初始化状态
  563. streaming_message: streamingMessage,
  564. streaming_session_id: streamingSessionId,
  565. current_model: currModel,
  566. error,
  567. },
  568. actions,
  569. };
  570. }