Просмотр исходного кода

时间不足,重新放入队列

visuddhinanda 9 месяцев назад
Родитель
Сommit
886daf8862
1 измененных файлов с 141 добавлено и 148 удалено
  1. 141 148
      ai-translate/ai_translate/ai_translate.py

+ 141 - 148
ai-translate/ai_translate/ai_translate.py

@@ -4,10 +4,10 @@ import logging
 import requests
 from typing import List, Dict, Any, Optional
 from datetime import datetime
-import redis
 from requests.exceptions import RequestException
 from dataclasses import dataclass
-from abc import ABC, abstractmethod
+import pdb
+import time
 
 # 配置日志
 logging.basicConfig(level=logging.INFO)
@@ -15,8 +15,15 @@ logger = logging.getLogger(__name__)
 
 
 class DatabaseException(Exception):
-    """数据库异常"""
-    pass
+    def __init__(self, message="database api access exception"):
+        self.message = message
+        super().__init__(self.message)
+
+
+class SectionTimeout(Exception):
+    def __init__(self, message="片段超时"):
+        self.message = message
+        super().__init__(self.message)
 
 
 @dataclass
@@ -27,7 +34,7 @@ class TaskProgress:
 
 
 @dataclass
-class Task:
+class TaskInfo:
     """任务模型"""
     id: str
     title: str
@@ -36,6 +43,13 @@ class Task:
     status: str
 
 
+@dataclass
+class Task:
+    """任务模型"""
+    info: TaskInfo
+    progress: TaskProgress
+
+
 @dataclass
 class AiModel:
     """AI模型配置"""
@@ -61,7 +75,7 @@ class Sentence:
 
 
 @dataclass
-class Message:
+class Payload:
     """消息模型"""
     model: AiModel
     task: Task
@@ -69,180 +83,113 @@ class Message:
     sentence: Sentence
 
 
-class RedisClusters:
-    """Redis集群工具类"""
-
-    def __init__(self, host='localhost', port=6379, db=0):
-        self.redis_client = redis.Redis(host=host, port=port, db=db)
-
-    def put(self, key: str, value: Any, ttl: Optional[int] = None):
-        """存储数据"""
-        if isinstance(value, (dict, list)):
-            value = json.dumps(value)
-        if ttl:
-            self.redis_client.setex(key, ttl, value)
-        else:
-            self.redis_client.set(key, value)
-
-    def get(self, key: str) -> Any:
-        """获取数据"""
-        value = self.redis_client.get(key)
-        if value:
-            return value.decode('utf-8')
-        return None
-
-    def has(self, key: str) -> bool:
-        """检查键是否存在"""
-        return self.redis_client.exists(key) > 0
-
-    def forget(self, key: str):
-        """删除键"""
-        self.redis_client.delete(key)
-
-
-class JobInterface(ABC):
-    """作业接口"""
-
-    @abstractmethod
-    def is_stop(self) -> bool:
-        """检查作业是否停止"""
-        pass
+@dataclass
+class Message:
+    """消息模型"""
+    model: AiModel
+    task: Task
+    payload: List[Payload]
+    sentence: Sentence
 
 
 class AiTranslateService:
     """AI翻译服务"""
 
-    def __init__(self, app_url: str = "http://localhost:8000"):
+    def __init__(self, redis, ch, method, api_url, customer_timeout):
         self.queue = 'ai_translate'
         self.model_token = None
         self.task = None
-        self.redis_clusters = RedisClusters()
-        self.mq = RabbitMQService()
+        self.redis_clusters = redis[0]
+        self.redis_namespace = redis[1]
         self.api_timeout = 100
         self.llm_timeout = 300
         self.task_topic_id = None
-        self.app_url = app_url
+        self.api_url = api_url
+        self.customer_timeout = customer_timeout
+        self.channel = ch
+        self.maxProcessTime = 15 * 60  # 一个句子的最大处理时间
 
