Kaynağa Gözat

:fire: content

visuddhinanda 8 ay önce
ebeveyn
işleme
0e9e656866

+ 60 - 58
dashboard-v4/dashboard/src/components/chat/AiChat.tsx

@@ -27,18 +27,17 @@ const { TextArea } = Input;
 
 // 类型定义
 export interface MessageVersion {
+  id: number;
   content: string;
   model: string;
+  role: "system" | "user" | "assistant";
+  timestamp: string;
 }
 
 export interface Message {
   id: number;
   type: "user" | "ai" | "error";
-  content: string;
-  timestamp: string;
-  model?: string;
-  versions?: MessageVersion[];
-  currentVersionIndex?: number;
+  versions: MessageVersion[];
 }
 
 interface OpenAIMessage {
@@ -59,6 +58,10 @@ interface OpenAIStreamResponse {
   }>;
 }
 
+const endOfMsg = (msg: Message) => {
+  return msg.versions[msg.versions.length - 1];
+};
+
 interface IWidget {
   initMessage?: string;
   systemPrompt?: string;
@@ -149,6 +152,7 @@ const AIChatComponent = ({
   const callOpenAI = useCallback(
     async (
       messages: OpenAIMessage[],
+      modelId: string,
       isRegenerate: boolean = false,
       messageIndex?: number
     ): Promise<{ success: boolean; content?: string; error?: string }> => {
@@ -157,9 +161,10 @@ const AIChatComponent = ({
         console.error("no REACT_APP_OPENAI_PROXY");
         return { success: false, error: "API配置错误" };
       }
+      console.log("modelId", modelId);
       try {
         const payload = {
-          model: models?.find((value) => value.uid === selectedModel)?.model,
+          model: models?.find((value) => value.uid === modelId)?.model,
           messages: messages,
           stream: true,
           temperature: 0.7,
@@ -167,7 +172,7 @@ const AIChatComponent = ({
         };
         const url = process.env.REACT_APP_OPENAI_PROXY;
         const data = {
-          model_id: selectedModel,
+          model_id: modelId,
           payload: payload,
         };
         console.info("api request", url, data);
@@ -196,28 +201,22 @@ const AIChatComponent = ({
         const typeController = streamTypeWriter(
           (content: string) => {},
           (finalContent: string) => {
+            const newData: MessageVersion = {
+              id: Date.now(),
+              content: finalContent,
+              model: selectedModel,
+              role: "assistant",
+              timestamp: new Date().toLocaleTimeString(),
+            };
             if (isRegenerate && messageIndex !== undefined) {
               setMessages((prev) => {
                 const newMessages = [...prev];
                 const targetMessage = newMessages[messageIndex];
                 if (targetMessage) {
                   if (!targetMessage.versions) {
-                    targetMessage.versions = [
-                      {
-                        content: targetMessage.content,
-                        model: targetMessage.model || "",
-                      },
-                    ];
-                    targetMessage.currentVersionIndex = 0;
+                    targetMessage.versions = [];
                   }
-                  targetMessage.versions.push({
-                    content: finalContent,
-                    model: selectedModel,
-                  });
-                  targetMessage.currentVersionIndex =
-                    targetMessage.versions.length - 1;
-                  targetMessage.content = finalContent;
-                  targetMessage.model = selectedModel;
+                  targetMessage.versions.push(newData);
                 }
                 setRefreshingMessageId(null);
                 return newMessages;
@@ -226,11 +225,7 @@ const AIChatComponent = ({
               const aiMessage: Message = {
                 id: Date.now(),
                 type: "ai",
-                content: finalContent,
-                timestamp: new Date().toLocaleTimeString(),
-                model: selectedModel,
-                versions: [{ content: finalContent, model: selectedModel }],
-                currentVersionIndex: 0,
+                versions: [newData],
               };
               setMessages((prev) => [...prev, aiMessage]);
               setRefreshingMessageId(null);
@@ -291,13 +286,17 @@ const AIChatComponent = ({
     async (messageText: string = inputValue): Promise<void> => {
       if (!messageText.trim()) return;
 
-      const userMessage: Message = {
+      const newData: MessageVersion = {
         id: Date.now(),
-        type: "user",
         content: messageText,
+        model: "",
+        role: "user",
         timestamp: new Date().toLocaleTimeString(),
-        versions: [{ content: messageText, model: "" }],
-        currentVersionIndex: 0,
+      };
+      const userMessage: Message = {
+        id: Date.now(),
+        type: "user",
+        versions: [newData],
       };
 
       setMessages((prev) => [...prev, userMessage]);
@@ -314,14 +313,14 @@ const AIChatComponent = ({
           ...messages.map((msg) => {
             const data: OpenAIMessage = {
               role: msg.type === "user" ? "user" : "assistant",
-              content: msg.content,
+              content: msg.versions[msg.versions.length - 1].content,
             };
             return data;
           }),
           { role: "user", content: messageText },
         ];
 
-        const result = await callOpenAI(conversationHistory);
+        const result = await callOpenAI(conversationHistory, selectedModel);
         setIsLoading(false);
         if (!result.success) {
           setError("请求失败,请重试");
@@ -336,7 +335,7 @@ const AIChatComponent = ({
   );
 
   const refreshAIResponse = useCallback(
-    async (messageIndex: number): Promise<void> => {
+    async (messageIndex: number, modelId: string): Promise<void> => {
       console.debug("refresh", messageIndex);
       const userMessage = messages[messageIndex - 1];
       if (userMessage && userMessage.type === "user") {
@@ -346,16 +345,17 @@ const AIChatComponent = ({
           ...messages.slice(0, messageIndex - 1).map((msg) => {
             const data: OpenAIMessage = {
               role: msg.type === "user" ? "user" : "assistant",
-              content: msg.content,
+              content: endOfMsg(msg).content,
             };
             return data;
           }),
-          { role: "user", content: userMessage.content },
+          { role: "user", content: endOfMsg(userMessage).content },
         ];
 
         try {
           const result = await callOpenAI(
             conversationHistory,
+            modelId,
             true,
             messageIndex
           );
@@ -369,24 +369,18 @@ const AIChatComponent = ({
               const newMessages = [...prev];
               const targetMessage = newMessages[messageIndex];
               if (targetMessage) {
+                const newData: MessageVersion = {
+                  id: Date.now(),
+                  content: result.content || "",
+                  model: selectedModel,
+                  role: "assistant",
+                  timestamp: new Date().toLocaleTimeString(),
+                };
                 targetMessage.type = "ai"; // Update type to "ai"
-                targetMessage.content = result.content || "";
-                targetMessage.model = selectedModel;
                 if (!targetMessage.versions) {
-                  targetMessage.versions = [
-                    {
-                      content: targetMessage.content,
-                      model: targetMessage.model || "",
-                    },
-                  ];
-                  targetMessage.currentVersionIndex = 0;
+                  targetMessage.versions = [];
                 }
-                targetMessage.versions.push({
-                  content: result.content || "",
-                  model: selectedModel,
-                });
-                targetMessage.currentVersionIndex =
-                  targetMessage.versions.length - 1;
+                targetMessage.versions.push(newData);
               }
               setRefreshingMessageId(null);
               return newMessages;
@@ -410,12 +404,16 @@ const AIChatComponent = ({
       if (messageIndex !== -1) {
         const message = newMessages[messageIndex];
         if (!message.versions) {
-          message.versions = [{ content: message.content, model: "" }];
-          message.currentVersionIndex = 0;
+          message.versions = [];
         }
-        message.versions.push({ content: text, model: "" });
-        message.currentVersionIndex = message.versions.length - 1;
-        message.content = text;
+        const newData: MessageVersion = {
+          id: Date.now(),
+          content: text,
+          model: "",
+          role: "user",
+          timestamp: new Date().toLocaleTimeString(),
+        };
+        message.versions.push(newData);
       }
       return newMessages;
     });
@@ -467,7 +465,9 @@ const AIChatComponent = ({
                   <MsgAssistant
                     msg={msg}
                     models={models}
-                    onRefresh={() => refreshAIResponse(index)}
+                    onRefresh={(modelId: string) => {
+                      refreshAIResponse(index, modelId);
+                    }}
                   />
                 );
               } else {
@@ -478,7 +478,9 @@ const AIChatComponent = ({
           {error ? (
             <MsgError
               message={error}
-              onRefresh={() => refreshAIResponse(messages.length - 1)}
+              onRefresh={() =>
+                refreshAIResponse(messages.length - 1, selectedModel)
+              }
             />
           ) : (
             <></>