Procházet zdrojové kódy

:sparkles: function call

visuddhinanda před 6 měsíci
rodič
revize
467e9b0edb
1 změnil soubory, kde provedl 139 přidání a 27 odebrání
  1. 139 27
      dashboard-v4/dashboard/src/components/chat/AiChat.tsx

+ 139 - 27
dashboard-v4/dashboard/src/components/chat/AiChat.tsx

@@ -25,15 +25,17 @@ import MsgError from "./MsgError";
 import PromptButtonGroup from "./PromptButtonGroup";
 import { useAppSelector } from "../../hooks";
 import { currentUser } from "../../reducers/current-user";
+import { IFtsResponse } from "../fts/FullTextSearchResult";
 
 const { TextArea } = Input;
 
+type AIRole = "system" | "user" | "assistant" | "function";
 // 类型定义
 export interface MessageVersion {
   id: number;
   content: string;
   model: string;
-  role: "system" | "user" | "assistant";
+  role: AIRole;
   timestamp: string;
 }
 
@@ -44,8 +46,9 @@ export interface Message {
 }
 
 interface OpenAIMessage {
-  role: "system" | "user" | "assistant";
+  role: AIRole;
   content: string;
+  name?: string;
 }
 
 interface StreamTypeController {
@@ -73,7 +76,11 @@ interface IWidget {
 
 const AIChatComponent = ({
   initMessage,
-  systemPrompt = "你是一个巴利语专家",
+  systemPrompt = `你是一个巴利语专家和佛教术语解释助手。当用户询问佛教术语时,你可以调用 searchTerm 函数来查询详细信息。
+
+  使用方法:
+  - 当用户输入类似"术语:dhamma"、"查询:karma"、"什么是 buddha"等时,调用函数查询
+  - 查询结果会以结构化的方式展示,包含定义、词源、分类、详细说明等信息`,
   onChat,
 }: IWidget) => {
   const [messages, setMessages] = useState<Message[]>([]);
@@ -155,12 +162,25 @@ const AIChatComponent = ({
     []
   );
 
+  const searchTerm = async (term: string) => {
+    // 示例:请求你后端的百科 API
+    const url = `/v2/search-pali-wbw?view=pali&key=${term}&limit=20&offset=0`;
+    console.info("search api request", url);
+    const res = await get<IFtsResponse>(url);
+    if (res.ok) {
+      console.info("search 搜索结果", res.data.rows);
+      return res.data.rows;
+    }
+    return { error: "没有找到相关术语" };
+  };
+
   const callOpenAI = useCallback(
     async (
       messages: OpenAIMessage[],
       modelId: string,
       isRegenerate: boolean = false,
-      messageIndex?: number
+      messageIndex?: number,
+      _depth: number = 0
     ): Promise<{ success: boolean; content?: string; error?: string }> => {
       setError(undefined);
       if (typeof process.env.REACT_APP_OPENAI_PROXY === "undefined") {
@@ -168,27 +188,46 @@ const AIChatComponent = ({
         return { success: false, error: "API配置错误" };
       }
 
+      const functions = [
+        {
+          name: "searchTerm",
+          description: "查询佛教术语,返回百科词条",
+          parameters: {
+            type: "object",
+            properties: {
+              term: {
+                type: "string",
+                description: "要查询的巴利语或佛学术语",
+              },
+            },
+            required: ["term"],
+          },
+        },
+      ];
       try {
         setFetchModel(modelId);
-        const payload = {
+        const payload: any = {
           model: models?.find((value) => value.uid === modelId)?.model,
-          messages: messages,
+          messages,
           stream: true,
-          temperature: 0.7,
-          max_tokens: 3000, //本次回复”最大输出长度
+          temperature: 0.5,
+          max_tokens: 3000,
+          functions,
+          function_call: "auto", // 让模型决定是否调用函数
         };
         const url = process.env.REACT_APP_OPENAI_PROXY;
         const data = {
           model_id: modelId,
-          payload: payload,
+          payload,
         };
         console.info("api request", url, data);
         setIsLoading(true);
+
         const response = await fetch(url, {
           method: "POST",
           headers: {
             "Content-Type": "application/json",
-            Authorization: `Bearer AIzaSyCzr8KqEdaQ3cRCxsFwSHh8c7kF3RZTZWw`,
+            Authorization: `Bearer ${process.env.REACT_APP_OPENAI_KEY}`,
           },
           body: JSON.stringify(data),
         });
@@ -205,6 +244,11 @@ const AIChatComponent = ({
         const decoder = new TextDecoder();
         let buffer = "";
 
+        // 🔑 新增 function_call 拼接缓冲
+        let functionCallName: string | null = null;
+        let functionCallArgsBuffer = "";
+        let functionCallInProgress = false;
+
         const typeController = streamTypeWriter(
           (content: string) => {},
           (finalContent: string) => {
@@ -241,13 +285,56 @@ const AIChatComponent = ({
           }
         );
 
+        // ✅ 安全 parse
+        const safeParseArgs = (s: string) => {
+          try {
+            return JSON.parse(s);
+          } catch {
+            return {};
+          }
+        };
+
+        // 🔑 处理 function_call 完成时执行函数 + 再次请求
+        const handleFunctionCallAndReask = async () => {
+          const argsObj = safeParseArgs(functionCallArgsBuffer);
+          console.log("完整 arguments:", functionCallArgsBuffer, argsObj);
+
+          if (functionCallName === "searchTerm") {
+            const result = await searchTerm(
+              argsObj.term || argsObj.query || ""
+            );
+            const followUp: OpenAIMessage[] = [
+              ...messages,
+              {
+                role: "function",
+                name: "searchTerm",
+                content: JSON.stringify(result),
+              },
+            ];
+            console.log("search 再次请求", followUp);
+            return await callOpenAI(
+              followUp,
+              modelId,
+              isRegenerate,
+              messageIndex,
+              _depth + 1
+            );
+          }
+          return { success: false, error: "未知函数: " + functionCallName };
+        };
+
         try {
           while (true) {
             const { done, value } = await reader.read();
-
             if (done) {
+              if (functionCallInProgress && functionCallName) {
+                const res = await handleFunctionCallAndReask();
+                setIsLoading(false);
+                return res;
+              }
               typeController.complete();
-              return { success: true, content: currentTypingMessage };
+              setIsLoading(false);
+              return { success: true, content: "" };
             }
 
             buffer += decoder.decode(value, { stream: true });
@@ -255,26 +342,51 @@ const AIChatComponent = ({
             buffer = lines.pop() || "";
 
             for (const line of lines) {
-              if (line.trim() === "") continue;
-              if (line.startsWith("data: ")) {
-                const data = line.slice(6);
-
-                if (data === "[DONE]") {
-                  typeController.complete();
-                  return { success: true, content: currentTypingMessage };
+              if (!line.trim() || !line.startsWith("data: ")) continue;
+              const data = line.slice(6);
+              if (data === "[DONE]") {
+                if (functionCallInProgress && functionCallName) {
+                  const res = await handleFunctionCallAndReask();
+                  setIsLoading(false);
+                  return res;
                 }
+                typeController.complete();
+                setIsLoading(false);
+                return { success: true, content: "" };
+              }
 
-                try {
-                  const parsed: OpenAIStreamResponse = JSON.parse(data);
-                  const delta = parsed.choices?.[0]?.delta;
+              let parsed: any = null;
+              try {
+                parsed = JSON.parse(data);
+              } catch {
+                continue;
+              }
 
-                  if (delta?.content) {
-                    typeController.addToken(delta.content);
-                  }
-                } catch (e) {
-                  console.warn("解析SSE数据失败:", e);
+              const delta = parsed.choices?.[0]?.delta;
+              const finish_reason = parsed.choices?.[0]?.finish_reason;
+
+              // 🔑 拼接 function_call
+              if (delta?.function_call) {
+                if (delta.function_call.name) {
+                  functionCallName = delta.function_call.name;
+                }
+                if (typeof delta.function_call.arguments === "string") {
+                  functionCallInProgress = true;
+                  functionCallArgsBuffer += delta.function_call.arguments;
                 }
               }
+
+              // 正常文本输出
+              if (delta?.content && !functionCallInProgress) {
+                typeController.addToken(delta.content);
+              }
+
+              // function_call 完成
+              if (finish_reason === "function_call") {
+                const res = await handleFunctionCallAndReask();
+                setIsLoading(false);
+                return res;
+              }
             }
           }
         } catch (error) {