-    def process_translate(self, message_id: str, messages: List[Message], job: JobInterface) -> bool:
+    def process_translate(self, message_id: str, body: Message) -> bool:
         """处理翻译任务"""
 
-        if not messages or len(messages) == 0:
-            logger.error('message is not array')
-            return False
+        taskStartAt = int(time.time())
 
-        first = messages[0]
-        self.task = first.task
-        task_id = self.task.id
+        self.task = body.task
 
-        self.redis_clusters.put(f"/task/{task_id}/message_id", message_id)
-        pointer_key = f"/task/{task_id}/pointer"
+        self.redis_clusters.set(
+            f"{self.redis_namespace}/task/{self.task.id}/message_id", message_id)
+        pointer_key = f"{self.redis_namespace}/task/{message_id}/pointer"
         pointer = 0
 
-        if self.redis_clusters.has(pointer_key):
+        if self.redis_clusters.exists(pointer_key):
             # 回到上次中断的点
             pointer = int(self.redis_clusters.get(pointer_key))
             logger.info(f"last break point {pointer}")
+            if pointer >= len(body.payload):
+                self.redis_clusters.delete(pointer_key)
+                return True
 
         # 获取model token
-        self.model_token = first.model.token
-        logger.debug(f'{self.queue} ai assistant token: {self.model_token}')
+        self.model_token = body.model.token
 
         self._set_task_status(self.task.id, 'running')
 
         # 设置task discussion topic
-        self.task_topic_id = self._task_discussion(
-            self.task.id,
-            'task',
-            self.task.title,
-            self.task.category,
-            None
-        )
-
-        for i in range(pointer, len(messages)):
-            # 获取当前内存使用量(Python版本的内存监控)
-            try:
-                import psutil
-                process = psutil.Process()
-                memory_info = process.memory_info()
-                logger.debug(
-                    f"memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
-            except ImportError:
-                logger.debug(
-                    "psutil not installed, skipping memory monitoring")
-
-            if job.is_stop():
-                logger.info(f"收到退出信号 pointer={i}")
-                return False
+        taskTopicKey = f'{self.redis_namespace}/message/{message_id}/topic'
+        if self.redis_clusters.exists(taskTopicKey):
+            # 获取上次的task topic id
+            self.task_topic_id = self.redis_clusters.get(taskTopicKey)
+        else:
+            self.task_topic_id = self._task_discussion(
+                self.task.id,
+                'task',
+                self.task.title,
+                self.task.category,
+                None
+            )
+        times = [self.maxProcessTime]
+        # breakpoint()
+        for i in range(pointer, len(body.payload)):
+            startAt = int(time.time())
 
             # 检测停止标记的工具函数需要实现
             # if Tools.is_stop():
             #     return False
 
-            self.redis_clusters.put(pointer_key, i)
-            message = messages[i]
+            message = body.payload[i]
             task_discussion_content = []
 
             # 推理
-            try:
-                response_llm = self._request_llm(message)
-                task_discussion_content.append('- LLM request successful')
-            except RequestException as e:
-                raise e
+
+            response_llm = self._request_llm(message)
+            task_discussion_content.append('- LLM request successful')
 
             if self.task.category == 'translate':
                 # 写入句子库
                 message.sentence.content = response_llm['content']
-                try:
-                    self._save_sentence(message.sentence)
-                except Exception as e:
-                    logger.error(f'sentence error: {e}')
-                    continue
+                self._save_sentence(message.sentence)
 
             if self.task.category == 'suggest':
                 # 写入pr
-                try:
-                    self._save_pr(message.sentence, response_llm['content'])
-                except Exception as e:
-                    logger.error(f'sentence error: {e}')
-                    continue
+                self._save_pr(message.sentence, response_llm['content'])
 
             # 获取句子id
             s_uid = self._get_sentence_id(message.sentence)
 
             # 写入句子 discussion
-            topic_id = self._task_discussion(
-                s_uid,
-                'sentence',
-                self.task.title,
-                self.task.category,
-                None
-            )
-
-            if topic_id:
-                logger.info(f'{self.queue} discussion create topic successful')
-                topic_children = []
-                # 提示词
-                topic_children.append(message.prompt)
-                # 任务结果
-                topic_children.append(response_llm['content'])
-                # 推理过程写入discussion
-                if response_llm.get('reasoningContent'):
-                    topic_children.append(response_llm['reasoningContent'])
-
-                for content in topic_children:
-                    logger.debug(f'{self.queue} discussion child request')
-                    d_id = self._task_discussion(
-                        s_uid, 'sentence', self.task.title, content, topic_id)
-                    if d_id:
-                        logger.info(
-                            f'{self.queue} discussion child successful')
-            else:
-                logger.error(
-                    f'{self.queue} discussion create topic response is null')
+            topic_children = []
+            # 提示词
+            topic_children.append(message.prompt)
+            # 任务结果
+            topic_children.append(response_llm['content'])
+            # 推理过程写入discussion
+            if response_llm.get('reasoningContent'):
+                topic_children.append(response_llm['reasoningContent'])
+            self._sentence_discussion(s_uid, topic_children)
 
             # 修改task 完成度
             progress = self._set_task_progress(
-                TaskProgress(i + 1, len(messages)))
+                TaskProgress(i + 1, len(body.payload)))
             task_discussion_content.append(f"- progress={progress}")
 
             # 写入task discussion
