[MODIFY] Code optimization

This commit is contained in:
Mark
2025-12-15 14:09:43 +08:00
parent d2a630addb
commit a4e276ab27
157 changed files with 15976 additions and 3601 deletions

View File

@@ -27,7 +27,7 @@ class AgentRegistry:
self._cache[str(agent.id)] = agent_info
logger.info(
f"Agent 注册成功",
"Agent 注册成功",
extra={
"agent_id": str(agent.id),
"name": agent.app.name,
@@ -92,7 +92,7 @@ class AgentRegistry:
agents = self.db.scalars(stmt).all()
logger.debug(
f"Agent 发现",
"Agent 发现",
extra={
"query": query,
"domain": domain,

View File

@@ -1,130 +0,0 @@
from typing import Any, List
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import InMemorySaver
from pydantic import BaseModel
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model
from langchain.tools import tool
from langchain_core.messages import RemoveMessage
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
from app.services.api_resquests_server import send_message, model, retrieval
class config(BaseModel):
template_str:str
params:dict
model_configs: List[dict] = []
history_memory:bool
knowledge_base:bool
class RemoryInput(BaseModel):
question: str
end_user_id: str
search_switch:str
class ChatRequest(BaseModel):
end_user_id: str
message: str
search_switch:str
kb_ids: List[str] = []
similarity_threshold:float
vector_similarity_weight:float
top_k:int
hybrid:bool
token:str
class RetrievalInput(BaseModel):
message: str
kb_ids: List[str] = []
similarity_threshold: float
vector_similarity_weight: float
top_k: int
hybrid: bool
token: str
async def tool_Retrieval(req):
tool_result = retrieval_search.invoke({
"message":req.message, "kb_ids":req.kb_ids,
"similarity_threshold":req.similarity_threshold, "vector_similarity_weight":req.vector_similarity_weight,
"top_k":req.top_k, "hybrid":req.hybrid, "token":req.token
})
return tool_result
async def tool_memory(req):
tool_result = remory_sk.invoke({
"question": req.message,
"end_user_id": req.end_user_id,
"search_switch": req.search_switch
})
return tool_result
@before_model
# ========== 消息剪枝中间件 ==========
def trim_messages(state: AgentState, runtime: Runtime) -> dict[str, Any] | None:
"""保留前1条 + 最近3~4条消息"""
messages = state["messages"]
if len(messages) <= 10:
return None
first_msg = messages[0]
recent_messages = messages[-10:] if len(messages) % 2 == 0 else messages[-11:]
new_messages = [first_msg] + recent_messages
return {
"messages": [
RemoveMessage(id=REMOVE_ALL_MESSAGES),
*new_messages
]
}
##-----------历史记忆------------
@ tool(args_schema=RemoryInput)
def remory_sk(question: str, end_user_id: str, search_switch: str):
"""
条件调用工具:
- 仅当 question 是疑问句时调用 send_message
- 根据 end_user_id 和 search_switch 参数决定是否执行检索
Args:
question: 用户的提问内容
end_user_id: 用户唯一标识符
search_switch: 搜索开关,控制是否执行检索
Returns:
检索结果或空字符串
"""
# 移除关于 configurable 的描述,避免混淆
if not end_user_id or end_user_id == '123':
print("警告: 无效的 user_id 参数")
return ''
if search_switch in ['on', 'off'] or not search_switch:
print("警告: 无效的 search_switch 参数")
return ''
return send_message(end_user_id, question, '[]', search_switch)
#-------------检索------------
@ tool(args_schema=RetrievalInput)
def retrieval_search(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token):
'''检索'''
search=retrieval(message,kb_ids,similarity_threshold,vector_similarity_weight,top_k,hybrid,token)
return search
async def create_dynamic_agent(model_name: str,model_id:str,PROMPT:str,token:str):
"""根据模型名动态创建代理"""
model_name, api_key, api_base=await model(model_id,token)
llm = ChatOpenAI(model=model_name, base_url=api_base, temperature=0.2,api_key=api_key)
memory = InMemorySaver()
return create_agent(
llm,
tools=[remory_sk,retrieval_search],
middleware=[trim_messages],
checkpointer=memory,
system_prompt=PROMPT
)

View File

@@ -80,7 +80,7 @@ def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID
result += "\n"
logger.info(
f"Agent 发现成功",
"Agent 发现成功",
extra={
"query": query,
"domain": domain,
@@ -91,7 +91,7 @@ def create_agent_discovery_tool(registry: AgentRegistry, workspace_id: uuid.UUID
return result
except Exception as e:
logger.error(f"Agent 发现失败", extra={"error": str(e)})
logger.error("Agent 发现失败", extra={"error": str(e)})
return f"发现 Agent 失败: {str(e)}"
return discover_agents
@@ -138,7 +138,7 @@ def create_agent_invocation_tool(
if workspace and workspace.storage_type:
storage_type = workspace.storage_type
logger.debug(
f"获取工作空间存储类型成功",
"获取工作空间存储类型成功",
extra={
"workspace_id": str(workspace_id),
"storage_type": storage_type
@@ -146,7 +146,7 @@ def create_agent_invocation_tool(
)
except Exception as e:
logger.warning(
f"获取工作空间存储类型失败,使用默认值 neo4j",
"获取工作空间存储类型失败,使用默认值 neo4j",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
@@ -161,7 +161,7 @@ def create_agent_invocation_tool(
if knowledge:
user_rag_memory_id = str(knowledge.id)
logger.debug(
f"获取 RAG 知识库成功",
"获取 RAG 知识库成功",
extra={
"workspace_id": str(workspace_id),
"knowledge_id": user_rag_memory_id
@@ -169,13 +169,13 @@ def create_agent_invocation_tool(
)
else:
logger.warning(
f"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
"未找到名为 'USER_RAG_MEMORY' 的知识库,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id)}
)
storage_type = 'neo4j'
except Exception as e:
logger.warning(
f"获取 RAG 知识库失败,将使用 neo4j 存储",
"获取 RAG 知识库失败,将使用 neo4j 存储",
extra={"workspace_id": str(workspace_id), "error": str(e)}
)
storage_type = 'neo4j'
@@ -226,12 +226,12 @@ def create_agent_invocation_tool(
# 6. 获取 Agent 配置
agent_config = db.get(AgentConfig, agent_uuid)
if not agent_config:
return f"Agent 配置不存在"
return "Agent 配置不存在"
# 7. 获取模型配置
model_config = db.get(ModelConfig, agent_config.default_model_config_id)
if not model_config:
return f"Agent 模型配置不存在"
return "Agent 模型配置不存在"
# 8. 创建调用记录
invocation = AgentInvocation(
@@ -249,7 +249,7 @@ def create_agent_invocation_tool(
db.refresh(invocation)
logger.info(
f"Agent 调用开始",
"Agent 调用开始",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
@@ -286,7 +286,7 @@ def create_agent_invocation_tool(
db.commit()
logger.info(
f"Agent 调用成功",
"Agent 调用成功",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
@@ -306,7 +306,7 @@ def create_agent_invocation_tool(
db.commit()
logger.error(
f"Agent 调用失败",
"Agent 调用失败",
extra={
"invocation_id": str(invocation.id),
"caller_agent_id": str(current_agent_id),
@@ -319,7 +319,7 @@ def create_agent_invocation_tool(
except Exception as e:
logger.error(
f"Agent 调用异常",
"Agent 调用异常",
extra={
"caller_agent_id": str(current_agent_id),
"callee_agent_id": agent_id,

View File

@@ -1,16 +1,22 @@
"""API Key Service"""
from sqlalchemy.orm import Session
from typing import Optional, Tuple, List
import time
import uuid
import datetime
import math
from typing import Optional, Tuple
from datetime import datetime, timedelta
from app.models.api_key_model import ApiKey, ApiKeyType
from app.repositories.api_key_repository import ApiKeyRepository
from sqlalchemy.orm import Session
from sqlalchemy import select
from app.aioRedis import aio_redis
from app.models.api_key_model import ApiKey
from app.repositories.api_key_repository import ApiKeyRepository, ApiKeyLogRepository
from app.schemas import api_key_schema
from app.schemas.response_schema import PageData, PageMeta
from app.core.api_key_utils import generate_api_key
from app.core.exceptions import BusinessException
from app.core.api_key_utils import generate_api_key, hash_api_key, validate_resource_binding
from app.core.exceptions import (
BusinessException,
)
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
@@ -19,81 +25,108 @@ logger = get_business_logger()
class ApiKeyService:
"""API Key 业务逻辑服务"""
@staticmethod
def create_api_key(
db: Session,
*,
workspace_id: uuid.UUID,
user_id: uuid.UUID,
data: api_key_schema.ApiKeyCreate
db: Session,
*,
workspace_id: uuid.UUID,
user_id: uuid.UUID,
data: api_key_schema.ApiKeyCreate
) -> Tuple[ApiKey, str]:
"""创建 API Key
"""
创建 API Key
Returns:
Tuple[ApiKey, str]: (API Key 对象, API Key 明文)
"""
# 生成 API Key
api_key, key_hash, key_prefix = generate_api_key(data.type)
# 创建数据
api_key_data = {
"id": uuid.uuid4(),
"name": data.name,
"description": data.description,
"key_prefix": key_prefix,
"key_hash": key_hash,
"type": data.type,
"scopes": data.scopes,
"workspace_id": workspace_id,
"resource_id": data.resource_id,
"resource_type": data.resource_type,
"rate_limit": data.rate_limit,
"quota_limit": data.quota_limit,
"expires_at": data.expires_at,
"created_by": user_id,
"created_at": datetime.datetime.now(),
"updated_at": datetime.datetime.now(),
}
api_key_obj = ApiKeyRepository.create(db, api_key_data)
db.commit()
db.refresh(api_key_obj)
logger.info(f"API Key 创建成功", extra={
"api_key_id": str(api_key_obj.id),
"name": data.name,
"type": data.type
})
return api_key_obj, api_key
try:
# 验证资源绑定
if data.resource_type or data.resource_id:
is_valid, error_msg = validate_resource_binding(
data.resource_type, str(data.resource_id) if data.resource_id else None
)
if not is_valid:
raise BusinessException(error_msg, BizCode.API_KEY_INVALID_RESOURCE)
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.name == data.name,
ApiKey.is_active
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
# 生成 API Key
api_key, key_hash, key_prefix = generate_api_key(data.type)
# 创建数据
api_key_data = {
"id": uuid.uuid4(),
"name": data.name,
"description": data.description,
"key_prefix": key_prefix,
"key_hash": key_hash,
"type": data.type,
"scopes": data.scopes,
"workspace_id": workspace_id,
"resource_id": data.resource_id,
"resource_type": data.resource_type,
"rate_limit": data.rate_limit or 10,
"daily_request_limit": data.daily_request_limit or 10000,
"quota_limit": data.quota_limit,
"expires_at": data.expires_at,
"created_by": user_id,
}
api_key_obj = ApiKeyRepository.create(db, api_key_data)
db.commit()
logger.info("API Key 创建成功", extra={
"api_key_id": str(api_key_obj.id),
"workspace_id": str(workspace_id),
"api_key_name": data.name,
"type": data.type
})
return api_key_obj, api_key
except Exception as e:
db.rollback()
logger.error(f"API Key 创建失败: {e}", extra={
"workspace_id": str(workspace_id),
"api_key_name": getattr(data, 'name', 'unknown'),
"error": str(e)
})
raise
@staticmethod
def get_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> ApiKey:
"""获取 API Key"""
api_key = ApiKeyRepository.get_by_id(db, api_key_id)
if not api_key:
raise BusinessException("API Key 不存在", BizCode.NOT_FOUND)
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.API_KEY_NOT_FOUND)
if api_key.workspace_id != workspace_id:
raise BusinessException("无权访问此 API Key", BizCode.FORBIDDEN)
return api_key
@staticmethod
def list_api_keys(
db: Session,
workspace_id: uuid.UUID,
query: api_key_schema.ApiKeyQuery
db: Session,
workspace_id: uuid.UUID,
query: api_key_schema.ApiKeyQuery
) -> PageData:
"""列出 API Keys"""
items, total = ApiKeyRepository.list_by_workspace(db, workspace_id, query)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=query.page,
@@ -103,52 +136,69 @@ class ApiKeyService:
),
items=[api_key_schema.ApiKey.model_validate(item) for item in items]
)
@staticmethod
def update_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID,
data: api_key_schema.ApiKeyUpdate
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID,
data: api_key_schema.ApiKeyUpdate
) -> ApiKey:
"""更新 API Key"""
"""更新 API Key配置"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
# 检查名称重复
if data.name and data.name != api_key.name:
existing = db.scalar(
select(ApiKey).where(
ApiKey.workspace_id == workspace_id,
ApiKey.name == data.name,
ApiKey.is_active,
ApiKey.id != api_key_id
)
)
if existing:
raise BusinessException(f"API Key 名称 '{data.name}' 已存在", BizCode.API_KEY_DUPLICATE_NAME)
update_data = data.model_dump(exclude_unset=True)
ApiKeyRepository.update(db, api_key_id, update_data)
db.commit()
db.refresh(api_key)
logger.info(f"API Key 更新成功", extra={"api_key_id": str(api_key_id)})
logger.info("API Key 更新成功", extra={"api_key_id": str(api_key_id)})
return api_key
@staticmethod
def delete_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> bool:
"""删除 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
ApiKeyRepository.delete(db, api_key_id)
db.commit()
logger.info(f"API Key 删除成功", extra={"api_key_id": str(api_key_id)})
logger.info("API Key 删除成功", extra={"api_key_id": str(api_key_id)})
return True
@staticmethod
def regenerate_api_key(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Tuple[ApiKey, str]:
"""重新生成 API Key"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
# 检查 API Key 是否激活
if not api_key.is_active:
raise BusinessException("无法重新生成已停用的 API Key", BizCode.API_KEY_INACTIVE)
# 生成新的 API Key
new_api_key, key_hash, key_prefix = generate_api_key(ApiKeyType(api_key.type))
new_api_key, key_hash, key_prefix = generate_api_key(api_key_schema.ApiKeyType(api_key.type))
# 更新
ApiKeyRepository.update(db, api_key_id, {
"key_hash": key_hash,
@@ -156,18 +206,201 @@ class ApiKeyService:
})
db.commit()
db.refresh(api_key)
logger.info(f"API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
logger.info("API Key 重新生成成功", extra={"api_key_id": str(api_key_id)})
return api_key, new_api_key
@staticmethod
def get_stats(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID
) -> api_key_schema.ApiKeyStats:
"""获取使用统计"""
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
stats_data = ApiKeyRepository.get_stats(db, api_key_id)
return api_key_schema.ApiKeyStats(**stats_data)
@staticmethod
def get_logs(
db: Session,
api_key_id: uuid.UUID,
workspace_id: uuid.UUID,
filters: dict,
page: int,
pagesize: int
) -> PageData:
"""获取 API Key 使用日志"""
# 验证 API Key 权限
api_key = ApiKeyService.get_api_key(db, api_key_id, workspace_id)
items, total = ApiKeyLogRepository.list_by_api_key(
db, api_key_id, filters, page, pagesize
)
# 计算分页信息
pages = math.ceil(total / pagesize) if total > 0 else 0
return PageData(
page=PageMeta(
page=page,
pagesize=pagesize,
total=total,
hasnext=page < pages
),
items=[api_key_schema.ApiKeyLog.model_validate(item) for item in items]
)
class RateLimiterService:
def __init__(self):
self.redis = aio_redis
async def check_qps(self, api_key_id: uuid.UUID, limit: int) -> Tuple[bool, dict]:
"""
检查QPS限制
Returns:
(is_allowed, rate_limit_info)
"""
key = f"rate_limit:qps:{api_key_id}"
async with self.redis.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, 1) # 1 秒过期
results = await pipe.execute()
current = results[0]
remaining = max(0, limit - current)
reset_time = int(time.time()) + 1
return current <= limit, {
"limit": limit,
"remaining": remaining,
"reset": reset_time
}
async def check_daily_requests(
self,
api_key_id: uuid.UUID,
limit: int
) -> Tuple[bool, dict]:
"""检查日调用量限制"""
today = datetime.now().strftime("%Y%m%d")
key = f"rate_limit:daily:{api_key_id}:{today}"
now = datetime.now()
tomorrow_0 = (now + timedelta(days=1)).replace(
hour=0, minute=0, second=0, microsecond=0
)
expire_seconds = int((tomorrow_0 - now).total_seconds())
async with self.redis.pipeline() as pipe:
pipe.incr(key)
pipe.expire(key, expire_seconds, nx=True)
results = await pipe.execute()
current = results[0]
remaining = max(0, limit - current)
reset_time = int(tomorrow_0.timestamp())
return current <= limit, {
"limit": limit,
"remaining": remaining,
"reset": reset_time
}
async def check_all_limits(
self,
api_key: ApiKey
) -> Tuple[bool, str, dict]:
"""
检查所有限制
Returns:
(is_allowed, error_message, rate_limit_headers)
"""
# Check QPS
qps_ok, qps_info = await self.check_qps(
api_key.id,
api_key.rate_limit
)
if not qps_ok:
return False, "QPS limit exceeded", {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Reset": str(qps_info["reset"])
}
# Check daily requests
daily_ok, daily_info = await self.check_daily_requests(
api_key.id,
api_key.daily_request_limit
)
if not daily_ok:
return False, "Daily request limit exceeded", {
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
"X-RateLimit-Reset": str(daily_info["reset"])
}
# All checks passed
headers = {
"X-RateLimit-Limit-QPS": str(qps_info["limit"]),
"X-RateLimit-Remaining-QPS": str(qps_info["remaining"]),
"X-RateLimit-Limit-Day": str(daily_info["limit"]),
"X-RateLimit-Remaining-Day": str(daily_info["remaining"]),
"X-RateLimit-Reset": str(daily_info["reset"])
}
return True, "", headers
class ApiKeyAuthService:
@staticmethod
def validate_api_key(
db: Session,
api_key: str
) -> Optional[ApiKey]:
"""
验证API Key 有效性
检查:
1. Key hash 是否存在
2. is_active 是否为true
3. expires_at 是否未过期
4. quota 是否未超限
"""
key_hash = hash_api_key(api_key)
api_key_obj = ApiKeyRepository.get_by_hash(db, key_hash)
if not api_key_obj:
return None
if not api_key_obj.is_active:
return None
if api_key_obj.expires_at and datetime.now() > api_key_obj.expires_at:
return None
if api_key_obj.quota_limit and api_key_obj.quota_used >= api_key_obj.quota_limit:
return None
return api_key_obj
@staticmethod
def check_scope(api_key: ApiKey, required_scope: str) -> bool:
"""检查权限范围"""
return required_scope in api_key.scopes
@staticmethod
def check_resource(
api_key: ApiKey,
resource_type: str,
resource_id: uuid.UUID
) -> bool:
"""检查资源绑定"""
if not api_key.resource_id:
return True
return (
api_key.resource_type == resource_type and
api_key.resource_id == resource_id
)

View File

@@ -58,7 +58,7 @@ class AppService:
"""
if workspace_id is not None and app.workspace_id != workspace_id:
logger.warning(
f"工作空间访问被拒",
"工作空间访问被拒",
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
)
raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS)
@@ -103,7 +103,7 @@ class AppService:
"""
if not self._check_app_accessible(app, workspace_id):
logger.warning(
f"应用访问被拒",
"应用访问被拒",
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
)
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
@@ -122,7 +122,7 @@ class AppService:
"""
app = self.db.get(App, app_id)
if not app:
logger.warning(f"应用不存在", extra={"app_id": str(app_id)})
logger.warning("应用不存在", extra={"app_id": str(app_id)})
raise ResourceNotFoundException("应用", str(app_id))
return app
@@ -257,7 +257,7 @@ class AppService:
)
logger.info(
f"多智能体配置检查通过",
"多智能体配置检查通过",
extra={
"app_id": str(app_id),
"master_agent_id": str(multi_agent_config.master_agent_id),
@@ -295,7 +295,7 @@ class AppService:
updated_at=now,
)
self.db.add(agent_cfg)
logger.debug(f"Agent 配置已创建", extra={"app_id": str(app_id)})
logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)})
def _create_multi_agent_config(
self,
@@ -380,7 +380,7 @@ class AppService:
updated_at=now,
)
self.db.add(multi_agent_cfg)
logger.debug(f"多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode})
logger.debug("多 Agent 配置已创建", extra={"app_id": str(app_id), "mode": config.orchestration_mode})
def _get_next_version(self, app_id: uuid.UUID) -> int:
"""获取下一个版本号
@@ -474,7 +474,7 @@ class AppService:
BusinessException: 当创建失败时
"""
logger.info(
f"创建应用",
"创建应用",
extra={"app_name": data.name, "type": data.type, "workspace_id": str(workspace_id)}
)
@@ -511,12 +511,12 @@ class AppService:
self.db.commit()
self.db.refresh(app)
logger.info(f"应用创建成功", extra={"app_id": str(app.id), "app_name": app.name})
logger.info("应用创建成功", extra={"app_id": str(app.id), "app_name": app.name})
return app
except Exception as e:
self.db.rollback()
logger.error(f"应用创建失败", extra={"app_name": data.name, "error": str(e)})
logger.error("应用创建失败", extra={"app_name": data.name, "error": str(e)})
raise BusinessException(f"应用创建失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
def update_app(
@@ -540,7 +540,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用不在指定工作空间时
"""
logger.info(f"更新应用", extra={"app_id": str(app_id)})
logger.info("更新应用", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
self._validate_workspace_access(app, workspace_id)
@@ -556,9 +556,9 @@ class AppService:
app.updated_at = datetime.datetime.now()
self.db.commit()
self.db.refresh(app)
logger.info(f"应用更新成功", extra={"app_id": str(app_id)})
logger.info("应用更新成功", extra={"app_id": str(app_id)})
else:
logger.debug(f"应用无变更", extra={"app_id": str(app_id)})
logger.debug("应用无变更", extra={"app_id": str(app_id)})
return app
@@ -578,17 +578,17 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用不在指定工作空间时
"""
logger.info(f"删除应用", extra={"app_id": str(app_id)})
logger.info("删除应用", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
self._validate_workspace_access(app, workspace_id)
# 删除应用(级联删除相关数据)
self.db.delete(app)
# 逻辑删除应用
app.is_active = False
self.db.commit()
logger.info(
f"应用删除成功",
"应用删除成功",
extra={
"app_id": str(app_id),
"app_name": app.name,
@@ -619,7 +619,7 @@ class AppService:
ResourceNotFoundException: 当源应用不存在时
BusinessException: 当复制失败时
"""
logger.info(f"复制应用", extra={"source_app_id": str(app_id)})
logger.info("复制应用", extra={"source_app_id": str(app_id)})
try:
# 获取源应用
@@ -682,7 +682,7 @@ class AppService:
self.db.refresh(new_app)
logger.info(
f"应用复制成功",
"应用复制成功",
extra={
"source_app_id": str(app_id),
"new_app_id": str(new_app.id),
@@ -695,7 +695,7 @@ class AppService:
except Exception as e:
self.db.rollback()
logger.error(
f"应用复制失败",
"应用复制失败",
extra={"source_app_id": str(app_id), "error": str(e)}
)
raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
@@ -734,7 +734,7 @@ class AppService:
from app.models import AppShare
logger.debug(
f"查询应用列表",
"查询应用列表",
extra={
"workspace_id": str(workspace_id),
"include_shared": include_shared,
@@ -745,6 +745,7 @@ class AppService:
# 构建查询条件
filters = []
filters.append(App.is_active == True)
if type:
filters.append(App.type == type)
if visibility:
@@ -791,7 +792,7 @@ class AppService:
items = list(self.db.scalars(stmt).all())
logger.debug(
f"应用列表查询完成",
"应用列表查询完成",
extra={"total": total, "returned": len(items), "include_shared": include_shared}
)
return items, int(total)
@@ -819,7 +820,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或不在指定工作空间时
"""
logger.info(f"更新 Agent 配置", extra={"app_id": str(app_id)})
logger.info("更新 Agent 配置", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
@@ -841,7 +842,7 @@ class AppService:
updated_at=now,
)
self.db.add(agent_cfg)
logger.debug(f"创建新的 Agent 配置", extra={"app_id": str(app_id)})
logger.debug("创建新的 Agent 配置", extra={"app_id": str(app_id)})
# 转换为存储格式
storage_data = AgentConfigConverter.to_storage_format(data)
@@ -867,7 +868,7 @@ class AppService:
self.db.commit()
self.db.refresh(agent_cfg)
logger.info(f"Agent 配置更新成功", extra={"app_id": str(app_id)})
logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)})
return agent_cfg
def get_agent_config(
@@ -891,7 +892,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用类型不支持或不可访问时
"""
logger.debug(f"获取 Agent 配置", extra={"app_id": str(app_id)})
logger.debug("获取 Agent 配置", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
@@ -908,7 +909,7 @@ class AppService:
return config
# 返回默认配置模板(不保存到数据库)
logger.debug(f"配置不存在,返回默认模板", extra={"app_id": str(app_id)})
logger.debug("配置不存在,返回默认模板", extra={"app_id": str(app_id)})
return self._create_default_agent_config(app_id)
def _create_default_agent_config(self, app_id: uuid.UUID) -> AgentConfig:
@@ -981,7 +982,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用缺少配置或不在指定工作空间时
"""
logger.info(f"发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)})
logger.info("发布应用", extra={"app_id": str(app_id), "publisher_id": str(publisher_id)})
app = self._get_app_or_404(app_id)
# 检查应用归属
@@ -1039,7 +1040,7 @@ class AppService:
}
logger.info(
f"多智能体应用发布配置准备完成",
"多智能体应用发布配置准备完成",
extra={
"app_id": str(app_id),
"master_agent_id": str(multi_agent_cfg.master_agent_id),
@@ -1083,7 +1084,7 @@ class AppService:
self.db.refresh(release)
logger.info(
f"应用发布成功",
"应用发布成功",
extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)}
)
return release
@@ -1107,7 +1108,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用不可访问时
"""
logger.debug(f"获取当前发布版本", extra={"app_id": str(app_id)})
logger.debug("获取当前发布版本", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
# 只读操作,允许访问共享应用
@@ -1137,7 +1138,7 @@ class AppService:
ResourceNotFoundException: 当应用不存在时
BusinessException: 当应用不可访问时
"""
logger.debug(f"列出发布版本", extra={"app_id": str(app_id)})
logger.debug("列出发布版本", extra={"app_id": str(app_id)})
app = self._get_app_or_404(app_id)
# 只读操作,允许访问共享应用
@@ -1171,7 +1172,7 @@ class AppService:
ResourceNotFoundException: 当应用或版本不存在时
BusinessException: 当应用不在指定工作空间时
"""
logger.info(f"回滚应用", extra={"app_id": str(app_id), "version": version})
logger.info("回滚应用", extra={"app_id": str(app_id), "version": version})
app = self._get_app_or_404(app_id)
self._validate_app_accessible(app, workspace_id)
@@ -1184,7 +1185,7 @@ class AppService:
if not release:
logger.warning(
f"发布版本不存在",
"发布版本不存在",
extra={"app_id": str(app_id), "version": version}
)
raise ResourceNotFoundException("发布版本", f"app_id={app_id}, version={version}")
@@ -1196,7 +1197,7 @@ class AppService:
self.db.refresh(release)
logger.info(
f"应用回滚成功",
"应用回滚成功",
extra={"app_id": str(app_id), "version": version, "release_id": str(release.id)}
)
return release
@@ -1229,7 +1230,7 @@ class AppService:
from app.models import AppShare, Workspace
logger.info(
f"分享应用",
"分享应用",
extra={
"app_id": str(app_id),
"target_workspaces": [str(wid) for wid in target_workspace_ids],
@@ -1268,7 +1269,7 @@ class AppService:
if existing_share:
logger.debug(
f"应用已分享到该工作空间,跳过",
"应用已分享到该工作空间,跳过",
extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)}
)
shares.append(existing_share)
@@ -1288,14 +1289,14 @@ class AppService:
shares.append(share)
logger.debug(
f"创建分享记录",
"创建分享记录",
extra={"app_id": str(app_id), "target_workspace_id": str(target_ws_id)}
)
self.db.commit()
logger.info(
f"应用分享成功",
"应用分享成功",
extra={
"app_id": str(app_id),
"shared_count": len(shares),
@@ -1326,7 +1327,7 @@ class AppService:
from app.models import AppShare
logger.info(
f"取消应用分享",
"取消应用分享",
extra={
"app_id": str(app_id),
"target_workspace_id": str(target_workspace_id)
@@ -1346,7 +1347,7 @@ class AppService:
if not share:
logger.warning(
f"分享记录不存在",
"分享记录不存在",
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
)
raise ResourceNotFoundException(
@@ -1359,7 +1360,7 @@ class AppService:
self.db.commit()
logger.info(
f"应用分享已取消",
"应用分享已取消",
extra={"app_id": str(app_id), "target_workspace_id": str(target_workspace_id)}
)
@@ -1384,7 +1385,7 @@ class AppService:
"""
from app.models import AppShare
logger.debug(f"列出应用分享记录", extra={"app_id": str(app_id)})
logger.debug("列出应用分享记录", extra={"app_id": str(app_id)})
# 验证应用
app = self._get_app_or_404(app_id)
@@ -1398,7 +1399,7 @@ class AppService:
shares = list(self.db.scalars(stmt).all())
logger.debug(
f"应用分享记录查询完成",
"应用分享记录查询完成",
extra={"app_id": str(app_id), "count": len(shares)}
)
@@ -1435,7 +1436,7 @@ class AppService:
"""
from app.services.draft_run_service import DraftRunService
logger.info(f"试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
@@ -1464,7 +1465,7 @@ class AppService:
# 4. 调用试运行服务
logger.debug(
f"准备调用试运行服务",
"准备调用试运行服务",
extra={
"app_id": str(app_id),
"model": model_config.name,
@@ -1485,7 +1486,7 @@ class AppService:
)
logger.debug(
f"试运行服务返回结果",
"试运行服务返回结果",
extra={
"result_type": str(type(result)),
"result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict",
@@ -1495,7 +1496,7 @@ class AppService:
)
logger.info(
f"试运行完成",
"试运行完成",
extra={
"app_id": str(app_id),
"elapsed_time": result.get("elapsed_time"),
@@ -1534,7 +1535,7 @@ class AppService:
"""
from app.services.draft_run_service import DraftRunService
logger.info(f"流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]})
# 1. 验证应用
app = self._get_app_or_404(app_id)
@@ -1609,7 +1610,7 @@ class AppService:
from app.models import ModelConfig
logger.info(
f"多模型对比试运行",
"多模型对比试运行",
extra={
"app_id": str(app_id),
"model_count": len(models),
@@ -1666,7 +1667,7 @@ class AppService:
)
logger.info(
f"多模型对比完成",
"多模型对比完成",
extra={
"app_id": str(app_id),
"successful": result["successful_count"],
@@ -1708,7 +1709,7 @@ class AppService:
from app.models import ModelConfig
logger.info(
f"多模型对比流式试运行",
"多模型对比流式试运行",
extra={
"app_id": str(app_id),
"model_count": len(models)
@@ -1765,7 +1766,7 @@ class AppService:
yield event
logger.info(
f"多模型对比流式完成",
"多模型对比流式完成",
extra={"app_id": str(app_id)}
)

View File

@@ -162,7 +162,7 @@ def register_user_with_invite(
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit
invite_accept = InviteAcceptRequest(token=invite_token)
workspace_service.accept_workspace_invite(db, invite_accept, user)
logger.info(f"用户接受邀请成功")
logger.info("用户接受邀请成功")
# 重新查询用户对象以确保获取最新状态
from app.repositories import user_repository
@@ -200,7 +200,7 @@ def bind_workspace_with_invite(
# 接受工作空间邀请(此时用户已成为工作空间成员,并且会 commit
invite_accept = InviteAcceptRequest(token=invite_token)
workspace_service.accept_workspace_invite(db, invite_accept, user)
logger.info(f"用户接受邀请成功")
logger.info("用户接受邀请成功")
# 重新查询用户对象以确保获取最新状态
from app.repositories import user_repository

View File

@@ -42,7 +42,7 @@ class ConversationService:
self.db.refresh(conversation)
logger.info(
f"创建会话成功",
"创建会话成功",
extra={
"conversation_id": str(conversation.id),
"app_id": str(app_id),
@@ -201,7 +201,7 @@ class ConversationService:
)
logger.debug(
f"保存会话消息成功",
"保存会话消息成功",
extra={
"conversation_id": str(conversation_id),
"user_message_length": len(user_message),
@@ -221,7 +221,7 @@ class ConversationService:
self.db.commit()
logger.info(
f"删除会话成功",
"删除会话成功",
extra={
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)

View File

@@ -74,7 +74,7 @@ class ConversationStateManager:
state["same_agent_turns"] = 0
logger.info(
f"Agent 切换",
"Agent 切换",
extra={
"conversation_id": conversation_id,
"from": state["current_agent_id"],

File diff suppressed because it is too large Load Diff

View File

@@ -88,7 +88,7 @@ class LLMRouter:
路由结果
"""
logger.info(
f"开始 LLM 智能路由",
"开始 LLM 智能路由",
extra={
"message_length": len(message),
"conversation_id": conversation_id,
@@ -177,7 +177,7 @@ class LLMRouter:
}
logger.info(
f"路由完成",
"路由完成",
extra={
"agent_id": agent_id,
"strategy": strategy,
@@ -393,7 +393,7 @@ class LLMRouter:
# 打印供应商信息
logger.info(
f"LLM 路由使用模型",
"LLM 路由使用模型",
extra={
"provider": api_key_config.provider,
"model_name": api_key_config.model_name,
@@ -680,6 +680,6 @@ class LLMRouter:
return self.routing_rules[0].get("target_agent_id")
if self.sub_agents:
return list(self.sub_agents.keys())[0]
return next(iter(self.sub_agents.keys()))
return "default-agent"

View File

@@ -0,0 +1,593 @@
"""Master Agent 路由器 - 让 Master Agent 真正成为决策中心"""
import json
import re
import uuid
from typing import Dict, Any, List, Optional, Tuple
from sqlalchemy.orm import Session
from app.services.conversation_state_manager import ConversationStateManager
from app.models import ModelConfig, AgentConfig
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class MasterAgentRouter:
"""Master Agent 路由器
让 Master Agent 作为"大脑",负责:
1. 分析用户意图
2. 选择最合适的 Sub Agent
3. 决定是否需要多 Agent 协作
4. 管理会话上下文
优势:
- 更智能的决策(基于完整上下文)
- 减少 LLM 调用次数
- 架构更清晰Master Agent 真正起作用)
"""
def __init__(
self,
db: Session,
master_agent_config: AgentConfig,
master_model_config: ModelConfig,
sub_agents: Dict[str, Any],
state_manager: ConversationStateManager,
enable_rule_fast_path: bool = True
):
"""初始化 Master Agent 路由器
Args:
db: 数据库会话
master_agent_config: Master Agent 配置
master_model_config: Master Agent 使用的模型配置
sub_agents: 子 Agent 配置字典
state_manager: 会话状态管理器
enable_rule_fast_path: 是否启用规则快速路径(性能优化)
"""
self.db = db
self.master_agent_config = master_agent_config
self.master_model_config = master_model_config
self.sub_agents = sub_agents
self.state_manager = state_manager
self.enable_rule_fast_path = enable_rule_fast_path
logger.info(
"Master Agent 路由器初始化",
extra={
"master_agent": master_agent_config.name,
"sub_agent_count": len(sub_agents),
"enable_rule_fast_path": enable_rule_fast_path
}
)
async def route(
self,
message: str,
conversation_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""智能路由决策
Args:
message: 用户消息
conversation_id: 会话 ID
variables: 变量参数
Returns:
路由决策结果
"""
logger.info(
"开始 Master Agent 路由",
extra={
"message_length": len(message),
"conversation_id": conversation_id
}
)
# 1. 获取会话状态
state = None
if conversation_id:
state = self.state_manager.get_state(conversation_id)
# 2. 尝试规则快速路径(可选的性能优化)
if self.enable_rule_fast_path:
rule_result = self._try_rule_fast_path(message, state)
if rule_result:
logger.info(
"规则快速路径命中",
extra={
"agent_id": rule_result["selected_agent_id"],
"confidence": rule_result["confidence"]
}
)
# 更新会话状态
if conversation_id:
self.state_manager.update_state(
conversation_id,
rule_result["selected_agent_id"],
message,
rule_result.get("topic"),
rule_result["confidence"]
)
return rule_result
# 3. 调用 Master Agent 做决策
decision = await self._master_agent_decide(message, state, variables)
# 4. 更新会话状态
if conversation_id:
self.state_manager.update_state(
conversation_id,
decision["selected_agent_id"],
message,
decision.get("topic"),
decision["confidence"]
)
logger.info(
"Master Agent 路由完成",
extra={
"agent_id": decision["selected_agent_id"],
"strategy": decision["strategy"],
"confidence": decision["confidence"]
}
)
return decision
def _try_rule_fast_path(
self,
message: str,
state: Optional[Dict[str, Any]]
) -> Optional[Dict[str, Any]]:
"""尝试规则快速路径(性能优化)
对于明确的关键词匹配,直接返回结果,不调用 Master Agent
Args:
message: 用户消息
state: 会话状态
Returns:
如果命中规则返回决策结果,否则返回 None
"""
# 定义高置信度关键词规则
high_confidence_rules = [
{
"keywords": ["数学", "方程", "计算", "求解"],
"agent_role": "数学",
"confidence_threshold": 0.9
},
{
"keywords": ["物理", "力学", "电路", "光学"],
"agent_role": "物理",
"confidence_threshold": 0.9
},
{
"keywords": ["订单", "发货", "物流", "快递"],
"agent_role": "订单",
"confidence_threshold": 0.9
},
{
"keywords": ["退款", "退货", "售后"],
"agent_role": "退款",
"confidence_threshold": 0.9
}
]
message_lower = message.lower()
for rule in high_confidence_rules:
matched_keywords = [kw for kw in rule["keywords"] if kw in message_lower]
if matched_keywords:
confidence = len(matched_keywords) / len(rule["keywords"])
if confidence >= rule["confidence_threshold"]:
# 查找对应的 agent
for agent_id, agent_data in self.sub_agents.items():
agent_info = agent_data.get("info", {})
if agent_info.get("role") == rule["agent_role"]:
return {
"selected_agent_id": agent_id,
"confidence": confidence,
"strategy": "rule_fast_path",
"reasoning": f"关键词匹配: {', '.join(matched_keywords)}",
"topic": rule["agent_role"],
"need_collaboration": False,
"routing_method": "rule"
}
return None
async def _master_agent_decide(
self,
message: str,
state: Optional[Dict[str, Any]],
variables: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
"""让 Master Agent 做路由决策
Args:
message: 用户消息
state: 会话状态
variables: 变量参数
Returns:
决策结果
"""
# 1. 构建决策 prompt
prompt = self._build_decision_prompt(message, state, variables)
# 2. 调用 Master Agent 的 LLM
try:
response = await self._call_master_agent_llm(prompt)
# 3. 解析决策
decision = self._parse_decision(response)
# 4. 验证决策
decision = self._validate_decision(decision)
return decision
except Exception as e:
logger.error(f"Master Agent 决策失败: {str(e)}")
# 降级到默认 agent
return self._get_fallback_decision(message)
def _build_decision_prompt(
self,
message: str,
state: Optional[Dict[str, Any]],
variables: Optional[Dict[str, Any]]
) -> str:
"""构建 Master Agent 的决策 prompt
Args:
message: 用户消息
state: 会话状态
variables: 变量参数
Returns:
prompt 字符串
"""
# 1. 构建 Sub Agent 描述(简化版,提升性能)
agent_descriptions = []
for agent_id, agent_data in self.sub_agents.items():
agent_info = agent_data.get("info", {})
name = agent_info.get("name", "未命名")
role = agent_info.get("role", "")
capabilities = agent_info.get("capabilities", [])
# 简化格式:一行描述
desc = f"- {agent_id}: {name}"
if role:
desc += f" ({role})"
if capabilities:
desc += f" - {', '.join(capabilities[:3])}" # 只取前3个能力
agent_descriptions.append(desc)
agents_text = "\n".join(agent_descriptions)
# 2. 构建会话上下文
context_text = ""
if state:
current_agent = state.get("current_agent_id")
last_topic = state.get("last_topic")
same_turns = state.get("same_agent_turns", 0)
if current_agent:
context_text = f"""
当前会话上下文:
- 当前使用的 Agent: {current_agent}
- 上一个主题: {last_topic}
- 连续使用轮数: {same_turns}
"""
# 获取第一个可用的 agent_id 作为示例
example_agent_id = next(iter(self.sub_agents.keys())) if self.sub_agents else "agent_id"
# 3. 构建完整 prompt简化版提升性能
prompt = f"""路由任务:分析问题并选择合适的 Agent。
可用 Agent
{agents_text}
{context_text}
问题:"{message}"
返回 JSON 格式决策:
**情况1单一问题最常见**
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": false, "reasoning": "选择理由"}}
**情况2需要拆分成多个独立子问题**
当用户问题包含多个完全独立的子问题时使用(如"写诗+做数学题")。
必须提供 sub_questions 数组,每个子问题必须指定 agent_id
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": true, "need_decomposition": true,
"sub_questions": [
{{"question": "具体的子问题1", "agent_id": "{example_agent_id}", "order": 1, "depends_on": []}},
{{"question": "具体的子问题2", "agent_id": "{example_agent_id}", "order": 2, "depends_on": []}}
],
"collaboration_strategy": "decomposition", "reasoning": "问题包含X个独立子问题"}}
**情况3需要多个Agent协作分析同一问题**
{{"selected_agent_id": "{example_agent_id}", "confidence": 0.9, "need_collaboration": true,
"collaboration_agents": [{{"agent_id": "{example_agent_id}", "role": "primary", "task": "主要任务", "order": 1}}],
"collaboration_strategy": "sequential", "reasoning": "需要多角度分析"}}
重要规则:
1. selected_agent_id 必须从上面的可用 Agent 列表中选择
2. 如果选择情况2拆分sub_questions 数组不能为空,必须包含具体的子问题
3. 每个子问题的 agent_id 也必须从可用列表中选择
4. depends_on 表示依赖关系(如 [1] 表示依赖第1个子问题的结果
5. 大多数情况应该选择情况1单一Agent只有明确需要时才拆分或协作
6. 只做路由决策,不要回答问题内容
请返回 JSON"""
return prompt
async def _call_master_agent_llm(self, prompt: str) -> str:
"""调用 Master Agent 的 LLM
Args:
prompt: 提示词
Returns:
LLM 响应
"""
try:
from app.core.models import RedBearLLM
from app.core.models.base import RedBearModelConfig
from app.models import ModelApiKey, ModelType
# 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == self.master_model_config.id,
ModelApiKey.is_active == True
).first()
if not api_key_config:
raise Exception("Master Agent 模型没有可用的 API Key")
logger.info(
"调用 Master Agent LLM",
extra={
"provider": api_key_config.provider,
"model_name": api_key_config.model_name
}
)
# 创建 RedBearModelConfig
model_config = RedBearModelConfig(
model_name=api_key_config.model_name,
provider=api_key_config.provider,
api_key=api_key_config.api_key,
base_url=api_key_config.api_base,
temperature=0.3, # 决策任务使用较低温度
max_tokens=1000
)
# 创建 LLM 实例
llm = RedBearLLM(model_config, type=ModelType.CHAT)
# 调用模型
response = await llm.ainvoke(prompt)
# 提取响应内容
if hasattr(response, 'content'):
return response.content
else:
return str(response)
except Exception as e:
logger.error(f"Master Agent LLM 调用失败: {str(e)}")
raise
def _parse_decision(self, response: str) -> Dict[str, Any]:
"""解析 Master Agent 的决策
Args:
response: LLM 响应
Returns:
决策字典
"""
try:
# 提取 JSON
json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', response, re.DOTALL)
if json_match:
decision = json.loads(json_match.group())
# 添加默认值
decision.setdefault("confidence", 0.8)
decision.setdefault("strategy", "master_agent")
decision.setdefault("routing_method", "master_agent")
decision.setdefault("need_collaboration", False)
decision.setdefault("collaboration_agents", [])
return decision
else:
raise ValueError("无法从响应中提取 JSON")
except Exception as e:
logger.error(f"解析 Master Agent 决策失败: {str(e)}")
logger.debug(f"原始响应: {response}")
raise
def _validate_decision(self, decision: Dict[str, Any]) -> Dict[str, Any]:
"""验证决策的有效性
Args:
decision: 决策字典
Returns:
验证后的决策
"""
# 验证 agent_id
selected_agent_id = decision.get("selected_agent_id")
if selected_agent_id not in self.sub_agents:
logger.warning(f"Master Agent 返回的 agent_id 无效: {selected_agent_id}")
# 使用默认 agent
decision["selected_agent_id"] = self._get_default_agent_id()
decision["confidence"] = 0.5
decision["reasoning"] = "原始选择无效,使用默认 Agent"
# 验证 confidence
confidence = decision.get("confidence", 0.8)
if not isinstance(confidence, (int, float)) or confidence < 0 or confidence > 1:
decision["confidence"] = 0.8
# 验证协作 agents
if decision.get("need_collaboration"):
# 检查是否是问题拆分模式
if decision.get("need_decomposition") or decision.get("sub_questions"):
# 问题拆分模式
sub_questions = decision.get("sub_questions", [])
# 验证每个子问题
valid_sub_questions = []
for sub_q in sub_questions:
if isinstance(sub_q, dict):
agent_id = sub_q.get("agent_id")
question = sub_q.get("question")
if agent_id in self.sub_agents and question:
# 确保有必要的字段
sub_q.setdefault("order", len(valid_sub_questions) + 1)
sub_q.setdefault("depends_on", [])
valid_sub_questions.append(sub_q)
else:
# 记录验证失败的原因
logger.warning(
"子问题验证失败",
extra={
"agent_id": agent_id,
"has_question": bool(question),
"agent_exists": agent_id in self.sub_agents if agent_id else False,
"available_agents": list(self.sub_agents.keys())
}
)
decision["sub_questions"] = valid_sub_questions
# 如果所有子问题都验证失败,降级处理
if not valid_sub_questions and sub_questions:
logger.warning(
"所有子问题验证失败,降级到单 Agent 模式",
extra={
"original_sub_question_count": len(sub_questions),
"available_agents": list(self.sub_agents.keys())
}
)
# 降级:取消协作标记,使用默认 Agent
decision["need_collaboration"] = False
decision["need_decomposition"] = False
decision["collaboration_strategy"] = None
# 选择第一个可用的 Agent
if self.sub_agents:
first_agent_id = next(iter(self.sub_agents.keys()))
decision["selected_agent_id"] = first_agent_id
logger.info(f"降级使用默认 Agent: {first_agent_id}")
# 设置协作策略为 decomposition
decision["collaboration_strategy"] = "decomposition"
logger.info(
"问题拆分决策验证完成",
extra={
"sub_question_count": len(valid_sub_questions),
"strategy": "decomposition"
}
)
else:
# 普通协作模式
collaboration_agents = decision.get("collaboration_agents", [])
# 如果是简单列表格式,转换为详细格式
if collaboration_agents and isinstance(collaboration_agents[0], str):
collaboration_agents = [
{
"agent_id": agent_id,
"role": "primary" if i == 0 else "secondary",
"task": "协作处理",
"order": i + 1
}
for i, agent_id in enumerate(collaboration_agents)
]
# 验证每个协作 agent
valid_agents = []
for agent_info in collaboration_agents:
if isinstance(agent_info, dict):
agent_id = agent_info.get("agent_id")
if agent_id in self.sub_agents:
# 确保有必要的字段
agent_info.setdefault("role", "secondary")
agent_info.setdefault("task", "协作处理")
agent_info.setdefault("order", len(valid_agents) + 1)
valid_agents.append(agent_info)
elif isinstance(agent_info, str) and agent_info in self.sub_agents:
valid_agents.append({
"agent_id": agent_info,
"role": "secondary",
"task": "协作处理",
"order": len(valid_agents) + 1
})
decision["collaboration_agents"] = valid_agents
# 设置默认协作策略
if not decision.get("collaboration_strategy"):
decision["collaboration_strategy"] = "sequential"
logger.info(
"协作决策验证完成",
extra={
"collaboration_agent_count": len(valid_agents),
"strategy": decision.get("collaboration_strategy")
}
)
return decision
def _get_fallback_decision(self, message: str) -> Dict[str, Any]:
"""获取降级决策(当 Master Agent 失败时)
Args:
message: 用户消息
Returns:
降级决策
"""
default_agent_id = self._get_default_agent_id()
return {
"selected_agent_id": default_agent_id,
"confidence": 0.5,
"strategy": "fallback",
"reasoning": "Master Agent 决策失败,使用默认 Agent",
"topic": "未知",
"need_collaboration": False,
"collaboration_agents": [],
"routing_method": "fallback"
}
def _get_default_agent_id(self) -> str:
"""获取默认 Agent ID
Returns:
默认 Agent ID
"""
if self.sub_agents:
# 返回第一个 agent
return next(iter(self.sub_agents.keys()))
return "default-agent"

View File

@@ -29,8 +29,7 @@ from app.core.memory.agent.utils.mcp_tools import get_mcp_server_config
from app.core.memory.agent.utils.type_classifier import status_typle
from app.db import get_db
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# TODO 后续更新
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.utils.llm.llm_utils import get_llm_client
from app.schemas.memory_storage_schema import ApiResponse, ok, fail
from app.models.knowledge_model import Knowledge, KnowledgeType
@@ -697,7 +696,7 @@ class MemoryAgentService:
logger.info(f"知识库类型统计成功 (workspace_id={current_workspace_id}): {result}")
else:
# 没有提供 workspace_id所有知识库类型返回 0
logger.info(f"未提供 workspace_id知识库类型统计全部为 0")
logger.info("未提供 workspace_id知识库类型统计全部为 0")
except Exception as e:
logger.error(f"知识库类型统计失败: {e}")
@@ -720,7 +719,7 @@ class MemoryAgentService:
end_users = []
for app_id in app_ids:
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list])
end_users.extend(h for h in end_user_orm_list)
# 统计所有宿主的 Chunk 总数
total_chunks = 0
@@ -742,7 +741,7 @@ class MemoryAgentService:
else:
# 没有 workspace_id 时,返回 0
result["memory"] = 0
logger.info(f"未提供 workspace_idmemory 统计为 0")
logger.info("未提供 workspace_idmemory 统计为 0")
except Exception as e:
logger.error(f"Neo4j memory统计失败: {e}", exc_info=True)

View File

@@ -1,5 +1,5 @@
from sqlalchemy.orm import Session
from typing import List
from typing import List, Optional
import uuid
from fastapi import HTTPException
@@ -24,6 +24,30 @@ from app.core.logging_config import get_business_logger
business_logger = get_business_logger()
def get_current_workspace_type(
db: Session,
workspace_id: uuid.UUID,
current_user: User
) -> Optional[str]:
"""获取当前工作空间类型"""
business_logger.info(f"获取工作空间类型: workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
from app.repositories.workspace_repository import get_workspace_by_id
workspace = get_workspace_by_id(db, workspace_id)
if not workspace:
business_logger.warning(f"工作空间不存在: workspace_id={workspace_id}")
return None
business_logger.info(f"成功获取工作空间类型: {workspace.storage_type}")
return workspace.storage_type
except Exception as e:
business_logger.error(f"获取工作空间类型失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_end_users(
db: Session,
workspace_id: uuid.UUID,
@@ -169,7 +193,7 @@ def get_workspace_memory_list(
business_logger.warning(f"获取宿主列表失败: {str(e)}")
result["hosts"] = []
business_logger.info(f"成功获取工作空间记忆列表")
business_logger.info("成功获取工作空间记忆列表")
return result
except HTTPException:
@@ -587,7 +611,7 @@ async def get_chunk_insight(
"insight": insight
}
business_logger.info(f"成功获取chunk洞察")
business_logger.info("成功获取chunk洞察")
return result
except Exception as e:

View File

@@ -168,7 +168,7 @@ async def get_document_chunks(
# 执行分页查询
try:
api_logger.debug(f"开始执行文档块查询")
api_logger.debug("开始执行文档块查询")
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
total, items = vector_service.search_by_segment(
document_id=str(document_id),
@@ -516,67 +516,18 @@ async def write_rag(group_id, message, user_rag_memory_id):
db=db,
current_user=current_user
)
await parse_document_by_id(document, db=db, current_user=current_user)
# 重新查询刚创建的文档ID
new_document_id = find_document_id_by_kb_and_filename(
db=db,
kb_id=user_rag_memory_id,
file_name=f"{group_id}.txt"
)
if new_document_id:
await parse_document_by_id(new_document_id, db=db, current_user=current_user)
else:
api_logger.error(f"创建文档后无法找到文档ID: group_id={group_id}")
return result
finally:
# 确保数据库会话被关闭
db.close()
# 在异步环境中调用示例
async def example_usage():
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 创建 CustomTextFileCreate 对象
title = '2f6ff1eb-50c7-4765-8e89-e4566be19122'
create_data = CustomTextFileCreate(
title=title,
content="莫扎特在巴黎经历母亲去世后返回萨尔茨堡,他随后创作的交响曲主题是否与格鲁克在维也纳推动的“改革歌剧”理念存在共通之处?贝多芬早年曾师从海顿,而海顿又受雇于埃斯特哈齐家族——这种师承体系如何影响了当时欧洲宫廷音乐的传承结构?斯卡拉歌剧院选择萨列里的歌剧作为开幕演出,是否与当时米兰政治环境和奥地利宫廷影响有关?"
)
# 创建用户对象
current_user = SimpleUser("6243c125-9420-402c-bbb5-d1977811ac96")
# 上传文件
result = await memory_konwledges_up(
kb_id="c71df60a-36a6-4759-a2ce-101e3087b401",
parent_id="c71df60a-36a6-4759-a2ce-101e3087b401",
create_data=create_data,
db=db,
current_user=current_user
)
print(result)
#找到document_id
# 使用刚创建的文档ID进行解析
document = find_document_id_by_kb_and_filename(db=db, kb_id="c71df60a-36a6-4759-a2ce-101e3087b401", file_name=f"{title}.txt")
print('====',document)
res___=await parse_document_by_id(document, db=db, current_user=current_user)
print(res___)
# result='e8cf9ace-d1a9-4af2-b0c4-3fc94f4f8042'
# document_id='d22e8173-50d0-4e10-a7de-aa638ef893bc'
#
# '''更新块'''
#
# new_content = "这是新的 chunk 内容,用来覆盖原来的内容"
# # 构造 ChunkUpdate 对象
# update_data = ChunkCreate(content=new_content)
# updated_chunk = await create_document_chunk(
# kb_id= result,
# document_id=document_id,
# create_data= update_data,
# db=db,
# current_user=current_user
# )
# print(updated_chunk)
return '','',''
if __name__ == "__main__":
# asyncio.run(example_usage())
asyncio.run(write_rag('1111','22222',"c71df60a-36a6-4759-a2ce-101e3087b401"))
db.close()

View File

@@ -7,9 +7,12 @@ Handles business logic for memory storage operations.
from typing import Dict, List, Optional, Any
import os
import json
from sqlalchemy.orm import Session
from dotenv import load_dotenv
from app.models.user_model import User
from app.models.end_user_model import EndUser
from app.core.logging_config import get_logger
from app.schemas.memory_storage_schema import (
ConfigFilter,
@@ -23,11 +26,10 @@ from app.schemas.memory_storage_schema import (
)
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
# TODO 后续更新
# from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
# from app.core.memory.analytics.memory_insight import MemoryInsight
# from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
# from app.core.memory.analytics.user_summary import generate_user_summary
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.memory_insight import MemoryInsight
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.core.memory.analytics.user_summary import generate_user_summary
from app.repositories.data_config_repository import DataConfigRepository
logger = get_logger(__name__)
@@ -52,7 +54,7 @@ class MemoryStorageService:
Returns:
Storage information dictionary
"""
logger.info(f"Getting storage info ")
logger.info("Getting storage info ")
# Empty wrapper - implement your logic here
result = {
@@ -65,30 +67,28 @@ class MemoryStorageService:
class DataConfigService: # 数据配置服务类PostgreSQL
"""Service layer for config params CRUD.
The DB connection is optional; when absent, methods return a failure
response containing an SQL preview to aid integration.
使用 SQLAlchemy ORM 进行数据库操作。
"""
def __init__(self, db_conn: Optional[Any] = None) -> None:
self.db_conn = db_conn
# --- Driver compatibility helpers ---
@staticmethod
def _is_pgsql_conn(conn: Any) -> bool: # 判断是否为 PostgreSQL 连接
mod = type(conn).__module__
return ("psycopg2" in mod) or ("psycopg" in mod)
def __init__(self, db: Session) -> None:
"""初始化服务
Args:
db: SQLAlchemy 数据库会话
"""
self.db = db
@staticmethod
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
from datetime import datetime
for item in data_list:
for field in ['created_at', 'updated_at']:
if field in item and item[field] is not None:
value = item[field]
dt = None
# 如果是 datetime 对象,直接使用
if isinstance(value, datetime):
dt = value
@@ -98,24 +98,21 @@ class DataConfigService: # 数据配置服务类PostgreSQL
dt = datetime.fromisoformat(value.replace('Z', '+00:00'))
except Exception:
pass # 保持原值
# 转换为 YYYYMMDDHHmmss 格式
if dt:
item[field] = dt.strftime('%Y%m%d%H%M%S')
return data_list
# --- Create ---
def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
# 如果workspace_id存在且模型字段未全部指定则自动获取
if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]):
configs = self._get_workspace_configs(params.workspace_id)
if configs is None:
raise ValueError(f"工作空间不存在: workspace_id={params.workspace_id}")
# 只在未指定时填充(允许手动覆盖)
if not params.llm_id:
params.llm_id = configs.get('llm')
@@ -123,19 +120,16 @@ class DataConfigService: # 数据配置服务类PostgreSQL
params.embedding_id = configs.get('embedding')
if not params.rerank_id:
params.rerank_id = configs.get('rerank')
query, qparams = DataConfigRepository.build_insert(params)
cur = self.db_conn.cursor()
# PostgreSQL 使用 psycopg2 的命名参数格式
cur.execute(query, qparams)
self.db_conn.commit()
return {"affected": getattr(cur, "rowcount", None)}
config = DataConfigRepository.create(self.db, params)
self.db.commit()
return {"affected": 1, "config_id": config.config_id}
def _get_workspace_configs(self, workspace_id) -> Optional[Dict[str, Any]]:
"""获取工作空间模型配置(内部方法,便于测试)"""
from app.db import SessionLocal
from app.repositories.workspace_repository import get_workspace_models_configs
db_session = SessionLocal()
try:
return get_workspace_models_configs(db_session, workspace_id)
@@ -143,121 +137,91 @@ class DataConfigService: # 数据配置服务类PostgreSQL
db_session.close()
# --- Delete ---
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置名称
query, qparams = DataConfigRepository.build_delete(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
# 如果没有任何行被删除,抛出异常
if not affected:
def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID
success = DataConfigRepository.delete(self.db, key.config_id)
if not success:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Update ---
def update(self, update: ConfigUpdate) -> Dict[str, Any]: # 部分更新配置参数
query, qparams = DataConfigRepository.build_update(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
def update_extracted(self, update: ConfigUpdateExtracted) -> Dict[str, Any]: # 更新记忆萃取引擎配置参数
query, qparams = DataConfigRepository.build_update_extracted(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update_extracted(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Forget config params ---
def update_forget(self, update: ConfigUpdateForget) -> Dict[str, Any]: # 保存遗忘引擎的配置
query, qparams = DataConfigRepository.build_update_forget(update)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
affected = getattr(cur, "rowcount", None)
self.db_conn.commit()
if not affected:
config = DataConfigRepository.update_forget(self.db, update)
if not config:
raise ValueError("未找到配置")
return {"affected": affected}
return {"affected": 1}
# --- Read ---
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_extracted(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
def get_extracted(self, key: ConfigKey) -> Dict[str, Any]: # 获取萃取配置参数
result = DataConfigRepository.get_extracted_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
return result
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取配置参数
query, qparams = DataConfigRepository.build_select_forget(key)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
cur = self.db_conn.cursor()
cur.execute(query, qparams)
row = cur.fetchone()
if not row:
def get_forget(self, key: ConfigKey) -> Dict[str, Any]: # 获取遗忘配置参数
result = DataConfigRepository.get_forget_config(self.db, key.config_id)
if not result:
raise ValueError("未找到配置")
# Map row to dict (DB-API cursor description available for many drivers)
columns = [desc[0] for desc in cur.description]
raw = {columns[i]: row[i] for i in range(len(columns))}
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
data_list = self._convert_timestamps_to_format([raw])
return data_list[0] if data_list else raw
return result
# --- Read All ---
def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数
query, qparams = DataConfigRepository.build_select_all(workspace_id)
if self.db_conn is None:
raise ConnectionError("数据库连接未配置")
configs = DataConfigRepository.get_all(self.db, workspace_id)
cur = self.db_conn.cursor()
cur.execute(query, qparams)
rows = cur.fetchall()
# 如果没有查询到任何配置,返回空列表(这是正常情况,不应抛出异常)
if not rows:
return []
# Map rows to list of dicts
columns = [desc[0] for desc in cur.description]
data_list = [dict(zip(columns, row)) for row in rows]
# 将 UUID 转换为字符串,将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
for item in data_list:
if 'workspace_id' in item and item['workspace_id'] is not None:
item['workspace_id'] = str(item['workspace_id'])
# 将 ORM 对象转换为字典列表
data_list = []
for config in configs:
config_dict = {
"config_id": config.config_id,
"config_name": config.config_name,
"config_desc": config.config_desc,
"workspace_id": str(config.workspace_id) if config.workspace_id else None,
"group_id": config.group_id,
"user_id": config.user_id,
"apply_id": config.apply_id,
"llm_id": config.llm_id,
"embedding_id": config.embedding_id,
"rerank_id": config.rerank_id,
"llm": config.llm,
"enable_llm_dedup_blockwise": config.enable_llm_dedup_blockwise,
"enable_llm_disambiguation": config.enable_llm_disambiguation,
"deep_retrieval": config.deep_retrieval,
"t_type_strict": config.t_type_strict,
"t_name_strict": config.t_name_strict,
"t_overall": config.t_overall,
"state": config.state,
"chunker_strategy": config.chunker_strategy,
"pruning_enabled": config.pruning_enabled,
"pruning_scene": config.pruning_scene,
"pruning_threshold": config.pruning_threshold,
"enable_self_reflexion": config.enable_self_reflexion,
"iteration_period": config.iteration_period,
"reflexion_range": config.reflexion_range,
"baseline": config.baseline,
"statement_granularity": config.statement_granularity,
"include_dialogue_context": config.include_dialogue_context,
"max_context": config.max_context,
"lambda_time": config.lambda_time,
"lambda_mem": config.lambda_mem,
"offset": config.offset,
"created_at": config.created_at,
"updated_at": config.updated_at,
}
data_list.append(config_dict)
# 将 created_at 和 updated_at 转换为 YYYYMMDDHHmmss 格式
return self._convert_timestamps_to_format(data_list)
@@ -296,7 +260,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
# 应用内存覆写并刷新常量(在导入主管线前)
# 注意:仅在内存中覆写配置,不修改 runtime.json 文件
from app.core.memory.utils.config.definitions import reload_configuration_from_database
ok_override = reload_configuration_from_database(cid)
if not ok_override:
raise RuntimeError("运行时覆写失败config_id 无效或刷新常量失败")
@@ -308,7 +272,7 @@ class DataConfigService: # 数据配置服务类PostgreSQL
logger.info(f"[PILOT_RUN] Calling pipeline_main with dialogue_text length: {len(dialogue_text)}, is_pilot_run=True")
await pipeline_main(dialogue_text=dialogue_text, is_pilot_run=True)
logger.info("[PILOT_RUN] pipeline_main completed")
# 调用自我反思
# data = [
# {
@@ -346,10 +310,10 @@ class DataConfigService: # 数据配置服务类PostgreSQL
result_path = settings.get_memory_output_path("extracted_result.json")
if not os.path.isfile(result_path):
raise FileNotFoundError(f"试运行完成,但未找到提取结果文件: {result_path}")
with open(result_path, "r", encoding="utf-8") as rf:
extracted_result = json.load(rf)
extracted_result["self_reflexion"] = reflexion_result if reflexion_result else None
return {
"config_id": cid,
@@ -405,7 +369,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
DataConfigRepository.SEARCH_FOR_ALL,
group_id=end_user_id,
)
# 检查结果是否为空或长度不足
if not result or len(result) < 4:
data = {
@@ -418,7 +382,7 @@ async def search_all(end_user_id: Optional[str] = None) -> Dict[str, Any]:
},
}
return data
data = {
"total": result[-1]["Count"],
"counts": {
@@ -504,14 +468,27 @@ async def search_entity_graph(end_user_id: Optional[str] = None) -> Dict[str, An
return data
async def analytics_hot_memory_tags(end_user_id: Optional[str] = None, limit: int = 10) -> List[Dict[str, Any]]:
async def analytics_hot_memory_tags(
db: Session,
current_user: User,
limit: int = 10
) -> List[Dict[str, Any]]:
"""
获取热门记忆标签按数量排序并返回前N个
"""
workspace_id = current_user.current_workspace_id
# 获取更多标签供LLM筛选获取limit*4个标签
raw_limit = limit * 4
tags = await get_hot_memory_tags(end_user_id, limit=raw_limit)
from app.services.memory_dashboard_service import get_workspace_end_users
end_users = get_workspace_end_users(db, workspace_id, current_user)
tags = []
for end_user in end_users:
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
if tag:
# 将每个用户的标签列表展平到总列表中
tags.extend(tag)
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)

View File

@@ -53,13 +53,13 @@ class ModelParameterMerger:
# 应用模型配置参数
if model_config_params:
for key in default_params.keys():
for key in default_params:
if key in model_config_params:
merged[key] = model_config_params[key]
# 应用 Agent 配置参数(优先级最高)
if agent_config_params:
for key in default_params.keys():
for key in default_params:
if key in agent_config_params and agent_config_params[key] is not None:
merged[key] = agent_config_params[key]
@@ -67,7 +67,7 @@ class ModelParameterMerger:
merged = {k: v for k, v in merged.items() if v is not None}
logger.debug(
f"参数合并完成",
"参数合并完成",
extra={
"model_params": model_config_params,
"agent_params": agent_config_params,

View File

@@ -24,17 +24,17 @@ class ModelConfigService:
"""模型配置服务"""
@staticmethod
def get_model_by_id(db: Session, model_id: uuid.UUID) -> ModelConfig:
def get_model_by_id(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据ID获取模型配置"""
model = ModelConfigRepository.get_by_id(db, model_id)
model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def get_model_list(db: Session, query: ModelConfigQuery) -> PageData:
def get_model_list(db: Session, query: ModelConfigQuery, tenant_id: uuid.UUID | None = None) -> PageData:
"""获取模型配置列表"""
models, total = ModelConfigRepository.get_list(db, query)
models, total = ModelConfigRepository.get_list(db, query, tenant_id=tenant_id)
pages = math.ceil(total / query.pagesize) if total > 0 else 0
return PageData(
@@ -48,17 +48,17 @@ class ModelConfigService:
)
@staticmethod
def get_model_by_name(db: Session, name: str) -> ModelConfig:
def get_model_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name)
model = ModelConfigRepository.get_by_name(db, name, tenant_id=tenant_id)
if not model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
return model
@staticmethod
def search_models_by_name(db: Session, name: str, limit: int = 10) -> List[ModelConfig]:
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
"""按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, limit)
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod
async def validate_model_config(
@@ -220,10 +220,10 @@ class ModelConfigService:
}
@staticmethod
async def create_model(db: Session, model_data: ModelConfigCreate) -> ModelConfig:
async def create_model(db: Session, model_data: ModelConfigCreate, tenant_id: uuid.UUID) -> ModelConfig:
"""创建模型配置"""
# 检查名称是否已存在
if ModelConfigRepository.get_by_name(db, model_data.name):
# 检查名称是否已存在(同租户内)
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证配置
@@ -247,6 +247,8 @@ class ModelConfigService:
# 事务处理
api_key_data = model_data.api_keys
model_config_data = model_data.dict(exclude={"api_keys", "skip_validation"})
# 添加租户ID
model_config_data["tenant_id"] = tenant_id
model = ModelConfigRepository.create(db, model_config_data)
db.flush() # 获取生成的 ID
@@ -263,28 +265,28 @@ class ModelConfigService:
return model
@staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate) -> ModelConfig:
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id)
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name):
if ModelConfigRepository.get_by_name(db, model_data.name, tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
db.commit()
db.refresh(model)
return model
@staticmethod
def delete_model(db: Session, model_id: uuid.UUID) -> bool:
def delete_model(db: Session, model_id: uuid.UUID, tenant_id: uuid.UUID | None = None) -> bool:
"""删除模型配置"""
if not ModelConfigRepository.get_by_id(db, model_id):
if not ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id):
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
success = ModelConfigRepository.delete(db, model_id)
success = ModelConfigRepository.delete(db, model_id, tenant_id=tenant_id)
db.commit()
return success

File diff suppressed because it is too large Load Diff

View File

@@ -116,7 +116,7 @@ class MultiAgentService:
self.db.refresh(config)
logger.info(
f"创建多 Agent 配置成功",
"创建多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id),
@@ -320,7 +320,7 @@ class MultiAgentService:
self.db.refresh(config)
logger.info(
f"创建多 Agent 配置成功",
"创建多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id),
@@ -363,12 +363,12 @@ class MultiAgentService:
# if data.routing_rules is not None:
# config.routing_rules = [convert_uuids_to_str(rule.model_dump()) for rule in data.routing_rules] if data.routing_rules else None
if data.execution_config is None:
execution_config_data = {}
elif isinstance(data.execution_config, dict):
execution_config_data = convert_uuids_to_str(data.execution_config)
else:
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
if data.execution_config is not None:
if isinstance(data.execution_config, dict):
execution_config_data = convert_uuids_to_str(data.execution_config)
else:
execution_config_data = convert_uuids_to_str(data.execution_config.model_dump())
config.execution_config = execution_config_data
if data.aggregation_strategy is not None:
config.aggregation_strategy = data.aggregation_strategy
@@ -380,7 +380,7 @@ class MultiAgentService:
self.db.refresh(config)
logger.info(
f"更新多 Agent 配置成功",
"更新多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id)
@@ -399,11 +399,12 @@ class MultiAgentService:
if not config:
raise ResourceNotFoundException("多 Agent 配置", str(app_id))
self.db.delete(config)
# 逻辑删除多 Agent 配置
config.is_active = False
self.db.commit()
logger.info(
f"删除多 Agent 配置成功",
"删除多 Agent 配置成功",
extra={
"config_id": str(config.id),
"app_id": str(app_id)
@@ -542,7 +543,7 @@ class MultiAgentService:
self.db.refresh(config)
logger.info(
f"添加子 Agent 成功",
"添加子 Agent 成功",
extra={
"config_id": str(config.id),
"agent_id": str(agent_id),
@@ -586,7 +587,7 @@ class MultiAgentService:
self.db.refresh(config)
logger.info(
f"移除子 Agent 成功",
"移除子 Agent 成功",
extra={
"config_id": str(config.id),
"agent_id": str(agent_id)

View File

@@ -92,7 +92,7 @@ class ReleaseShareService:
share = self.repo.create(share)
logger.info(
f"创建分享配置",
"创建分享配置",
extra={
"share_id": str(share.id),
"release_id": str(release.id),
@@ -130,7 +130,7 @@ class ReleaseShareService:
share = self.repo.update(share)
logger.info(
f"更新分享配置",
"更新分享配置",
extra={
"share_id": str(share.id),
"release_id": str(share.release_id)
@@ -214,7 +214,7 @@ class ReleaseShareService:
self.repo.delete(share)
logger.info(
f"删除分享配置",
"删除分享配置",
extra={
"share_id": str(share.id),
"release_id": str(release_id)
@@ -249,7 +249,7 @@ class ReleaseShareService:
share = self.repo.update(share)
logger.info(
f"重新生成分享 token",
"重新生成分享 token",
extra={
"share_id": str(share.id),
"old_token": old_token,

View File

@@ -4,7 +4,7 @@ import time
import asyncio
from typing import Optional, Dict, Any, AsyncGenerator
from sqlalchemy.orm import Session
from app.services.memory_konwledges_server import write_rag
from app.models import ReleaseShare, AppRelease, Conversation
from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool
@@ -16,6 +16,8 @@ from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger()
@@ -88,7 +90,7 @@ class SharedChatService:
return conversation
except ResourceNotFoundException:
logger.warning(
f"会话不存在,将创建新会话",
"会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)}
)
@@ -102,7 +104,7 @@ class SharedChatService:
)
logger.info(
f"为分享链接创建新会话",
"为分享链接创建新会话",
extra={
"conversation_id": str(conversation.id),
"share_token": share_token,
@@ -121,17 +123,24 @@ class SharedChatService:
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""聊天(非流式)"""
actual_config_id = None
config_id=actual_config_id
from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
if variables is None:
variables = {}
@@ -199,10 +208,11 @@ class SharedChatService:
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag=False
if memory==True:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag=True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
@@ -234,6 +244,7 @@ class SharedChatService:
max_tokens=model_parameters.get("max_tokens", 2000),
system_prompt=system_prompt,
tools=tools,
)
# 加载历史消息
@@ -254,7 +265,11 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
)
# 保存消息
@@ -280,6 +295,7 @@ class SharedChatService:
# )
elapsed_time = time.time() - start_time
return {
"conversation_id": conversation.id,
@@ -301,7 +317,9 @@ class SharedChatService:
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
memory: bool = True,
storage_type:Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
) -> AsyncGenerator[str, None]:
"""聊天(流式)"""
from app.core.agent.langchain_agent import LangChainAgent
@@ -312,6 +330,9 @@ class SharedChatService:
import json
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
if variables is None:
variables = {}
@@ -381,9 +402,11 @@ class SharedChatService:
tools.append(kb_tool)
# 添加长期记忆工具
memory_flag=False
if memory:
memory_config = config.get("memory", {})
if memory_config.get("enabled") and user_id:
memory_flag = True
memory_tool = create_long_term_memory_tool(memory_config, user_id)
tools.append(memory_tool)
@@ -440,7 +463,11 @@ class SharedChatService:
message=message,
history=history,
context=None,
end_user_id=user_id
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
):
full_content += chunk
# 发送消息块事件
@@ -464,13 +491,14 @@ class SharedChatService:
"usage": {}
}
)
# 发送结束事件
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)}
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
logger.info(
f"流式聊天完成",
"流式聊天完成",
extra={
"conversation_id": str(conversation.id),
"elapsed_time": elapsed_time,
@@ -539,13 +567,19 @@ class SharedChatService:
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
) -> Dict[str, Any]:
"""多 Agent 聊天(非流式)"""
from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
if variables is None:
variables = {}
@@ -609,6 +643,8 @@ class SharedChatService:
"sub_results": result.get("sub_results")
}
)
return {
"conversation_id": conversation.id,
@@ -630,11 +666,16 @@ class SharedChatService:
variables: Optional[Dict[str, Any]] = None,
password: Optional[str] = None,
web_search: bool = False,
memory: bool = True
memory: bool = True,
storage_type: Optional[str] = None,
user_rag_memory_id:Optional[str] = None
) -> AsyncGenerator[str, None]:
"""多 Agent 聊天(流式)"""
start_time = time.time()
actual_config_id=None
config_id=actual_config_id
if variables is None:
variables = {}
@@ -741,13 +782,14 @@ class SharedChatService:
)
logger.info(
f"多 Agent 流式聊天完成",
"多 Agent 流式聊天完成",
extra={
"conversation_id": str(conversation.id),
"elapsed_time": elapsed_time,
"message_length": len(full_content)
}
)
except (GeneratorExit, asyncio.CancelledError):
# 生成器被关闭或任务被取消,正常退出

View File

@@ -75,7 +75,7 @@ class SmartRouter:
}
"""
logger.info(
f"开始智能路由",
"开始智能路由",
extra={
"message_length": len(message),
"conversation_id": conversation_id,
@@ -170,7 +170,7 @@ class SmartRouter:
}
logger.info(
f"路由完成",
"路由完成",
extra={
"agent_id": agent_id,
"strategy": strategy,
@@ -421,6 +421,6 @@ class SmartRouter:
# 否则使用第一个子 Agent
if self.sub_agents:
return list(self.sub_agents.keys())[0]
return next(iter(self.sub_agents.keys()))
return "default-agent"

View File

@@ -64,7 +64,7 @@ def create_initial_superuser(db: Session):
raise BusinessException(
f"初始超级用户创建失败: {str(e)}",
code=BizCode.DB_ERROR,
context={"username": username, "email": email},
context={"username": user_in.username, "email": user_in.email},
cause=e
)
@@ -423,7 +423,7 @@ def update_last_login_time(db: Session, user_id: uuid.UUID) -> User:
business_logger.info(f"用户最后登录时间更新成功: {db_user.username} (ID: {user_id})")
return db_user
except HTTPException:
except (BusinessException, PermissionDeniedException):
raise
except Exception as e:
business_logger.error(f"更新用户最后登录时间失败: user_id={user_id} - {str(e)}")
@@ -438,19 +438,14 @@ async def change_password(db: Session, user_id: uuid.UUID, old_password: str, ne
# 检查权限:只能修改自己的密码
if current_user.id != user_id:
business_logger.warning(f"用户尝试修改他人密码: current_user={current_user.id}, target_user={user_id}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only change your own password"
)
raise PermissionDeniedException("You can only change your own password")
try:
# 获取用户
db_user = user_repository.get_user_by_id(db=db, user_id=user_id)
if not db_user:
business_logger.warning(f"用户不存在: {user_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="User not found"
)
raise BusinessException("User not found", code=BizCode.USER_NOT_FOUND)
# 验证旧密码
if not verify_password(old_password, db_user.hashed_password):

View File

@@ -148,7 +148,7 @@ def create_workspace(
description=f"工作空间 {workspace.name} 的默认知识库",
avatar='',
type=KnowledgeType.General,
permission_id=PermissionType.Private,
permission_id=PermissionType.Memory,
embedding_id=uuid.UUID(getenv('KB_embedding_id')) if None else embedding,
reranker_id=uuid.UUID(getenv('KB_reranker_id')) if None else rerank,
llm_id=uuid.UUID(getenv('KB_llm_id')) if None else llm,
@@ -459,7 +459,7 @@ def get_workspace_invites(
def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
"""验证邀请令牌"""
business_logger.info(f"验证邀请令牌")
business_logger.info("验证邀请令牌")
# 生成令牌哈希
token_hash = hashlib.sha256(token.encode()).hexdigest()
@@ -469,7 +469,7 @@ def validate_invite_token(db: Session, token: str) -> InviteValidateResponse:
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning(f"邀请令牌无效")
business_logger.warning("邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态和过期时间
@@ -511,7 +511,7 @@ def accept_workspace_invite(
invite = invite_repo.get_invite_by_token_hash(token_hash)
if not invite:
business_logger.warning(f"邀请令牌无效")
business_logger.warning("邀请令牌无效")
raise BusinessException("邀请令牌无效", BizCode.WORKSPACE_INVITE_NOT_FOUND)
# 检查邀请状态
@@ -522,7 +522,7 @@ def accept_workspace_invite(
# 检查过期时间
now = datetime.datetime.now()
if invite.expires_at < now:
business_logger.warning(f"邀请已过期")
business_logger.warning("邀请已过期")
# 标记为过期
invite_repo.update_invite_status(invite.id, InviteStatus.expired)
raise BusinessException("邀请已过期", BizCode.WORKSPACE_INVITE_EXPIRED)
@@ -547,7 +547,7 @@ def accept_workspace_invite(
)
if existing_member:
business_logger.info(f"用户已是工作空间成员,更新邀请状态")
business_logger.info("用户已是工作空间成员,更新邀请状态")
invite_repo.update_invite_status(
invite.id,
InviteStatus.accepted,
@@ -739,6 +739,34 @@ def get_workspace_storage_type(
return workspace.storage_type
def get_workspace_storage_type_without_auth(
db: Session,
workspace_id: uuid.UUID,
) -> Optional[str]:
"""获取工作空间的存储类型(无需权限验证,用于公开分享等场景)
Args:
db: 数据库会话
workspace_id: 工作空间ID
Returns:
storage_type: 存储类型字符串,如果未设置则返回 None
"""
business_logger.info(f"获取工作空间 {workspace_id} 的存储类型(无权限验证)")
# 查询工作空间
workspace = workspace_repository.get_workspace_by_id(db=db, workspace_id=workspace_id)
if not workspace:
business_logger.error(f"工作空间不存在: workspace_id={workspace_id}")
raise BusinessException(
code=BizCode.WORKSPACE_NOT_FOUND,
message="工作空间不存在"
)
business_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {workspace.storage_type}")
return workspace.storage_type
def get_workspace_models_configs(
db: Session,
workspace_id: uuid.UUID,