|
|
@@ -4,10 +4,10 @@ import logging
|
|
|
import requests
|
|
|
from typing import List, Dict, Any, Optional
|
|
|
from datetime import datetime
|
|
|
-import redis
|
|
|
from requests.exceptions import RequestException
|
|
|
from dataclasses import dataclass
|
|
|
-from abc import ABC, abstractmethod
|
|
|
+import pdb
|
|
|
+import time
|
|
|
|
|
|
# 配置日志
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
@@ -15,8 +15,15 @@ logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class DatabaseException(Exception):
|
|
|
- """数据库异常"""
|
|
|
- pass
|
|
|
+ def __init__(self, message="database api access exception"):
|
|
|
+ self.message = message
|
|
|
+ super().__init__(self.message)
|
|
|
+
|
|
|
+
|
|
|
+class SectionTimeout(Exception):
|
|
|
+ def __init__(self, message="片段超时"):
|
|
|
+ self.message = message
|
|
|
+ super().__init__(self.message)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -27,7 +34,7 @@ class TaskProgress:
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
-class Task:
|
|
|
+class TaskInfo:
|
|
|
"""任务模型"""
|
|
|
id: str
|
|
|
title: str
|
|
|
@@ -36,6 +43,13 @@ class Task:
|
|
|
status: str
|
|
|
|
|
|
|
|
|
+@dataclass
|
|
|
+class Task:
|
|
|
+ """任务模型"""
|
|
|
+ info: TaskInfo
|
|
|
+ progress: TaskProgress
|
|
|
+
|
|
|
+
|
|
|
@dataclass
|
|
|
class AiModel:
|
|
|
"""AI模型配置"""
|
|
|
@@ -61,7 +75,7 @@ class Sentence:
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
-class Message:
|
|
|
+class Payload:
|
|
|
"""消息模型"""
|
|
|
model: AiModel
|
|
|
task: Task
|
|
|
@@ -69,180 +83,113 @@ class Message:
|
|
|
sentence: Sentence
|
|
|
|
|
|
|
|
|
-class RedisClusters:
|
|
|
- """Redis集群工具类"""
|
|
|
-
|
|
|
- def __init__(self, host='localhost', port=6379, db=0):
|
|
|
- self.redis_client = redis.Redis(host=host, port=port, db=db)
|
|
|
-
|
|
|
- def put(self, key: str, value: Any, ttl: Optional[int] = None):
|
|
|
- """存储数据"""
|
|
|
- if isinstance(value, (dict, list)):
|
|
|
- value = json.dumps(value)
|
|
|
- if ttl:
|
|
|
- self.redis_client.setex(key, ttl, value)
|
|
|
- else:
|
|
|
- self.redis_client.set(key, value)
|
|
|
-
|
|
|
- def get(self, key: str) -> Any:
|
|
|
- """获取数据"""
|
|
|
- value = self.redis_client.get(key)
|
|
|
- if value:
|
|
|
- return value.decode('utf-8')
|
|
|
- return None
|
|
|
-
|
|
|
- def has(self, key: str) -> bool:
|
|
|
- """检查键是否存在"""
|
|
|
- return self.redis_client.exists(key) > 0
|
|
|
-
|
|
|
- def forget(self, key: str):
|
|
|
- """删除键"""
|
|
|
- self.redis_client.delete(key)
|
|
|
-
|
|
|
-
|
|
|
-class JobInterface(ABC):
|
|
|
- """作业接口"""
|
|
|
-
|
|
|
- @abstractmethod
|
|
|
- def is_stop(self) -> bool:
|
|
|
- """检查作业是否停止"""
|
|
|
- pass
|
|
|
+@dataclass
|
|
|
+class Message:
|
|
|
+ """消息模型"""
|
|
|
+ model: AiModel
|
|
|
+ task: Task
|
|
|
+ payload: List[Payload]
|
|
|
+ sentence: Sentence
|
|
|
|
|
|
|
|
|
class AiTranslateService:
|
|
|
"""AI翻译服务"""
|
|
|
|
|
|
- def __init__(self, app_url: str = "http://localhost:8000"):
|
|
|
+ def __init__(self, redis, ch, method, api_url, customer_timeout):
|
|
|
self.queue = 'ai_translate'
|
|
|
self.model_token = None
|
|
|
self.task = None
|
|
|
- self.redis_clusters = RedisClusters()
|
|
|
- self.mq = RabbitMQService()
|
|
|
+ self.redis_clusters = redis[0]
|
|
|
+ self.redis_namespace = redis[1]
|
|
|
self.api_timeout = 100
|
|
|
self.llm_timeout = 300
|
|
|
self.task_topic_id = None
|
|
|
- self.app_url = app_url
|
|
|
+ self.api_url = api_url
|
|
|
+ self.customer_timeout = customer_timeout
|
|
|
+ self.channel = ch
|
|
|
+ self.maxProcessTime = 15 * 60 # 一个句子的最大处理时间
|
|
|
|
|
|
- def process_translate(self, message_id: str, messages: List[Message], job: JobInterface) -> bool:
|
|
|
+ def process_translate(self, message_id: str, body: Message) -> bool:
|
|
|
"""处理翻译任务"""
|
|
|
|
|
|
- if not messages or len(messages) == 0:
|
|
|
- logger.error('message is not array')
|
|
|
- return False
|
|
|
+ taskStartAt = int(time.time())
|
|
|
|
|
|
- first = messages[0]
|
|
|
- self.task = first.task
|
|
|
- task_id = self.task.id
|
|
|
+ self.task = body.task
|
|
|
|
|
|
- self.redis_clusters.put(f"/task/{task_id}/message_id", message_id)
|
|
|
- pointer_key = f"/task/{task_id}/pointer"
|
|
|
+ self.redis_clusters.set(
|
|
|
+ f"{self.redis_namespace}/task/{self.task.id}/message_id", message_id)
|
|
|
+ pointer_key = f"{self.redis_namespace}/task/{message_id}/pointer"
|
|
|
pointer = 0
|
|
|
|
|
|
- if self.redis_clusters.has(pointer_key):
|
|
|
+ if self.redis_clusters.exists(pointer_key):
|
|
|
# 回到上次中断的点
|
|
|
pointer = int(self.redis_clusters.get(pointer_key))
|
|
|
logger.info(f"last break point {pointer}")
|
|
|
+ if pointer >= len(body.payload):
|
|
|
+ self.redis_clusters.delete(pointer_key)
|
|
|
+ return True
|
|
|
|
|
|
# 获取model token
|
|
|
- self.model_token = first.model.token
|
|
|
- logger.debug(f'{self.queue} ai assistant token: {self.model_token}')
|
|
|
+ self.model_token = body.model.token
|
|
|
|
|
|
self._set_task_status(self.task.id, 'running')
|
|
|
|
|
|
# 设置task discussion topic
|
|
|
- self.task_topic_id = self._task_discussion(
|
|
|
- self.task.id,
|
|
|
- 'task',
|
|
|
- self.task.title,
|
|
|
- self.task.category,
|
|
|
- None
|
|
|
- )
|
|
|
-
|
|
|
- for i in range(pointer, len(messages)):
|
|
|
- # 获取当前内存使用量(Python版本的内存监控)
|
|
|
- try:
|
|
|
- import psutil
|
|
|
- process = psutil.Process()
|
|
|
- memory_info = process.memory_info()
|
|
|
- logger.debug(
|
|
|
- f"memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
|
|
|
- except ImportError:
|
|
|
- logger.debug(
|
|
|
- "psutil not installed, skipping memory monitoring")
|
|
|
-
|
|
|
- if job.is_stop():
|
|
|
- logger.info(f"收到退出信号 pointer={i}")
|
|
|
- return False
|
|
|
+ taskTopicKey = f'{self.redis_namespace}/message/{message_id}/topic'
|
|
|
+ if self.redis_clusters.exists(taskTopicKey):
|
|
|
+ # 获取上次的task topic id
|
|
|
+ self.task_topic_id = self.redis_clusters.get(taskTopicKey)
|
|
|
+ else:
|
|
|
+ self.task_topic_id = self._task_discussion(
|
|
|
+ self.task.id,
|
|
|
+ 'task',
|
|
|
+ self.task.title,
|
|
|
+ self.task.category,
|
|
|
+ None
|
|
|
+ )
|
|
|
+ times = [self.maxProcessTime]
|
|
|
+ # breakpoint()
|
|
|
+ for i in range(pointer, len(body.payload)):
|
|
|
+ startAt = int(time.time())
|
|
|
|
|
|
# 检测停止标记的工具函数需要实现
|
|
|
# if Tools.is_stop():
|
|
|
# return False
|
|
|
|
|
|
- self.redis_clusters.put(pointer_key, i)
|
|
|
- message = messages[i]
|
|
|
+ message = body.payload[i]
|
|
|
task_discussion_content = []
|
|
|
|
|
|
# 推理
|
|
|
- try:
|
|
|
- response_llm = self._request_llm(message)
|
|
|
- task_discussion_content.append('- LLM request successful')
|
|
|
- except RequestException as e:
|
|
|
- raise e
|
|
|
+
|
|
|
+ response_llm = self._request_llm(message)
|
|
|
+ task_discussion_content.append('- LLM request successful')
|
|
|
|
|
|
if self.task.category == 'translate':
|
|
|
# 写入句子库
|
|
|
message.sentence.content = response_llm['content']
|
|
|
- try:
|
|
|
- self._save_sentence(message.sentence)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f'sentence error: {e}')
|
|
|
- continue
|
|
|
+ self._save_sentence(message.sentence)
|
|
|
|
|
|
if self.task.category == 'suggest':
|
|
|
# 写入pr
|
|
|
- try:
|
|
|
- self._save_pr(message.sentence, response_llm['content'])
|
|
|
- except Exception as e:
|
|
|
- logger.error(f'sentence error: {e}')
|
|
|
- continue
|
|
|
+ self._save_pr(message.sentence, response_llm['content'])
|
|
|
|
|
|
# 获取句子id
|
|
|
s_uid = self._get_sentence_id(message.sentence)
|
|
|
|
|
|
# 写入句子 discussion
|
|
|
- topic_id = self._task_discussion(
|
|
|
- s_uid,
|
|
|
- 'sentence',
|
|
|
- self.task.title,
|
|
|
- self.task.category,
|
|
|
- None
|
|
|
- )
|
|
|
-
|
|
|
- if topic_id:
|
|
|
- logger.info(f'{self.queue} discussion create topic successful')
|
|
|
- topic_children = []
|
|
|
- # 提示词
|
|
|
- topic_children.append(message.prompt)
|
|
|
- # 任务结果
|
|
|
- topic_children.append(response_llm['content'])
|
|
|
- # 推理过程写入discussion
|
|
|
- if response_llm.get('reasoningContent'):
|
|
|
- topic_children.append(response_llm['reasoningContent'])
|
|
|
-
|
|
|
- for content in topic_children:
|
|
|
- logger.debug(f'{self.queue} discussion child request')
|
|
|
- d_id = self._task_discussion(
|
|
|
- s_uid, 'sentence', self.task.title, content, topic_id)
|
|
|
- if d_id:
|
|
|
- logger.info(
|
|
|
- f'{self.queue} discussion child successful')
|
|
|
- else:
|
|
|
- logger.error(
|
|
|
- f'{self.queue} discussion create topic response is null')
|
|
|
+ topic_children = []
|
|
|
+ # 提示词
|
|
|
+ topic_children.append(message.prompt)
|
|
|
+ # 任务结果
|
|
|
+ topic_children.append(response_llm['content'])
|
|
|
+ # 推理过程写入discussion
|
|
|
+ if response_llm.get('reasoningContent'):
|
|
|
+ topic_children.append(response_llm['reasoningContent'])
|
|
|
+ self._sentence_discussion(s_uid, topic_children)
|
|
|
|
|
|
# 修改task 完成度
|
|
|
progress = self._set_task_progress(
|
|
|
- TaskProgress(i + 1, len(messages)))
|
|
|
+ TaskProgress(i + 1, len(body.payload)))
|
|
|
task_discussion_content.append(f"- progress={progress}")
|
|
|
|
|
|
# 写入task discussion
|
|
|
@@ -258,19 +205,53 @@ class AiTranslateService:
|
|
|
else:
|
|
|
logger.error('no task discussion root')
|
|
|
|
|
|
+ if i + 1 < len(body.payload):
|
|
|
+ self.redis_clusters.set(pointer_key, i+1)
|
|
|
+ # 计算本次时间和剩余时间
|
|
|
+ # breakpoint()
|
|
|
+ onceTime = int(time.time())-startAt
|
|
|
+ times.append(onceTime)
|
|
|
+ times.sort(reverse=True)
|
|
|
+ # 取出第一个元素
|
|
|
+ maxTime = times[0]
|
|
|
+ # 计算剩余时间
|
|
|
+ remain = self.customer_timeout-(int(time.time())-taskStartAt)
|
|
|
+ if remain < maxTime:
|
|
|
+ # 时间不足
|
|
|
+ raise SectionTimeout
|
|
|
# 任务完成 修改任务状态为 done
|
|
|
- if i + 1 == len(messages):
|
|
|
- self._set_task_status(self.task.id, 'done')
|
|
|
-
|
|
|
- self.redis_clusters.forget(pointer_key)
|
|
|
+ self._set_task_status(self.task.id, 'done')
|
|
|
+ self.redis_clusters.delete(pointer_key)
|
|
|
logger.info('ai translate task complete')
|
|
|
return True
|
|
|
|
|
|
+ def _sentence_discussion(self, id, discussions):
|
|
|
+ topic_id = self._task_discussion(
|
|
|
+ id,
|
|
|
+ 'sentence',
|
|
|
+ self.task.title,
|
|
|
+ self.task.category,
|
|
|
+ None
|
|
|
+ )
|
|
|
+
|
|
|
+ if topic_id:
|
|
|
+ logger.info(f'{self.queue} discussion create topic successful')
|
|
|
+
|
|
|
+ for content in discussions:
|
|
|
+ logger.debug(f'{self.queue} discussion child request')
|
|
|
+ d_id = self._task_discussion(
|
|
|
+ id, 'sentence', self.task.title, content, topic_id)
|
|
|
+ if d_id:
|
|
|
+ logger.info(
|
|
|
+ f'{self.queue} discussion child successful')
|
|
|
+ else:
|
|
|
+ logger.error(
|
|
|
+ f'{self.queue} discussion create topic response is null')
|
|
|
+
|
|
|
def _set_task_status(self, task_id: str, status: str):
|
|
|
"""设置任务状态"""
|
|
|
- url = f"{self.app_url}/api/v2/task-status/{task_id}"
|
|
|
+ url = f"{self.api_url}/v2/task-status/{task_id}"
|
|
|
data = {'status': status}
|
|
|
-
|
|
|
logger.debug(f'ai_translate task status request: {url}, data: {data}')
|
|
|
|
|
|
headers = {'Authorization': f'Bearer {self.model_token}'}
|
|
|
@@ -280,16 +261,16 @@ class AiTranslateService:
|
|
|
if not response.ok:
|
|
|
logger.error(f'ai_translate task status error: {response.json()}')
|
|
|
else:
|
|
|
- logger.info('ai_translate task status done')
|
|
|
+ logger.info(f'ai_translate task status successful ({status})')
|
|
|
|
|
|
def _save_model_log(self, token: str, data: Dict[str, Any]) -> bool:
|
|
|
"""保存模型日志"""
|
|
|
- url = f"{self.app_url}/api/v2/model-log"
|
|
|
+ url = f"{self.api_url}/v2/model-log"
|
|
|
|
|
|
headers = {'Authorization': f'Bearer {token}'}
|
|
|
response = requests.post(
|
|
|
url, json=data, headers=headers, timeout=self.api_timeout)
|
|
|
-
|
|
|
+ # breakpoint()
|
|
|
if not response.ok:
|
|
|
logger.error(
|
|
|
f'ai-translate model log create failed: {response.json()}')
|
|
|
@@ -298,7 +279,7 @@ class AiTranslateService:
|
|
|
|
|
|
def _task_discussion(self, res_id: str, res_type: str, title: str, content: str, parent_id: Optional[str] = None):
|
|
|
"""创建任务讨论"""
|
|
|
- url = f"{self.app_url}/api/v2/discussion"
|
|
|
+ url = f"{self.api_url}/v2/discussion"
|
|
|
|
|
|
task_discussion_data = {
|
|
|
'res_id': res_id,
|
|
|
@@ -377,15 +358,21 @@ class AiTranslateService:
|
|
|
logger.info(f'{self.queue} LLM request successful')
|
|
|
|
|
|
model_log_data.update({
|
|
|
+ 'request_headers': json.dumps(dict(response.request.headers), ensure_ascii=False),
|
|
|
'response_headers': json.dumps(dict(response.headers), ensure_ascii=False),
|
|
|
'status': response.status_code,
|
|
|
'response_data': json.dumps(response.json(), ensure_ascii=False),
|
|
|
'success': True
|
|
|
})
|
|
|
- self._save_model_log(self.model_token, model_log_data)
|
|
|
break
|
|
|
-
|
|
|
except requests.exceptions.RequestException as e:
|
|
|
+ model_log_data.update({
|
|
|
+ 'response_headers': json.dumps(dict(e.response.request.headers), ensure_ascii=False),
|
|
|
+ 'response_headers': json.dumps(dict(e.response.headers), ensure_ascii=False),
|
|
|
+ 'status': e.response.status_code,
|
|
|
+ 'response_data': json.dumps(e.response.json(), ensure_ascii=False),
|
|
|
+ 'success': False
|
|
|
+ })
|
|
|
attempt += 1
|
|
|
status = getattr(e.response, 'status_code',
|
|
|
0) if hasattr(e, 'response') else 0
|
|
|
@@ -403,8 +390,14 @@ class AiTranslateService:
|
|
|
else:
|
|
|
logger.error("达到最大重试次数,请求最终失败")
|
|
|
raise e
|
|
|
-
|
|
|
- logger.info(f'{self.queue} model log saved')
|
|
|
+ except Exception as e:
|
|
|
+ raise e
|
|
|
+ finally:
|
|
|
+ try:
|
|
|
+ self._save_model_log(self.model_token, model_log_data)
|
|
|
+ logger.info(f'{self.queue} model log saved')
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(e)
|
|
|
|
|
|
ai_data = response.json()
|
|
|
logger.debug(f'{self.queue} LLM http response: {response.json()}')
|
|
|
@@ -426,7 +419,7 @@ class AiTranslateService:
|
|
|
|
|
|
def _save_sentence(self, sentence: Sentence):
|
|
|
"""写入句子库"""
|
|
|
- url = f"{self.app_url}/api/v2/sentence"
|
|
|
+ url = f"{self.api_url}/v2/sentence"
|
|
|
|
|
|
logger.info(f"{self.queue} sentence update {url}")
|
|
|
|
|
|
@@ -446,7 +439,7 @@ class AiTranslateService:
|
|
|
|
|
|
def _save_pr(self, sentence: Sentence, content: str):
|
|
|
"""保存PR"""
|
|
|
- url = f"{self.app_url}/api/v2/sentpr"
|
|
|
+ url = f"{self.api_url}/v2/sentpr"
|
|
|
logger.info(f"{self.queue} sentence update {url}")
|
|
|
|
|
|
data = {
|
|
|
@@ -477,7 +470,7 @@ class AiTranslateService:
|
|
|
|
|
|
def _get_sentence_id(self, sentence: Sentence) -> str:
|
|
|
"""获取句子ID"""
|
|
|
- url = f"{self.app_url}/api/v2/sentence-info/aa"
|
|
|
+ url = f"{self.api_url}/v2/sentence-info/aa"
|
|
|
logger.info(f'ai translate: {url}')
|
|
|
|
|
|
params = {
|
|
|
@@ -509,7 +502,7 @@ class AiTranslateService:
|
|
|
logger.error(
|
|
|
f'{self.queue} progress total is zero, task_id: {self.task.id}')
|
|
|
|
|
|
- url = f"{self.app_url}/api/v2/task/{self.task.id}"
|
|
|
+ url = f"{self.api_url}/v2/task/{self.task.id}"
|
|
|
data = {'progress': progress}
|
|
|
|
|
|
logger.debug(
|