Merge branch 'refs/heads/develop' into feature/agent-tool_xjn
# Conflicts: # api/app/core/agent/langchain_agent.py # api/app/core/tools/mcp/client.py
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
|
||||
)
|
||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||
|
||||
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||
|
||||
# Thread-local storage for connection pools.
|
||||
# Each thread (and each forked process) gets its own pool to avoid
|
||||
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||
# and stale connections after fork in --pool=prefork.
|
||||
_thread_local = threading.local()
|
||||
|
||||
|
||||
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||
"""Return a Redis client whose connection pool is bound to the current
|
||||
thread, process **and** event loop.
|
||||
|
||||
The pool is recreated when:
|
||||
- The PID changes (fork, Celery --pool=prefork)
|
||||
- The thread has no pool yet (Celery --pool=threads)
|
||||
- The previously-cached event loop has been closed (Celery tasks call
|
||||
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||
"""
|
||||
current_pid = os.getpid()
|
||||
cached_loop = getattr(_thread_local, "loop", None)
|
||||
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||
|
||||
if not hasattr(_thread_local, "pool") \
|
||||
or getattr(_thread_local, "pid", None) != current_pid \
|
||||
or loop_stale:
|
||||
_thread_local.pid = current_pid
|
||||
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||
try:
|
||||
_thread_local.loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
_thread_local.loop = None
|
||||
_thread_local.pool = ConnectionPool.from_url(
|
||||
_REDIS_URL,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD,
|
||||
decode_responses=True,
|
||||
max_connections=5,
|
||||
health_check_interval=30,
|
||||
)
|
||||
|
||||
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||
|
||||
|
||||
async def get_redis_connection():
|
||||
"""获取Redis连接"""
|
||||
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
||||
val = json.dumps(val, ensure_ascii=False)
|
||||
|
||||
if expire is not None:
|
||||
# 设置带过期时间的键值
|
||||
await aio_redis.set(key, val, ex=expire)
|
||||
else:
|
||||
# 设置永久键值
|
||||
await aio_redis.set(key, val)
|
||||
except Exception as e:
|
||||
logger.error(f"Redis set错误: {str(e)}")
|
||||
|
||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中活动统计缓存: {key}")
|
||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(workspace_id)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
8
api/app/cache/memory/interest_memory.py
vendored
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
from app.aioRedis import get_thread_safe_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
value = await get_thread_safe_redis().get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
result = await get_thread_safe_redis().delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
|
||||
@@ -57,7 +57,6 @@ def list_apps(
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -66,7 +65,7 @@ def list_apps(
|
||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||
- 当提供 api_key 参数时,查找该 API Key 关联的应用
|
||||
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||
"""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.api_key_model import ApiKey
|
||||
@@ -74,23 +73,34 @@ def list_apps(
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
|
||||
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
|
||||
if api_key:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == api_key,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
ids = str(matched_id) if matched_id else ""
|
||||
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||
if search:
|
||||
search = search.strip()
|
||||
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||
if len(search) >= 10:
|
||||
matched_id = db.execute(
|
||||
sa_select(ApiKey.resource_id).where(
|
||||
ApiKey.workspace_id == workspace_id,
|
||||
ApiKey.api_key == search,
|
||||
ApiKey.resource_id.isnot(None),
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if matched_id:
|
||||
# 找到 API Key,直接返回关联的应用
|
||||
ids = str(matched_id)
|
||||
|
||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
||||
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||
if ids is not None:
|
||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
return success(data=items)
|
||||
if app_ids:
|
||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||
# 返回标准分页格式
|
||||
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
# ids 为空时,返回空列表
|
||||
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||
return success(data=PageData(page=meta, items=[]))
|
||||
|
||||
# 正常分页查询
|
||||
items_orm, total = app_service.list_apps(
|
||||
|
||||
@@ -3,17 +3,16 @@ import uuid
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, desc, func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.app_service import AppService
|
||||
from app.services.app_log_service import AppLogService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||
logger = get_business_logger()
|
||||
@@ -25,52 +24,35 @@ def list_app_logs(
|
||||
app_id: uuid.UUID,
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""查看应用下所有会话记录(分页)
|
||||
|
||||
- 支持按 user_id 筛选
|
||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||
- 按最新更新时间倒序排列
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversations, total = log_service.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
is_draft=is_draft
|
||||
)
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
total = int(db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(db.scalars(stmt).all())
|
||||
|
||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={"app_id": str(app_id), "total": total, "page": page}
|
||||
)
|
||||
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@@ -86,44 +68,22 @@ def get_app_log_detail(
|
||||
|
||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||
- 消息 meta_data 包含模型名、token 用量等信息
|
||||
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 验证应用访问权限
|
||||
service = AppService(db)
|
||||
service.get_app(app_id, workspace_id)
|
||||
app_service = AppService(db)
|
||||
app_service.get_app(app_id, workspace_id)
|
||||
|
||||
# 查询会话(确保属于该应用和工作空间)
|
||||
conversation = db.scalars(
|
||||
select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True),
|
||||
)
|
||||
).first()
|
||||
|
||||
if not conversation:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = list(db.scalars(
|
||||
select(Message)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at)
|
||||
).all())
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
# 使用 Service 层查询
|
||||
log_service = AppLogService(db)
|
||||
conversation = log_service.get_conversation_detail(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
detail = AppLogConversationDetail.model_validate(conversation)
|
||||
|
||||
return success(data=detail)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
||||
|
||||
优化策略:
|
||||
1. 批量查询 end_users(一次查询而非循环)
|
||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
||||
3. RAG 模式使用批量查询(一次 SQL)
|
||||
4. 只返回必要字段减少数据传输
|
||||
5. 添加短期缓存减少重复查询
|
||||
6. 并发执行配置查询和记忆数量查询
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
||||
"memory_num": {"total": 数量},
|
||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
||||
}
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含宿主列表和分页信息
|
||||
"""
|
||||
import asyncio
|
||||
import json
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 尝试从缓存获取(30秒缓存)
|
||||
cache_key = f"end_users:workspace:{workspace_id}"
|
||||
try:
|
||||
cached_data = await aio_redis_get(cache_key)
|
||||
if cached_data:
|
||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
||||
|
||||
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||
if workspace_id is None:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 获取当前空间类型
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
||||
|
||||
# 获取 end_users(已优化为批量查询)
|
||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
api_logger.info("工作空间下没有宿主")
|
||||
# 缓存空结果,避免重复查询
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
return success(data=[], msg="宿主列表获取成功")
|
||||
|
||||
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||
return success(data={
|
||||
"items": [],
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:并发查询(带并发限制)
|
||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
||||
MAX_CONCURRENT_QUERIES = 10
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
||||
|
||||
async def get_neo4j_memory_num(end_user_id: str):
|
||||
async with semaphore:
|
||||
try:
|
||||
return await memory_storage_service.search_all(end_user_id)
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
||||
return {"total": 0}
|
||||
|
||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
||||
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果(优化:使用列表推导式)
|
||||
result = []
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
result.append({
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
})
|
||||
|
||||
# 写入缓存(30秒过期)
|
||||
try:
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
try:
|
||||
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
# 构建分页响应
|
||||
result = {
|
||||
"items": items,
|
||||
"page": {
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
}
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
||||
ForgettingCurveRequest,
|
||||
ForgettingCurveResponse,
|
||||
ForgettingCurvePoint,
|
||||
PendingNodesResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||
async def get_pending_nodes(
|
||||
end_user_id: str,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
此接口独立分页,与 /stats 接口分离。
|
||||
|
||||
Args:
|
||||
end_user_id: 组ID(即 end_user_id,必填)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
current_user: 当前用户
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||
|
||||
Examples:
|
||||
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||
|
||||
Notes:
|
||||
- page 从1开始,pagesize 必须大于0
|
||||
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 验证 end_user_id 必填
|
||||
if not end_user_id:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||
|
||||
# 通过 end_user_id 获取关联的 config_id
|
||||
try:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
if config_id is None:
|
||||
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||
|
||||
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||
|
||||
# 验证分页参数
|
||||
if page < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||
if pagesize < 1:
|
||||
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 调用服务层获取待遗忘节点列表
|
||||
result = await forget_service.get_pending_nodes(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
config_id=config_id,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response_data = PendingNodesResponse(**result)
|
||||
|
||||
return success(data=response_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||
|
||||
|
||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||
async def get_forgetting_curve(
|
||||
request: ForgettingCurveRequest,
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
@@ -259,8 +260,41 @@ def get_conversation(
|
||||
conv_service = ConversationService(db)
|
||||
messages = conv_service.get_messages(conversation_id)
|
||||
|
||||
# 构建响应
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
||||
file_ids = []
|
||||
message_file_id_map = {}
|
||||
|
||||
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||
for idx, m in enumerate(messages):
|
||||
if m.role == "assistant" and m.meta_data:
|
||||
audio_url = m.meta_data.get("audio_url")
|
||||
if not audio_url:
|
||||
continue
|
||||
try:
|
||||
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||
except (ValueError, IndexError):
|
||||
# audio_url 无法解析为 UUID,标记为 unknown
|
||||
m.meta_data["audio_status"] = "unknown"
|
||||
continue
|
||||
|
||||
file_ids.append(file_id)
|
||||
message_file_id_map[idx] = file_id
|
||||
|
||||
# 批量查询所有相关的 FileMetadata
|
||||
file_status_map = {}
|
||||
if file_ids:
|
||||
file_metas = (
|
||||
db.query(FileMetadata)
|
||||
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||
.all()
|
||||
)
|
||||
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||
|
||||
# 第二次遍历:将查询结果映射回消息
|
||||
for idx, file_id in message_file_id_map.items():
|
||||
m = messages[idx]
|
||||
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||
|
||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||
conv_dict["messages"] = [
|
||||
conversation_schema.Message.model_validate(m) for m in messages
|
||||
]
|
||||
@@ -320,6 +354,16 @@ async def chat(
|
||||
other_id=other_id,
|
||||
original_user_id=user_id
|
||||
)
|
||||
|
||||
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||
if not new_end_user.memory_config_id:
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
memory_config_service = MemoryConfigService(db)
|
||||
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||
if memory_config_id:
|
||||
new_end_user.memory_config_id = memory_config_id
|
||||
db.commit()
|
||||
db.refresh(new_end_user)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
# appid = share.app_id
|
||||
@@ -410,30 +454,6 @@ async def chat(
|
||||
agent_config = agent_config_4_app_release(release)
|
||||
|
||||
if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
@@ -459,20 +479,6 @@ async def chat(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
# 非流式返回
|
||||
# result = await service.chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
result = await app_chat_service.agnet_chat(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
@@ -531,48 +537,6 @@ async def chat(
|
||||
)
|
||||
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
# 多 Agent 流式返回
|
||||
# if payload.stream:
|
||||
# async def event_generator():
|
||||
# async for event in service.multi_agent_chat_stream(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# ):
|
||||
# yield event
|
||||
|
||||
# return StreamingResponse(
|
||||
# event_generator(),
|
||||
# media_type="text/event-stream",
|
||||
# headers={
|
||||
# "Cache-Control": "no-cache",
|
||||
# "Connection": "keep-alive",
|
||||
# "X-Accel-Buffering": "no"
|
||||
# }
|
||||
# )
|
||||
|
||||
# # 多 Agent 非流式返回
|
||||
# result = await service.multi_agent_chat(
|
||||
# share_token=share_token,
|
||||
# message=payload.message,
|
||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
# user_id=str(new_end_user.id), # 转换为字符串
|
||||
# variables=payload.variables,
|
||||
# password=password,
|
||||
# web_search=payload.web_search,
|
||||
# memory=payload.memory,
|
||||
# storage_type=storage_type,
|
||||
# user_rag_memory_id=user_rag_memory_id
|
||||
# )
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
认证方式: API Key
|
||||
"""
|
||||
from fastapi import APIRouter
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
||||
|
||||
# 创建 V1 API 路由器
|
||||
service_router = APIRouter()
|
||||
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
|
||||
service_router.include_router(rag_api_file_controller.router)
|
||||
service_router.include_router(rag_api_chunk_controller.router)
|
||||
service_router.include_router(memory_api_controller.router)
|
||||
service_router.include_router(end_user_api_controller.router)
|
||||
|
||||
__all__ = ["service_router"]
|
||||
|
||||
@@ -91,7 +91,7 @@ async def chat(
|
||||
|
||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||
other_id = payload.user_id
|
||||
workspace_id = app.workspace_id
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
|
||||
92
api/app/controllers/service/end_user_api_controller.py
Normal file
92
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""End User 服务接口 - 基于 API Key 认证"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.api_key_auth import require_api_key
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def create_end_user(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
message: str = Body(..., description="Request body"),
|
||||
):
|
||||
"""
|
||||
Create or retrieve an end user for the workspace.
|
||||
|
||||
Creates a new end user and connects it to a memory configuration.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
|
||||
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||
memory configuration. If not provided, falls back to the workspace default config.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = CreateEndUserRequest(**body)
|
||||
workspace_id = api_key_auth.workspace_id
|
||||
|
||||
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}")
|
||||
|
||||
# Resolve memory_config_id: explicit > workspace default
|
||||
memory_config_id = None
|
||||
config_service = MemoryConfigService(db)
|
||||
|
||||
if payload.memory_config_id:
|
||||
try:
|
||||
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||
except ValueError:
|
||||
raise BusinessException(
|
||||
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||
if not config:
|
||||
raise BusinessException(
|
||||
f"Memory config not found: {payload.memory_config_id}",
|
||||
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||
)
|
||||
memory_config_id = config.config_id
|
||||
else:
|
||||
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||
if default_config:
|
||||
memory_config_id = default_config.config_id
|
||||
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||
else:
|
||||
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||
app_id=api_key_auth.resource_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=payload.other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -111,6 +111,18 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
|
||||
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||
if current_user.external_source:
|
||||
from premium.sso.models import SSOSource
|
||||
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||
if source and source.permissions:
|
||||
result_schema.permissions = source.permissions
|
||||
else:
|
||||
result_schema.permissions = []
|
||||
else:
|
||||
result_schema.permissions = ["all"]
|
||||
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@@ -135,7 +147,6 @@ def get_tenant_superusers(
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=ApiResponse)
|
||||
def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
|
||||
@@ -11,18 +11,14 @@ LangChain Agent 封装
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType, ModelProvider
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -226,10 +222,9 @@ class LangChainAgent:
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
"""
|
||||
messages = []
|
||||
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||
|
||||
# 添加系统提示词
|
||||
messages.append(SystemMessage(content=self.system_prompt))
|
||||
|
||||
# 添加历史消息
|
||||
if history:
|
||||
@@ -320,12 +315,7 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -333,32 +323,12 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||
context: 上下文信息(如知识库检索结果)
|
||||
files: 多模态文件
|
||||
|
||||
Returns:
|
||||
Dict: 包含 content 和元数据的字典
|
||||
"""
|
||||
message_chat = message
|
||||
start_time = time.time()
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -445,9 +415,6 @@ class LangChainAgent:
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -478,12 +445,7 @@ class LangChainAgent:
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
config_id: Optional[str] = None,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> AsyncGenerator[str | int, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -491,6 +453,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件
|
||||
|
||||
Yields:
|
||||
str: 消息内容块
|
||||
@@ -501,23 +464,6 @@ class LangChainAgent:
|
||||
logger.info(f" Has tools: {bool(self.tools)}")
|
||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||
logger.info("=" * 80)
|
||||
message_chat = message
|
||||
actual_config_id = config_id
|
||||
# If config_id is None, try to get from end_user's connected config
|
||||
if actual_config_id is None and end_user_id:
|
||||
try:
|
||||
db = next(get_db())
|
||||
try:
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
actual_config_id = connected_config.get("memory_config_id")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
@@ -527,17 +473,18 @@ class LangChainAgent:
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
yielded_content = False
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content = ''
|
||||
try:
|
||||
last_event = {}
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
last_event = event
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
|
||||
@@ -551,7 +498,6 @@ class LangChainAgent:
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -562,18 +508,15 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
@@ -584,7 +527,6 @@ class LangChainAgent:
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
@@ -595,22 +537,18 @@ class LangChainAgent:
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
# 记录工具调用(可选)
|
||||
elif kind == "on_tool_start":
|
||||
@@ -620,17 +558,14 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||
yield stream_total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
||||
actual_config_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
@@ -118,7 +98,7 @@ async def write(
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if not result:
|
||||
logger.warning(f"No write data found for user {end_user_id}")
|
||||
return
|
||||
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data) == scope:
|
||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_has_history:
|
||||
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
return
|
||||
end_user_visit_count += 1
|
||||
if end_user_visit_count < scope:
|
||||
redis_messages.extend(langchain_messages)
|
||||
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||
else:
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = redis_messages
|
||||
redis_messages.extend(langchain_messages)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
write_message_task.delay(
|
||||
end_user_id, # end_user_id: User ID
|
||||
redis_messages, # message: JSON string format message list
|
||||
config_id, # config_id: Configuration ID string
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""Time-based memory processing"""
|
||||
count_store.update_sessions_count(end_user_id, 0, [])
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
|
||||
@@ -1,49 +1,25 @@
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
async def long_term_storage(
|
||||
long_term_type: str,
|
||||
langchain_messages: list,
|
||||
memory_config_id: str,
|
||||
end_user_id: str,
|
||||
scope: int = 6
|
||||
):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
memory_config_id: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
if langchain_messages is None:
|
||||
langchain_messages = []
|
||||
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
config_id=memory_config_id, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
# Dialogue window with 6 rounds of conversation
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
# Time-based strategy
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
# Aggregate judgment
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
async def write_long_term(
|
||||
storage_type: str,
|
||||
end_user_id: str,
|
||||
messages: list[dict],
|
||||
user_rag_memory_id: str,
|
||||
actual_config_id: str
|
||||
):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
messages: message list
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
message_content = []
|
||||
for message in messages:
|
||||
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||
messages_string = "\n".join(message_content)
|
||||
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||
else:
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
await long_term_storage(long_term_type=CHUNK,
|
||||
langchain_messages=messages,
|
||||
memory_config_id=actual_config_id,
|
||||
end_user_id=end_user_id,
|
||||
scope=SCOPE)
|
||||
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
@@ -3,8 +3,9 @@ import uuid
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
|
||||
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
result_items = sort_and_limit_results(filtered_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
self.uuid = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"id": self.uuid,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
|
||||
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
messages: list[dict] = deserialize_messages(messages_str)
|
||||
return int(count), messages
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
|
||||
def update_sessions_count(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_count: int,
|
||||
messages: Any
|
||||
) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'count', str(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
|
||||
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
||||
return bool(results[0])
|
||||
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 批量获取所有数据
|
||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
||||
deleted_count += len(batch)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
|
||||
|
||||
@@ -151,11 +151,6 @@ async def write(
|
||||
|
||||
# Step 3: Save all data to Neo4j database
|
||||
step_start = time.time()
|
||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
||||
try:
|
||||
await create_fulltext_indexes()
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
@@ -279,5 +274,21 @@ async def write(
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
try:
|
||||
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||
if underlying is None:
|
||||
continue
|
||||
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||
inner = getattr(underlying, '_model', underlying)
|
||||
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||
http_client = getattr(inner, 'async_client', None)
|
||||
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||
await http_client.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
||||
self.max_retries = self.config.max_retries
|
||||
self.timeout = self.config.timeout
|
||||
|
||||
logger.info(
|
||||
logger.debug(
|
||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
||||
type=type_
|
||||
)
|
||||
|
||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||
|
||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||
"""
|
||||
|
||||
@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def message_has_files(message: "ConversationMessage") -> bool:
|
||||
"""检查消息是否包含文件。
|
||||
|
||||
Args:
|
||||
message: 待检查的消息对象
|
||||
|
||||
Returns:
|
||||
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||
"""
|
||||
return message.files and len(message.files) > 0
|
||||
|
||||
|
||||
class DialogExtractionResponse(BaseModel):
|
||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||
|
||||
@@ -128,7 +140,7 @@ class SemanticPruner:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
5. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
@@ -482,6 +494,11 @@ class SemanticPruner:
|
||||
"""
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
@@ -549,6 +566,11 @@ class SemanticPruner:
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||
continue
|
||||
|
||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||
if self._is_filler_message(m):
|
||||
@@ -801,6 +823,12 @@ class SemanticPruner:
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||
if message_has_files(m):
|
||||
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||
continue
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
|
||||
@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[PerceptualEdge],
|
||||
dict
|
||||
list[DialogData]
|
||||
]:
|
||||
"""
|
||||
运行完整的知识提取流水线(优化版:并行执行)
|
||||
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
|
||||
statement_entity_edges,
|
||||
entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
) = await self._run_dedup_and_write_summary(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
|
||||
dialog_data_list,
|
||||
)
|
||||
|
||||
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||
if not is_pilot_run:
|
||||
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||
|
||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||
return (
|
||||
dialogue_nodes,
|
||||
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
|
||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||
else:
|
||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||
if first_alias:
|
||||
# 确保 first_alias 不是占位名称
|
||||
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||
db.add(EndUserInfo(
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=first_alias,
|
||||
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
|
||||
|
||||
|
||||
|
||||
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
|
||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||
|
||||
这个方法直接返回 LLM 提取的别名列表,不做任何修改。
|
||||
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||
第一个别名将被用作 other_name。
|
||||
|
||||
Args:
|
||||
entity_nodes: 实体节点列表
|
||||
|
||||
Returns:
|
||||
别名列表(保持 LLM 提取的原始顺序)
|
||||
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||
"""
|
||||
USER_NAMES = {'用户', '我', 'User', 'I'}
|
||||
for entity in entity_nodes:
|
||||
if getattr(entity, 'name', '').strip() in USER_NAMES:
|
||||
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
aliases = getattr(entity, 'aliases', []) or []
|
||||
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
|
||||
return aliases
|
||||
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||
return filtered
|
||||
return []
|
||||
|
||||
|
||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
|
||||
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||
cypher = """
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
|
||||
aliases = result[0].get('aliases') or []
|
||||
if not aliases:
|
||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||
return aliases
|
||||
return []
|
||||
# 过滤掉占位名称,防止历史脏数据传播
|
||||
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||
return filtered
|
||||
|
||||
def _resolve_other_name(
|
||||
self,
|
||||
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
|
||||
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
||||
|
||||
决策规则:
|
||||
- 为空 → 用本次对话第一个别名
|
||||
- 为空或为占位名称 → 用本次对话第一个别名
|
||||
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
||||
- 否则 → 保持不变(返回 None)
|
||||
|
||||
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||
"""
|
||||
if not current or not current.strip():
|
||||
return current_aliases[0].strip() if current_aliases else None
|
||||
# 当前值为空或为占位名称时,需要更新
|
||||
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||
candidate = current_aliases[0].strip() if current_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
if current not in neo4j_aliases:
|
||||
return neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||
# 确保候选值不是占位名称
|
||||
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||
return None
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
|
||||
list[StatementChunkEdge],
|
||||
list[StatementEntityEdge],
|
||||
list[EntityEntityEdge],
|
||||
list[DialogData],
|
||||
dict
|
||||
]:
|
||||
"""
|
||||
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
|
||||
statement_chunk_edges,
|
||||
dedup_statement_entity_edges,
|
||||
dedup_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
final_entity_nodes = dedup_entity_nodes
|
||||
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
|
||||
final_entity_entity_edges = dedup_entity_entity_edges
|
||||
else:
|
||||
# 正式模式:执行完整的两阶段去重
|
||||
result_tuple = await dedup_layers_and_merge_and_return(
|
||||
(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = await dedup_layers_and_merge_and_return(
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
|
||||
llm_client=self.llm_client,
|
||||
)
|
||||
|
||||
# 解包返回值
|
||||
(
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
final_entity_nodes,
|
||||
_,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dedup_details,
|
||||
) = result_tuple
|
||||
|
||||
# 保存去重消歧的详细记录到实例变量
|
||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||
|
||||
result_tuple = (
|
||||
dialogue_nodes,
|
||||
chunk_nodes,
|
||||
statement_nodes,
|
||||
final_entity_nodes,
|
||||
statement_chunk_edges,
|
||||
final_statement_entity_edges,
|
||||
final_entity_entity_edges,
|
||||
dialog_data_list,
|
||||
dedup_details,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||
|
||||
@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
|
||||
{% if language == "zh" %}
|
||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||
- 用户的真实姓名:放入 aliases
|
||||
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||
- 示例:
|
||||
* "我叫李明" → name="用户", aliases=["李明"]
|
||||
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||
{% else %}
|
||||
- User entity name field: use "User" or "I"
|
||||
- User's real name: put in aliases
|
||||
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||
- Examples:
|
||||
* "I'm John" → name="User", aliases=["John"]
|
||||
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||
{% endif %}
|
||||
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: str,
|
||||
access_key_secret: str,
|
||||
bucket_name: str,
|
||||
connect_timeout: int = 30,
|
||||
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
|
||||
):
|
||||
"""
|
||||
Initialize the OSSStorage backend.
|
||||
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
|
||||
access_key_id: The Aliyun access key ID.
|
||||
access_key_secret: The Aliyun access key secret.
|
||||
bucket_name: The name of the OSS bucket.
|
||||
connect_timeout: Connection timeout in seconds (default: 30).
|
||||
multipart_threshold: File size threshold for multipart upload (default: 10MB).
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.bucket_name = bucket_name
|
||||
self.multipart_threshold = multipart_threshold
|
||||
|
||||
try:
|
||||
auth = oss2.Auth(access_key_id, access_key_secret)
|
||||
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
||||
# 设置超时和重试
|
||||
self.bucket = oss2.Bucket(
|
||||
auth,
|
||||
endpoint,
|
||||
bucket_name,
|
||||
connect_timeout=connect_timeout
|
||||
)
|
||||
logger.info(
|
||||
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
||||
)
|
||||
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
|
||||
if content_type:
|
||||
headers["Content-Type"] = content_type
|
||||
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB per part
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||
|
||||
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
||||
return file_key
|
||||
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e.message}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to upload file to OSS: {e}",
|
||||
message=f"Failed to upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
|
||||
) -> int:
|
||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||
buf = io.BytesIO()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
upload_id = None
|
||||
|
||||
try:
|
||||
# 收集流数据
|
||||
total_size = 0
|
||||
async for chunk in stream:
|
||||
if not chunk:
|
||||
continue
|
||||
buf.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
content = buf.getvalue()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||
return len(content)
|
||||
|
||||
if not content:
|
||||
raise StorageUploadError(
|
||||
message="Empty stream content",
|
||||
file_key=file_key,
|
||||
)
|
||||
|
||||
# 大文件使用分片上传
|
||||
if len(content) > self.multipart_threshold:
|
||||
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
|
||||
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
|
||||
parts = []
|
||||
part_size = 5 * 1024 * 1024 # 5MB
|
||||
part_num = 1
|
||||
|
||||
for offset in range(0, len(content), part_size):
|
||||
chunk = content[offset:offset + part_size]
|
||||
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||
part_num += 1
|
||||
|
||||
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||
else:
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
|
||||
return total_size
|
||||
|
||||
except OssError as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
if upload_id:
|
||||
try:
|
||||
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||
except:
|
||||
pass
|
||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e}",
|
||||
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
finally:
|
||||
buf.close()
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error downloading file {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e.message}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
||||
raise StorageDownloadError(
|
||||
message=f"Failed to download file from OSS: {e}",
|
||||
message=f"Failed to download file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error deleting file {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e.message}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
||||
raise StorageDeleteError(
|
||||
message=f"Failed to delete file from OSS: {e}",
|
||||
message=f"Failed to delete file from OSS: {str(e)}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
def merge_activate_state(x, y):
|
||||
return {
|
||||
k: x.get(k, False) or y.get(k, False)
|
||||
for k in set(x) | set(y)
|
||||
}
|
||||
merged = dict(x)
|
||||
for k, v in y.items():
|
||||
merged[k] = merged.get(k, False) or v
|
||||
return merged
|
||||
|
||||
|
||||
def merge_looping_state(x, y):
|
||||
|
||||
@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
|
||||
|
||||
|
||||
class LazyVariableDict:
|
||||
def __init__(self, source, literal):
|
||||
self._source: dict[str, VariableStruct[Any]] = source
|
||||
self._literal: bool = literal
|
||||
self._cache = {}
|
||||
|
||||
def keys(self):
|
||||
return self._source.keys()
|
||||
|
||||
def _resolve(self, key):
|
||||
if key in self._cache:
|
||||
return self._cache[key]
|
||||
var_struct = self._source.get(key)
|
||||
if var_struct is None:
|
||||
raise KeyError(key)
|
||||
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
|
||||
self._cache[key] = value
|
||||
return value
|
||||
|
||||
def get(self, key, default=None):
|
||||
try:
|
||||
return self._resolve(key)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._resolve(key)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key.startswith('_'):
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
||||
return self._resolve(key)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self._source
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._source)
|
||||
|
||||
def __len__(self):
|
||||
return len(self._source)
|
||||
|
||||
|
||||
class VariableSelector:
|
||||
"""变量选择器
|
||||
@@ -117,8 +162,7 @@ class VariablePool:
|
||||
|
||||
@staticmethod
|
||||
def transform_selector(selector):
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
||||
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||
selector = VariableSelector.from_string(variable_literal).path
|
||||
if len(selector) != 2:
|
||||
raise ValueError(f"Selector not valid - {selector}")
|
||||
@@ -303,6 +347,16 @@ class VariablePool:
|
||||
"""
|
||||
return self._get_variable_struct(selector) is not None
|
||||
|
||||
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||
|
||||
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
|
||||
return {
|
||||
ns: LazyVariableDict(vars_dict, literal)
|
||||
for ns, vars_dict in self.variables.items()
|
||||
if ns not in ("sys", "conv")
|
||||
}
|
||||
|
||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||
"""获取所有系统变量
|
||||
|
||||
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
|
||||
var_type=var_type,
|
||||
mut=False
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -552,9 +552,9 @@ class BaseNode(ABC):
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
||||
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
|
||||
system_vars=variable_pool.lazy_namespace("sys", literal=True),
|
||||
strict=strict
|
||||
)
|
||||
|
||||
@@ -579,9 +579,9 @@ class BaseNode(ABC):
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
conv_var=variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=variable_pool.get_all_node_outputs(),
|
||||
system_vars=variable_pool.get_all_system_vars()
|
||||
conv_var=variable_pool.lazy_namespace("conv"),
|
||||
node_outputs=variable_pool.lazy_all_node_outputs(),
|
||||
system_vars=variable_pool.lazy_namespace("sys")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -85,12 +84,7 @@ class LoopRuntime:
|
||||
|
||||
for variable in self.typed_config.cycle_vars:
|
||||
if variable.input_type == ValueInputType.VARIABLE:
|
||||
value = evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
value = self.variable_pool.get_value(variable.value)
|
||||
else:
|
||||
value = TypeTransformer.transform(variable.value, variable.type)
|
||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||
@@ -98,12 +92,7 @@ class LoopRuntime:
|
||||
**self.state
|
||||
)
|
||||
loopstate["node_outputs"][self.node_id] = {
|
||||
variable.name: evaluate_expression(
|
||||
expression=variable.value,
|
||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
||||
system_vars=self.variable_pool.get_all_system_vars(),
|
||||
)
|
||||
variable.name: self.variable_pool.get_value(variable.value)
|
||||
if variable.input_type == ValueInputType.VARIABLE
|
||||
else TypeTransformer.transform(variable.value, variable.type)
|
||||
for variable in self.typed_config.cycle_vars
|
||||
|
||||
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
|
||||
# Reuse cached bytes if already fetched
|
||||
if f.get_content():
|
||||
file_input.set_content(f.get_content())
|
||||
text = await svc._extract_document_text(file_input)
|
||||
text = await svc.extract_document_text(file_input)
|
||||
chunks.append(text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
||||
@@ -1,19 +1,23 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||
from app.core.rag.models.chunk import DocumentChunk
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_read
|
||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
||||
from app.models import knowledge_model, ModelType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.schemas.chunk_schema import RetrieveType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||
self.vector_service: ElasticSearchVector | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
unique.append(doc)
|
||||
return unique
|
||||
|
||||
def _get_existing_kb_ids(self, db, kb_ids):
|
||||
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
Resolve all accessible and valid knowledge base IDs for retrieval.
|
||||
|
||||
This includes:
|
||||
- Private knowledge bases owned by the user
|
||||
- Shared knowledge bases
|
||||
- Source knowledge bases mapped via knowledge sharing relationships
|
||||
|
||||
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||
Args:
|
||||
db: Database session.
|
||||
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
|
||||
query: query string
|
||||
docs: List of document chunk to be rearranged
|
||||
top_k: The number of top-level documents returned
|
||||
|
||||
Returns:
|
||||
list[UUID]: Final list of valid knowledge base IDs.
|
||||
Rearranged document chunk list (sorted in descending order of relevance)
|
||||
|
||||
Raises:
|
||||
ValueError: If the input document list is empty or top_k is invalid
|
||||
"""
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
|
||||
|
||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
||||
|
||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
||||
db=db,
|
||||
filters=filters
|
||||
)
|
||||
|
||||
if share_ids:
|
||||
filters = [
|
||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
||||
reranker = self.get_reranker_model()
|
||||
# parameter validation
|
||||
if not docs:
|
||||
raise ValueError("retrieval chunks be empty")
|
||||
if top_k <= 0:
|
||||
raise ValueError("top_k must be a positive integer")
|
||||
try:
|
||||
# Convert to LangChain Document object
|
||||
documents = [
|
||||
Document(
|
||||
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
||||
db=db,
|
||||
filters=filters
|
||||
|
||||
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||
|
||||
# Sort in descending order based on relevance score
|
||||
reranked_docs.sort(
|
||||
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||
reverse=True
|
||||
)
|
||||
existing_ids.extend(items)
|
||||
return existing_ids
|
||||
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||
result = []
|
||||
for item in reranked_docs[:top_k]:
|
||||
for doc in docs:
|
||||
if doc.page_content == item.page_content:
|
||||
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||
result.append(doc)
|
||||
return result
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||
|
||||
def get_reranker_model(self) -> RedBearRerank:
|
||||
"""
|
||||
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)
|
||||
return reranker
|
||||
|
||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
|
||||
rs = []
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
tasks = []
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
kb_config.kb_id = child.id
|
||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||
return
|
||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
child_kb_config = kb_config.model_copy()
|
||||
child_kb_config.kb_id = child.id
|
||||
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
return rs
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||
match kb_config.retrieve_type:
|
||||
case RetrieveType.PARTICIPLE:
|
||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.SEMANTIC:
|
||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
)
|
||||
case RetrieveType.HYBRID:
|
||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.vector_similarity_weight)
|
||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||
indices=indices,
|
||||
score_threshold=kb_config.similarity_threshold)
|
||||
rs1_task = asyncio.to_thread(
|
||||
vector_service.search_by_vector, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.vector_similarity_weight
|
||||
}
|
||||
)
|
||||
rs2_task = asyncio.to_thread(
|
||||
vector_service.search_by_full_text, **{
|
||||
"query": query,
|
||||
"top_k": kb_config.top_k,
|
||||
"indices": indices,
|
||||
"score_threshold": kb_config.similarity_threshold
|
||||
}
|
||||
)
|
||||
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||
|
||||
# Deduplicate hybrid retrieval results
|
||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||
if not unique_rs:
|
||||
return
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||
rs.extend(
|
||||
await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
|
||||
)
|
||||
)
|
||||
else:
|
||||
rs.extend(sorted(
|
||||
unique_rs,
|
||||
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
)[:kb_config.top_k])
|
||||
case _:
|
||||
raise RuntimeError("Unknown retrieval type")
|
||||
return rs
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
knowledge_bases = self.typed_config.knowledge_bases
|
||||
|
||||
rs = []
|
||||
tasks = []
|
||||
for kb_config in knowledge_bases:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||
if not db_knowledge:
|
||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||
if tasks:
|
||||
result = await asyncio.gather(*tasks)
|
||||
for _ in result:
|
||||
rs.extend(_)
|
||||
|
||||
if not rs:
|
||||
return []
|
||||
if self.typed_config.reranker_id:
|
||||
self.vector_service.reranker = self.get_reranker_model()
|
||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||
final_rs = await asyncio.to_thread(
|
||||
self.rerank,
|
||||
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
|
||||
)
|
||||
else:
|
||||
final_rs = sorted(
|
||||
rs,
|
||||
|
||||
@@ -4,32 +4,33 @@ from typing import Any
|
||||
|
||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class ExpressionEvaluator:
|
||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||
|
||||
|
||||
# Reserved namespaces
|
||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||
|
||||
@classmethod
|
||||
def normalize_template(cls, template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def evaluate(
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
cls,
|
||||
expression: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
"""
|
||||
Safely evaluate an expression using workflow variables.
|
||||
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
|
||||
# Remove Jinja2-style brackets if present
|
||||
expression = expression.strip()
|
||||
expression = cls.normalize_template(expression)
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
expression = re.sub(pattern, r"\1", expression).strip()
|
||||
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
|
||||
|
||||
# Build context for evaluation
|
||||
context = {
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
"conv": conv_vars, # conversation variables
|
||||
"node": node_outputs, # node outputs
|
||||
"sys": system_vars or {}, # system variables
|
||||
}
|
||||
|
||||
context.update(conv_vars)
|
||||
context["nodes"] = node_outputs
|
||||
# context.update(conv_vars)
|
||||
# context["nodes"] = node_outputs
|
||||
context.update(node_outputs)
|
||||
|
||||
|
||||
try:
|
||||
# simpleeval supports safe operations:
|
||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||
result = simple_eval(expression, names=context)
|
||||
return result
|
||||
|
||||
|
||||
except NameNotDefined as e:
|
||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
|
||||
|
||||
except InvalidExpression as e:
|
||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||
raise ValueError(f"Invalid expression syntax: {e}")
|
||||
|
||||
|
||||
except SyntaxError as e:
|
||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||
raise ValueError(f"Syntax error: {e}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||
raise ValueError(f"Expression evaluation failed: {e}")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def evaluate_bool(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate a boolean expression (for conditions).
|
||||
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
)
|
||||
return bool(result)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||
"""
|
||||
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
|
||||
list[str]: List of error messages. Empty if all names are valid.
|
||||
"""
|
||||
errors = []
|
||||
|
||||
|
||||
for var in variables:
|
||||
var_name = var.get("name", "")
|
||||
|
||||
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
|
||||
errors.append(
|
||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||
)
|
||||
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 便捷函数
|
||||
def evaluate_expression(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any]
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate an expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate(
|
||||
@@ -152,11 +152,11 @@ def evaluate_expression(
|
||||
|
||||
|
||||
def evaluate_condition(
|
||||
expression: str,
|
||||
conv_var: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
) -> bool:
|
||||
expression: str,
|
||||
conv_var: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict
|
||||
) -> Any:
|
||||
"""Evaluate a boolean condition expression (convenience function)."""
|
||||
return ExpressionEvaluator.evaluate_bool(
|
||||
expression, conv_var, node_outputs, system_vars
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
模板渲染器
|
||||
Template Renderer
|
||||
|
||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
||||
Provides safe template rendering using Jinja2, supporting variable references
|
||||
and expressions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
@@ -10,11 +11,15 @@ from typing import Any
|
||||
|
||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||
|
||||
from app.core.workflow.engine.variable_pool import LazyVariableDict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||
|
||||
|
||||
class SafeUndefined(Undefined):
|
||||
"""访问未定义属性不会报错,返回空字符串"""
|
||||
"""Return empty string instead of raising error when accessing undefined variables"""
|
||||
__slots__ = ()
|
||||
|
||||
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
|
||||
|
||||
|
||||
class TemplateRenderer:
|
||||
"""模板渲染器"""
|
||||
|
||||
def __init__(self, strict: bool = True):
|
||||
"""初始化渲染器
|
||||
|
||||
"""Initialize renderer
|
||||
|
||||
Args:
|
||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
||||
strict: Whether to enable strict mode (raise error on undefined variables)
|
||||
"""
|
||||
self.strict = strict
|
||||
self.env = Environment(
|
||||
undefined=StrictUndefined if strict else SafeUndefined,
|
||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
||||
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_template(template: str) -> str:
|
||||
pattern = re.compile(
|
||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
||||
)
|
||||
return pattern.sub(
|
||||
"""Normalize template syntax (convert numeric node reference to dict access)"""
|
||||
return _NORMALIZE_PATTERN.sub(
|
||||
r'{{ node["\1"].\2 }}',
|
||||
template
|
||||
)
|
||||
@@ -53,24 +54,24 @@ class TemplateRenderer:
|
||||
def render(
|
||||
self,
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any] | None = None
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict | None = None
|
||||
) -> str:
|
||||
"""渲染模板
|
||||
|
||||
"""Render template
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出结果
|
||||
system_vars: 系统变量
|
||||
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Raises:
|
||||
ValueError: 模板语法错误或变量未定义
|
||||
|
||||
ValueError: If template syntax is invalid or variables are undefined
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.render(
|
||||
@@ -80,122 +81,119 @@ class TemplateRenderer:
|
||||
... {}
|
||||
... )
|
||||
'Hello World!'
|
||||
|
||||
|
||||
>>> renderer.render(
|
||||
... "分析结果: {{node.analyze.output}}",
|
||||
... "Analysis result: {{node.analyze.output}}",
|
||||
... {},
|
||||
... {"analyze": {"output": "正面情绪"}},
|
||||
... {"analyze": {"output": "positive sentiment"}},
|
||||
... {}
|
||||
... )
|
||||
'分析结果: 正面情绪'
|
||||
'Analysis result: positive sentiment'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# Build namespace context
|
||||
context = {
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
||||
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
|
||||
"node": node_outputs, # Node outputs: {{node.node_1.output}}
|
||||
"sys": system_vars, # System variables: {{sys.execution_id}}
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
# if conv_vars:
|
||||
# context.update(conv_vars)
|
||||
#
|
||||
# context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
template = self.normalize_template(template)
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
return tmpl.render(**context)
|
||||
|
||||
except TemplateSyntaxError as e:
|
||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板语法错误: {e}")
|
||||
|
||||
logger.error(f"Template syntax error: {template}, error: {e}")
|
||||
raise ValueError(f"Template syntax error: {e}")
|
||||
except UndefinedError as e:
|
||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
||||
raise ValueError(f"未定义的变量: {e}")
|
||||
|
||||
logger.error(f"Undefined variable in template: {template}, error: {e}")
|
||||
raise ValueError(f"Undefined variable: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
||||
raise ValueError(f"模板渲染失败: {e}")
|
||||
logger.error(f"Template rendering error: {template}, error: {e}")
|
||||
raise ValueError(f"Template rendering failed: {e}")
|
||||
|
||||
def validate(self, template: str) -> list[str]:
|
||||
"""验证模板语法
|
||||
|
||||
"""Validate template syntax
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表,如果为空则验证通过
|
||||
|
||||
List of errors (empty if valid)
|
||||
|
||||
Examples:
|
||||
>>> renderer = TemplateRenderer()
|
||||
>>> renderer.validate("Hello {{var.name}}!")
|
||||
[]
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
||||
['模板语法错误: ...']
|
||||
|
||||
>>> renderer.validate("Hello {{var.name") # Missing closing tag
|
||||
['Template syntax error: ...']
|
||||
"""
|
||||
errors = []
|
||||
|
||||
try:
|
||||
self.env.from_string(template)
|
||||
except TemplateSyntaxError as e:
|
||||
errors.append(f"模板语法错误: {e}")
|
||||
errors.append(f"Template syntax error: {e}")
|
||||
except Exception as e:
|
||||
errors.append(f"模板验证失败: {e}")
|
||||
errors.append(f"Template validation failed: {e}")
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
# 全局渲染器实例(严格模式)
|
||||
# Global renderer instances (strict / lenient)
|
||||
_strict_renderer = TemplateRenderer(strict=True)
|
||||
_lenient_renderer = TemplateRenderer(strict=False)
|
||||
|
||||
|
||||
def render_template(
|
||||
template: str,
|
||||
conv_vars: dict[str, Any],
|
||||
node_outputs: dict[str, Any],
|
||||
system_vars: dict[str, Any],
|
||||
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||
system_vars: dict[str, Any] | LazyVariableDict,
|
||||
strict: bool = True
|
||||
) -> str:
|
||||
"""渲染模板(便捷函数)
|
||||
|
||||
"""Render template (convenience function)
|
||||
|
||||
Args:
|
||||
strict: 严格模式
|
||||
template: 模板字符串
|
||||
conv_vars: 会话变量
|
||||
node_outputs: 节点输出
|
||||
system_vars: 系统变量
|
||||
|
||||
strict: Whether to use strict mode
|
||||
template: Template string
|
||||
conv_vars: Conversation variables
|
||||
node_outputs: Node outputs
|
||||
system_vars: System variables
|
||||
|
||||
Returns:
|
||||
渲染后的字符串
|
||||
|
||||
Rendered string
|
||||
|
||||
Examples:
|
||||
>>> render_template(
|
||||
... "请分析: {{var.text}}",
|
||||
... {"text": "这是一段文本"},
|
||||
... "Analyze: {{var.text}}",
|
||||
... {"text": "This is a text"},
|
||||
... {},
|
||||
... {}
|
||||
... )
|
||||
'请分析: 这是一段文本'
|
||||
'Analyze: This is a text'
|
||||
"""
|
||||
renderer = _strict_renderer if strict else _lenient_renderer
|
||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||
|
||||
|
||||
def validate_template(template: str) -> list[str]:
|
||||
"""验证模板语法(便捷函数)
|
||||
|
||||
"""Validate template syntax (convenience function)
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
template: Template string
|
||||
|
||||
Returns:
|
||||
错误列表
|
||||
List of errors
|
||||
"""
|
||||
return _strict_renderer.validate(template)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import subprocess
|
||||
from app.repositories.neo4j.create_indexes import create_all_indexes
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, APIRouter
|
||||
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
|
||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||
else:
|
||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||
|
||||
await create_all_indexes()
|
||||
logger.info("应用程序启动完成")
|
||||
|
||||
|
||||
yield
|
||||
# 应用关闭事件
|
||||
logger.info("应用程序正在关闭")
|
||||
|
||||
@@ -19,9 +19,12 @@ class User(Base):
|
||||
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
|
||||
|
||||
# SSO 外部关联字段
|
||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
||||
external_id = Column(String(100), nullable=True) # 外部用户 ID
|
||||
external_source = Column(String(50), nullable=True) # 来源系统
|
||||
|
||||
# 用户联系方式
|
||||
phone = Column(String(50), nullable=True) # 用户电话
|
||||
|
||||
# 用户语言偏好
|
||||
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
||||
|
||||
|
||||
@@ -199,6 +199,96 @@ class ConversationRepository:
|
||||
)
|
||||
return conversations, total
|
||||
|
||||
def list_app_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> tuple[list[Conversation], int]:
|
||||
"""
|
||||
查询应用日志会话列表(带分页和过滤)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
is_draft: 是否草稿会话(None 表示不过滤)
|
||||
page: 页码(从 1 开始)
|
||||
pagesize: 每页数量
|
||||
|
||||
Returns:
|
||||
Tuple[List[Conversation], int]: (会话列表,总数)
|
||||
"""
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True)
|
||||
)
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
|
||||
# Calculate total number of records
|
||||
total = int(self.db.execute(
|
||||
select(func.count()).select_from(stmt.subquery())
|
||||
).scalar_one())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Listed app conversations successfully",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"returned": len(conversations),
|
||||
"total": total
|
||||
}
|
||||
)
|
||||
return conversations, total
|
||||
|
||||
def get_conversation_for_app_log(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Conversation:
|
||||
"""
|
||||
查询应用日志的会话详情
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
Conversation: 会话对象
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当会话不存在时
|
||||
"""
|
||||
logger.info(f"Fetching conversation for app log: {conversation_id}")
|
||||
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.id == conversation_id,
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active.is_(True)
|
||||
)
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
|
||||
if not conversation:
|
||||
logger.warning(f"Conversation not found: {conversation_id}")
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
|
||||
logger.info(f"Conversation fetched successfully: {conversation_id}")
|
||||
return conversation
|
||||
|
||||
def soft_delete_conversation_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
@@ -290,6 +380,34 @@ class MessageRepository:
|
||||
self.db.add(message)
|
||||
return message
|
||||
|
||||
def get_messages_by_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
) -> list[Message]:
|
||||
"""
|
||||
查询会话的所有消息(按时间正序)
|
||||
|
||||
Args:
|
||||
conversation_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
List[Message]: 消息列表
|
||||
"""
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
|
||||
logger.info(
|
||||
"Fetched messages for conversation",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
return messages
|
||||
|
||||
def get_message_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
|
||||
@@ -132,6 +132,82 @@ class EndUserRepository:
|
||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_or_create_end_user_with_config(
|
||||
self,
|
||||
app_id: Optional[uuid.UUID],
|
||||
workspace_id: uuid.UUID,
|
||||
other_id: str,
|
||||
memory_config_id: Optional[uuid.UUID] = None,
|
||||
other_name: Optional[str] = None
|
||||
) -> EndUser:
|
||||
"""获取或创建终端用户,并在单次事务中关联记忆配置。
|
||||
|
||||
与 get_or_create_end_user 类似,但额外支持在创建/获取时
|
||||
一并设置 memory_config_id,避免多次提交。
|
||||
|
||||
Args:
|
||||
app_id: 应用ID(可为 None)
|
||||
workspace_id: 工作空间ID
|
||||
other_id: 第三方ID
|
||||
memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置)
|
||||
other_name: 用户名称(用于创建 EndUserInfo)
|
||||
|
||||
Returns:
|
||||
EndUser: 终端用户对象(已关联记忆配置)
|
||||
"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(
|
||||
EndUser.workspace_id == workspace_id,
|
||||
EndUser.other_id == other_id
|
||||
)
|
||||
.order_by(EndUser.created_at.asc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if end_user:
|
||||
db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}")
|
||||
if app_id is not None:
|
||||
end_user.app_id = app_id
|
||||
if memory_config_id and not end_user.memory_config_id:
|
||||
end_user.memory_config_id = memory_config_id
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
return end_user
|
||||
|
||||
# 创建新用户
|
||||
end_user = EndUser(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
memory_config_id=memory_config_id,
|
||||
)
|
||||
self.db.add(end_user)
|
||||
self.db.flush()
|
||||
|
||||
end_user_info = EndUserInfo(
|
||||
end_user_id=end_user.id,
|
||||
other_name=other_name or "",
|
||||
aliases=[],
|
||||
meta_data={}
|
||||
)
|
||||
self.db.add(end_user_info)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(end_user)
|
||||
|
||||
db_logger.info(
|
||||
f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, "
|
||||
f"memory_config_id={memory_config_id}"
|
||||
)
|
||||
return end_user
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)
|
||||
|
||||
@@ -515,6 +591,51 @@ class EndUserRepository:
|
||||
)
|
||||
raise
|
||||
|
||||
def batch_update_memory_config_id_by_app(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
int: 更新的终端用户数量
|
||||
|
||||
Raises:
|
||||
Exception: 数据库操作失败时抛出
|
||||
"""
|
||||
try:
|
||||
from sqlalchemy import update
|
||||
|
||||
stmt = (
|
||||
update(EndUser)
|
||||
.where(EndUser.app_id == app_id)
|
||||
.values(memory_config_id=memory_config_id)
|
||||
)
|
||||
|
||||
result = self.db.execute(stmt)
|
||||
self.db.commit()
|
||||
|
||||
updated_count = result.rowcount
|
||||
|
||||
db_logger.info(
|
||||
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
)
|
||||
|
||||
return updated_count
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(
|
||||
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def count_by_memory_config_id(
|
||||
self,
|
||||
memory_config_id: uuid.UUID
|
||||
|
||||
@@ -78,6 +78,15 @@ class MemoryConfigRepository:
|
||||
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||
"""
|
||||
|
||||
# 批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||
SEARCH_FOR_ALL_BATCH = """
|
||||
MATCH (n) WHERE n.end_user_id IN $end_user_ids
|
||||
RETURN
|
||||
n.end_user_id as user_id,
|
||||
count(n) as total
|
||||
ORDER BY user_id
|
||||
"""
|
||||
|
||||
# Extracted entity details within group/app/user
|
||||
SEARCH_FOR_DETIALS = """
|
||||
MATCH (n:ExtractedEntity)
|
||||
|
||||
@@ -1,62 +1,47 @@
|
||||
import asyncio
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def create_fulltext_indexes():
|
||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Full-Text Indexes (for keyword search)")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
# 创建 Statements 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: statementsFulltext")
|
||||
""")
|
||||
|
||||
# # 创建 Dialogues 索引
|
||||
# await connector.execute_query("""
|
||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
# """)
|
||||
|
||||
# 创建 Entities 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: entitiesFulltext")
|
||||
""")
|
||||
|
||||
# 创建 Chunks 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: chunksFulltext")
|
||||
""")
|
||||
|
||||
# 创建 MemorySummary 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: summariesFulltext")
|
||||
|
||||
""")
|
||||
# 创建 Community 索引
|
||||
await connector.execute_query("""
|
||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||
""")
|
||||
print("✓ Created: communitiesFulltext")
|
||||
|
||||
print("\nFull-text indexes created successfully with BM25 support.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating full-text indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_vector_indexes():
|
||||
"""Create vector indexes for fast embedding similarity search.
|
||||
|
||||
@@ -65,12 +50,7 @@ async def create_vector_indexes():
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Vector Indexes (for embedding search)")
|
||||
print("=" * 70)
|
||||
print("Note: Adjust vector.dimensions if using different embedding model")
|
||||
print(" Current setting: 1024 dimensions (for bge-m3)")
|
||||
print()
|
||||
|
||||
|
||||
# Statement embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -82,7 +62,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: statement_embedding_index")
|
||||
|
||||
|
||||
# Chunk embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -94,7 +74,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: chunk_embedding_index")
|
||||
|
||||
|
||||
# Entity name embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -106,7 +86,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: entity_embedding_index")
|
||||
|
||||
|
||||
# Memory summary embedding index
|
||||
await connector.execute_query("""
|
||||
@@ -118,8 +98,7 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: summary_embedding_index")
|
||||
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
@@ -129,8 +108,7 @@ async def create_vector_indexes():
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
""")
|
||||
|
||||
# Dialogue embedding index (optional)
|
||||
await connector.execute_query("""
|
||||
@@ -142,91 +120,15 @@ async def create_vector_indexes():
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: dialogue_embedding_index")
|
||||
|
||||
# Community summary embedding index
|
||||
await connector.execute_query("""
|
||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||
FOR (c:Community)
|
||||
ON c.summary_embedding
|
||||
OPTIONS {indexConfig: {
|
||||
`vector.dimensions`: 1024,
|
||||
`vector.similarity_function`: 'cosine'
|
||||
}}
|
||||
""")
|
||||
print("✓ Created: community_summary_embedding_index")
|
||||
|
||||
print("\nVector indexes created successfully!")
|
||||
print("\nExpected performance improvement:")
|
||||
print(" Before: ~1.4s for embedding search")
|
||||
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating vector indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_config_id_indexes():
|
||||
"""Create indexes on config_id fields for improved query performance.
|
||||
|
||||
These indexes enable fast filtering of nodes by configuration ID,
|
||||
which is essential for configuration isolation and multi-tenant scenarios.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Config ID Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
# Dialogue.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
|
||||
FOR (d:Dialogue) ON (d.config_id)
|
||||
""")
|
||||
print("✓ Created: dialogue_config_id_index")
|
||||
|
||||
# Statement.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX statement_config_id_index IF NOT EXISTS
|
||||
FOR (s:Statement) ON (s.config_id)
|
||||
""")
|
||||
print("✓ Created: statement_config_id_index")
|
||||
|
||||
# ExtractedEntity.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX entity_config_id_index IF NOT EXISTS
|
||||
FOR (e:ExtractedEntity) ON (e.config_id)
|
||||
""")
|
||||
print("✓ Created: entity_config_id_index")
|
||||
|
||||
# MemorySummary.config_id index
|
||||
await connector.execute_query("""
|
||||
CREATE INDEX summary_config_id_index IF NOT EXISTS
|
||||
FOR (m:MemorySummary) ON (m.config_id)
|
||||
""")
|
||||
print("✓ Created: summary_config_id_index")
|
||||
|
||||
print("\nConfig ID indexes created successfully!")
|
||||
print("These indexes enable fast filtering by configuration ID.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating config_id indexes: {e}")
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_unique_constraints():
|
||||
"""Create uniqueness constraints for core node identifiers.
|
||||
|
||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||
"""
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Creating Unique Constraints")
|
||||
print("=" * 70)
|
||||
|
||||
try:
|
||||
# Dialogue.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -234,8 +136,7 @@ async def create_unique_constraints():
|
||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: dialog_id_unique")
|
||||
|
||||
|
||||
# Statement.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -243,8 +144,7 @@ async def create_unique_constraints():
|
||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: statement_id_unique")
|
||||
|
||||
|
||||
# Chunk.id unique
|
||||
await connector.execute_query(
|
||||
"""
|
||||
@@ -252,112 +152,13 @@ async def create_unique_constraints():
|
||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||
"""
|
||||
)
|
||||
print("✓ Created: chunk_id_unique")
|
||||
|
||||
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
|
||||
except Exception as e:
|
||||
print(f"✗ Error creating unique constraints: {e}")
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
async def create_all_indexes():
|
||||
"""Create all indexes and constraints in one go."""
|
||||
print("\n" + "=" * 70)
|
||||
print("Neo4j Index & Constraint Setup")
|
||||
print("=" * 70)
|
||||
print("This will create:")
|
||||
print(" 1. Full-text indexes (for keyword/BM25 search)")
|
||||
print(" 2. Vector indexes (for embedding similarity search)")
|
||||
print(" 3. Config ID indexes (for configuration isolation)")
|
||||
print(" 4. Unique constraints (for data integrity)")
|
||||
print("=" * 70)
|
||||
|
||||
await create_fulltext_indexes()
|
||||
await create_vector_indexes()
|
||||
await create_config_id_indexes()
|
||||
await create_unique_constraints()
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print("✓ All indexes and constraints created successfully!")
|
||||
print("=" * 70)
|
||||
print("\nTo verify, run in Neo4j Browser:")
|
||||
print(" SHOW INDEXES")
|
||||
print(" SHOW CONSTRAINTS")
|
||||
print()
|
||||
|
||||
|
||||
async def check_indexes():
|
||||
"""Check what indexes currently exist."""
|
||||
connector = Neo4jConnector()
|
||||
|
||||
try:
|
||||
print("\n" + "=" * 70)
|
||||
print("Checking Existing Indexes")
|
||||
print("=" * 70)
|
||||
|
||||
query = "SHOW INDEXES"
|
||||
result = await connector.execute_query(query)
|
||||
|
||||
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
|
||||
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
|
||||
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
|
||||
|
||||
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
|
||||
for idx in fulltext_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nVector indexes: {len(vector_indexes)}")
|
||||
for idx in vector_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
|
||||
for idx in range_indexes:
|
||||
print(f" ✓ {idx.get('name')}")
|
||||
|
||||
if not vector_indexes:
|
||||
print("\n⚠️ WARNING: No vector indexes found!")
|
||||
print(" Embedding search will be VERY SLOW (~1.4s)")
|
||||
print(" Run: python create_indexes.py")
|
||||
|
||||
# Check for config_id indexes
|
||||
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
|
||||
if len(config_id_indexes) < 4:
|
||||
print("\n⚠️ WARNING: Not all config_id indexes found!")
|
||||
print(f" Expected 4, found {len(config_id_indexes)}")
|
||||
print(" Run: python create_indexes.py config_id")
|
||||
|
||||
print("=" * 70)
|
||||
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
command = sys.argv[1]
|
||||
if command == "check":
|
||||
asyncio.run(check_indexes())
|
||||
elif command == "fulltext":
|
||||
asyncio.run(create_fulltext_indexes())
|
||||
elif command == "vector":
|
||||
asyncio.run(create_vector_indexes())
|
||||
elif command == "config_id":
|
||||
asyncio.run(create_config_id_indexes())
|
||||
elif command == "constraints":
|
||||
asyncio.run(create_unique_constraints())
|
||||
else:
|
||||
print(f"Unknown command: {command}")
|
||||
print("\nUsage:")
|
||||
print(" python create_indexes.py # Create all indexes")
|
||||
print(" python create_indexes.py check # Check existing indexes")
|
||||
print(" python create_indexes.py fulltext # Create only full-text indexes")
|
||||
print(" python create_indexes.py vector # Create only vector indexes")
|
||||
print(" python create_indexes.py config_id # Create only config_id indexes")
|
||||
print(" python create_indexes.py constraints # Create only constraints")
|
||||
else:
|
||||
asyncio.run(create_all_indexes())
|
||||
|
||||
|
||||
@@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
WITH e, score
|
||||
UNION
|
||||
MATCH (e:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||
AND e.aliases IS NOT NULL
|
||||
AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
WITH e,
|
||||
WITH collect({entity: e, score: score}) AS fulltextResults
|
||||
|
||||
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||
AND ae.aliases IS NOT NULL
|
||||
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||
|
||||
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||
CASE
|
||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||
ELSE 0.8
|
||||
END AS score
|
||||
END
|
||||
}]) AS row
|
||||
WITH row.entity AS e, row.score AS score
|
||||
WITH DISTINCT e, MAX(score) AS score
|
||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||
|
||||
@@ -158,22 +158,26 @@ class UserRepository:
|
||||
raise
|
||||
|
||||
def get_users_by_tenant(
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[User]:
|
||||
"""获取租户下的用户列表"""
|
||||
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
|
||||
|
||||
|
||||
try:
|
||||
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
|
||||
if is_superuser is not None:
|
||||
query = query.filter(User.is_superuser == is_superuser)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
@@ -181,7 +185,7 @@ class UserRepository:
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
users = query.offset(skip).limit(limit).all()
|
||||
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
||||
return users
|
||||
@@ -190,18 +194,22 @@ class UserRepository:
|
||||
raise
|
||||
|
||||
def count_users_by_tenant(
|
||||
self,
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
try:
|
||||
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
||||
|
||||
|
||||
if is_active is not None:
|
||||
query = query.filter(User.is_active == is_active)
|
||||
|
||||
|
||||
if is_superuser is not None:
|
||||
query = query.filter(User.is_superuser == is_superuser)
|
||||
|
||||
if search:
|
||||
query = query.filter(
|
||||
or_(
|
||||
@@ -209,7 +217,7 @@ class UserRepository:
|
||||
User.email.ilike(f"%{search}%")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
return query.scalar()
|
||||
except Exception as e:
|
||||
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Any, Annotated
|
||||
from typing import Any, Annotated, Literal
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import desc, select
|
||||
from fastapi import Depends
|
||||
|
||||
from app.models.workflow_model import (
|
||||
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.app_id == app_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).limit(limit).offset(offset).all()
|
||||
).limit(limit).offset(offset)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_conversation_id(
|
||||
self,
|
||||
conversation_id: uuid.UUID
|
||||
conversation_id: uuid.UUID,
|
||||
status: Literal["running", "completed", "failed"] = None,
|
||||
limit_count: int = 50
|
||||
) -> list[WorkflowExecution]:
|
||||
"""根据会话 ID 获取执行记录列表
|
||||
|
||||
Args:
|
||||
limit_count:
|
||||
conversation_id: 会话 ID
|
||||
status: 状态(可选)
|
||||
|
||||
Returns:
|
||||
执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowExecution).filter(
|
||||
stmt = select(WorkflowExecution).filter(
|
||||
WorkflowExecution.conversation_id == conversation_id
|
||||
).order_by(
|
||||
desc(WorkflowExecution.started_at)
|
||||
).all()
|
||||
)
|
||||
if status:
|
||||
stmt = stmt.filter(WorkflowExecution.status == status)
|
||||
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||
"""统计应用的执行次数
|
||||
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表(按执行顺序排序)
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.execution_order
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
def get_by_node_id(
|
||||
self,
|
||||
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
|
||||
Returns:
|
||||
节点执行记录列表
|
||||
"""
|
||||
return self.db.query(WorkflowNodeExecution).filter(
|
||||
stmt = select(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.execution_id == execution_id,
|
||||
WorkflowNodeExecution.node_id == node_id
|
||||
).order_by(
|
||||
WorkflowNodeExecution.retry_count
|
||||
).all()
|
||||
)
|
||||
return list(self.db.execute(stmt).scalars())
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
@@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel):
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(
|
||||
default_factory=lambda: MemoryConfig(enabled=True),
|
||||
default_factory=lambda: MemoryConfig(enabled=False),
|
||||
description="对话历史记忆配置"
|
||||
)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
|
||||
end_user_id: str
|
||||
config_id: Optional[str] = None
|
||||
|
||||
|
||||
class AgentMemory_Long_Term(ABC):
|
||||
"""长期记忆配置常量"""
|
||||
STORAGE_NEO4J = "neo4j"
|
||||
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
|
||||
STRATEGY_CHUNK = "chunk"
|
||||
STRATEGY_TIME = "time"
|
||||
DEFAULT_SCOPE = 6
|
||||
TIME_SCOPE=5
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
||||
NAME='用户'
|
||||
TIME_SCOPE = 5
|
||||
|
||||
|
||||
class AgentMemoryDataset(ABC):
|
||||
PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余']
|
||||
NAME = '用户'
|
||||
|
||||
@@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel):
|
||||
"""Request schema for creating an end user.
|
||||
|
||||
Attributes:
|
||||
workspace_id: Workspace ID (required)
|
||||
other_id: External user identifier (required)
|
||||
other_name: Display name for the end user
|
||||
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
|
||||
"""
|
||||
workspace_id: str = Field(..., description="Workspace ID (required)")
|
||||
other_id: str = Field(..., description="External user identifier (required)")
|
||||
other_name: Optional[str] = Field("", description="Display name")
|
||||
|
||||
@field_validator("workspace_id")
|
||||
@classmethod
|
||||
def validate_workspace_id(cls, v: str) -> str:
|
||||
"""Validate that workspace_id is not empty."""
|
||||
if not v or not v.strip():
|
||||
raise ValueError("workspace_id is required and cannot be empty")
|
||||
return v.strip()
|
||||
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
|
||||
|
||||
@field_validator("other_id")
|
||||
@classmethod
|
||||
@@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel):
|
||||
other_id: External user identifier
|
||||
other_name: Display name
|
||||
workspace_id: Workspace the user belongs to
|
||||
memory_config_id: Connected memory config ID
|
||||
"""
|
||||
id: str = Field(..., description="End user UUID")
|
||||
other_id: str = Field(..., description="External user identifier")
|
||||
other_name: str = Field("", description="Display name")
|
||||
workspace_id: str = Field(..., description="Workspace ID")
|
||||
memory_config_id: Optional[str] = Field(None, description="Connected memory config ID")
|
||||
|
||||
|
||||
class MemoryConfigItem(BaseModel):
|
||||
|
||||
@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
|
||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||
|
||||
|
||||
class PageInfo(BaseModel):
|
||||
"""分页信息模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
page: int = Field(..., description="当前页码(从1开始)")
|
||||
pagesize: int = Field(..., description="每页数量")
|
||||
total: int = Field(..., description="总记录数")
|
||||
hasnext: bool = Field(..., description="是否有下一页")
|
||||
|
||||
|
||||
class PendingNodesResponse(BaseModel):
|
||||
"""待遗忘节点列表响应模型(独立分页接口)"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
|
||||
page: PageInfo = Field(..., description="分页信息")
|
||||
|
||||
|
||||
class ForgettingStatsResponse(BaseModel):
|
||||
"""遗忘引擎统计信息响应模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
|
||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from dataclasses import field
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
@@ -20,6 +20,7 @@ class UserCreate(UserBase):
|
||||
class UserUpdate(BaseModel):
|
||||
username: Optional[str] = None
|
||||
email: Optional[EmailStr] = None
|
||||
phone: Optional[str] = None
|
||||
is_active: Optional[bool] = None
|
||||
is_superuser: Optional[bool] = None
|
||||
|
||||
@@ -85,6 +86,8 @@ class User(UserBase):
|
||||
current_workspace_name: Optional[str] = None
|
||||
role: Optional[WorkspaceRole] = None
|
||||
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
||||
phone: Optional[str] = None # 用户电话
|
||||
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
|
||||
|
||||
# 将 datetime 转换为毫秒时间戳
|
||||
@validator("created_at", pre=True)
|
||||
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||
from app.models import WorkflowConfig
|
||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.workflow_service import WorkflowService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -43,18 +44,17 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
user_id: Optional[str] = None,
|
||||
files: list[FileInput],
|
||||
user_id: str,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""聊天(非流式)"""
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
|
||||
# 应用 features 配置
|
||||
features_config: dict = config.features or {}
|
||||
@@ -93,7 +93,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||
user_id)
|
||||
tools.extend(kb_tools)
|
||||
memory_flag = False
|
||||
if memory:
|
||||
@@ -140,13 +141,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
@@ -168,11 +169,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -229,6 +225,21 @@ class AppChatService:
|
||||
# 保存消息
|
||||
if audio_url:
|
||||
assistant_meta["audio_url"] = audio_url
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": result["content"]}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -264,20 +275,19 @@ class AppChatService:
|
||||
message: str,
|
||||
conversation_id: uuid.UUID,
|
||||
config: AgentConfig,
|
||||
files: list[FileInput],
|
||||
user_id: Optional[str] = None,
|
||||
variables: Optional[Dict[str, Any]] = None,
|
||||
web_search: bool = False,
|
||||
memory: bool = True,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
workspace_id: Optional[str] = None,
|
||||
files: Optional[List[FileInput]] = None
|
||||
workspace_id: Optional[str] = None
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""聊天(流式)"""
|
||||
|
||||
try:
|
||||
start_time = time.time()
|
||||
config_id = None
|
||||
message_id = uuid.uuid4()
|
||||
|
||||
# 应用 features 配置
|
||||
@@ -319,7 +329,8 @@ class AppChatService:
|
||||
tools.extend(skill_tools)
|
||||
if skill_prompts:
|
||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||
config.knowledge_retrieval, user_id)
|
||||
tools.extend(kb_tools)
|
||||
# 添加长期记忆工具
|
||||
memory_flag = False
|
||||
@@ -367,13 +378,13 @@ class AppChatService:
|
||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||
is_new_conversation = len(history) == 0
|
||||
if is_new_conversation:
|
||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||
if opening:
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
content=opening,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
# 重新加载历史(包含刚写入的开场白)
|
||||
history = await self.conversation_service.get_conversation_history(
|
||||
@@ -411,11 +422,6 @@ class AppChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -459,7 +465,7 @@ class AppChatService:
|
||||
|
||||
# 保存消息
|
||||
human_meta = {
|
||||
"files":[],
|
||||
"files": [],
|
||||
"history_files": {}
|
||||
}
|
||||
assistant_meta = {
|
||||
@@ -484,6 +490,22 @@ class AppChatService:
|
||||
|
||||
if stream_audio_url:
|
||||
assistant_meta["audio_url"] = stream_audio_url
|
||||
|
||||
if memory_flag:
|
||||
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||
memory_config_id: str = connected_config.get("memory_config_id")
|
||||
messages = [
|
||||
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||
{"role": "assistant", "content": full_content}
|
||||
]
|
||||
if memory_config_id:
|
||||
await write_long_term(
|
||||
storage_type,
|
||||
user_id,
|
||||
messages,
|
||||
user_rag_memory_id,
|
||||
memory_config_id
|
||||
)
|
||||
self.conversation_service.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
@@ -618,7 +640,6 @@ class AppChatService:
|
||||
# 2. 创建编排器
|
||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||
|
||||
|
||||
# 3. 流式执行任务
|
||||
async for event in orchestrator.execute_stream(
|
||||
message=message,
|
||||
|
||||
128
api/app/services/app_log_service.py
Normal file
128
api/app/services/app_log_service.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""应用日志服务层"""
|
||||
import uuid
|
||||
from typing import Optional, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models.conversation_model import Conversation, Message
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class AppLogService:
|
||||
"""应用日志服务"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_repository = ConversationRepository(db)
|
||||
self.message_repository = MessageRepository(db)
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
page: int = 1,
|
||||
pagesize: int = 20,
|
||||
is_draft: Optional[bool] = None,
|
||||
) -> Tuple[list[Conversation], int]:
|
||||
"""
|
||||
查询应用日志会话列表
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
workspace_id: 工作空间 ID
|
||||
page: 页码(从 1 开始)
|
||||
pagesize: 每页数量
|
||||
is_draft: 是否草稿会话(None 表示不过滤)
|
||||
|
||||
Returns:
|
||||
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话列表",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
|
||||
# 使用 Repository 查询
|
||||
conversations, total = self.conversation_repository.list_app_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
is_draft=is_draft,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话列表成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"total": total,
|
||||
"returned": len(conversations)
|
||||
}
|
||||
)
|
||||
|
||||
return conversations, total
|
||||
|
||||
def get_conversation_detail(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Conversation:
|
||||
"""
|
||||
查询会话详情(包含消息)
|
||||
|
||||
Args:
|
||||
app_id: 应用 ID
|
||||
conversation_id: 会话 ID
|
||||
workspace_id: 工作空间 ID
|
||||
|
||||
Returns:
|
||||
Conversation: 包含消息的会话对象
|
||||
|
||||
Raises:
|
||||
ResourceNotFoundException: 当会话不存在时
|
||||
"""
|
||||
logger.info(
|
||||
"查询应用日志会话详情",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
|
||||
# 查询会话
|
||||
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||
conversation_id=conversation_id,
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 查询消息(按时间正序)
|
||||
messages = self.message_repository.get_messages_by_conversation(
|
||||
conversation_id=conversation_id
|
||||
)
|
||||
|
||||
# 将消息附加到会话对象
|
||||
conversation.messages = messages
|
||||
|
||||
logger.info(
|
||||
"查询应用日志会话详情成功",
|
||||
extra={
|
||||
"app_id": str(app_id),
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_count": len(messages)
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
@@ -1084,7 +1084,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
return cleaned
|
||||
|
||||
exists = self.db.query(
|
||||
@@ -1096,7 +1095,6 @@ class AppService:
|
||||
if not exists:
|
||||
cleaned["memory_config_id"] = None
|
||||
cleaned.pop("memory_content", None)
|
||||
cleaned["enabled"] = False
|
||||
|
||||
return cleaned
|
||||
|
||||
@@ -1684,15 +1682,15 @@ class AppService:
|
||||
|
||||
return config.config_id
|
||||
|
||||
def _update_endusers_memory_config_by_workspace(
|
||||
def _update_endusers_memory_config_by_app(
|
||||
self,
|
||||
workspace_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
memory_config_id: uuid.UUID
|
||||
) -> int:
|
||||
"""批量更新应用下所有终端用户的 memory_config_id
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
app_id: 应用ID
|
||||
memory_config_id: 新的记忆配置ID
|
||||
|
||||
Returns:
|
||||
@@ -1701,8 +1699,8 @@ class AppService:
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
|
||||
repo = EndUserRepository(self.db)
|
||||
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
||||
workspace_id=workspace_id,
|
||||
updated_count = repo.batch_update_memory_config_id_by_app(
|
||||
app_id=app_id,
|
||||
memory_config_id=memory_config_id
|
||||
)
|
||||
|
||||
@@ -1753,12 +1751,16 @@ class AppService:
|
||||
|
||||
miss_params = []
|
||||
if agent_cfg.default_model_config_id is None:
|
||||
miss_params.append("model config")
|
||||
miss_params.append("模型配置")
|
||||
|
||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||
miss_params.append("memory config")
|
||||
miss_params.append("记忆配置")
|
||||
if miss_params:
|
||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
||||
raise BusinessException(
|
||||
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
|
||||
BizCode.CONFIG_MISSING,
|
||||
context={"missing_params": miss_params},
|
||||
)
|
||||
|
||||
config = {
|
||||
"system_prompt": agent_cfg.system_prompt,
|
||||
@@ -1877,8 +1879,8 @@ class AppService:
|
||||
if memory_config_id:
|
||||
app = self.db.query(App).filter(App.id == app_id).first()
|
||||
if app:
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(
|
||||
app.workspace_id, memory_config_id
|
||||
updated_count = self._update_endusers_memory_config_by_app(
|
||||
app_id, memory_config_id
|
||||
)
|
||||
logger.info(
|
||||
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||
@@ -2014,7 +2016,7 @@ class AppService:
|
||||
|
||||
if memory_config_id:
|
||||
|
||||
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
|
||||
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
|
||||
logger.info(
|
||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||
|
||||
@@ -214,7 +214,7 @@ class ConversationService:
|
||||
|
||||
conversation.message_count += 1
|
||||
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
if conversation.message_count <= 2 and role == "user":
|
||||
conversation.title = (
|
||||
content[:50] + ("..." if len(content) > 50 else "")
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.models import AgentConfig, ModelConfig, ModelType
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.repositories.tool_repository import ToolRepository
|
||||
from app.schemas.app_schema import FileInput, Citation
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas import FileType
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -449,15 +448,16 @@ class AgentRunService:
|
||||
features_config: Dict[str, Any],
|
||||
is_new_conversation: bool,
|
||||
variables: Optional[Dict[str, Any]] = None
|
||||
) -> Optional[str]:
|
||||
) -> tuple[Any, Any]:
|
||||
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
||||
if not is_new_conversation:
|
||||
return None
|
||||
return None, None
|
||||
opening = features_config.get("opening_statement", {})
|
||||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||
return None
|
||||
return None, None
|
||||
|
||||
statement = opening["statement"]
|
||||
suggested_questions = opening["suggested_questions"]
|
||||
|
||||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||||
if variables:
|
||||
@@ -465,7 +465,7 @@ class AgentRunService:
|
||||
placeholder = f"{{{{{var_name}}}}}"
|
||||
statement = statement.replace(placeholder, str(var_value))
|
||||
|
||||
return statement
|
||||
return statement, suggested_questions
|
||||
|
||||
@staticmethod
|
||||
def _filter_citations(
|
||||
@@ -599,13 +599,16 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -657,11 +660,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files # 传递处理后的文件
|
||||
)
|
||||
|
||||
@@ -845,14 +843,17 @@ class AgentRunService:
|
||||
|
||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||
is_new_conversation = not conversation_id
|
||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
opening, suggested_questions = None, None
|
||||
if not sub_agent:
|
||||
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||
conversation_id = await self._ensure_conversation(
|
||||
conversation_id=conversation_id,
|
||||
app_id=agent_config.app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
sub_agent=sub_agent,
|
||||
opening_statement=opening
|
||||
opening_statement=opening,
|
||||
suggested_questions=suggested_questions
|
||||
)
|
||||
|
||||
model_info = ModelInfo(
|
||||
@@ -911,11 +912,6 @@ class AgentRunService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=context,
|
||||
end_user_id=user_id,
|
||||
config_id=config_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_flag=memory_flag,
|
||||
files=processed_files
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
@@ -1061,7 +1057,8 @@ class AgentRunService:
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str],
|
||||
sub_agent: bool = False,
|
||||
opening_statement: Optional[str] = None
|
||||
opening_statement: Optional[str] = None,
|
||||
suggested_questions: Optional[List[str]] = None
|
||||
) -> str:
|
||||
"""确保会话存在(创建或验证)
|
||||
|
||||
@@ -1072,6 +1069,7 @@ class AgentRunService:
|
||||
user_id: 用户ID
|
||||
sub_agent: 是否为子代理
|
||||
opening_statement: 开场白(新会话时作为第一条消息写入)
|
||||
suggested_questions: 预设问题列表
|
||||
|
||||
Returns:
|
||||
str: 会话ID
|
||||
@@ -1115,7 +1113,7 @@ class AgentRunService:
|
||||
conversation_id=uuid.UUID(new_conv_id),
|
||||
role="assistant",
|
||||
content=opening_statement,
|
||||
meta_data={}
|
||||
meta_data={"suggested_questions": suggested_questions}
|
||||
)
|
||||
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
|
||||
)
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
logger = get_logger(__name__)
|
||||
config_logger = get_config_logger()
|
||||
|
||||
@@ -68,24 +65,22 @@ class MemoryAgentService:
|
||||
if str(messages) == 'success':
|
||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||
# 记录成功的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration, details={"message_length": len(message)})
|
||||
return context
|
||||
else:
|
||||
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||
|
||||
# 记录失败的操作
|
||||
if audit_logger:
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
audit_logger.log_operation(
|
||||
operation="WRITE",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=f"写入失败: {messages[:100]}"
|
||||
)
|
||||
|
||||
raise ValueError(f"写入失败: {messages}")
|
||||
|
||||
@@ -338,10 +333,9 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -401,10 +395,10 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Write operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||
success=False, duration=duration, error=error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
async def read_memory(
|
||||
@@ -469,10 +463,9 @@ class MemoryAgentService:
|
||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||
|
||||
# 导入审计日志记录器
|
||||
try:
|
||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||
except ImportError:
|
||||
audit_logger = None
|
||||
|
||||
|
||||
|
||||
|
||||
config_load_start = time.time()
|
||||
try:
|
||||
@@ -492,16 +485,15 @@ class MemoryAgentService:
|
||||
logger.error(error_msg)
|
||||
|
||||
# Log failed operation
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
raise ValueError(error_msg)
|
||||
|
||||
@@ -633,15 +625,15 @@ class MemoryAgentService:
|
||||
total_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=True,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
return {
|
||||
"answer": summary,
|
||||
@@ -651,16 +643,16 @@ class MemoryAgentService:
|
||||
# Ensure proper error handling and logging
|
||||
error_msg = f"Read operation failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if audit_logger:
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
|
||||
duration = time.time() - start_time
|
||||
audit_logger.log_operation(
|
||||
operation="READ",
|
||||
config_id=config_id,
|
||||
end_user_id=end_user_id,
|
||||
success=False,
|
||||
duration=duration,
|
||||
error=error_msg
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.end_user_model import EndUser, EndUser as EndUserModel
|
||||
from app.models.memory_increment_model import MemoryIncrement
|
||||
|
||||
from app.repositories import (
|
||||
@@ -49,44 +50,40 @@ def get_current_workspace_type(
|
||||
|
||||
|
||||
def get_workspace_end_users(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User
|
||||
) -> List[EndUser]:
|
||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
|
||||
try:
|
||||
# 查询应用(ORM)
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
|
||||
|
||||
if not apps_orm:
|
||||
business_logger.info("工作空间下没有应用")
|
||||
return []
|
||||
|
||||
|
||||
# 提取所有 app_id
|
||||
# app_ids = [app.id for app in apps_orm]
|
||||
|
||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
from app.models.end_user_model import EndUser as EndUserModel
|
||||
from sqlalchemy import desc, nullslast
|
||||
end_users_orm = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
).order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).all()
|
||||
|
||||
|
||||
# 转换为 Pydantic 模型(只在需要时转换)
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return end_users
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -94,6 +91,85 @@ def get_workspace_end_users(
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_end_users_paginated(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
keyword: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
current_user: 当前用户
|
||||
page: 页码(从1开始)
|
||||
pagesize: 每页数量
|
||||
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||
|
||||
Returns:
|
||||
dict: 包含 items(宿主列表)和 total(总记录数)的字典
|
||||
"""
|
||||
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 构建基础查询
|
||||
base_query = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
)
|
||||
|
||||
# 构建搜索条件(过滤空字符串和None)
|
||||
keyword = keyword.strip() if keyword else None
|
||||
|
||||
if keyword:
|
||||
keyword_pattern = f"%{keyword}%"
|
||||
# other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效
|
||||
base_query = base_query.filter(
|
||||
or_(
|
||||
EndUserModel.other_name.ilike(keyword_pattern),
|
||||
and_(
|
||||
or_(
|
||||
EndUserModel.other_name.is_(None),
|
||||
EndUserModel.other_name == "",
|
||||
),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
),
|
||||
)
|
||||
)
|
||||
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)")
|
||||
|
||||
# 获取总记录数
|
||||
total = base_query.count()
|
||||
|
||||
if total == 0:
|
||||
business_logger.info("工作空间下没有宿主")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
# 分页查询
|
||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||
end_users_orm = base_query.order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id)
|
||||
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
|
||||
# 转换为 Pydantic 模型
|
||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||
|
||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||
return {"items": end_users, "total": total}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_workspace_memory_increment(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
@@ -638,7 +714,24 @@ def get_rag_content(
|
||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||
continue
|
||||
|
||||
# 4. 返回结果
|
||||
# 4. 将所有 page_content 拼接后按角色分割为对话列表
|
||||
merged_text = "\n".join(page_contents)
|
||||
conversations = []
|
||||
if merged_text.strip():
|
||||
import re
|
||||
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
|
||||
parts = re.split(r'(user|assistant):', merged_text)
|
||||
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
|
||||
i = 1
|
||||
while i < len(parts) - 1:
|
||||
role = parts[i].strip()
|
||||
content = parts[i + 1].strip()
|
||||
# 将 content 中的 \n 还原为真实换行
|
||||
content = content.replace("\\n", "\n")
|
||||
if role in ("user", "assistant") and content:
|
||||
conversations.append({"role": role, "content": content})
|
||||
i += 2
|
||||
|
||||
result = {
|
||||
"page": {
|
||||
"page": page,
|
||||
@@ -646,10 +739,10 @@ def get_rag_content(
|
||||
"total": global_total,
|
||||
"hasnext": offset_end < global_total,
|
||||
},
|
||||
"items": page_contents
|
||||
"items": conversations
|
||||
}
|
||||
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -204,30 +204,35 @@ class MemoryForgetService:
|
||||
end_user_id: str,
|
||||
forgetting_threshold: float,
|
||||
min_days_since_access: int,
|
||||
limit: int = 20
|
||||
) -> list[Dict[str, Any]]:
|
||||
page: Optional[int] = None,
|
||||
pagesize: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
||||
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||
|
||||
Args:
|
||||
connector: Neo4j 连接器
|
||||
end_user_id: 组ID
|
||||
forgetting_threshold: 遗忘阈值
|
||||
min_days_since_access: 最小未访问天数
|
||||
limit: 返回节点数量限制
|
||||
|
||||
page: 页码(可选,从1开始)
|
||||
pagesize: 每页数量(可选)
|
||||
|
||||
Returns:
|
||||
list: 待遗忘节点列表
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息(分页时)
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
|
||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||
|
||||
query = """
|
||||
|
||||
# 基础查询(用于获取总数)
|
||||
count_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
@@ -235,10 +240,22 @@ class MemoryForgetService:
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
RETURN count(n) as total
|
||||
"""
|
||||
|
||||
# 数据查询
|
||||
data_query = """
|
||||
MATCH (n)
|
||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||
AND n.end_user_id = $end_user_id
|
||||
AND n.activation_value IS NOT NULL
|
||||
AND n.activation_value < $threshold
|
||||
AND n.last_access_time IS NOT NULL
|
||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||
RETURN
|
||||
elementId(n) as node_id,
|
||||
labels(n)[0] as node_type,
|
||||
CASE
|
||||
CASE
|
||||
WHEN n:Statement THEN n.statement
|
||||
WHEN n:ExtractedEntity THEN n.name
|
||||
WHEN n:MemorySummary THEN n.content
|
||||
@@ -247,18 +264,32 @@ class MemoryForgetService:
|
||||
n.activation_value as activation_value,
|
||||
n.last_access_time as last_access_time
|
||||
ORDER BY n.activation_value ASC
|
||||
LIMIT $limit
|
||||
"""
|
||||
|
||||
|
||||
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
data_query += " SKIP $skip LIMIT $limit"
|
||||
|
||||
params = {
|
||||
'end_user_id': end_user_id,
|
||||
'threshold': forgetting_threshold,
|
||||
'min_access_time_str': min_access_time_str,
|
||||
'limit': limit
|
||||
'min_access_time_str': min_access_time_str
|
||||
}
|
||||
|
||||
results = await connector.execute_query(query, **params)
|
||||
|
||||
|
||||
# 获取总数(分页时需要)
|
||||
total = 0
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
count_results = await connector.execute_query(count_query, **params)
|
||||
if count_results:
|
||||
total = count_results[0]['total']
|
||||
|
||||
# 添加分页参数
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
params['skip'] = (page - 1) * pagesize
|
||||
params['limit'] = pagesize
|
||||
|
||||
results = await connector.execute_query(data_query, **params)
|
||||
|
||||
pending_nodes = []
|
||||
for result in results:
|
||||
# 将节点类型标签转换为小写
|
||||
@@ -267,7 +298,7 @@ class MemoryForgetService:
|
||||
node_type_label = 'entity'
|
||||
elif node_type_label == 'memorysummary':
|
||||
node_type_label = 'summary'
|
||||
|
||||
|
||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||
last_access_time = result['last_access_time']
|
||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||
@@ -278,7 +309,7 @@ class MemoryForgetService:
|
||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||
else:
|
||||
last_access_timestamp = 0
|
||||
|
||||
|
||||
pending_nodes.append({
|
||||
'node_id': str(result['node_id']),
|
||||
'node_type': node_type_label,
|
||||
@@ -286,8 +317,20 @@ class MemoryForgetService:
|
||||
'activation_value': result['activation_value'],
|
||||
'last_access_time': last_access_timestamp
|
||||
})
|
||||
|
||||
return pending_nodes
|
||||
|
||||
# 构建返回结果
|
||||
result: Dict[str, Any] = {'items': pending_nodes}
|
||||
|
||||
# 如果启用分页,添加分页信息
|
||||
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||
result['page'] = {
|
||||
'page': page,
|
||||
'pagesize': pagesize,
|
||||
'total': total,
|
||||
'hasnext': (page * pagesize) < total
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
async def trigger_forgetting_cycle(
|
||||
self,
|
||||
@@ -636,7 +679,7 @@ class MemoryForgetService:
|
||||
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
||||
# 获取待遗忘节点列表
|
||||
pending_nodes = []
|
||||
try:
|
||||
if end_user_id:
|
||||
@@ -652,8 +695,7 @@ class MemoryForgetService:
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
limit=20
|
||||
min_days_since_access=int(min_days)
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
||||
@@ -661,24 +703,79 @@ class MemoryForgetService:
|
||||
except Exception as e:
|
||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||
# 失败时返回空列表,不影响主流程
|
||||
|
||||
# 构建统计信息
|
||||
|
||||
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||
stats = {
|
||||
'activation_metrics': activation_metrics,
|
||||
'node_distribution': node_distribution,
|
||||
'recent_trends': recent_trends,
|
||||
'pending_nodes': pending_nodes,
|
||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||
}
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
||||
f"trend_days={len(recent_trends)}"
|
||||
)
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def get_pending_nodes(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
config_id: Optional[UUID] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取待遗忘节点列表(独立分页接口)
|
||||
|
||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 组ID(必填)
|
||||
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||
page: 页码(从1开始,默认1)
|
||||
pagesize: 每页数量(默认10)
|
||||
|
||||
Returns:
|
||||
dict: 包含待遗忘节点列表和分页信息的字典
|
||||
- items: 待遗忘节点列表
|
||||
- page: 分页信息
|
||||
"""
|
||||
# 获取遗忘引擎组件
|
||||
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||
|
||||
connector = forgetting_scheduler.connector
|
||||
forgetting_threshold = config['forgetting_threshold']
|
||||
|
||||
# 验证 min_days_since_access 配置值
|
||||
min_days = config.get('min_days_since_access')
|
||||
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||
api_logger.warning(
|
||||
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||
)
|
||||
min_days = 7
|
||||
|
||||
# 调用内部方法获取分页数据
|
||||
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||
connector=connector,
|
||||
end_user_id=end_user_id,
|
||||
forgetting_threshold=forgetting_threshold,
|
||||
min_days_since_access=int(min_days),
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||
)
|
||||
|
||||
return pending_nodes_result
|
||||
|
||||
async def get_forgetting_curve(
|
||||
self,
|
||||
db: Session,
|
||||
|
||||
@@ -243,28 +243,9 @@ class MemoryPerceptualService:
|
||||
memory_config: MemoryConfig,
|
||||
file: FileInput
|
||||
):
|
||||
memories = self.repository.get_by_url(file.url)
|
||||
if memories:
|
||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
||||
memory_cache = memories[0]
|
||||
memory = self.repository.create_perceptual_memory(
|
||||
end_user_id=uuid.UUID(end_user_id),
|
||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
||||
file_path=memory_cache.file_path,
|
||||
file_name=memory_cache.file_name,
|
||||
file_ext=memory_cache.file_ext,
|
||||
summary=memory_cache.summary,
|
||||
meta_data=memory_cache.meta_data
|
||||
)
|
||||
self.db.commit()
|
||||
return memory
|
||||
else:
|
||||
for memory in memories:
|
||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
||||
return memory
|
||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||
if model_config is None or llm is None:
|
||||
return None
|
||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||
model_name=model_config.model_name,
|
||||
provider=model_config.provider,
|
||||
@@ -286,15 +267,20 @@ class MemoryPerceptualService:
|
||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||
opt_system_prompt = f.read()
|
||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||
except FileNotFoundError:
|
||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
||||
except FileNotFoundError as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
messages = [
|
||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||
{"role": RoleType.USER.value, "content": [
|
||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||
]}
|
||||
]
|
||||
result = await llm.ainvoke(messages)
|
||||
try:
|
||||
result = await llm.ainvoke(messages)
|
||||
except Exception as e:
|
||||
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||
return None
|
||||
content = result.content
|
||||
final_output = ""
|
||||
if isinstance(content, list):
|
||||
|
||||
@@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
||||
return result
|
||||
|
||||
|
||||
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
|
||||
"""批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||
|
||||
Args:
|
||||
end_user_ids: 用户ID列表
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: 以user_id为key的记忆数量字典
|
||||
格式: {"user_id": total_count}
|
||||
"""
|
||||
if not end_user_ids:
|
||||
return {}
|
||||
|
||||
result = await _neo4j_connector.execute_query(
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||
end_user_ids=end_user_ids,
|
||||
)
|
||||
|
||||
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
|
||||
data = {}
|
||||
for row in result:
|
||||
data[row["user_id"]] = row["total"]
|
||||
|
||||
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
|
||||
for user_id in end_user_ids:
|
||||
if user_id not in data:
|
||||
data[user_id] = 0
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_hot_memory_tags(
|
||||
db: Session,
|
||||
current_user: User,
|
||||
|
||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""根据名称获取模型配置"""
|
||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||
if not model:
|
||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||
ModelConfig]:
|
||||
"""按名称模糊匹配获取模型配置列表"""
|
||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||
|
||||
@staticmethod
|
||||
async def validate_model_config(
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
db: Session,
|
||||
*,
|
||||
model_name: str,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
api_base: Optional[str] = None,
|
||||
model_type: str = "llm",
|
||||
test_message: str = "Hello",
|
||||
is_omni: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""验证模型配置是否有效
|
||||
|
||||
@@ -158,13 +160,13 @@ class ModelConfigService:
|
||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||
embedding = RedBearEmbeddings(model_config)
|
||||
test_texts = [test_message, "测试文本"]
|
||||
|
||||
|
||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||
if provider.lower() == "volcano":
|
||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||
else:
|
||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
@@ -200,11 +202,11 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "image":
|
||||
# 图片生成模型验证
|
||||
from app.core.models.generation import RedBearImageGenerator
|
||||
|
||||
|
||||
generator = RedBearImageGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda",
|
||||
@@ -212,7 +214,7 @@ class ModelConfigService:
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(f"成功生成图片,结果: {result}")
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "图片生成模型配置验证成功",
|
||||
@@ -224,21 +226,21 @@ class ModelConfigService:
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
|
||||
elif model_type_lower == "video":
|
||||
# 视频生成模型验证
|
||||
from app.core.models.generation import RedBearVideoGenerator
|
||||
|
||||
|
||||
generator = RedBearVideoGenerator(model_config)
|
||||
result = await generator.agenerate(
|
||||
prompt="a cute panda playing in bamboo forest",
|
||||
duration=5
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
# 视频生成是异步任务,返回任务ID
|
||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"message": "视频生成模型配置验证成功",
|
||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
||||
# 提取详细的错误信息
|
||||
error_message = str(e)
|
||||
error_type = type(e).__name__
|
||||
print("=========error_message:",error_message.lower())
|
||||
# 特殊处理常见的错误类型
|
||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||
# 区域/国家限制(适用于所有提供商)
|
||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||
"""更新模型配置"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||
@@ -370,25 +373,27 @@ class ModelConfigService:
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""创建组合模型"""
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
# 检查 API Key 关联的模型配置类型
|
||||
for model_config in api_key.model_configs:
|
||||
# chat 和 llm 类型可以兼容
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = model_data.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
@@ -399,7 +404,7 @@ class ModelConfigService:
|
||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||
# BizCode.INVALID_PARAMETER
|
||||
# )
|
||||
|
||||
|
||||
# 创建组合模型
|
||||
model_config_data = {
|
||||
"tenant_id": tenant_id,
|
||||
@@ -418,49 +423,51 @@ class ModelConfigService:
|
||||
|
||||
model = ModelConfigRepository.create(db, model_config_data)
|
||||
db.flush()
|
||||
|
||||
|
||||
# 关联 API Keys
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||
tenant_id: uuid.UUID) -> ModelConfig:
|
||||
"""更新组合模型"""
|
||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||
if not existing_model:
|
||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
if model_data.name and model_data.name != existing_model.name:
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||
tenant_id=tenant_id):
|
||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
if not existing_model.is_composite:
|
||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||
|
||||
|
||||
# 验证所有 API Key 存在且类型匹配
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if not api_key:
|
||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||
|
||||
|
||||
for model_config in api_key.model_configs:
|
||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||
config_type = model_config.type
|
||||
request_type = existing_model.type
|
||||
|
||||
if not (config_type == request_type or
|
||||
|
||||
if not (config_type == request_type or
|
||||
(config_type in compatible_types and request_type in compatible_types)):
|
||||
raise BusinessException(
|
||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||
BizCode.INVALID_PARAMETER
|
||||
)
|
||||
|
||||
|
||||
# 更新基本信息
|
||||
existing_model.name = model_data.name
|
||||
# existing_model.type = model_data.type
|
||||
@@ -471,14 +478,14 @@ class ModelConfigService:
|
||||
existing_model.is_public = model_data.is_public
|
||||
if "load_balance_strategy" in model_data.model_fields_set:
|
||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||
|
||||
|
||||
# 更新 API Keys 关联
|
||||
existing_model.api_keys.clear()
|
||||
for api_key_id in model_data.api_key_ids:
|
||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||
if api_key:
|
||||
existing_model.api_keys.append(api_key)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_model)
|
||||
return existing_model
|
||||
@@ -532,7 +539,7 @@ class ModelApiKeyService:
|
||||
"""根据provider为多个ModelConfig创建API Key"""
|
||||
created_keys = []
|
||||
failed_models = [] # 记录验证失败的模型
|
||||
|
||||
|
||||
for model_config_id in data.model_config_ids:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
@@ -540,10 +547,10 @@ class ModelApiKeyService:
|
||||
|
||||
data.is_omni = model_config.is_omni
|
||||
data.capability = model_config.capability
|
||||
|
||||
|
||||
# 从ModelBase获取model_name
|
||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||
|
||||
|
||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -553,7 +560,7 @@ class ModelApiKeyService:
|
||||
ModelApiKey.model_name == model_name,
|
||||
ModelConfig.tenant_id == model_config.tenant_id
|
||||
).first()
|
||||
|
||||
|
||||
if existing_key:
|
||||
# 如果已存在,重新激活并更新
|
||||
if existing_key.is_active:
|
||||
@@ -566,14 +573,14 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = model_name
|
||||
existing_key.capability = data.capability
|
||||
existing_key.is_omni = data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
created_keys.append(existing_key)
|
||||
continue
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -589,7 +596,7 @@ class ModelApiKeyService:
|
||||
# 记录验证失败的模型,但不抛出异常
|
||||
failed_models.append(model_name)
|
||||
continue
|
||||
|
||||
|
||||
# 创建API Key
|
||||
api_key_data = ModelApiKeyCreate(
|
||||
model_config_ids=[model_config_id],
|
||||
@@ -606,12 +613,12 @@ class ModelApiKeyService:
|
||||
)
|
||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||
created_keys.append(api_key_obj)
|
||||
|
||||
|
||||
if created_keys:
|
||||
db.commit()
|
||||
for key in created_keys:
|
||||
db.refresh(key)
|
||||
|
||||
|
||||
return created_keys, failed_models
|
||||
|
||||
@staticmethod
|
||||
@@ -626,7 +633,7 @@ class ModelApiKeyService:
|
||||
api_key_data.is_omni = model_config.is_omni
|
||||
if api_key_data.capability is None:
|
||||
api_key_data.capability = model_config.capability
|
||||
|
||||
|
||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||
existing_key = db.query(ModelApiKey).join(
|
||||
ModelApiKey.model_configs
|
||||
@@ -650,15 +657,15 @@ class ModelApiKeyService:
|
||||
existing_key.model_name = api_key_data.model_name
|
||||
existing_key.capability = api_key_data.capability
|
||||
existing_key.is_omni = api_key_data.is_omni
|
||||
|
||||
|
||||
# 检查是否已关联该模型配置
|
||||
if model_config not in existing_key.model_configs:
|
||||
existing_key.model_configs.append(model_config)
|
||||
|
||||
|
||||
db.commit()
|
||||
db.refresh(existing_key)
|
||||
return existing_key
|
||||
|
||||
|
||||
# 验证配置
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
@@ -691,7 +698,7 @@ class ModelApiKeyService:
|
||||
# 获取关联的模型配置以获取模型类型
|
||||
if existing_api_key.model_configs:
|
||||
model_config = existing_api_key.model_configs[0]
|
||||
|
||||
|
||||
validation_result = await ModelConfigService.validate_model_config(
|
||||
db=db,
|
||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||
@@ -729,15 +736,15 @@ class ModelApiKeyService:
|
||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||
if not model_config:
|
||||
return None
|
||||
|
||||
|
||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||
if not api_keys:
|
||||
return None
|
||||
|
||||
|
||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||
|
||||
|
||||
# 否则返回第一个
|
||||
return api_keys[0]
|
||||
|
||||
@@ -760,20 +767,19 @@ class ModelApiKeyService:
|
||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
|
||||
class ModelBaseService:
|
||||
"""基础模型服务"""
|
||||
|
||||
@staticmethod
|
||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||
models = ModelBaseRepository.get_list(db, query)
|
||||
|
||||
|
||||
provider_groups = {}
|
||||
for m in models:
|
||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||
if tenant_id:
|
||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||
|
||||
|
||||
provider = m.provider
|
||||
if provider not in provider_groups:
|
||||
provider_groups[provider] = {
|
||||
@@ -781,7 +787,7 @@ class ModelBaseService:
|
||||
"models": []
|
||||
}
|
||||
provider_groups[provider]["models"].append(model_dict)
|
||||
|
||||
|
||||
return list(provider_groups.values())
|
||||
|
||||
@staticmethod
|
||||
@@ -823,10 +829,10 @@ class ModelBaseService:
|
||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||
if not model_base:
|
||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||
|
||||
|
||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||
|
||||
|
||||
model_config_data = {
|
||||
"model_id": model_base_id,
|
||||
"tenant_id": tenant_id,
|
||||
|
||||
@@ -12,6 +12,9 @@ import base64
|
||||
import csv
|
||||
import io
|
||||
import json
|
||||
import re
|
||||
import olefile
|
||||
import struct
|
||||
import zipfile
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional
|
||||
@@ -438,13 +441,13 @@ class MultimodalService:
|
||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||
return True, {
|
||||
"type": "text",
|
||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||
}
|
||||
else:
|
||||
# 本地文件,提取文本内容
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||
text = await self._extract_document_text(file)
|
||||
text = await self.extract_document_text(file)
|
||||
file_metadata = self.db.query(FileMetadata).filter(
|
||||
FileMetadata.id == file.upload_file_id
|
||||
).first()
|
||||
@@ -542,7 +545,7 @@ class MultimodalService:
|
||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||
return f"{server_url}/storage/permanent/{file_id}"
|
||||
|
||||
async def _extract_document_text(self, file: FileInput) -> str:
|
||||
async def extract_document_text(self, file: FileInput) -> str:
|
||||
"""
|
||||
提取文档文本内容
|
||||
|
||||
@@ -602,31 +605,75 @@ class MultimodalService:
|
||||
try:
|
||||
word_file = io.BytesIO(file_content)
|
||||
doc = Document(word_file)
|
||||
return '\n'.join(p.text for p in doc.paragraphs)
|
||||
text_lines = []
|
||||
for p in doc.paragraphs:
|
||||
text = p.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
for table in doc.tables:
|
||||
for row in table.rows:
|
||||
for cell in row.cells:
|
||||
text = cell.text.strip()
|
||||
if text:
|
||||
text_lines.append(text)
|
||||
|
||||
full_text = "\n".join(text_lines)
|
||||
return full_text.strip() or "[docx 文件无文本内容]"
|
||||
except Exception as e:
|
||||
logger.error(f"提取 docx 文本失败: {e}")
|
||||
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
|
||||
return f"[docx 提取失败: {str(e)}]"
|
||||
|
||||
# 旧版 .doc(OLE2 格式)
|
||||
# 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
|
||||
try:
|
||||
import olefile
|
||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||
if not ole.exists('WordDocument'):
|
||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
||||
stream = ole.openstream('WordDocument').read()
|
||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
||||
# 简单提取:过滤出可打印字符段
|
||||
try:
|
||||
text = stream.decode('utf-16-le', errors='ignore')
|
||||
except Exception:
|
||||
text = stream.decode('latin-1', errors='ignore')
|
||||
# 过滤控制字符,保留可打印内容
|
||||
import re
|
||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
||||
text = re.sub(r' +', ' ', text).strip()
|
||||
word_stream = ole.openstream('WordDocument').read()
|
||||
|
||||
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
|
||||
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
|
||||
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
|
||||
table_stream = ole.openstream(table_name).read()
|
||||
|
||||
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
|
||||
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
|
||||
clx = table_stream[fc_clx: fc_clx + lcb_clx]
|
||||
|
||||
# 解析 CLX,找到 PlcPcd(piece table)
|
||||
i, plc_pcd = 0, None
|
||||
while i < len(clx):
|
||||
clxt = clx[i]
|
||||
if clxt == 0x01:
|
||||
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
|
||||
elif clxt == 0x02:
|
||||
cb = struct.unpack_from('<I', clx, i + 1)[0]
|
||||
plc_pcd = clx[i + 5: i + 5 + cb]
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
if plc_pcd is None:
|
||||
raise ValueError("PlcPcd not found")
|
||||
|
||||
# PlcPcd: (n+1) 个 CP(4字节)+ n 个 PCD(8字节)
|
||||
n_pieces = (len(plc_pcd) - 4) // 12
|
||||
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
|
||||
|
||||
parts = []
|
||||
for k in range(n_pieces):
|
||||
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
|
||||
is_ansi = bool(fc_value & 0x40000000)
|
||||
fc = fc_value & 0x3FFFFFFF
|
||||
char_count = cp_array[k + 1] - cp_array[k]
|
||||
|
||||
if is_ansi:
|
||||
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
|
||||
else:
|
||||
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
|
||||
|
||||
ole.close()
|
||||
return text
|
||||
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
|
||||
return result.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提取 doc 文本失败: {e}")
|
||||
return f"[doc 提取失败: {str(e)}]"
|
||||
|
||||
@@ -1,26 +1,24 @@
|
||||
"""基于分享链接的聊天服务"""
|
||||
import uuid
|
||||
import time
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, Dict, Any, AsyncGenerator
|
||||
|
||||
from deprecated import deprecated
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.repositories.model_repository import ModelApiKeyRepository
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.models import MultiAgentConfig
|
||||
from app.models import ReleaseShare, AppRelease, Conversation
|
||||
from app.repositories import knowledge_repository
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.draft_run_service import create_web_search_tool
|
||||
from app.services.model_service import ModelApiKeyService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import MultiAgentConfig
|
||||
from app.repositories import knowledge_repository
|
||||
import json
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
||||
|
||||
return conversation
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
||||
config_id = actual_config_id
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.services.model_parameter_merger import ModelParameterMerger
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
|
||||
start_time = time.time()
|
||||
actual_config_id = None
|
||||
@@ -273,11 +269,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
)
|
||||
|
||||
# 保存消息
|
||||
@@ -324,6 +315,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -341,8 +333,6 @@ class SharedChatService:
|
||||
from app.core.agent.langchain_agent import LangChainAgent
|
||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||
from sqlalchemy import select
|
||||
from app.models import ModelApiKey
|
||||
import json
|
||||
|
||||
start_time = time.time()
|
||||
@@ -486,11 +476,6 @@ class SharedChatService:
|
||||
message=message,
|
||||
history=history,
|
||||
context=None,
|
||||
end_user_id=user_id,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
config_id=config_id,
|
||||
memory_flag=memory_flag
|
||||
):
|
||||
if isinstance(chunk, int):
|
||||
total_tokens = chunk
|
||||
@@ -585,6 +570,7 @@ class SharedChatService:
|
||||
|
||||
return conversations, total
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat(
|
||||
self,
|
||||
share_token: str,
|
||||
@@ -680,6 +666,7 @@ class SharedChatService:
|
||||
"elapsed_time": elapsed_time
|
||||
}
|
||||
|
||||
@deprecated("Use the chat method under app_chat_service instead.")
|
||||
async def multi_agent_chat_stream(
|
||||
self,
|
||||
share_token: str,
|
||||
|
||||
@@ -138,7 +138,7 @@ class TenantService:
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"删除租户失败: {str(e)}")
|
||||
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
||||
raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR)
|
||||
|
||||
# 租户用户管理
|
||||
def get_tenant_users(
|
||||
@@ -147,6 +147,7 @@ class TenantService:
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> List[UserModel]:
|
||||
"""获取租户下的用户列表"""
|
||||
@@ -155,6 +156,7 @@ class TenantService:
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
@@ -162,12 +164,14 @@ class TenantService:
|
||||
self,
|
||||
tenant_id: uuid.UUID,
|
||||
is_active: Optional[bool] = None,
|
||||
is_superuser: Optional[bool] = None,
|
||||
search: Optional[str] = None
|
||||
) -> int:
|
||||
"""统计租户下的用户数量"""
|
||||
return self.user_repo.count_users_by_tenant(
|
||||
tenant_id=tenant_id,
|
||||
is_active=is_active,
|
||||
is_superuser=is_superuser,
|
||||
search=search
|
||||
)
|
||||
|
||||
|
||||
@@ -472,6 +472,21 @@ class UserMemoryService:
|
||||
# 定义允许更新的字段白名单
|
||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||
|
||||
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
||||
_user_placeholder_names = {'用户', '我', 'User', 'I'}
|
||||
|
||||
# 过滤 other_name:不允许设置为占位名称
|
||||
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
||||
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
|
||||
del update_data['other_name']
|
||||
|
||||
# 过滤 aliases:移除占位名称和非字符串值
|
||||
if 'aliases' in update_data and update_data['aliases']:
|
||||
update_data['aliases'] = [
|
||||
a for a in update_data['aliases']
|
||||
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
|
||||
]
|
||||
|
||||
# 检查是否更新了 aliases 字段
|
||||
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
|
||||
|
||||
|
||||
@@ -561,6 +561,24 @@ class WorkflowService:
|
||||
storage_type = 'neo4j'
|
||||
return storage_type, user_rag_memory_id
|
||||
|
||||
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
|
||||
executions = self.execution_repo.get_by_conversation_id(
|
||||
conversation_id=conversation_id,
|
||||
status="completed",
|
||||
limit_count=1
|
||||
)
|
||||
|
||||
if executions:
|
||||
last_state = executions[0].output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
# input_data["conv"] = conv_vars
|
||||
# input_data["conv_messages"] = last_state.get("messages") or []
|
||||
conv_messages = last_state.get("messages") or []
|
||||
return conv_vars, conv_messages
|
||||
return None
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -634,18 +652,11 @@ class WorkflowService:
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
|
||||
result = await execute_workflow(
|
||||
@@ -807,17 +818,11 @@ class WorkflowService:
|
||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
|
||||
for exec_res in executions:
|
||||
if exec_res.status == "completed":
|
||||
last_state = exec_res.output_data
|
||||
if isinstance(last_state, dict):
|
||||
variables = last_state.get("variables", {})
|
||||
conv_vars = variables.get("conv", {})
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
history = self._get_history_info(conversation_id_uuid)
|
||||
if history:
|
||||
conv_vars, conv_messages = history
|
||||
input_data["conv"] = conv_vars
|
||||
input_data["conv_messages"] = conv_messages
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
message_id = uuid.uuid4()
|
||||
async for event in execute_workflow_stream(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
@@ -38,12 +37,10 @@ from app.db import get_db, get_db_context
|
||||
from app.models import Document, File, Knowledge
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.schemas import document_schema, file_schema
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
||||
from app.services.memory_forget_service import MemoryForgetService
|
||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
from app.utils.redis_lock import RedisLock
|
||||
from app.utils.redis_lock import RedisFairLock
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
||||
|
||||
|
||||
def set_asyncio_event_loop():
|
||||
"""Set the asyncio event loop for the current thread."""
|
||||
"""Ensure an open asyncio event loop exists for the current thread.
|
||||
|
||||
Reuses the existing event loop if one is available and still open.
|
||||
Creates and installs a new event loop only when the current one is
|
||||
closed or missing (e.g. after ``_shutdown_loop_gracefully``).
|
||||
"""
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
@@ -116,6 +118,30 @@ def set_asyncio_event_loop():
|
||||
return loop
|
||||
|
||||
|
||||
def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop):
|
||||
"""Gracefully shutdown pending async generators and tasks on the event loop.
|
||||
|
||||
This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||
by giving pending aclose() coroutines a chance to run before the loop is discarded.
|
||||
|
||||
Note: This only tears down the given loop. Callers that need a fresh event
|
||||
loop afterwards should use ``set_asyncio_event_loop()`` explicitly.
|
||||
"""
|
||||
try:
|
||||
# Cancel and collect all remaining tasks
|
||||
all_tasks = asyncio.all_tasks(loop)
|
||||
if all_tasks:
|
||||
for task in all_tasks:
|
||||
task.cancel()
|
||||
loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True))
|
||||
# Shutdown async generators (triggers __aclose__ on httpx clients etc.)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(name="tasks.process_item")
|
||||
def process_item(item: dict):
|
||||
"""
|
||||
@@ -1148,8 +1174,28 @@ def write_message_task(
|
||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||
return result
|
||||
|
||||
redis_client = get_sync_redis_client()
|
||||
lock = None
|
||||
if redis_client is not None:
|
||||
lock = RedisFairLock(
|
||||
key=f"memory_write:{end_user_id}",
|
||||
redis_client=redis_client,
|
||||
expire=600,
|
||||
timeout=3600,
|
||||
auto_renewal=True,
|
||||
)
|
||||
if not lock.acquire():
|
||||
logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}")
|
||||
return {
|
||||
"status": "SKIPPED",
|
||||
"error": "acquire lock timeout",
|
||||
"end_user_id": end_user_id,
|
||||
"config_id": str(config_id),
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
try:
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
loop = set_asyncio_event_loop()
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
@@ -1158,7 +1204,6 @@ def write_message_task(
|
||||
logger.info(f"[CELERY WRITE] Task completed successfully "
|
||||
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||
|
||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
||||
try:
|
||||
_r = get_sync_redis_client()
|
||||
if _r is not None:
|
||||
@@ -1199,6 +1244,15 @@ def write_message_task(
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
finally:
|
||||
if lock is not None:
|
||||
try:
|
||||
lock.release()
|
||||
except Exception as e:
|
||||
logger.warning(f"[CELERY WRITE] 释放锁失败: {e}")
|
||||
# Gracefully shutdown the event loop to prevent
|
||||
# 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||
_shutdown_loop_gracefully(loop)
|
||||
|
||||
|
||||
# unused task
|
||||
@@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"task_id": self.request.id,
|
||||
}
|
||||
|
||||
|
||||
# unused task
|
||||
37
api/app/utils/performance_timer.py
Normal file
37
api/app/utils/performance_timer.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
性能监控工具模块
|
||||
|
||||
提供代码块执行时间统计功能,用于接口性能分析。
|
||||
如需再次启用性能监控,只需在 controller 中导入 from app.utils.performance_timer import timer 并添加 with timer(...) 包裹需要监控的代码块即可
|
||||
"""
|
||||
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from app.core.logging_config import get_api_logger
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(label: str, user_count: int = 0):
|
||||
"""上下文管理器:用于测量代码块执行时间
|
||||
|
||||
Args:
|
||||
label: 统计标签,用于标识被测量的代码块
|
||||
user_count: 用户数,可选参数,用于记录处理的用户数量
|
||||
|
||||
Usage:
|
||||
with timer("获取用户列表"):
|
||||
users = get_users()
|
||||
|
||||
with timer("批量处理", user_count=len(user_ids)):
|
||||
process_users(user_ids)
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
|
||||
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
|
||||
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
|
||||
@@ -1,6 +1,7 @@
|
||||
import redis
|
||||
import uuid
|
||||
import time
|
||||
import threading
|
||||
|
||||
UNLOCK_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
@@ -10,45 +11,136 @@ else
|
||||
end
|
||||
"""
|
||||
|
||||
RENEW_SCRIPT = """
|
||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||
return redis.call("expire", KEYS[1], ARGV[2])
|
||||
else
|
||||
return 0
|
||||
end
|
||||
"""
|
||||
|
||||
class RedisLock:
|
||||
CLEANUP_DEAD_HEAD_SCRIPT = """
|
||||
local queue_key = KEYS[1]
|
||||
local lock_key = KEYS[2]
|
||||
|
||||
local first = redis.call("lindex", queue_key, 0)
|
||||
if not first then
|
||||
return 0
|
||||
end
|
||||
|
||||
if redis.call("exists", lock_key) == 1 then
|
||||
return 0
|
||||
end
|
||||
|
||||
redis.call("lpop", queue_key)
|
||||
return 1
|
||||
"""
|
||||
|
||||
SAFE_RELEASE_QUEUE_SCRIPT = """
|
||||
local queue_key = KEYS[1]
|
||||
local value = ARGV[1]
|
||||
|
||||
local first = redis.call("lindex", queue_key, 0)
|
||||
if first == value then
|
||||
redis.call("lpop", queue_key)
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
def _ensure_str(val):
|
||||
"""统一将 Redis 返回值转为 str,兼容 decode_responses=True/False"""
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, bytes):
|
||||
return val.decode("utf-8")
|
||||
return str(val)
|
||||
|
||||
|
||||
class RedisFairLock:
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
redis_client: redis.StrictRedis,
|
||||
expire: int = 60,
|
||||
retry_interval: float = 0.1,
|
||||
timeout: float = 30
|
||||
|
||||
expire: int = 30,
|
||||
retry_interval: float = 0.05,
|
||||
timeout: float = 600,
|
||||
auto_renewal: bool = True
|
||||
):
|
||||
self.key = key
|
||||
self.expire = expire
|
||||
self.queue_key = f"{key}:queue"
|
||||
self.value = str(uuid.uuid4())
|
||||
self._locked = False
|
||||
self.expire = expire
|
||||
self.retry_interval = retry_interval
|
||||
self.timeout = timeout
|
||||
self.redis_client = redis_client
|
||||
self.redis = redis_client
|
||||
self._locked = False
|
||||
self.auto_renewal = auto_renewal
|
||||
self._renew_thread = None
|
||||
self._stop_renew = threading.Event()
|
||||
|
||||
def acquire(self) -> bool:
|
||||
def acquire(self):
|
||||
start = time.time()
|
||||
|
||||
self.redis.rpush(self.queue_key, self.value)
|
||||
|
||||
while True:
|
||||
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
|
||||
if ok:
|
||||
self._locked = True
|
||||
return True
|
||||
if time.time() - start >= self.timeout:
|
||||
first = _ensure_str(self.redis.lindex(self.queue_key, 0))
|
||||
|
||||
if first == self.value:
|
||||
ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
|
||||
if ok:
|
||||
self._locked = True
|
||||
|
||||
if self.auto_renewal:
|
||||
self._start_renewal()
|
||||
return True
|
||||
|
||||
if first:
|
||||
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
|
||||
|
||||
if time.time() - start > self.timeout:
|
||||
self.redis.lrem(self.queue_key, 0, self.value)
|
||||
return False
|
||||
|
||||
time.sleep(self.retry_interval)
|
||||
|
||||
def _renewal_loop(self):
|
||||
while not self._stop_renew.is_set():
|
||||
time.sleep(self.expire / 3)
|
||||
if self._stop_renew.is_set():
|
||||
break
|
||||
|
||||
self.redis.eval(
|
||||
RENEW_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value,
|
||||
str(self.expire)
|
||||
)
|
||||
|
||||
def _start_renewal(self):
|
||||
self._stop_renew = threading.Event()
|
||||
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
|
||||
self._renew_thread.start()
|
||||
|
||||
def _stop_renewal(self):
|
||||
self._stop_renew.set()
|
||||
if self._renew_thread:
|
||||
self._renew_thread.join(timeout=1)
|
||||
|
||||
def release(self):
|
||||
if not self._locked:
|
||||
return
|
||||
self.redis_client.eval(
|
||||
UNLOCK_SCRIPT,
|
||||
1,
|
||||
self.key,
|
||||
self.value
|
||||
)
|
||||
|
||||
if self.auto_renewal:
|
||||
self._stop_renewal()
|
||||
|
||||
self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
|
||||
|
||||
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
|
||||
|
||||
self._locked = False
|
||||
|
||||
def __enter__(self):
|
||||
@@ -59,3 +151,4 @@ class RedisLock:
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.release()
|
||||
|
||||
|
||||
30
api/migrations/versions/4e89970f9e7c_202603271515.py
Normal file
30
api/migrations/versions/4e89970f9e7c_202603271515.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""202603271515
|
||||
|
||||
Revision ID: 4e89970f9e7c
|
||||
Revises: 6b8a461148ff
|
||||
Create Date: 2026-03-27 15:12:27.518344
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '4e89970f9e7c'
|
||||
down_revision: Union[str, None] = '6b8a461148ff'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('users', 'phone')
|
||||
# ### end Alembic commands ###
|
||||
Reference in New Issue
Block a user