useSessionGroups.ts 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. // dashboard-v4/dashboard/src/hooks/useSessionGroups.ts
  2. import { useMemo, useCallback } from "react";
  3. import type { MessageNode, SessionInfo, VersionInfo } from "../types/chat"
  4. export function useSessionGroups(
  5. activePath: MessageNode[],
  6. rawMessages: MessageNode[]
  7. ) {
  8. const computeSessionVersions = useCallback(
  9. (sessionId: string): VersionInfo[] => {
  10. /**
  11. * 找到session的parent message
  12. * parent message children 就是versions
  13. */
  14. // 找到该session的所有消息
  15. const sessionMessages = rawMessages.filter(
  16. (m) => m.session_id === sessionId
  17. );
  18. if (sessionMessages.length === 0) {
  19. return [];
  20. }
  21. const firstMsg = sessionMessages.sort((a, b) => a.id - b.id)[0];
  22. const parentMsg = rawMessages.find(
  23. (value) => value.uid === firstMsg.parent_id
  24. );
  25. if (!parentMsg) {
  26. return [];
  27. }
  28. const childrenMsg = rawMessages.filter(
  29. (value) => value.parent_id === parentMsg.uid
  30. );
  31. console.debug("parentMsg", parentMsg, childrenMsg);
  32. // 转换为VersionInfo数组
  33. const versions: VersionInfo[] = childrenMsg.map((msg, index) => {
  34. return {
  35. version_index: index,
  36. message_id: msg.uid,
  37. };
  38. });
  39. return versions;
  40. },
  41. [rawMessages]
  42. );
  43. const findCurrentVersion = useCallback(
  44. (sessionMessages: MessageNode[], versions: VersionInfo[]): number => {
  45. // 找到当前激活的AI消息
  46. const activeAiMsg = sessionMessages.find(
  47. (m) => m.role === "assistant" && m.is_active
  48. );
  49. if (!activeAiMsg) return versions.length - 1;
  50. // 根据创建时间找到对应的版本索引
  51. const versionIndex = versions.findIndex(
  52. (v) => v.message_id === activeAiMsg.uid
  53. );
  54. return Math.max(0, versionIndex);
  55. },
  56. []
  57. );
  58. const computeSessionGroups = useCallback((): SessionInfo[] => {
  59. const sessionMap = new Map<string, MessageNode[]>();
  60. // 按session_id分组激活路径上的消息(排除system消息)
  61. activePath.forEach((msg) => {
  62. if (msg.role !== "system") {
  63. const sessionId = msg.session_id;
  64. if (!sessionMap.has(sessionId)) {
  65. sessionMap.set(sessionId, []);
  66. }
  67. sessionMap.get(sessionId)!.push(msg);
  68. }
  69. });
  70. // 为每个session计算版本信息
  71. const sessionGroups: SessionInfo[] = [];
  72. sessionMap.forEach((messages, sessionId) => {
  73. const versions = computeSessionVersions(sessionId);
  74. const currentVersion = findCurrentVersion(messages, versions);
  75. const userMessage = messages.find((m) => m.role === "user");
  76. const aiMessages = messages.filter((m) => m.role !== "user");
  77. sessionGroups.push({
  78. session_id: sessionId,
  79. messages,
  80. versions,
  81. current_version: currentVersion,
  82. user_message: userMessage,
  83. ai_messages: aiMessages,
  84. });
  85. });
  86. // 按消息ID排序,保证显示顺序
  87. return sessionGroups.sort((a, b) => {
  88. const aFirstId = Math.min(...a.messages.map((m) => m.id));
  89. const bFirstId = Math.min(...b.messages.map((m) => m.id));
  90. return aFirstId - bFirstId;
  91. });
  92. }, [activePath, computeSessionVersions, findCurrentVersion]);
  93. return useMemo(() => computeSessionGroups(), [computeSessionGroups]);
  94. }