Files
MemoryBear/api/app/services/app_chat_service.py
Timebomb2018 531d785629 fix(multimodal): support HTML image tags in document extraction and chat responses
- Replace plain image URLs with `<img src="..." data-url="...">` HTML tags in multimodal and document extractor services
- Propagate citations from workflow end events to client responses
- Update system prompts to instruct LLMs to render images using Markdown `![alt](url)` with strict UUID-preserving URL copying
2026-04-27 17:56:58 +08:00

908 lines
37 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""基于分享链接的聊天服务"""
import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator, Annotated, List
from fastapi import Depends
from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig
from app.repositories.tool_repository import ToolRepository
from app.schemas import DraftRunRequest
from app.schemas.app_schema import FileInput, FileType
from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService
from app.models.file_metadata_model import FileMetadata
logger = get_business_logger()
class AppChatService:
"""基于分享链接的聊天服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_service = ConversationService(db)
self.agent_service = AgentRunService(db)
self.workflow_service = WorkflowService(db)
async def agnet_chat(
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> Dict[str, Any]:
"""聊天(非流式)"""
start_time = time.time()
# 应用 features 配置
features_config: dict = config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 获取工具服务
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools)
memory_flag = False
if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
# 加载历史消息(包含开场白)
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, model_info)
fu_config = features_config.get("file_upload", {})
if hasattr(fu_config, "model_dump"):
fu_config = fu_config.model_dump()
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
processed_files = await multimodal_service.process_files(
files, document_image_recognition=doc_img_recognition,
workspace_id=workspace_id
)
logger.info(f"处理了 {len(processed_files)} 个文件")
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
f.type == FileType.DOCUMENT for f in files
):
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
t.tool_instance.set_runtime_context(
user_id=user_id or "anonymous",
conversation_id=str(conversation_id) if conversation_id else None,
uploaded_files=processed_files or []
)
# 调用 Agent支持多模态
result = await agent.chat(
message=message,
history=history,
context=None,
files=processed_files # 传递处理后的文件
)
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
elapsed_time = time.time() - start_time
# suggested_questions
suggested_questions = []
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
suggested_questions = await self.agent_service._generate_suggested_questions(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
audio_url = await self.agent_service._generate_tts(
features_config, result["content"],
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
tenant_id=tenant_id, workspace_id=workspace_id
)
# 过滤 citations只调用一次
filtered_citations = self.agent_service._filter_citations(features_config, citations_collector)
# 构建用户消息内容(含多模态文件)
human_meta = {
"files": [],
"history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
"usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
"audio_url": None,
"citations": filtered_citations,
"reasoning_content": result.get("reasoning_content")
}
if files:
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url,
"name": name,
"size": size,
"file_type": f.file_type,
})
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
# 保存消息
if audio_url:
assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
file_list = []
for file in files:
file_dict = file.model_dump()
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
file_list.append(file_dict)
messages = [
{"role": "user", "content": message, "files": file_list},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message,
meta_data=human_meta
)
ai_message = self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=result["content"],
meta_data=assistant_meta
)
message_id = ai_message.id
return {
"conversation_id": conversation_id,
"message_id": str(message_id),
"message": result["content"],
"reasoning_content": result.get("reasoning_content"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}),
"elapsed_time": elapsed_time,
"suggested_questions": suggested_questions,
"citations": filtered_citations,
"audio_url": audio_url,
"audio_status": "pending" if audio_url else None
}
async def agnet_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
try:
start_time = time.time()
message_id = uuid.uuid4()
# 应用 features 配置
features_config: dict = config.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
web_search = False
# 校验文件上传
self.agent_service._validate_file_upload(features_config, files)
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
variables = self.agent_service.prepare_variables(variables, config.variables)
# 获取模型配置ID
model_config_id = config.default_model_config_id
api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id)
# 处理系统提示词(支持变量替换)
system_prompt = config.system_prompt
if variables:
system_prompt_rendered = render_prompt_message(
system_prompt,
PromptMessageRole.USER,
variables
)
system_prompt = system_prompt_rendered.get_text_content() or system_prompt
# 准备工具列表
tools = []
# 获取工具服务
tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id))
tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id))
skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id)
tools.extend(skill_tools)
if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools)
# 添加长期记忆工具
memory_flag = False
if memory:
memory_tools, memory_flag = self.agent_service.load_memory_config(
config.memory, user_id, storage_type, user_rag_memory_id
)
tools.extend(memory_tools)
# 获取模型参数
model_parameters = config.model_parameters
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
# 加载历史消息(包含开场白)
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0
if is_new_conversation:
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening:
self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=opening,
meta_data={"suggested_questions": suggested_questions}
)
# 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
# 处理多模态文件
processed_files = None
if files:
multimodal_service = MultimodalService(self.db, model_info)
fu_config = features_config.get("file_upload", {})
if hasattr(fu_config, "model_dump"):
fu_config = fu_config.model_dump()
doc_img_recognition = isinstance(fu_config, dict) and fu_config.get("document_image_recognition", False)
processed_files = await multimodal_service.process_files(
files, document_image_recognition=doc_img_recognition,
workspace_id=workspace_id
)
logger.info(f"处理了 {len(processed_files)} 个文件")
if doc_img_recognition and "vision" in (api_key_obj.capability or []) and any(
f.type == FileType.DOCUMENT for f in files
):
from langchain.agents import create_agent
system_prompt += (
"\n\n文档文字中包含图片位置标记如 [图片 第2页 第1张]: <img src=\"url\"...>"
"请在回答中用 Markdown 格式 ![图片描述](url) 展示对应图片。"
"重要:图片 URL 中包含 UUID如 /storage/permanent/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx"
"必须将 src 属性的值原封不动复制到 Markdown 的括号中,不得增删任何字符。"
)
# 创建 LangChain Agent
agent = LangChainAgent(
model_name=api_key_obj.model_name,
api_key=api_key_obj.api_key,
provider=api_key_obj.provider,
api_base=api_key_obj.api_base,
is_omni=api_key_obj.is_omni,
temperature=model_parameters.get("temperature", 0.7),
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
streaming=True,
deep_thinking=model_parameters.get("deep_thinking", False),
thinking_budget_tokens=model_parameters.get("thinking_budget_tokens"),
json_output=model_parameters.get("json_output", False),
capability=api_key_obj.capability or [],
)
# 为需要运行时上下文的工具注入上下文
for t in tools:
if hasattr(t, 'tool_instance') and hasattr(t.tool_instance, 'set_runtime_context'):
t.tool_instance.set_runtime_context(
user_id=user_id or "anonymous",
conversation_id=str(conversation_id) if conversation_id else None,
uploaded_files=processed_files or []
)
# 流式调用 Agent支持多模态同时并行启动 TTS
full_content = ""
full_reasoning = ""
total_tokens = 0
text_queue: asyncio.Queue = asyncio.Queue()
api_key_config = {
"model_name": api_key_obj.model_name,
"api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base,
"provider": api_key_obj.provider,
}
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
features_config, api_key_config,
text_queue=text_queue,
tenant_id=tenant_id, workspace_id=workspace_id
)
async for chunk in agent.chat_stream(
message=message,
history=history,
context=None,
files=processed_files
):
if isinstance(chunk, int):
total_tokens = chunk
elif isinstance(chunk, dict) and chunk.get("type") == "reasoning":
full_reasoning += chunk['content']
yield f"event: reasoning\ndata: {json.dumps({'content': chunk['content']}, ensure_ascii=False)}\n\n"
else:
full_content += chunk
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
if tts_task is not None:
await text_queue.put(chunk)
if tts_task is not None:
await text_queue.put(None)
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions(
features_config, full_content,
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base}, {}
)
end_data["audio_url"] = stream_audio_url
# 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
# 过滤 citations只调用一次
filtered_citations = self.agent_service._filter_citations(features_config, citations_collector)
end_data["citations"] = filtered_citations
# 保存消息
human_meta = {
"files": [],
"history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
"audio_url": None,
"citations": filtered_citations,
"reasoning_content": full_reasoning or None
}
if files:
local_ids = [f.upload_file_id for f in files
if f.transfer_method.value == "local_file" and f.upload_file_id
and (not f.name or not f.size)]
meta_map = {}
if local_ids:
rows = self.db.query(FileMetadata).filter(
FileMetadata.id.in_(local_ids),
FileMetadata.status == "completed"
).all()
meta_map = {str(r.id): r for r in rows}
for f in files:
name, size = f.name, f.size
if f.transfer_method.value == "local_file" and f.upload_file_id and (not name or not size):
meta = meta_map.get(str(f.upload_file_id))
if meta:
name = name or meta.file_name
size = size or meta.file_size
human_meta["files"].append({
"type": f.type,
"url": f.url,
"name": name,
"size": size,
"file_type": f.file_type,
})
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
file_list = []
for file in files:
file_dict = file.model_dump()
file_dict["upload_file_id"] = str(file_dict["upload_file_id"]) if file_dict["upload_file_id"] else None
file_list.append(file_dict)
messages = [
{"role": "user", "content": message, "files": file_list},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message,
meta_data=human_meta
)
self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id,
role="assistant",
content=full_content,
meta_data=assistant_meta
)
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
"流式聊天完成",
extra={
"conversation_id": str(conversation_id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("流式聊天被中断")
raise
except Exception as e:
logger.error(f"流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def multi_agent_chat(
self,
message: str,
conversation_id: uuid.UUID,
config: MultiAgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""多 Agent 聊天(非流式)"""
start_time = time.time()
actual_config_id = None
config_id = actual_config_id
if variables is None:
variables = {}
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 执行任务
result = await orchestrator.execute(
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
use_llm_routing=True, # 默认启用 LLM 路由
web_search=web_search, # 网络搜索参数
memory=memory # 记忆功能参数
)
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
ai_message = self.conversation_service.add_message(
conversation_id=conversation_id,
role="assistant",
content=result.get("message", ""),
meta_data={
"mode": result.get("mode"),
"elapsed_time": result.get("elapsed_time"),
"usage": result.get("usage", {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
})
}
)
return {
"conversation_id": conversation_id,
"message": result.get("message", ""),
"message_id": str(ai_message.id),
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
},
"elapsed_time": elapsed_time
}
async def multi_agent_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: MultiAgentConfig,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)"""
start_time = time.time()
if variables is None:
variables = {}
try:
message_id = uuid.uuid4()
# 发送开始事件
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
full_content = ""
total_tokens = 0
# 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务
async for event in orchestrator.execute_stream(
message=message,
conversation_id=conversation_id,
user_id=user_id,
variables=variables,
use_llm_routing=True,
web_search=web_search, # 网络搜索参数
memory=memory, # 记忆功能参数
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
# 拦截 sub_usage 事件,累加 token
if "event: sub_usage" in event:
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
total_tokens += data.get("total_tokens", 0)
except:
pass
else:
yield event
# 尝试提取内容(用于保存)
if "data:" in event:
try:
data_line = event.split("data: ", 1)[1].strip()
data = json.loads(data_line)
if "content" in data:
full_content += data["content"]
except:
pass
elapsed_time = time.time() - start_time
# 保存消息
self.conversation_service.add_message(
conversation_id=conversation_id,
role="user",
content=message
)
self.conversation_service.add_message(
message_id=message_id,
conversation_id=conversation_id,
role="assistant",
content=full_content,
meta_data={
"elapsed_time": elapsed_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
}
}
)
logger.info(
"多 Agent 流式聊天完成",
extra={
"conversation_id": str(conversation_id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出
logger.debug("多 Agent 流式聊天被中断")
raise
except Exception as e:
logger.error(f"多 Agent 流式聊天失败: {str(e)}", exc_info=True)
# 发送错误事件
yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
async def workflow_chat(
self,
message: str,
conversation_id: uuid.UUID,
config: WorkflowConfig,
app_id: uuid.UUID,
release_id: uuid.UUID,
workspace_id: uuid.UUID,
files: Optional[List[FileInput]] = None,
user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id,
files=files
)
return await self.workflow_service.run(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
release_id=release_id,
)
async def workflow_chat_stream(
self,
message: str,
conversation_id: uuid.UUID,
config: WorkflowConfig,
app_id: uuid.UUID,
release_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: str = None,
variables: Optional[Dict[str, Any]] = None,
files: Optional[List[FileInput]] = None,
web_search: bool = False,
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
public=False
) -> AsyncGenerator[dict, None]:
"""聊天(流式)"""
payload = DraftRunRequest(
message=message,
variables=variables,
conversation_id=str(conversation_id),
stream=True,
user_id=user_id,
files=files
)
async for event in self.workflow_service.run_stream(
app_id=app_id,
payload=payload,
config=config,
workspace_id=workspace_id,
release_id=release_id,
public=public
):
yield event
# ==================== 依赖注入函数 ====================
def get_app_chat_service(
db: Annotated[Session, Depends(get_db)]
) -> AppChatService:
"""获取工作流服务(依赖注入)"""
return AppChatService(db)