@@ -258,19 +205,53 @@ class AiTranslateService:
             else:
                 logger.error('no task discussion root')
 
+            if i + 1 < len(body.payload):
+                self.redis_clusters.set(pointer_key, i+1)
+                # 计算本次时间和剩余时间
+                # breakpoint()
+                onceTime = int(time.time())-startAt
+                times.append(onceTime)
+                times.sort(reverse=True)
+                # 取出第一个元素
+                maxTime = times[0]
+                # 计算剩余时间
+                remain = self.customer_timeout-(int(time.time())-taskStartAt)
+                if remain < maxTime:
+                    # 时间不足
+                    raise SectionTimeout
         # 任务完成 修改任务状态为 done
-        if i + 1 == len(messages):
-            self._set_task_status(self.task.id, 'done')
-
-        self.redis_clusters.forget(pointer_key)
+        self._set_task_status(self.task.id, 'done')
+        self.redis_clusters.delete(pointer_key)
         logger.info('ai translate task complete')
         return True
 
+    def _sentence_discussion(self, id, discussions):
+        topic_id = self._task_discussion(
+            id,
+            'sentence',
+            self.task.title,
+            self.task.category,
+            None
+        )
+
+        if topic_id:
+            logger.info(f'{self.queue} discussion create topic successful')
+
+            for content in discussions:
+                logger.debug(f'{self.queue} discussion child request')
+                d_id = self._task_discussion(
+                    id, 'sentence', self.task.title, content, topic_id)
+                if d_id:
+                    logger.info(
+                        f'{self.queue} discussion child successful')
+        else:
+            logger.error(
+                f'{self.queue} discussion create topic response is null')
+
     def _set_task_status(self, task_id: str, status: str):
         """设置任务状态"""
-        url = f"{self.app_url}/api/v2/task-status/{task_id}"
+        url = f"{self.api_url}/v2/task-status/{task_id}"
         data = {'status': status}
-
         logger.debug(f'ai_translate task status request: {url}, data: {data}')
 
         headers = {'Authorization': f'Bearer {self.model_token}'}
@@ -280,16 +261,16 @@ class AiTranslateService:
         if not response.ok:
             logger.error(f'ai_translate task status error: {response.json()}')
         else:
-            logger.info('ai_translate task status done')
+            logger.info(f'ai_translate task status successful ({status})')
 
     def _save_model_log(self, token: str, data: Dict[str, Any]) -> bool:
         """保存模型日志"""
