Merge branch 'develop' into feature/multimodel_memory
# Conflicts: # api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/embedding_generation.py # api/app/repositories/neo4j/add_nodes.py # api/app/repositories/neo4j/cypher_queries.py # api/app/repositories/neo4j/graph_saver.py # api/app/services/memory_agent_service.py # api/app/services/multimodal_service.py
This commit is contained in:
@@ -118,28 +118,27 @@ class AppChatService:
|
||||
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=10
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
@@ -180,7 +179,8 @@ class AppChatService:
|
||||
|
||||
# 构建用户消息内容(含多模态文件)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -195,6 +195,13 @@ class AppChatService:
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
@@ -225,6 +232,7 @@ class AppChatService:
|
||||
"suggested_questions": suggested_questions,
|
||||
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
async def agnet_chat_stream(
|
||||
@@ -313,31 +321,27 @@ class AppChatService:
|
||||
streaming=True
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
# 加载历史消息
|
||||
history = []
|
||||
memory_config = {"enabled": True, 'max_history': 10}
|
||||
if memory_config.get("enabled"):
|
||||
messages = self.conversation_service.get_messages(
|
||||
conversation_id=conversation_id,
|
||||
limit=memory_config.get("max_history", 10)
|
||||
)
|
||||
history = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in messages
|
||||
]
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10,
|
||||
current_provider=api_key_obj.provider,
|
||||
current_is_omni=api_key_obj.is_omni
|
||||
)
|
||||
|
||||
# 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_obj.model_name,
|
||||
provider=api_key_obj.provider,
|
||||
api_key=api_key_obj.api_key,
|
||||
api_base=api_key_obj.api_base,
|
||||
capability=api_key_obj.capability,
|
||||
is_omni=api_key_obj.is_omni,
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||
@@ -347,8 +351,14 @@ class AppChatService:
|
||||
total_tokens = 0
|
||||
|
||||
text_queue: asyncio.Queue = asyncio.Queue()
|
||||
api_key_config = {
|
||||
"model_name": api_key_obj.model_name,
|
||||
"api_key": api_key_obj.api_key,
|
||||
"api_base": api_key_obj.api_base,
|
||||
"provider": api_key_obj.provider,
|
||||
}
|
||||
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||
features_config, api_key_obj,
|
||||
features_config, api_key_config,
|
||||
text_queue=text_queue,
|
||||
tenant_id=tenant_id, workspace_id=workspace_id
|
||||
)
|
||||
@@ -378,7 +388,7 @@ class AppChatService:
|
||||
elapsed_time = time.time() - start_time
|
||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||
|
||||
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations)
|
||||
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||
@@ -388,11 +398,23 @@ class AppChatService:
|
||||
"api_base": api_key_obj.api_base}, {}
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[]
|
||||
"files":[],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
"model": api_key_obj.model_name,
|
||||
@@ -402,11 +424,16 @@ class AppChatService:
|
||||
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": api_key_obj.provider,
|
||||
"is_omni": api_key_obj.is_omni
|
||||
}
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
@@ -16,6 +16,7 @@ from app.models.app_release_model import AppRelease
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.models.models_model import ModelConfig
|
||||
from app.models.tool_model import ToolConfig as ToolConfigModel
|
||||
from app.models.skill_model import Skill
|
||||
from app.models.workflow_model import WorkflowConfig
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
|
||||
@@ -84,7 +85,9 @@ class AppDslService:
|
||||
if "knowledge_retrieval" in cfg:
|
||||
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
|
||||
if "tools" in cfg:
|
||||
enriched["tools"] = self._enrich_tools(cfg["tools"])
|
||||
enriched["tools"] = self._enrich_tools(cfg.get("tools"))
|
||||
if "skills" in cfg:
|
||||
enriched["skills"] = self._enrich_skills(cfg.get("skills"))
|
||||
return enriched
|
||||
if app_type == AppType.MULTI_AGENT:
|
||||
enriched = {**cfg}
|
||||
@@ -108,6 +111,7 @@ class AppDslService:
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"features": config.features if config else {},
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
} if config else {}
|
||||
@@ -123,7 +127,8 @@ class AppDslService:
|
||||
"memory": config.memory if config else None,
|
||||
"variables": config.variables if config else [],
|
||||
"tools": self._enrich_tools(config.tools) if config else [],
|
||||
"skills": config.skills if config else {},
|
||||
"skills": self._enrich_skills(config.skills) if config else {},
|
||||
"features": config.features if config else {}
|
||||
} if config else {}
|
||||
dsl = {**meta, "app": app_meta, "agent_config": config_data}
|
||||
|
||||
@@ -185,6 +190,22 @@ class AppDslService:
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||
if not skill_id:
|
||||
return None
|
||||
s = self.db.query(Skill).filter(Skill.id == skill_id).first()
|
||||
return {"id": str(skill_id), "name": s.name} if s else {"id": str(skill_id)}
|
||||
|
||||
def _enrich_skills(self, skills: Optional[dict]) -> Optional[dict]:
|
||||
if not skills:
|
||||
return skills
|
||||
skill_ids = skills.get("skill_ids", [])
|
||||
enriched_ids = [
|
||||
{"id": sid, "_ref": self._skill_ref(sid)}
|
||||
for sid in (skill_ids or [])
|
||||
]
|
||||
return {**skills, "skill_ids": enriched_ids}
|
||||
|
||||
def _agent_ref(self, agent_id) -> Optional[dict]:
|
||||
if not agent_id:
|
||||
return None
|
||||
@@ -249,7 +270,8 @@ class AppDslService:
|
||||
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
|
||||
variables=cfg.get("variables", []),
|
||||
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
|
||||
skills=cfg.get("skills", {}),
|
||||
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
|
||||
features=cfg.get("features", {}),
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -290,6 +312,7 @@ class AppDslService:
|
||||
edges=[e.model_dump() for e in result.edges],
|
||||
variables=[v.model_dump() for v in result.variables],
|
||||
execution_config=wf.get("execution_config", {}),
|
||||
features=wf.get("features", {}),
|
||||
triggers=wf.get("triggers", []),
|
||||
validate=False,
|
||||
)
|
||||
@@ -444,6 +467,46 @@ class AppDslService:
|
||||
return {**memory, "memory_config_id": None, "enabled": False}
|
||||
return memory
|
||||
|
||||
def _resolve_skills(self, skills: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> dict:
|
||||
if not skills:
|
||||
return skills or {}
|
||||
resolved_ids = []
|
||||
for entry in (skills.get("skill_ids") or []):
|
||||
# entry 可能是 {"id": "...", "_ref": {...}} 或直接是字符串
|
||||
if isinstance(entry, dict):
|
||||
ref = entry.get("_ref") or ({"name": None, "id": entry.get("id")} if entry.get("id") else None)
|
||||
skill_id = self._resolve_skill(ref, tenant_id, warnings)
|
||||
else:
|
||||
skill_id = self._resolve_skill({"id": str(entry)}, tenant_id, warnings)
|
||||
if skill_id:
|
||||
resolved_ids.append(str(skill_id))
|
||||
return {**{k: v for k, v in skills.items() if k != "skill_ids"}, "skill_ids": resolved_ids}
|
||||
|
||||
def _resolve_skill(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
|
||||
if not ref:
|
||||
return None
|
||||
# 先按 id 匹配
|
||||
if ref.get("id"):
|
||||
try:
|
||||
s = self.db.query(Skill).filter(
|
||||
Skill.id == uuid.UUID(str(ref["id"])),
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
if s:
|
||||
return str(s.id)
|
||||
except Exception:
|
||||
pass
|
||||
# 再按名称匹配
|
||||
if ref.get("name"):
|
||||
s = self.db.query(Skill).filter(
|
||||
Skill.name == ref["name"],
|
||||
Skill.tenant_id == tenant_id
|
||||
).first()
|
||||
if s:
|
||||
return str(s.id)
|
||||
warnings.append(f"未找到技能: {ref}")
|
||||
return None
|
||||
|
||||
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
|
||||
result = []
|
||||
for t in (tools or []):
|
||||
|
||||
@@ -833,8 +833,6 @@ class AppService:
|
||||
|
||||
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
|
||||
target_tenant_id = None
|
||||
available_model_ids: set = set()
|
||||
available_kb_ids: set = set()
|
||||
if is_cross_workspace:
|
||||
target_ws = self.db.get(Workspace, target_workspace_id)
|
||||
if not target_ws:
|
||||
@@ -849,28 +847,29 @@ class AppService:
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
# Batch-collect and preload all referenced resources
|
||||
model_ids, kb_ids = self._collect_resource_ids_from_config(
|
||||
source_config.default_model_config_id,
|
||||
source_config.knowledge_retrieval,
|
||||
source_config.tools
|
||||
# 跨工作空间:model/tools/skills 属于 tenant 级别直接保留,
|
||||
# knowledge_bases 属于 workspace 级别需过滤,memory_config 需清空
|
||||
_, kb_ids = self._collect_resource_ids_from_config(
|
||||
None, source_config.knowledge_retrieval
|
||||
)
|
||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||
)
|
||||
new_model_config_id = self._is_model_available(
|
||||
source_config.default_model_config_id, available_model_ids
|
||||
_, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, set(), kb_ids
|
||||
)
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
new_knowledge_retrieval = self._clean_knowledge_retrieval(
|
||||
source_config.knowledge_retrieval, available_kb_ids
|
||||
)
|
||||
new_tools = self._clean_tools(
|
||||
source_config.tools, available_kb_ids
|
||||
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||
new_memory = self._clean_memory_cross_workspace(
|
||||
source_config.memory, target_workspace_id
|
||||
)
|
||||
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||
else:
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
|
||||
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||
new_memory = copy.deepcopy(source_config.memory) if source_config.memory else None
|
||||
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
|
||||
|
||||
new_config = AgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
@@ -879,9 +878,11 @@ class AppService:
|
||||
default_model_config_id=new_model_config_id,
|
||||
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||
knowledge_retrieval=new_knowledge_retrieval,
|
||||
memory=copy.deepcopy(source_config.memory) if source_config.memory else None,
|
||||
memory=new_memory,
|
||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||
tools=new_tools,
|
||||
skills=new_skills,
|
||||
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
@@ -894,28 +895,14 @@ class AppService:
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes(
|
||||
source_config.nodes
|
||||
)
|
||||
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||
)
|
||||
new_nodes = self._clean_workflow_nodes_for_cross_workspace(
|
||||
source_config.nodes or [],
|
||||
available_model_ids,
|
||||
available_kb_ids
|
||||
)
|
||||
else:
|
||||
new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else []
|
||||
|
||||
new_config = WorkflowConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
nodes=new_nodes,
|
||||
nodes=copy.deepcopy(source_config.nodes) if source_config.nodes else [],
|
||||
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
|
||||
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
||||
features=copy.deepcopy(source_config.features) if source_config.features else {},
|
||||
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
|
||||
is_active=True,
|
||||
created_at=now,
|
||||
@@ -929,24 +916,15 @@ class AppService:
|
||||
).first()
|
||||
|
||||
if source_config:
|
||||
if is_cross_workspace:
|
||||
model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set()
|
||||
available_model_ids, _ = self._preload_cross_workspace_resources(
|
||||
target_tenant_id, target_workspace_id, model_ids, set()
|
||||
)
|
||||
new_model_config_id = self._is_model_available(
|
||||
source_config.default_model_config_id, available_model_ids
|
||||
)
|
||||
else:
|
||||
new_model_config_id = source_config.default_model_config_id
|
||||
|
||||
# multi_agent 的 model_config_id/sub_agents/routing_rules 均属于 tenant 级别直接保留
|
||||
# 跨空间时 master_agent_id(AppRelease)属于源空间,需清空
|
||||
new_config = MultiAgentConfig(
|
||||
id=uuid.uuid4(),
|
||||
app_id=new_app.id,
|
||||
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
|
||||
master_agent_name=source_config.master_agent_name,
|
||||
default_model_config_id=new_model_config_id,
|
||||
model_parameters=source_config.model_parameters,
|
||||
default_model_config_id=source_config.default_model_config_id,
|
||||
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||
orchestration_mode=source_config.orchestration_mode,
|
||||
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
|
||||
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
|
||||
@@ -1037,8 +1015,7 @@ class AppService:
|
||||
@staticmethod
|
||||
def _collect_resource_ids_from_config(
|
||||
model_config_id: Optional[uuid.UUID],
|
||||
knowledge_retrieval: Optional[dict],
|
||||
tools: Optional[list]
|
||||
knowledge_retrieval: Optional[dict]
|
||||
) -> tuple:
|
||||
"""Extract all model config IDs and knowledge base IDs from an app config."""
|
||||
model_ids: set = set()
|
||||
@@ -1048,62 +1025,12 @@ class AppService:
|
||||
model_ids.add(model_config_id)
|
||||
|
||||
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
|
||||
if "kb_ids" in knowledge_retrieval:
|
||||
for kid in knowledge_retrieval.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if knowledge_retrieval.get("knowledge_id"):
|
||||
kb_ids.add(str(knowledge_retrieval["knowledge_id"]))
|
||||
|
||||
if tools:
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
kid = tool.get("knowledge_id") or tool.get("kb_id")
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if "knowledge_bases" in knowledge_retrieval:
|
||||
for kid in knowledge_retrieval.get("knowledge_bases", []):
|
||||
kb_ids.add(str(kid.get("kb_id")))
|
||||
|
||||
return model_ids, kb_ids
|
||||
|
||||
@staticmethod
|
||||
def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple:
|
||||
"""Extract all model config IDs and knowledge base IDs from workflow nodes."""
|
||||
model_ids: set = set()
|
||||
kb_ids: set = set()
|
||||
|
||||
for node in (nodes or []):
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
data = node.get("data", {})
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for key in ("model_config_id", "default_model_config_id"):
|
||||
val = data.get(key)
|
||||
if val:
|
||||
try:
|
||||
model_ids.add(uuid.UUID(str(val)))
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
kr = data.get("knowledge_retrieval")
|
||||
if isinstance(kr, dict):
|
||||
for kid in kr.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
if kr.get("knowledge_id"):
|
||||
kb_ids.add(str(kr["knowledge_id"]))
|
||||
if data.get("knowledge_id"):
|
||||
kb_ids.add(str(data["knowledge_id"]))
|
||||
for kid in data.get("kb_ids", []):
|
||||
if kid:
|
||||
kb_ids.add(str(kid))
|
||||
|
||||
return model_ids, kb_ids
|
||||
|
||||
@staticmethod
|
||||
def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]:
|
||||
if not model_config_id:
|
||||
return None
|
||||
return model_config_id if model_config_id in available_model_ids else None
|
||||
|
||||
@staticmethod
|
||||
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
|
||||
if not kb_id:
|
||||
@@ -1124,95 +1051,53 @@ class AppService:
|
||||
|
||||
cleaned = copy.deepcopy(knowledge_retrieval)
|
||||
|
||||
if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list):
|
||||
cleaned["kb_ids"] = [
|
||||
kid for kid in cleaned["kb_ids"]
|
||||
if self._is_kb_available(kid, available_kb_ids)
|
||||
if "knowledge_bases" in cleaned and isinstance(cleaned["knowledge_bases"], list):
|
||||
cleaned["knowledge_bases"] = [
|
||||
kb for kb in cleaned["knowledge_bases"]
|
||||
if self._is_kb_available(kb.get("kb_id"), available_kb_ids)
|
||||
]
|
||||
|
||||
if "knowledge_id" in cleaned:
|
||||
cleaned["knowledge_id"] = self._is_kb_available(
|
||||
cleaned.get("knowledge_id"), available_kb_ids
|
||||
)
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_tools(
|
||||
def _clean_memory_cross_workspace(
|
||||
self,
|
||||
tools: Optional[list],
|
||||
available_kb_ids: set
|
||||
) -> list:
|
||||
"""Clean tools config, keeping built-in tools and tools with available KBs."""
|
||||
if not tools:
|
||||
return []
|
||||
memory: Optional[dict],
|
||||
target_workspace_id: uuid.UUID
|
||||
) -> Optional[dict]:
|
||||
"""Clear memory_config_id/memory_content if it doesn't belong to target workspace."""
|
||||
if not memory:
|
||||
return None
|
||||
|
||||
cleaned = []
|
||||
for tool in tools:
|
||||
if not isinstance(tool, dict):
|
||||
cleaned.append(tool)
|
||||
continue
|
||||
from app.models.memory_config_model import MemoryConfig
|
||||
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in ("builtin", "built_in", "system"):
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
continue
|
||||
cleaned = copy.deepcopy(memory)
|
||||
# 兼容旧字段 memory_content 和新字段 memory_config_id
|
||||
mid = cleaned.get("memory_config_id") or cleaned.get("memory_content")
|
||||
if mid:
|
||||
try:
|
||||
mid_uuid = uuid.UUID(str(mid))
|
||||
except (ValueError, AttributeError):
|
||||
exists = self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id_old == int(mid),
|
||||
MemoryConfig.workspace_id == target_workspace_id
|
||||
).first()
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
return cleaned
|
||||
|
||||
kb_id = tool.get("knowledge_id") or tool.get("kb_id")
|
||||
if kb_id:
|
||||
if self._is_kb_available(kb_id, available_kb_ids):
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
continue
|
||||
exists = self.db.query(
|
||||
self.db.query(MemoryConfig).filter(
|
||||
MemoryConfig.config_id == mid_uuid,
|
||||
MemoryConfig.workspace_id == target_workspace_id
|
||||
).exists()
|
||||
).scalar()
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
|
||||
cleaned.append(copy.deepcopy(tool))
|
||||
|
||||
return cleaned
|
||||
|
||||
def _clean_workflow_nodes_for_cross_workspace(
|
||||
self,
|
||||
nodes: list,
|
||||
available_model_ids: set,
|
||||
available_kb_ids: set
|
||||
) -> list:
|
||||
"""Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source."""
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
cleaned = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, dict):
|
||||
cleaned.append(node)
|
||||
continue
|
||||
|
||||
node_copy = copy.deepcopy(node)
|
||||
data = node_copy.get("data")
|
||||
if not isinstance(data, dict):
|
||||
cleaned.append(node_copy)
|
||||
continue
|
||||
|
||||
for key in ("model_config_id", "default_model_config_id"):
|
||||
if key in data and data[key]:
|
||||
try:
|
||||
mid = uuid.UUID(str(data[key]))
|
||||
except (ValueError, AttributeError):
|
||||
data[key] = None
|
||||
continue
|
||||
data[key] = str(mid) if mid in available_model_ids else None
|
||||
|
||||
if "knowledge_retrieval" in data and data["knowledge_retrieval"]:
|
||||
data["knowledge_retrieval"] = self._clean_knowledge_retrieval(
|
||||
data["knowledge_retrieval"], available_kb_ids
|
||||
)
|
||||
if "knowledge_id" in data:
|
||||
data["knowledge_id"] = self._is_kb_available(
|
||||
data.get("knowledge_id"), available_kb_ids
|
||||
)
|
||||
if "kb_ids" in data and isinstance(data["kb_ids"], list):
|
||||
data["kb_ids"] = [
|
||||
kid for kid in data["kb_ids"]
|
||||
if self._is_kb_available(kid, available_kb_ids)
|
||||
]
|
||||
|
||||
cleaned.append(node_copy)
|
||||
return cleaned
|
||||
|
||||
def list_apps(
|
||||
|
||||
@@ -21,6 +21,7 @@ from app.models.conversation_model import ConversationDetail
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
from app.schemas.conversation_schema import ConversationOut
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services import workspace_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -119,25 +120,27 @@ class ConversationService:
|
||||
|
||||
def get_user_conversations(
|
||||
self,
|
||||
user_id: uuid.UUID
|
||||
) -> list[Conversation]:
|
||||
user_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
page_size: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
Retrieve recent conversations for a specific user
|
||||
|
||||
This method delegates persistence logic to the repository layer and
|
||||
applies service-level defaults (e.g. recent conversation limit).
|
||||
Retrieve recent conversations for a specific user with pagination.
|
||||
|
||||
Args:
|
||||
user_id (uuid.UUID): Unique identifier of the user.
|
||||
page (int): Page number (1-based). Defaults to 1.
|
||||
page_size (int): Number of items per page. Defaults to 20.
|
||||
|
||||
Returns:
|
||||
list[Conversation]: A list of recent conversation entities.
|
||||
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
|
||||
"""
|
||||
conversations = self.conversation_repo.get_conversation_by_user_id(
|
||||
conversations, total = self.conversation_repo.get_conversation_by_user_id(
|
||||
user_id,
|
||||
limit=10
|
||||
page=page,
|
||||
page_size=page_size
|
||||
)
|
||||
return conversations
|
||||
return conversations, total
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
@@ -267,10 +270,12 @@ class ConversationService:
|
||||
|
||||
return messages
|
||||
|
||||
def get_conversation_history(
|
||||
async def get_conversation_history(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
max_history: Optional[int] = None,
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve historical conversation messages formatted as dictionaries.
|
||||
@@ -278,6 +283,8 @@ class ConversationService:
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||
current_provider (Optional[str]): Current provider for file handling.
|
||||
current_is_omni (Optional[bool]): Current omni flag for file handling.
|
||||
|
||||
Returns:
|
||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||
@@ -287,14 +294,30 @@ class ConversationService:
|
||||
limit=max_history
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
history = [
|
||||
{
|
||||
history = []
|
||||
for msg in messages:
|
||||
msg_dict = {
|
||||
"role": msg.role,
|
||||
"content": msg.content
|
||||
"content": [{"type": "text", "text": msg.content}]
|
||||
}
|
||||
for msg in messages
|
||||
]
|
||||
|
||||
# 处理用户消息中的多模态文件
|
||||
if msg.role == "user" and msg.meta_data:
|
||||
history_files = msg.meta_data.get("history_files", {})
|
||||
|
||||
if history_files and current_provider and current_is_omni is not None:
|
||||
# 检查是否需要重新处理文件
|
||||
stored_provider = history_files.get("provider")
|
||||
stored_is_omni = history_files.get("is_omni")
|
||||
|
||||
# 如果provider或is_omni不匹配,需要重新处理
|
||||
if stored_provider != current_provider or stored_is_omni != current_is_omni:
|
||||
continue
|
||||
|
||||
# provider和is_omni匹配,直接使用存储的内容
|
||||
msg_dict["content"].extend(history_files.get("content"))
|
||||
|
||||
history.append(msg_dict)
|
||||
|
||||
return history
|
||||
|
||||
@@ -510,6 +533,7 @@ class ConversationService:
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
@@ -517,14 +541,17 @@ class ConversationService:
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
base_url=api_base,
|
||||
is_omni=is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
conversation_messages = self.get_conversation_history(
|
||||
conversation_messages = await self.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=20
|
||||
max_history=20,
|
||||
current_provider=provider,
|
||||
current_is_omni=is_omni
|
||||
)
|
||||
if len(conversation_messages) == 0:
|
||||
return ConversationOut(
|
||||
|
||||
@@ -579,25 +579,28 @@ class AgentRunService:
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=model_config.type
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=10
|
||||
max_history=10,
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
@@ -659,7 +662,10 @@ class AgentRunService:
|
||||
})
|
||||
},
|
||||
files=files,
|
||||
audio_url=audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
response = {
|
||||
@@ -676,6 +682,7 @@ class AgentRunService:
|
||||
) if not sub_agent else [],
|
||||
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": "pending"
|
||||
}
|
||||
|
||||
logger.info(
|
||||
@@ -815,25 +822,28 @@ class AgentRunService:
|
||||
sub_agent=sub_agent
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=model_config.type
|
||||
)
|
||||
|
||||
# 6. 加载历史消息
|
||||
history = await self._load_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=memory_config.get("max_history", 10)
|
||||
max_history=memory_config.get("max_history", 10),
|
||||
current_provider=api_key_config.get("provider"),
|
||||
current_is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 6. 处理多模态文件
|
||||
processed_files = None
|
||||
if files:
|
||||
# 获取 provider 信息
|
||||
model_info = ModelInfo(
|
||||
model_name=api_key_config["model_name"],
|
||||
provider=api_key_config["provider"],
|
||||
api_key=api_key_config["api_key"],
|
||||
api_base=api_key_config["api_base"],
|
||||
capability=api_key_config["capability"],
|
||||
is_omni=api_key_config["is_omni"],
|
||||
model_type=ModelType.LLM
|
||||
)
|
||||
provider = api_key_config.get("provider", "openai")
|
||||
multimodal_service = MultimodalService(self.db, model_info)
|
||||
processed_files = await multimodal_service.process_files(files)
|
||||
@@ -905,10 +915,13 @@ class AgentRunService:
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||
},
|
||||
files=files,
|
||||
audio_url=stream_audio_url
|
||||
processed_files=processed_files,
|
||||
audio_url=stream_audio_url,
|
||||
provider=api_key_config.get("provider"),
|
||||
is_omni=api_key_config.get("is_omni", False)
|
||||
)
|
||||
|
||||
# 12. 发送结束事件(包含 suggested_questions 和 tts)
|
||||
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status)
|
||||
end_data: Dict[str, Any] = {
|
||||
"conversation_id": conversation_id,
|
||||
"elapsed_time": elapsed_time,
|
||||
@@ -919,6 +932,17 @@ class AgentRunService:
|
||||
features_config, full_content, api_key_config, effective_params
|
||||
)
|
||||
end_data["audio_url"] = stream_audio_url
|
||||
# 检查TTS是否已完成(非阻塞,不取消任务)
|
||||
audio_status = "pending"
|
||||
if tts_task is not None and tts_task.done():
|
||||
# 任务已完成,检查是否有异常
|
||||
try:
|
||||
tts_task.result()
|
||||
audio_status = "completed"
|
||||
except Exception as e:
|
||||
logger.warning(f"TTS任务异常: {e}")
|
||||
audio_status = "failed"
|
||||
end_data["audio_status"] = audio_status if stream_audio_url else None
|
||||
end_data["citations"] = self._filter_citations(features_config, [])
|
||||
yield self._format_sse_event("end", end_data)
|
||||
|
||||
@@ -1115,13 +1139,17 @@ class AgentRunService:
|
||||
async def _load_conversation_history(
|
||||
self,
|
||||
conversation_id: str,
|
||||
max_history: int = 10
|
||||
max_history: int = 10,
|
||||
current_provider: Optional[str] = None,
|
||||
current_is_omni: Optional[bool] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""加载会话历史消息
|
||||
"""加载会话历史消息,并根据当前模型配置处理多模态文件
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
current_provider: 当前模型的provider
|
||||
current_is_omni: 当前模型的is_omni
|
||||
|
||||
Returns:
|
||||
List[Dict]: 历史消息列表
|
||||
@@ -1129,9 +1157,12 @@ class AgentRunService:
|
||||
try:
|
||||
|
||||
conversation_service = ConversationService(self.db)
|
||||
history = conversation_service.get_conversation_history(
|
||||
# 获取 API 配置用于多模态处理
|
||||
history = await conversation_service.get_conversation_history(
|
||||
conversation_id=uuid.UUID(conversation_id),
|
||||
max_history=max_history
|
||||
max_history=max_history,
|
||||
current_provider=current_provider,
|
||||
current_is_omni=current_is_omni
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1159,7 +1190,10 @@ class AgentRunService:
|
||||
app_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None,
|
||||
audio_url: Optional[str] = None
|
||||
processed_files: Optional[List[Dict[str, Any]]] = None,
|
||||
audio_url: Optional[str] = None,
|
||||
provider: Optional[str] = None,
|
||||
is_omni: Optional[bool] = None
|
||||
) -> None:
|
||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||
|
||||
@@ -1170,6 +1204,11 @@ class AgentRunService:
|
||||
app_id: 应用ID(未使用,保留用于兼容性)
|
||||
user_id: 用户ID(未使用,保留用于兼容性)
|
||||
meta_data: token消耗
|
||||
files: 原始文件输入
|
||||
processed_files: 处理后的文件
|
||||
audio_url: 音频URL
|
||||
provider: 模型供应商
|
||||
is_omni: 是否为全模态模型
|
||||
"""
|
||||
try:
|
||||
from app.services.conversation_service import ConversationService
|
||||
@@ -1179,15 +1218,24 @@ class AgentRunService:
|
||||
|
||||
# 保存消息(会话已经存在)
|
||||
human_meta = {
|
||||
"files": []
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
if files:
|
||||
for f in files:
|
||||
# url = await MultimodalService(self.db).get_file_url(f)
|
||||
human_meta["files"].append({
|
||||
"type": f.type,
|
||||
"url": f.url
|
||||
})
|
||||
|
||||
# 保存 history_files,包含 provider 和 is_omni 信息
|
||||
if processed_files:
|
||||
human_meta["history_files"] = {
|
||||
"content": processed_files,
|
||||
"provider": provider,
|
||||
"is_omni": is_omni
|
||||
}
|
||||
|
||||
# 保存用户消息
|
||||
conversation_service.add_message(
|
||||
conversation_id=conv_uuid,
|
||||
@@ -1413,8 +1461,9 @@ class AgentRunService:
|
||||
workspace_id: Optional[uuid.UUID] = None,
|
||||
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||
"""文本流式输入并行合成音频。
|
||||
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
||||
返回 (audio_url, task),audio_url 立即可用(pending状态),task 完成后文件内容就绪。
|
||||
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
|
||||
"""
|
||||
tts_config = features_config.get("text_to_speech", {})
|
||||
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||
@@ -1801,6 +1850,7 @@ class AgentRunService:
|
||||
),
|
||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||
"audio_url": result.get("audio_url"),
|
||||
"audio_status": result.get("audio_status"),
|
||||
"citations": result.get("citations", []),
|
||||
"suggested_questions": result.get("suggested_questions", []),
|
||||
"error": None
|
||||
@@ -1878,6 +1928,7 @@ class AgentRunService:
|
||||
"results": [{
|
||||
**r,
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
} for r in results],
|
||||
@@ -2009,6 +2060,7 @@ class AgentRunService:
|
||||
full_content = ""
|
||||
returned_conversation_id = model_conversation_id
|
||||
audio_url = None
|
||||
audio_status = None
|
||||
citations = []
|
||||
suggested_questions = []
|
||||
|
||||
@@ -2067,6 +2119,7 @@ class AgentRunService:
|
||||
# 从 end 事件中提取 features 输出字段
|
||||
if event_type == "end" and event_data:
|
||||
audio_url = event_data.get("audio_url")
|
||||
audio_status = event_data.get("audio_status")
|
||||
citations = event_data.get("citations", [])
|
||||
suggested_questions = event_data.get("suggested_questions", [])
|
||||
|
||||
@@ -2096,6 +2149,7 @@ class AgentRunService:
|
||||
"message": full_content,
|
||||
"elapsed_time": elapsed,
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"error": None
|
||||
@@ -2110,6 +2164,7 @@ class AgentRunService:
|
||||
"elapsed_time": elapsed,
|
||||
"message_length": len(full_content),
|
||||
"audio_url": audio_url,
|
||||
"audio_status": audio_status,
|
||||
"citations": citations,
|
||||
"suggested_questions": suggested_questions,
|
||||
"timestamp": time.time()
|
||||
@@ -2246,6 +2301,7 @@ class AgentRunService:
|
||||
"message": r.get("message"),
|
||||
"elapsed_time": r.get("elapsed_time", 0),
|
||||
"audio_url": r.get("audio_url"),
|
||||
"audio_status": r.get("audio_status"),
|
||||
"citations": r.get("citations", []),
|
||||
"suggested_questions": r.get("suggested_questions", []),
|
||||
"error": r.get("error")
|
||||
|
||||
@@ -619,7 +619,7 @@ class MemoryForgetService:
|
||||
recent_trends.append({
|
||||
'date': date_str,
|
||||
'merged_count': record.merged_count,
|
||||
'average_activation': record.average_activation_value,
|
||||
'average_activation': round(record.average_activation_value, 2) if record.average_activation_value is not None else None,
|
||||
'total_nodes': record.total_nodes,
|
||||
'execution_time': int(record.execution_time.timestamp() * 1000)
|
||||
})
|
||||
|
||||
@@ -12,10 +12,12 @@ import base64
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
import PyPDF2
|
||||
import chardet
|
||||
import httpx
|
||||
import magic
|
||||
import openpyxl
|
||||
@@ -39,12 +41,10 @@ PDF_MIME = ['application/pdf']
|
||||
DOC_MIME = [
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/zip'
|
||||
]
|
||||
XLSX_MIME = [
|
||||
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||
'application/vnd.ms-excel',
|
||||
'application/zip'
|
||||
]
|
||||
CSV_MIME = ['text/csv', 'application/csv']
|
||||
JSON_MIME = ['application/json']
|
||||
@@ -402,6 +402,71 @@ class MultimodalService:
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def history_process_files(
|
||||
self,
|
||||
files: Optional[List[FileInput]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
处理文件列表,返回 LLM 可用的格式
|
||||
|
||||
Args:
|
||||
files: 文件输入列表
|
||||
|
||||
Returns:
|
||||
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
|
||||
"""
|
||||
if not files:
|
||||
return []
|
||||
|
||||
# 获取对应的策略
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容格式
|
||||
if self.provider == "dashscope" and self.is_omni:
|
||||
strategy_class = OpenAIFormatStrategy
|
||||
else:
|
||||
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
|
||||
if not strategy_class:
|
||||
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
|
||||
strategy_class = DashScopeFormatStrategy
|
||||
|
||||
result = []
|
||||
for idx, file in enumerate(files):
|
||||
strategy = strategy_class(file)
|
||||
if not file.url:
|
||||
file.url = await self.get_file_url(file)
|
||||
try:
|
||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||
is_support, content = await self._process_image(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.DOCUMENT:
|
||||
is_support, content = await self._process_document(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||
is_support, content = await self._process_audio(file, strategy)
|
||||
result.append(content)
|
||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||
is_support, content = await self._process_video(file, strategy)
|
||||
result.append(content)
|
||||
else:
|
||||
logger.warning(f"不支持的文件类型: {file.type}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"处理文件失败",
|
||||
extra={
|
||||
"file_index": idx,
|
||||
"file_type": file.type,
|
||||
"error": str(e)
|
||||
},
|
||||
exc_info=True
|
||||
)
|
||||
# 继续处理其他文件,不中断整个流程
|
||||
result.append({
|
||||
"type": "text",
|
||||
"text": f"[文件处理失败: {str(e)}]"
|
||||
})
|
||||
|
||||
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}")
|
||||
return result
|
||||
|
||||
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||
"""
|
||||
处理图片文件
|
||||
@@ -561,12 +626,12 @@ class MultimodalService:
|
||||
file.set_content(file_content)
|
||||
file_mime_type = magic.from_buffer(file_content, mime=True)
|
||||
if file_mime_type in TEXT_MIME:
|
||||
return file_content.decode("utf-8")
|
||||
return self._decode_text_safe(file_content)
|
||||
elif file_mime_type in PDF_MIME:
|
||||
return await self._extract_pdf_text(file_content)
|
||||
elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')):
|
||||
elif self._is_word_file(file_content, file_mime_type):
|
||||
return await self._extract_word_text(file_content)
|
||||
elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")):
|
||||
elif self._is_excel_file(file_content, file_mime_type):
|
||||
return await self._extract_xlsx_text(file_content)
|
||||
elif file_mime_type in CSV_MIME:
|
||||
return await self._extract_csv_text(file_content)
|
||||
@@ -595,52 +660,156 @@ class MultimodalService:
|
||||
|
||||
@staticmethod
|
||||
async def _extract_word_text(file_content: bytes) -> str:
|
||||
"""提取 Word 文档文本"""
|
||||
"""提取 Word 文档文本(支持 .docx 和旧版 .doc)"""
|
||||
# 先尝试 docx(ZIP 格式)
|
||||
if file_content[:2] == b'PK':
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
return '\n'.join(p.text for p in doc.paragraphs)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 docx 文本失败: {e}")
|
||||
return f"[docx 提取失败: {str(e)}]"
|
||||
|
||||
# 旧版 .doc(OLE2 格式)
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||
return '\n'.join(text_parts)
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||
if not ole.exists('WordDocument'):
|
||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||
stream = ole.openstream('WordDocument').read()
|
||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||
# 简单提取:过滤出可打印字符段
|
||||
try:
|
||||
text = stream.decode('utf-16-le', errors='ignore')
|
||||
except Exception:
|
||||
text = stream.decode('latin-1', errors='ignore')
|
||||
# 过滤控制字符,保留可打印内容
|
||||
import re
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
text = re.sub(r' +', ' ', text).strip()
|
||||
ole.close()
|
||||
return text
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Word 文本失败: {e}")
|
||||
return f"[Word 提取失败: {str(e)}]"
|
||||
logger.error(f"提取 doc 文本失败: {e}")
|
||||
return f"[doc 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_xlsx_text(file_content: bytes) -> str:
|
||||
"""提取 Excel 文本"""
|
||||
"""提取 Excel 文本(支持 .xlsx 和旧版 .xls)"""
|
||||
# xlsx(ZIP 格式)
|
||||
if file_content[:2] == b'PK':
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||
parts = []
|
||||
for sheet in wb.worksheets:
|
||||
parts.append(f"[Sheet: {sheet.title}]")
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||
return '\n'.join(parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 xlsx 文本失败: {e}")
|
||||
return f"[xlsx 提取失败: {str(e)}]"
|
||||
|
||||
# xls(OLE2/BIFF 格式)
|
||||
try:
|
||||
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||
import xlrd
|
||||
wb = xlrd.open_workbook(file_contents=file_content)
|
||||
parts = []
|
||||
for sheet in wb.worksheets:
|
||||
parts.append(f"[Sheet: {sheet.title}]")
|
||||
for row in sheet.iter_rows(values_only=True):
|
||||
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||
for sheet in wb.sheets():
|
||||
parts.append(f"[Sheet: {sheet.name}]")
|
||||
for row_idx in range(sheet.nrows):
|
||||
parts.append('\t'.join(str(sheet.cell_value(row_idx, col)) for col in range(sheet.ncols)))
|
||||
return '\n'.join(parts)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 Excel 文本失败: {e}")
|
||||
return f"[Excel 提取失败: {str(e)}]"
|
||||
logger.error(f"提取 xls 文本失败: {e}")
|
||||
return f"[xls 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_csv_text(file_content: bytes) -> str:
|
||||
async def _extract_csv_text(self, file_content: bytes) -> str:
|
||||
"""提取 CSV 文本"""
|
||||
try:
|
||||
text = file_content.decode('utf-8-sig')
|
||||
text = self._decode_text_safe(file_content)
|
||||
reader = csv.reader(io.StringIO(text))
|
||||
return '\n'.join('\t'.join(row) for row in reader)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 CSV 文本失败: {e}")
|
||||
return f"[CSV 提取失败: {str(e)}]"
|
||||
|
||||
@staticmethod
|
||||
async def _extract_json_text(file_content: bytes) -> str:
|
||||
async def _extract_json_text(self, file_content: bytes) -> str:
|
||||
"""提取 JSON 文本"""
|
||||
try:
|
||||
data = json.loads(file_content.decode('utf-8'))
|
||||
text = self._decode_text_safe(file_content)
|
||||
data = json.loads(text)
|
||||
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"提取 JSON 文本失败: {e}")
|
||||
return f"[JSON 提取失败: {str(e)}]"
|
||||
|
||||
def _is_word_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||
"""判断是不是 Word 文件(doc / docx),不依赖后缀"""
|
||||
# 旧版 .doc
|
||||
if mime_type == 'application/msword':
|
||||
return True
|
||||
|
||||
# 新版 .docx(ZIP 内部包含 word/document.xml)
|
||||
header = file_content[:4]
|
||||
if header == b'PK\x03\x04':
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
return "word/document.xml" in zf.namelist()
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
def _is_excel_file(self, file_content: bytes, mime_type: str) -> bool:
|
||||
"""判断是不是 Excel 文件(xls / xlsx),不依赖后缀"""
|
||||
# 旧版 .xls
|
||||
if mime_type == 'application/vnd.ms-excel':
|
||||
return True
|
||||
|
||||
# 新版 .xlsx(ZIP 内部包含 xl/workbook.xml)
|
||||
header = file_content[:4]
|
||||
if header == b'PK\x03\x04':
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
|
||||
return "xl/workbook.xml" in zf.namelist()
|
||||
except:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _decode_text_safe(file_content: bytes) -> str:
|
||||
"""
|
||||
【万能文本解码】
|
||||
自动检测编码,支持 utf-8 / gbk / gb2312 / utf-8-sig / ascii 等
|
||||
永远不报错,永远不乱码
|
||||
"""
|
||||
if not file_content:
|
||||
return ""
|
||||
|
||||
# 1. 自动检测文件编码
|
||||
detect = chardet.detect(file_content)
|
||||
encoding = detect.get("encoding") or "utf-8"
|
||||
encoding = encoding.lower()
|
||||
|
||||
# 2. 兼容常见中文编码
|
||||
compatible_encodings = ["utf-8", "gbk", "gb18030", "gb2312", "ascii", "latin-1"]
|
||||
|
||||
# 3. 按优先级尝试解码
|
||||
for enc in [encoding] + compatible_encodings:
|
||||
if not enc:
|
||||
continue
|
||||
try:
|
||||
return file_content.decode(enc.strip())
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
# 终极兜底
|
||||
return file_content.decode("utf-8", errors="replace")
|
||||
|
||||
|
||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||
"""获取多模态服务实例(依赖注入)"""
|
||||
|
||||
@@ -1408,12 +1408,11 @@ async def analytics_memory_types(
|
||||
if end_user_id:
|
||||
try:
|
||||
conversation_repo = ConversationRepository(db)
|
||||
conversations = conversation_repo.get_conversation_by_user_id(
|
||||
conversations, total = conversation_repo.get_conversation_by_user_id(
|
||||
user_id=uuid.UUID(end_user_id),
|
||||
limit=100, # 获取更多会话以准确统计
|
||||
is_activate=True
|
||||
)
|
||||
work_count = len(conversations)
|
||||
work_count = total
|
||||
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取会话数量失败,工作记忆数量设为0: {str(e)}")
|
||||
|
||||
@@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User:
|
||||
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={"username": user.username, "email": user.email}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
# 检查邮箱是否已注册(邮箱保持唯一)
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
@@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
|
||||
)
|
||||
|
||||
try:
|
||||
# 检查用户名是否已存在
|
||||
business_logger.debug(f"检查用户名是否已存在: {user.username}")
|
||||
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
|
||||
if db_user_by_username:
|
||||
business_logger.warning(f"用户名已存在: {user.username}")
|
||||
raise BusinessException(
|
||||
"用户名已存在",
|
||||
code=BizCode.DUPLICATE_NAME,
|
||||
context={
|
||||
"username": user.username,
|
||||
"email": user.email,
|
||||
"created_by": str(current_user.id)
|
||||
}
|
||||
)
|
||||
|
||||
# 检查邮箱是否已注册
|
||||
# 检查邮箱是否已注册(邮箱保持唯一)
|
||||
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
|
||||
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
|
||||
if db_user_by_email:
|
||||
|
||||
@@ -57,6 +57,7 @@ class WorkflowService:
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
features: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
@@ -68,6 +69,7 @@ class WorkflowService:
|
||||
edges: 边列表
|
||||
variables: 变量列表
|
||||
execution_config: 执行配置
|
||||
features: 功能特性
|
||||
triggers: 触发器列表
|
||||
validate: 是否验证配置
|
||||
|
||||
@@ -83,6 +85,7 @@ class WorkflowService:
|
||||
"edges": edges,
|
||||
"variables": variables or [],
|
||||
"execution_config": execution_config or {},
|
||||
"features": features or {},
|
||||
"triggers": triggers or []
|
||||
}
|
||||
|
||||
@@ -103,6 +106,7 @@ class WorkflowService:
|
||||
edges=edges,
|
||||
variables=variables,
|
||||
execution_config=execution_config,
|
||||
features=features,
|
||||
triggers=triggers
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user