-        url = f"{self.app_url}/api/v2/model-log"
+        url = f"{self.api_url}/v2/model-log"
 
         headers = {'Authorization': f'Bearer {token}'}
         response = requests.post(
             url, json=data, headers=headers, timeout=self.api_timeout)
-
+        # breakpoint()
         if not response.ok:
             logger.error(
                 f'ai-translate model log create failed: {response.json()}')
@@ -298,7 +279,7 @@ class AiTranslateService:
 
     def _task_discussion(self, res_id: str, res_type: str, title: str, content: str, parent_id: Optional[str] = None):
         """创建任务讨论"""
-        url = f"{self.app_url}/api/v2/discussion"
+        url = f"{self.api_url}/v2/discussion"
 
         task_discussion_data = {
             'res_id': res_id,
@@ -377,15 +358,21 @@ class AiTranslateService:
                 logger.info(f'{self.queue} LLM request successful')
 
                 model_log_data.update({
+                    'request_headers': json.dumps(dict(response.request.headers), ensure_ascii=False),
                     'response_headers': json.dumps(dict(response.headers), ensure_ascii=False),
                     'status': response.status_code,
                     'response_data': json.dumps(response.json(), ensure_ascii=False),
                     'success': True
                 })
-                self._save_model_log(self.model_token, model_log_data)
                 break
-
             except requests.exceptions.RequestException as e:
+                model_log_data.update({
+                    'response_headers': json.dumps(dict(e.response.request.headers), ensure_ascii=False),
+                    'response_headers': json.dumps(dict(e.response.headers), ensure_ascii=False),
+                    'status': e.response.status_code,
+                    'response_data': json.dumps(e.response.json(), ensure_ascii=False),
+                    'success': False
+                })
                 attempt += 1
                 status = getattr(e.response, 'status_code',
                                  0) if hasattr(e, 'response') else 0
@@ -403,8 +390,14 @@ class AiTranslateService:
                 else:
                     logger.error("达到最大重试次数,请求最终失败")
                     raise e
-
-        logger.info(f'{self.queue} model log saved')
+            except Exception as e:
+                raise e
+            finally:
+                try:
+                    self._save_model_log(self.model_token, model_log_data)
+                    logger.info(f'{self.queue} model log saved')
+                except Exception as e:
+                    logger.error(e)
 
         ai_data = response.json()
         logger.debug(f'{self.queue} LLM http response: {response.json()}')
@@ -426,7 +419,7 @@ class AiTranslateService:
 
     def _save_sentence(self, sentence: Sentence):
         """写入句子库"""
-        url = f"{self.app_url}/api/v2/sentence"
+        url = f"{self.api_url}/v2/sentence"
 
         logger.info(f"{self.queue} sentence update {url}")
 
@@ -446,7 +439,7 @@ class AiTranslateService:
 
     def _save_pr(self, sentence: Sentence, content: str):
         """保存PR"""
-        url = f"{self.app_url}/api/v2/sentpr"
+        url = f"{self.api_url}/v2/sentpr"
         logger.info(f"{self.queue} sentence update {url}")
 
         data = {
@@ -477,7 +470,7 @@ class AiTranslateService:
 
     def _get_sentence_id(self, sentence: Sentence) -> str:
         """获取句子ID"""
-        url = f"{self.app_url}/api/v2/sentence-info/aa"
+        url = f"{self.api_url}/v2/sentence-info/aa"
         logger.info(f'ai translate: {url}')
 
         params = {
@@ -509,7 +502,7 @@ class AiTranslateService:
             logger.error(
                 f'{self.queue} progress total is zero, task_id: {self.task.id}')
 
-        url = f"{self.app_url}/api/v2/task/{self.task.id}"
+        url = f"{self.api_url}/v2/task/{self.task.id}"
         data = {'progress': progress}
 
         logger.debug(