Merge branch 'refs/heads/develop' into feature/agent-tool_xjn

# Conflicts:
#	api/app/core/agent/langchain_agent.py
#	api/app/core/tools/mcp/client.py
This commit is contained in:
Timebomb2018
2026-04-01 15:27:34 +08:00
219 changed files with 4861 additions and 2599 deletions

View File

@@ -1,6 +1,8 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import threading
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import redis.asyncio as redis import redis.asyncio as redis
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
) )
aio_redis = redis.StrictRedis(connection_pool=pool) aio_redis = redis.StrictRedis(connection_pool=pool)
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
# Thread-local storage for connection pools.
# Each thread (and each forked process) gets its own pool to avoid
# "Future attached to a different loop" errors in Celery --pool=threads
# and stale connections after fork in --pool=prefork.
_thread_local = threading.local()
def get_thread_safe_redis() -> redis.StrictRedis:
"""Return a Redis client whose connection pool is bound to the current
thread, process **and** event loop.
The pool is recreated when:
- The PID changes (fork, Celery --pool=prefork)
- The thread has no pool yet (Celery --pool=threads)
- The previously-cached event loop has been closed (Celery tasks call
``_shutdown_loop_gracefully`` which closes the loop after each run)
"""
current_pid = os.getpid()
cached_loop = getattr(_thread_local, "loop", None)
loop_stale = cached_loop is not None and cached_loop.is_closed()
if not hasattr(_thread_local, "pool") \
or getattr(_thread_local, "pid", None) != current_pid \
or loop_stale:
_thread_local.pid = current_pid
# Python 3.10+: get_event_loop() raises RuntimeError in threads
# where no loop has been set yet (e.g. Celery --pool=threads).
try:
_thread_local.loop = asyncio.get_event_loop()
except RuntimeError:
_thread_local.loop = None
_thread_local.pool = ConnectionPool.from_url(
_REDIS_URL,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=5,
health_check_interval=30,
)
return redis.StrictRedis(connection_pool=_thread_local.pool)
async def get_redis_connection(): async def get_redis_connection():
"""获取Redis连接""" """获取Redis连接"""
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
val = json.dumps(val, ensure_ascii=False) val = json.dumps(val, ensure_ascii=False)
if expire is not None: if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire) await aio_redis.set(key, val, ex=expire)
else: else:
# 设置永久键值
await aio_redis.set(key, val) await aio_redis.set(key, val)
except Exception as e: except Exception as e:
logger.error(f"Redis set错误: {str(e)}") logger.error(f"Redis set错误: {str(e)}")

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
from app.aioRedis import aio_redis from app.aioRedis import get_thread_safe_redis
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -68,7 +68,7 @@ class ActivityStatsCache:
"cached": True, "cached": True,
} }
value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire) await get_thread_safe_redis().set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}") logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True return True
except Exception as e: except Exception as e:
@@ -90,7 +90,7 @@ class ActivityStatsCache:
""" """
try: try:
key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
value = await aio_redis.get(key) value = await get_thread_safe_redis().get(key)
if value: if value:
payload = json.loads(value) payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}") logger.info(f"命中活动统计缓存: {key}")
@@ -116,7 +116,7 @@ class ActivityStatsCache:
""" """
try: try:
key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
result = await aio_redis.delete(key) result = await get_thread_safe_redis().delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}") logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0 return result > 0
except Exception as e: except Exception as e:

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from datetime import datetime from datetime import datetime
from app.aioRedis import aio_redis from app.aioRedis import get_thread_safe_redis
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class InterestMemoryCache:
"cached": True, "cached": True,
} }
value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire) await get_thread_safe_redis().set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}") logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True return True
except Exception as e: except Exception as e:
@@ -86,7 +86,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key) value = await get_thread_safe_redis().get(key)
if value: if value:
payload = json.loads(value) payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}") logger.info(f"命中兴趣分布缓存: {key}")
@@ -114,7 +114,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key) result = await get_thread_safe_redis().delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0 return result > 0
except Exception as e: except Exception as e:

View File

@@ -57,7 +57,6 @@ def list_apps(
page: int = 1, page: int = 1,
pagesize: int = 10, pagesize: int = 10,
ids: Optional[str] = None, ids: Optional[str] = None,
api_key: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
@@ -66,7 +65,7 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用 - 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页
- 当提供 api_key 参数时,查找该 API Key 关联的应用 - search 参数支持:应用名称模糊搜索、API Key 精确搜索
""" """
from sqlalchemy import select as sa_select from sqlalchemy import select as sa_select
from app.models.api_key_model import ApiKey from app.models.api_key_model import ApiKey
@@ -74,23 +73,34 @@ def list_apps(
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
service = app_service.AppService(db) service = app_service.AppService(db)
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程 # 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
if api_key: if search:
matched_id = db.execute( search = search.strip()
sa_select(ApiKey.resource_id).where( # 尝试作为 API Key 精确匹配API Key 通常较长)
ApiKey.workspace_id == workspace_id, if len(search) >= 10:
ApiKey.api_key == api_key, matched_id = db.execute(
ApiKey.resource_id.isnot(None), sa_select(ApiKey.resource_id).where(
) ApiKey.workspace_id == workspace_id,
).scalar_one_or_none() ApiKey.api_key == search,
ids = str(matched_id) if matched_id else "" ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
if matched_id:
# 找到 API Key直接返回关联的应用
ids = str(matched_id)
# 当 ids 存在且不为 None 时,根据 ids 获取应用 # 当 ids 存在时,根据 ids 获取应用(不分页)
if ids is not None: if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) if app_ids:
items = [service._convert_to_schema(app, workspace_id) for app in items_orm] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
return success(data=items) items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
# 返回标准分页格式
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
return success(data=PageData(page=meta, items=items))
# ids 为空时,返回空列表
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
return success(data=PageData(page=meta, items=[]))
# 正常分页查询 # 正常分页查询
items_orm, total = app_service.list_apps( items_orm, total = app_service.list_apps(

View File

@@ -3,17 +3,16 @@ import uuid
from typing import Optional from typing import Optional
from fastapi import APIRouter, Depends, Query from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.response_utils import success from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.conversation_model import Conversation, Message from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
from app.schemas.response_schema import PageData, PageMeta from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService from app.services.app_service import AppService
from app.services.app_log_service import AppLogService
router = APIRouter(prefix="/apps", tags=["App Logs"]) router = APIRouter(prefix="/apps", tags=["App Logs"])
logger = get_business_logger() logger = get_business_logger()
@@ -25,52 +24,35 @@ def list_app_logs(
app_id: uuid.UUID, app_id: uuid.UUID,
page: int = Query(1, ge=1), page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100), pagesize: int = Query(20, ge=1, le=100),
user_id: Optional[str] = None,
is_draft: Optional[bool] = None, is_draft: Optional[bool] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user=Depends(get_current_user), current_user=Depends(get_current_user),
): ):
"""查看应用下所有会话记录(分页) """查看应用下所有会话记录(分页)
- 支持按 user_id 筛选
- 支持按 is_draft 筛选(草稿会话 / 发布会话) - 支持按 is_draft 筛选(草稿会话 / 发布会话)
- 按最新更新时间倒序排列 - 按最新更新时间倒序排列
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 验证应用访问权限 # 验证应用访问权限
service = AppService(db) app_service = AppService(db)
service.get_app(app_id, workspace_id) app_service.get_app(app_id, workspace_id)
stmt = select(Conversation).where( # 使用 Service 层查询
Conversation.app_id == app_id, log_service = AppLogService(db)
Conversation.workspace_id == workspace_id, conversations, total = log_service.list_conversations(
Conversation.is_active.is_(True), app_id=app_id,
workspace_id=workspace_id,
page=page,
pagesize=pagesize,
is_draft=is_draft
) )
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
total = int(db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(db.scalars(stmt).all())
items = [AppLogConversation.model_validate(c) for c in conversations] items = [AppLogConversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total) meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
logger.info(
"查询应用日志会话列表",
extra={"app_id": str(app_id), "total": total, "page": page}
)
return success(data=PageData(page=meta, items=items)) return success(data=PageData(page=meta, items=items))
@@ -86,44 +68,22 @@ def get_app_log_detail(
- 返回会话基本信息 + 所有消息(按时间正序) - 返回会话基本信息 + 所有消息(按时间正序)
- 消息 meta_data 包含模型名、token 用量等信息 - 消息 meta_data 包含模型名、token 用量等信息
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 验证应用访问权限 # 验证应用访问权限
service = AppService(db) app_service = AppService(db)
service.get_app(app_id, workspace_id) app_service.get_app(app_id, workspace_id)
# 查询会话(确保属于该应用和工作空间) # 使用 Service 层查询
conversation = db.scalars( log_service = AppLogService(db)
select(Conversation).where( conversation = log_service.get_conversation_detail(
Conversation.id == conversation_id, app_id=app_id,
Conversation.app_id == app_id, conversation_id=conversation_id,
Conversation.workspace_id == workspace_id, workspace_id=workspace_id
Conversation.is_active.is_(True),
)
).first()
if not conversation:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("会话", str(conversation_id))
# 查询消息(按时间正序)
messages = list(db.scalars(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
).all())
detail = AppLogConversationDetail.model_validate(conversation)
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
) )
detail = AppLogConversationDetail.model_validate(conversation)
return success(data=detail) return success(data=detail)

View File

@@ -1,3 +1,5 @@
import asyncio
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
@router.get("/end_users", response_model=ApiResponse) @router.get("/end_users", response_model=ApiResponse)
async def get_workspace_end_users( async def get_workspace_end_users(
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID可选默认当前用户工作空间"),
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id"),
page: int = Query(1, ge=1, description="页码从1开始"),
pagesize: int = Query(10, ge=1, description="每页数量"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取工作空间的宿主列表(高性能优化版本 v2 获取工作空间的宿主列表(分页查询,支持模糊搜索
优化策略: 返回工作空间下的宿主列表,支持分页查询和模糊搜索。
1. 批量查询 end_users一次查询而非循环 通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL Args:
4. 只返回必要字段减少数据传输 workspace_id: 工作空间ID可选默认当前用户工作空间
5. 添加短期缓存减少重复查询 keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
6. 并发执行配置查询和记忆数量查询 page: 页码从1开始默认1
pagesize: 每页数量默认10
返回格式: db: 数据库会话
{ current_user: 当前用户
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量}, Returns:
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"} ApiResponse: 包含宿主列表和分页信息
}
""" """
import asyncio # 如果未提供 workspace_id使用当前用户的工作空间
import json if workspace_id is None:
from app.aioRedis import aio_redis_get, aio_redis_set workspace_id = current_user.current_workspace_id
workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型 # 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
# 获取 end_users(已优化为批量查询) # 获取分页的 end_users
end_users = memory_dashboard_service.get_workspace_end_users( end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
current_user=current_user current_user=current_user,
page=page,
pagesize=pagesize,
keyword=keyword
) )
end_users = end_users_result.get("items", [])
total = end_users_result.get("total", 0)
if not end_users: if not end_users:
api_logger.info("工作空间下没有宿主") api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
# 缓存空结果,避免重复查询 return success(data={
try: "items": [],
await aio_redis_set(cache_key, json.dumps([]), expire=30) "page": {
except Exception as e: "page": page,
api_logger.warning(f"Redis 缓存写入失败: {str(e)}") "pagesize": pagesize,
return success(data=[], msg="宿主列表获取成功") "total": total,
"hasnext": (page * pagesize) < total
}
}, msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users] end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务 # 并发执行两个独立的查询任务
async def get_memory_configs(): async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)""" """获取记忆配置(在线程池中执行同步查询)"""
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}") api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {} return {}
async def get_memory_nums(): async def get_memory_nums():
"""获取记忆数量""" """获取记忆数量"""
if current_workspace_type == "rag": if current_workspace_type == "rag":
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids} return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j": elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制 # Neo4j 模式:批量查询(简化版本只返回total
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j try:
MAX_CONCURRENT_QUERIES = 10 batch_result = await memory_storage_service.search_all_batch(end_user_ids)
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES) return {uid: {"total": count} for uid, count in batch_result.items()}
except Exception as e:
async def get_neo4j_memory_num(end_user_id: str): api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
async with semaphore: return {uid: {"total": 0} for uid in end_user_ids}
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids} return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try: try:
from app.celery_app import celery_app as _celery_app from app.celery_app import celery_app as _celery_app
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
get_memory_configs(), get_memory_configs(),
get_memory_nums() get_memory_nums()
) )
# 构建结果(优化:使用列表推导式) # 构建结果列表
result = [] items = []
for end_user in end_users: for end_user in end_users:
user_id = str(end_user.id) user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {}) config_info = memory_configs_map.get(user_id, {})
result.append({ items.append({
'end_user': { 'end_user': {
'id': user_id, 'id': user_id,
'other_name': end_user.other_name 'other_name': end_user.other_name
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
"memory_config_name": config_info.get("memory_config_name") "memory_config_name": config_info.get("memory_config_name")
} }
}) })
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应) # 触发社区聚类补全任务(异步,不阻塞接口响应)
try: try:
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") # 构建分页响应
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"hasnext": (page * pagesize) < total
}
}
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
return success(data=result, msg="宿主列表获取成功") return success(data=result, msg="宿主列表获取成功")

View File

@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
ForgettingCurveRequest, ForgettingCurveRequest,
ForgettingCurveResponse, ForgettingCurveResponse,
ForgettingCurvePoint, ForgettingCurvePoint,
PendingNodesResponse,
) )
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.get("/pending-nodes", response_model=ApiResponse)
async def get_pending_nodes(
end_user_id: str,
page: int = 1,
pagesize: int = 10,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
此接口独立分页,与 /stats 接口分离。
Args:
end_user_id: 组ID即 end_user_id必填
page: 页码从1开始默认1
pagesize: 每页数量默认10
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含待遗忘节点列表和分页信息的响应
Examples:
- 第1页每页10条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
- 第2页每页20条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
Notes:
- page 从1开始pagesize 必须大于0
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证 end_user_id 必填
if not end_user_id:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
# 验证分页参数
if page < 1:
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
if pagesize < 1:
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
)
try:
# 调用服务层获取待遗忘节点列表
result = await forget_service.get_pending_nodes(
db=db,
end_user_id=end_user_id,
config_id=config_id,
page=page,
pagesize=pagesize
)
# 构建响应
response_data = PendingNodesResponse(**result)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse) @router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve( async def get_forgetting_curve(
request: ForgettingCurveRequest, request: ForgettingCurveRequest,

View File

@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService from app.services.shared_chat_service import SharedChatService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
from app.models.file_metadata_model import FileMetadata
from app.utils.app_config_utils import workflow_config_4_app_release, \ from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release agent_config_4_app_release, multi_agent_config_4_app_release
@@ -259,8 +260,41 @@ def get_conversation(
conv_service = ConversationService(db) conv_service = ConversationService(db)
messages = conv_service.get_messages(conversation_id) messages = conv_service.get_messages(conversation_id)
# 构建响应 file_ids = []
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump() message_file_id_map = {}
# 第一次遍历:解析 audio_url收集所有有效的 file_id
for idx, m in enumerate(messages):
if m.role == "assistant" and m.meta_data:
audio_url = m.meta_data.get("audio_url")
if not audio_url:
continue
try:
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
except (ValueError, IndexError):
# audio_url 无法解析为 UUID标记为 unknown
m.meta_data["audio_status"] = "unknown"
continue
file_ids.append(file_id)
message_file_id_map[idx] = file_id
# 批量查询所有相关的 FileMetadata
file_status_map = {}
if file_ids:
file_metas = (
db.query(FileMetadata)
.filter(FileMetadata.id.in_(set(file_ids)))
.all()
)
file_status_map = {fm.id: fm.status for fm in file_metas}
# 第二次遍历:将查询结果映射回消息
for idx, file_id in message_file_id_map.items():
m = messages[idx]
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
conv_dict["messages"] = [ conv_dict["messages"] = [
conversation_schema.Message.model_validate(m) for m in messages conversation_schema.Message.model_validate(m) for m in messages
] ]
@@ -320,6 +354,16 @@ async def chat(
other_id=other_id, other_id=other_id,
original_user_id=user_id original_user_id=user_id
) )
# Only extract and set memory_config_id when the end user doesn't have one yet
if not new_end_user.memory_config_id:
from app.services.memory_config_service import MemoryConfigService
memory_config_service = MemoryConfigService(db)
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
if memory_config_id:
new_end_user.memory_config_id = memory_config_id
db.commit()
db.refresh(new_end_user)
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
# appid = share.app_id # appid = share.app_id
@@ -410,30 +454,6 @@ async def chat(
agent_config = agent_config_4_app_release(release) agent_config = agent_config_4_app_release(release)
if payload.stream: if payload.stream:
# async def event_generator():
# async for event in service.chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
async def event_generator(): async def event_generator():
async for event in app_chat_service.agnet_chat_stream( async for event in app_chat_service.agnet_chat_stream(
message=payload.message, message=payload.message,
@@ -459,20 +479,6 @@ async def chat(
"X-Accel-Buffering": "no" "X-Accel-Buffering": "no"
} }
) )
# 非流式返回
# result = await service.chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
result = await app_chat_service.agnet_chat( result = await app_chat_service.agnet_chat(
message=payload.message, message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID conversation_id=conversation.id, # 使用已创建的会话 ID
@@ -531,48 +537,6 @@ async def chat(
) )
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
# 多 Agent 流式返回
# if payload.stream:
# async def event_generator():
# async for event in service.multi_agent_chat_stream(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# ):
# yield event
# return StreamingResponse(
# event_generator(),
# media_type="text/event-stream",
# headers={
# "Cache-Control": "no-cache",
# "Connection": "keep-alive",
# "X-Accel-Buffering": "no"
# }
# )
# # 多 Agent 非流式返回
# result = await service.multi_agent_chat(
# share_token=share_token,
# message=payload.message,
# conversation_id=conversation.id, # 使用已创建的会话 ID
# user_id=str(new_end_user.id), # 转换为字符串
# variables=payload.variables,
# password=password,
# web_search=payload.web_search,
# memory=payload.memory,
# storage_type=storage_type,
# user_rag_memory_id=user_rag_memory_id
# )
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW: elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release) config = workflow_config_4_app_release(release)
if not config.id: if not config.id:

View File

@@ -4,7 +4,7 @@
认证方式: API Key 认证方式: API Key
""" """
from fastapi import APIRouter from fastapi import APIRouter
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
# 创建 V1 API 路由器 # 创建 V1 API 路由器
service_router = APIRouter() service_router = APIRouter()
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
service_router.include_router(rag_api_file_controller.router) service_router.include_router(rag_api_file_controller.router)
service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(rag_api_chunk_controller.router)
service_router.include_router(memory_api_controller.router) service_router.include_router(memory_api_controller.router)
service_router.include_router(end_user_api_controller.router)
__all__ = ["service_router"] __all__ = ["service_router"]

View File

@@ -91,7 +91,7 @@ async def chat(
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id other_id = payload.user_id
workspace_id = app.workspace_id workspace_id = api_key_auth.workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id, app_id=app.id,

View File

@@ -0,0 +1,92 @@
"""End User 服务接口 - 基于 API Key 认证"""
import uuid
from fastapi import APIRouter, Body, Depends, Request
from sqlalchemy.orm import Session
from app.core.api_key_auth import require_api_key
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
from app.services.memory_config_service import MemoryConfigService
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
logger = get_business_logger()
@router.post("/create")
@require_api_key(scopes=["memory"])
async def create_end_user(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Request body"),
):
"""
Create or retrieve an end user for the workspace.
Creates a new end user and connects it to a memory configuration.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
Optionally accepts a memory_config_id to connect the end user to a specific
memory configuration. If not provided, falls back to the workspace default config.
"""
body = await request.json()
payload = CreateEndUserRequest(**body)
workspace_id = api_key_auth.workspace_id
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}")
# Resolve memory_config_id: explicit > workspace default
memory_config_id = None
config_service = MemoryConfigService(db)
if payload.memory_config_id:
try:
memory_config_id = uuid.UUID(payload.memory_config_id)
except ValueError:
raise BusinessException(
f"Invalid memory_config_id format: {payload.memory_config_id}",
BizCode.INVALID_PARAMETER
)
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
if not config:
raise BusinessException(
f"Memory config not found: {payload.memory_config_id}",
BizCode.MEMORY_CONFIG_NOT_FOUND
)
memory_config_id = config.config_id
else:
default_config = config_service.get_workspace_default_config(workspace_id)
if default_config:
memory_config_id = default_config.config_id
logger.info(f"Using workspace default memory config: {memory_config_id}")
else:
logger.warning(f"No default memory config found for workspace: {workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user_with_config(
app_id=api_key_auth.resource_id,
workspace_id=workspace_id,
other_id=payload.other_id,
memory_config_id=memory_config_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -111,6 +111,18 @@ def get_current_user_info(
break break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
# 设置权限:如果用户来自 SSO Source则使用该 Source 的 permissions否则返回 "all" 表示拥有所有权限
if current_user.external_source:
from premium.sso.models import SSOSource
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
if source and source.permissions:
result_schema.permissions = source.permissions
else:
result_schema.permissions = []
else:
result_schema.permissions = ["all"]
return success(data=result_schema, msg=t("users.info.get_success")) return success(data=result_schema, msg=t("users.info.get_success"))
@@ -135,7 +147,6 @@ def get_tenant_superusers(
return success(data=superusers_schema, msg=t("users.list.superusers_success")) return success(data=superusers_schema, msg=t("users.list.superusers_success"))
@router.get("/{user_id}", response_model=ApiResponse) @router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id( def get_user_info_by_id(
user_id: uuid.UUID, user_id: uuid.UUID,

View File

@@ -11,18 +11,14 @@ LangChain Agent 封装
import time import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
from app.core.logging_config import get_business_logger
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType
logger = get_business_logger() logger = get_business_logger()
@@ -226,10 +222,9 @@ class LangChainAgent:
Returns: Returns:
List[BaseMessage]: 消息列表 List[BaseMessage]: 消息列表
""" """
messages = [] messages:list = [SystemMessage(content=self.system_prompt)]
# 添加系统提示词 # 添加系统提示词
messages.append(SystemMessage(content=self.system_prompt))
# 添加历史消息 # 添加历史消息
if history: if history:
@@ -320,12 +315,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""执行对话 """执行对话
@@ -333,32 +323,12 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
context: 上下文信息(如知识库检索结果) context: 上下文信息(如知识库检索结果)
files: 多模态文件
Returns: Returns:
Dict: 包含 content 和元数据的字典 Dict: 包含 content 和元数据的字典
""" """
message_chat = message
start_time = time.time() start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -445,9 +415,6 @@ class LangChainAgent:
logger.info(f"最终提取的内容长度: {len(content)}") logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
response = { response = {
"content": content, "content": content,
"model": self.model_name, "model": self.model_name,
@@ -478,12 +445,7 @@ class LangChainAgent:
message: str, message: str,
history: Optional[List[Dict[str, str]]] = None, history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None, context: Optional[str] = None,
end_user_id: Optional[str] = None, files: Optional[List[Dict[str, Any]]] = None
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
) -> AsyncGenerator[str | int, None]: ) -> AsyncGenerator[str | int, None]:
"""执行流式对话 """执行流式对话
@@ -491,6 +453,7 @@ class LangChainAgent:
message: 用户消息 message: 用户消息
history: 历史消息列表 history: 历史消息列表
context: 上下文信息 context: 上下文信息
files: 多模态文件
Yields: Yields:
str: 消息内容块 str: 消息内容块
@@ -501,23 +464,6 @@ class LangChainAgent:
logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Has tools: {bool(self.tools)}")
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
logger.info("=" * 80) logger.info("=" * 80)
message_chat = message
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
if actual_config_id is None and end_user_id:
try:
db = next(get_db())
try:
connected_config = get_end_user_connected_config(end_user_id, db)
actual_config_id = connected_config.get("memory_config_id")
except Exception as e:
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
finally:
db.close()
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
try: try:
# 准备消息列表(支持多模态) # 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files) messages = self._prepare_messages(message, history, context, files)
@@ -527,17 +473,18 @@ class LangChainAgent:
) )
chunk_count = 0 chunk_count = 0
yielded_content = False
# 统一使用 agent 的 astream_events 实现流式输出 # 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出") logger.debug("使用 Agent astream_events 实现流式输出")
full_content = '' full_content = ''
try: try:
last_event = {}
async for event in self.agent.astream_events( async for event in self.agent.astream_events(
{"messages": messages}, {"messages": messages},
version="v2", version="v2",
config={"recursion_limit": self.max_iterations} config={"recursion_limit": self.max_iterations}
): ):
last_event = event
chunk_count += 1 chunk_count += 1
kind = event.get("event") kind = event.get("event")
@@ -551,7 +498,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -562,18 +508,15 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif kind == "on_llm_stream": elif kind == "on_llm_stream":
# 另一种 LLM 流式事件 # 另一种 LLM 流式事件
@@ -584,7 +527,6 @@ class LangChainAgent:
if isinstance(chunk_content, str) and chunk_content: if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content full_content += chunk_content
yield chunk_content yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list): elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分 # 多模态响应:提取文本部分
for item in chunk_content: for item in chunk_content:
@@ -595,22 +537,18 @@ class LangChainAgent:
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."} # OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text": elif item.get("type") == "text":
text = item.get("text", "") text = item.get("text", "")
if text: if text:
full_content += text full_content += text
yield text yield text
yielded_content = True
elif isinstance(item, str): elif isinstance(item, str):
full_content += item full_content += item
yield item yield item
yielded_content = True
elif isinstance(chunk, str): elif isinstance(chunk, str):
full_content += chunk full_content += chunk
yield chunk yield chunk
yielded_content = True
# 记录工具调用(可选) # 记录工具调用(可选)
elif kind == "on_tool_start": elif kind == "on_tool_start":
@@ -620,17 +558,14 @@ class LangChainAgent:
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗 # 统计token消耗
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
stream_total_tokens = self._extract_tokens_from_message(msg) stream_total_tokens = self._extract_tokens_from_message(msg)
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}") logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
yield stream_total_tokens yield stream_total_tokens
break break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
except Exception as e: except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise raise

View File

@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write( async def write(
storage_type, storage_type,
end_user_id, end_user_id,
@@ -118,7 +98,7 @@ async def write(
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): async def term_memory_save(end_user_id, strategy_type, scope):
""" """
Save long-term memory data to database Save long-term memory data to database
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
to long-term memory storage. to long-term memory storage.
Args: Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing scope: Scope/window size for memory processing
""" """
with get_db_context() as db_session: with get_db_context() as db_session:
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
from app.core.memory.agent.utils.redis_tool import write_store from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id) result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: if not result:
logger.warning(f"No write data found for user {end_user_id}")
return
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
data = await format_parsing(result, "dict") data = await format_parsing(result, "dict")
chunk_data = data[:scope] chunk_data = data[:scope]
if len(chunk_data) == scope: if len(chunk_data) == scope:
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
logger.info(f'写入短长期:') logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
""" """
Process dialogue based on window size and write to Neo4j Process dialogue based on window size and write to Neo4j
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
langchain_messages: Original message data list langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage scope: Window size determining when to trigger long-term storage
""" """
scope = scope is_end_user_has_history = count_store.get_sessions_count(end_user_id)
is_end_user_id = count_store.get_sessions_count(end_user_id) if is_end_user_has_history:
if is_end_user_id is not False: end_user_visit_count, redis_messages = is_end_user_has_history
is_end_user_id = count_store.get_sessions_count(end_user_id)[0] else:
redis_messages = count_store.get_sessions_count(end_user_id)[1] count_store.save_sessions_count(end_user_id, 1, langchain_messages)
if is_end_user_id and int(is_end_user_id) != int(scope): return
is_end_user_id += 1 end_user_visit_count += 1
langchain_messages += redis_messages if end_user_visit_count < scope:
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) redis_messages.extend(langchain_messages)
elif int(is_end_user_id) == int(scope): count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
else:
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages redis_messages.extend(langchain_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id
else: else:
config_id = memory_config config_id = memory_config
await write( write_message_task.delay(
AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, # end_user_id: User ID
end_user_id, redis_messages, # message: JSON string format message list
"", config_id, # config_id: Configuration ID string
"", AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
None, "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
end_user_id,
config_id,
formatted_messages
) )
count_store.update_sessions_count(end_user_id, 1, langchain_messages) count_store.update_sessions_count(end_user_id, 0, [])
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time): async def memory_long_term_storage(end_user_id, memory_config, time):
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
return result_dict return result_dict
except Exception as e: except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}") logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
import traceback
traceback.print_exc()
return { return {
"is_same_event": False, "is_same_event": False,

View File

@@ -1,49 +1,25 @@
import asyncio
import json
import sys
import warnings import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
from app.db import get_db_context
from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
from app.services.memory_konwledges_server import write_rag
warnings.filterwarnings("ignore", category=RuntimeWarning) warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
async def long_term_storage(
@asynccontextmanager long_term_type: str,
async def make_write_graph(): langchain_messages: list,
""" memory_config_id: str,
Create a write graph workflow for memory operations. end_user_id: str,
scope: int = 6
Args: ):
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
workflow = StateGraph(WriteState)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
""" """
Handle long-term memory storage with different strategies Handle long-term memory storage with different strategies
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
Args: Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store langchain_messages: List of messages to store
memory_config: Memory configuration identifier memory_config_id: Memory configuration identifier
end_user_id: User group identifier end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6) scope: Scope parameter for chunk-based storage (default: 6)
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ if langchain_messages is None:
aggregate_judgment langchain_messages = []
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, langchain_messages) write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话 # 获取数据库会话
with get_db_context() as db_session: with get_db_context() as db_session:
config_service = MemoryConfigService(db_session) config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数 config_id=memory_config_id, # 改为整数
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation''' # Dialogue window with 6 rounds of conversation
await window_dialogue(end_user_id, langchain_messages, memory_config, scope) await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy""" # Time-based strategy
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment""" # Aggregate judgment
await aggregate_judgment(end_user_id, langchain_messages, memory_config) await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): async def write_long_term(
storage_type: str,
end_user_id: str,
messages: list[dict],
user_rag_memory_id: str,
actual_config_id: str
):
""" """
Write long-term memory with different storage types Write long-term memory with different storage types
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
Args: Args:
storage_type: Type of storage (RAG or traditional) storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier end_user_id: User group identifier
message_chat: User message content messages: message list
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID actual_config_id: Actual configuration ID
""" """
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG: if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) message_content = []
for message in messages:
message_content.append(f'{message.get("role")}:{message.get("content")}')
messages_string = "\n".join(message_content)
await write_rag(end_user_id, messages_string, user_rag_memory_id)
else: else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once) # AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages) await long_term_storage(long_term_type=CHUNK,
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, langchain_messages=messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) memory_config_id=actual_config_id,
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) end_user_id=end_user_id,
scope=SCOPE)
# async def main(): await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -3,8 +3,9 @@ import uuid
from app.core.config import settings from app.core.config import settings
from typing import List, Dict, Any, Optional, Union from typing import List, Dict, Any, Optional, Union
from app.core.logging_config import get_logger
from app.core.memory.agent.utils.redis_base import ( from app.core.memory.agent.utils.redis_base import (
serialize_messages, serialize_messages,
deserialize_messages, deserialize_messages,
fix_encoding, fix_encoding,
format_session_data, format_session_data,
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
get_current_timestamp get_current_timestamp
) )
logger = get_logger(__name__)
class RedisWriteStore: class RedisWriteStore:
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" """Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -66,10 +67,10 @@ class RedisWriteStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session_write] 保存会话失败: {e}") logger.error(f"[save_session_write] 保存会话失败: {e}")
raise e raise e
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
@@ -99,7 +100,7 @@ class RedisWriteStore:
for key, data in zip(keys, all_data): for key, data in zip(keys, all_data):
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid: if data.get('sessionid') == userid:
# 从 key 中提取 session_id: session:write:{session_id} # 从 key 中提取 session_id: session:write:{session_id}
@@ -108,16 +109,16 @@ class RedisWriteStore:
"sessionid": session_id, "sessionid": session_id,
"messages": fix_encoding(data.get('messages', '')) "messages": fix_encoding(data.get('messages', ''))
}) })
if not results: if not results:
return False return False
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_session_by_userid] 查询失败: {e}") logger.error(f"[get_session_by_userid] 查询失败: {e}")
return False return False
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
""" """
通过 end_user_id 获取所有 write 类型的会话数据 通过 end_user_id 获取所有 write 类型的会话数据
@@ -144,7 +145,7 @@ class RedisWriteStore:
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
return False return False
# 批量获取数据 # 批量获取数据
@@ -158,12 +159,12 @@ class RedisWriteStore:
for key, data in zip(keys, all_data): for key, data in zip(keys, all_data):
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == end_user_id: if data.get('sessionid') == end_user_id:
# 从 key 中提取 session_id: session:write:{session_id} # 从 key 中提取 session_id: session:write:{session_id}
session_id = key.split(':')[-1] session_id = key.split(':')[-1]
# 构建完整的会话信息 # 构建完整的会话信息
session_info = { session_info = {
"session_id": session_id, "session_id": session_id,
@@ -173,23 +174,21 @@ class RedisWriteStore:
"starttime": data.get('starttime', '') "starttime": data.get('starttime', '')
} }
results.append(session_info) results.append(session_info)
if not results: if not results:
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
return False return False
# 按时间排序(最新的在前) # 按时间排序(最新的在前)
results.sort(key=lambda x: x.get('starttime', ''), reverse=True) results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
return results return results
except Exception as e: except Exception as e:
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
import traceback
traceback.print_exc()
return False return False
def find_user_recent_sessions(self, userid: str, def find_user_recent_sessions(self, userid: str,
minutes: int = 5) -> List[Dict[str, str]]: minutes: int = 5) -> List[Dict[str, str]]:
""" """
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
@@ -203,11 +202,11 @@ class RedisWriteStore:
""" """
import time import time
start_time = time.time() start_time = time.time()
# 只查询 write 类型的 key # 只查询 write 类型的 key
keys = self.r.keys('session:write:*') keys = self.r.keys('session:write:*')
if not keys: if not keys:
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -221,7 +220,7 @@ class RedisWriteStore:
for data in all_data: for data in all_data:
if not data: if not data:
continue continue
# 从 write 类型读取,匹配 sessionid 字段 # 从 write 类型读取,匹配 sessionid 字段
if data.get('sessionid') == userid and data.get('starttime'): if data.get('sessionid') == userid and data.get('starttime'):
# write 类型没有 aimessages所以 Answer 为空 # write 类型没有 aimessages所以 Answer 为空
@@ -230,15 +229,14 @@ class RedisWriteStore:
"Answer": "", "Answer": "",
"starttime": data.get('starttime', '') "starttime": data.get('starttime', '')
}) })
# 根据时间范围过滤 # 根据时间范围过滤
filtered_items = filter_by_time_range(matched_items, minutes) filtered_items = filter_by_time_range(matched_items, minutes)
# 排序并移除时间字段 # 排序并移除时间字段
result_items = sort_and_limit_results(filtered_items, limit=None) result_items = sort_and_limit_results(filtered_items)
print(result_items)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
@@ -258,7 +256,7 @@ class RedisWriteStore:
class RedisCountStore: class RedisCountStore:
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据""" """Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -278,7 +276,7 @@ class RedisCountStore:
decode_responses=True, decode_responses=True,
encoding='utf-8' encoding='utf-8'
) )
self.uudi = session_id self.uuid = session_id
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
""" """
@@ -295,26 +293,26 @@ class RedisCountStore:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
index_key = f'session:count:index:{end_user_id}' # 索引键 index_key = f'session:count:index:{end_user_id}' # 索引键
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, mapping={ pipe.hset(key, mapping={
"id": self.uudi, "id": self.uuid,
"end_user_id": end_user_id, "end_user_id": end_user_id,
"count": int(count), "count": int(count),
"messages": serialize_messages(messages), "messages": serialize_messages(messages),
"starttime": get_current_timestamp() "starttime": get_current_timestamp()
}) })
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
# 创建索引end_user_id -> session_id 映射 # 创建索引end_user_id -> session_id 映射
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
result = pipe.execute() result = pipe.execute()
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
return session_id return session_id
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
""" """
通过 end_user_id 查询访问次数统计 通过 end_user_id 查询访问次数统计
@@ -327,7 +325,7 @@ class RedisCountStore:
try: try:
# 使用索引键快速查找 # 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}' index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误 # 检查索引键类型,避免 WRONGTYPE 错误
try: try:
key_type = self.r.type(index_key) key_type = self.r.type(index_key)
@@ -335,35 +333,40 @@ class RedisCountStore:
self.r.delete(index_key) self.r.delete(index_key)
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[get_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
if not session_id: if not session_id:
return False return False
# 直接获取数据 # 直接获取数据
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
data = self.r.hgetall(key) data = self.r.hgetall(key)
if not data: if not data:
# 索引存在但数据不存在,清理索引 # 索引存在但数据不存在,清理索引
self.r.delete(index_key) self.r.delete(index_key)
return False return False
count = data.get('count') count = data.get('count')
messages_str = data.get('messages') messages_str = data.get('messages')
if count is not None: if count is not None:
messages = deserialize_messages(messages_str) messages: list[dict] = deserialize_messages(messages_str)
return [int(count), messages] return int(count), messages
return False return False
except Exception as e: except Exception as e:
print(f"[get_sessions_count] 查询失败: {e}") logger.error(f"[get_sessions_count] 查询失败: {e}")
return False return False
def update_sessions_count(self, end_user_id: str, new_count: int,
messages: Any) -> bool: def update_sessions_count(
self,
end_user_id: str,
new_count: int,
messages: Any
) -> bool:
""" """
通过 end_user_id 修改访问次数统计(优化版:使用索引) 通过 end_user_id 修改访问次数统计(优化版:使用索引)
@@ -378,39 +381,39 @@ class RedisCountStore:
try: try:
# 使用索引键快速查找 # 使用索引键快速查找
index_key = f'session:count:index:{end_user_id}' index_key = f'session:count:index:{end_user_id}'
# 检查索引键类型,避免 WRONGTYPE 错误 # 检查索引键类型,避免 WRONGTYPE 错误
try: try:
key_type = self.r.type(index_key) key_type = self.r.type(index_key)
if key_type != 'string' and key_type != 'none': if key_type != 'string' and key_type != 'none':
# 索引键类型错误,删除并返回 False # 索引键类型错误,删除并返回 False
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
self.r.delete(index_key) self.r.delete(index_key)
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
except Exception as type_error: except Exception as type_error:
print(f"[update_sessions_count] 检查键类型失败: {type_error}") logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
session_id = self.r.get(index_key) session_id = self.r.get(index_key)
if not session_id: if not session_id:
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
return False return False
# 直接更新数据 # 直接更新数据
key = generate_session_key(session_id, key_type="count") key = generate_session_key(session_id, key_type="count")
messages_str = serialize_messages(messages) messages_str = serialize_messages(messages)
pipe = self.r.pipeline() pipe = self.r.pipeline()
pipe.hset(key, 'count', int(new_count)) pipe.hset(key, 'count', str(new_count))
pipe.hset(key, 'messages', messages_str) pipe.hset(key, 'messages', messages_str)
result = pipe.execute() result = pipe.execute()
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
return True return True
except Exception as e: except Exception as e:
print(f"[update_sessions_count] 更新失败: {e}") logger.debug(f"[update_sessions_count] 更新失败: {e}")
return False return False
def delete_all_count_sessions(self) -> int: def delete_all_count_sessions(self) -> int:
@@ -428,7 +431,7 @@ class RedisCountStore:
class RedisSessionStore: class RedisSessionStore:
"""Redis 会话存储类,用于管理会话数据""" """Redis 会话存储类,用于管理会话数据"""
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
""" """
初始化 Redis 连接 初始化 Redis 连接
@@ -451,9 +454,9 @@ class RedisSessionStore:
self.uudi = session_id self.uudi = session_id
# ==================== 写入操作 ==================== # ==================== 写入操作 ====================
def save_session(self, userid: str, messages: str, aimessages: str, def save_session(self, userid: str, messages: str, aimessages: str,
apply_id: str, end_user_id: str) -> str: apply_id: str, end_user_id: str) -> str:
""" """
写入一条会话数据,返回 session_id 写入一条会话数据,返回 session_id
@@ -483,14 +486,14 @@ class RedisSessionStore:
}) })
result = pipe.execute() result = pipe.execute()
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
return session_id return session_id
except Exception as e: except Exception as e:
print(f"[save_session] 保存会话失败: {e}") logger.error(f"[save_session] 保存会话失败: {e}")
raise e raise e
# ==================== 读取操作 ==================== # ==================== 读取操作 ====================
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
""" """
读取一条会话数据 读取一条会话数据
@@ -520,8 +523,8 @@ class RedisSessionStore:
sessions[sid] = self.get_session(sid) sessions[sid] = self.get_session(sid)
return sessions return sessions
def find_user_apply_group(self, sessionid: str, apply_id: str, def find_user_apply_group(self, sessionid: str, apply_id: str,
end_user_id: str) -> List[Dict[str, str]]: end_user_id: str) -> List[Dict[str, str]]:
""" """
根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条 根据 sessionid、apply_id 和 end_user_id 查询会话数据返回最新的6条
@@ -535,10 +538,10 @@ class RedisSessionStore:
""" """
import time import time
start_time = time.time() start_time = time.time()
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
return [] return []
# 批量获取数据 # 批量获取数据
@@ -556,21 +559,21 @@ class RedisSessionStore:
continue continue
if (data.get('apply_id') == apply_id and if (data.get('apply_id') == apply_id and
data.get('end_user_id') == end_user_id): data.get('end_user_id') == end_user_id):
# 支持模糊匹配或完全匹配 sessionid # 支持模糊匹配或完全匹配 sessionid
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
matched_items.append(format_session_data(data, include_time=True)) matched_items.append(format_session_data(data, include_time=True))
# 排序、限制数量并移除时间字段 # 排序、限制数量并移除时间字段
result_items = sort_and_limit_results(matched_items, limit=6) result_items = sort_and_limit_results(matched_items, limit=6)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
return result_items return result_items
# ==================== 更新操作 ==================== # ==================== 更新操作 ====================
def update_session(self, session_id: str, field: str, value: Any) -> bool: def update_session(self, session_id: str, field: str, value: Any) -> bool:
""" """
更新单个字段 更新单个字段
@@ -591,7 +594,7 @@ class RedisSessionStore:
return bool(results[0]) return bool(results[0])
# ==================== 删除操作 ==================== # ==================== 删除操作 ====================
def delete_session(self, session_id: str) -> int: def delete_session(self, session_id: str) -> int:
""" """
删除单条会话 删除单条会话
@@ -632,7 +635,7 @@ class RedisSessionStore:
keys = self.r.keys('session:*') keys = self.r.keys('session:*')
if not keys: if not keys:
print("[delete_duplicate_sessions] 没有会话数据") logger.debug("[delete_duplicate_sessions] 没有会话数据")
return 0 return 0
# 批量获取所有数据 # 批量获取所有数据
@@ -678,7 +681,7 @@ class RedisSessionStore:
deleted_count += len(batch) deleted_count += len(batch)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}") logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}")
return deleted_count return deleted_count

View File

@@ -151,11 +151,6 @@ async def write(
# Step 3: Save all data to Neo4j database # Step 3: Save all data to Neo4j database
step_start = time.time() step_start = time.time()
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制 # 添加死锁重试机制
max_retries = 3 max_retries = 3
@@ -279,5 +274,21 @@ async def write(
except Exception as cache_err: except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client):
try:
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
if underlying is None:
continue
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
inner = getattr(underlying, '_model', underlying)
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
http_client = getattr(inner, 'async_client', None)
if http_client is not None and hasattr(http_client, 'aclose'):
await http_client.aclose()
except Exception:
pass
logger.info("=== Pipeline Complete ===") logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds") logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -56,7 +56,7 @@ class LLMClient(ABC):
self.max_retries = self.config.max_retries self.max_retries = self.config.max_retries
self.timeout = self.config.timeout self.timeout = self.config.timeout
logger.info( logger.debug(
f"初始化 LLM 客户端: provider={self.provider}, " f"初始化 LLM 客户端: provider={self.provider}, "
f"model={self.model_name}, max_retries={self.max_retries}" f"model={self.model_name}, max_retries={self.max_retries}"
) )

View File

@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
type=type_ type=type_
) )
logger.info(f"OpenAI 客户端初始化完成: type={type_}") logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
""" """

View File

@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def message_has_files(message: "ConversationMessage") -> bool:
"""检查消息是否包含文件。
Args:
message: 待检查的消息对象
Returns:
bool: 如果消息包含文件则返回 True否则返回 False
"""
return message.files and len(message.files) > 0
class DialogExtractionResponse(BaseModel): class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。 """对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -128,7 +140,7 @@ class SemanticPruner:
1. 空消息 1. 空消息
2. 场景特定填充词库精确匹配 2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配 3. 常见寒暄精确匹配
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢""同学你好""明白了" 4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢""同学你好""明白了"
5. 纯表情/标点 5. 纯表情/标点
""" """
t = message.msg.strip() t = message.msg.strip()
@@ -482,6 +494,11 @@ class SemanticPruner:
""" """
to_delete_ids: set = set() to_delete_ids: set = set()
for m in msgs: for m in msgs:
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
continue
# 填充检测优先:先判断是否为填充,再看 LLM 保护 # 填充检测优先:先判断是否为填充,再看 LLM 保护
if self._is_filler_message(m): if self._is_filler_message(m):
to_delete_ids.add(id(m)) to_delete_ids.add(id(m))
@@ -549,6 +566,11 @@ class SemanticPruner:
to_delete_ids: set = set() to_delete_ids: set = set()
for m in msgs: for m in msgs:
msg_text = m.msg.strip() msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
continue
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断 # 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
if self._is_filler_message(m): if self._is_filler_message(m):
@@ -801,6 +823,12 @@ class SemanticPruner:
for idx, m in enumerate(msgs): for idx, m in enumerate(msgs):
msg_text = m.msg.strip() msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与分类
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
llm_protected_msgs.append((idx, m)) # 放入保护列表
continue
if self._msg_matches_tokens(m, preserve_tokens): if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m)) llm_protected_msgs.append((idx, m))

View File

@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
list[StatementEntityEdge], list[StatementEntityEdge],
list[EntityEntityEdge], list[EntityEntityEdge],
list[PerceptualEdge], list[PerceptualEdge],
dict list[DialogData]
]: ]:
""" """
运行完整的知识提取流水线(优化版:并行执行) 运行完整的知识提取流水线(优化版:并行执行)
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
statement_entity_edges, statement_entity_edges,
entity_entity_edges, entity_entity_edges,
dialog_data_list, dialog_data_list,
dedup_details,
) = await self._run_dedup_and_write_summary( ) = await self._run_dedup_and_write_summary(
dialogue_nodes, dialogue_nodes,
chunk_nodes, chunk_nodes,
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
dialog_data_list, dialog_data_list,
) )
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
if not is_pilot_run:
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
logger.info(f"知识提取流水线运行完成({mode_str}") logger.info(f"知识提取流水线运行完成({mode_str}")
return ( return (
dialogue_nodes, dialogue_nodes,
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}") logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
else: else:
first_alias = current_aliases[0].strip() if current_aliases else "" first_alias = current_aliases[0].strip() if current_aliases else ""
if first_alias: # 确保 first_alias 不是占位名称
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
db.add(EndUserInfo( db.add(EndUserInfo(
end_user_id=end_user_uuid, end_user_id=end_user_uuid,
other_name=first_alias, other_name=first_alias,
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
USER_PLACEHOLDER_NAMES = {'用户', '', 'User', 'I'}
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]: def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序) """从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
这个方法直接返回 LLM 提取的别名列表,不做任何修改 这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户""""User""I"
第一个别名将被用作 other_name。 第一个别名将被用作 other_name。
Args: Args:
entity_nodes: 实体节点列表 entity_nodes: 实体节点列表
Returns: Returns:
别名列表(保持 LLM 提取的原始顺序) 别名列表(保持 LLM 提取的原始顺序,已过滤占位名称
""" """
USER_NAMES = {'用户', '', 'User', 'I'}
for entity in entity_nodes: for entity in entity_nodes:
if getattr(entity, 'name', '').strip() in USER_NAMES: if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
aliases = getattr(entity, 'aliases', []) or [] aliases = getattr(entity, 'aliases', []) or []
logger.debug(f"提取到用户别名(原始顺序): {aliases}") # 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
return aliases filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
return filtered
return [] return []
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]: async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
"""从 Neo4j 查询用户实体的完整 aliases 列表""" """从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
cypher = """ cypher = """
MATCH (e:ExtractedEntity) MATCH (e:ExtractedEntity)
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I'] WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '', 'User', 'I']
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
aliases = result[0].get('aliases') or [] aliases = result[0].get('aliases') or []
if not aliases: if not aliases:
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}") logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
return aliases return []
# 过滤掉占位名称,防止历史脏数据传播
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
return filtered
def _resolve_other_name( def _resolve_other_name(
self, self,
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
决定 other_name 是否需要更新,返回新值;无需更新返回 None。 决定 other_name 是否需要更新,返回新值;无需更新返回 None。
决策规则: 决策规则:
- 为空 → 用本次对话第一个别名 - 为空或为占位名称 → 用本次对话第一个别名
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除) - 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
- 否则 → 保持不变(返回 None - 否则 → 保持不变(返回 None
注意:返回值不允许是占位名称("用户""""User""I"
""" """
if not current or not current.strip(): # 当前值为空或为占位名称时,需要更新
return current_aliases[0].strip() if current_aliases else None if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
candidate = current_aliases[0].strip() if current_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
if current not in neo4j_aliases: if current not in neo4j_aliases:
return neo4j_aliases[0].strip() if neo4j_aliases else None candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
# 确保候选值不是占位名称
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
return None
return candidate
return None return None
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
list[StatementChunkEdge], list[StatementChunkEdge],
list[StatementEntityEdge], list[StatementEntityEdge],
list[EntityEntityEdge], list[EntityEntityEdge],
list[DialogData],
dict dict
]: ]:
""" """
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
statement_chunk_edges, statement_chunk_edges,
dedup_statement_entity_edges, dedup_statement_entity_edges,
dedup_entity_entity_edges, dedup_entity_entity_edges,
dialog_data_list,
dedup_details,
) )
final_entity_nodes = dedup_entity_nodes final_entity_nodes = dedup_entity_nodes
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
final_entity_entity_edges = dedup_entity_entity_edges final_entity_entity_edges = dedup_entity_entity_edges
else: else:
# 正式模式:执行完整的两阶段去重 # 正式模式:执行完整的两阶段去重
result_tuple = await dedup_layers_and_merge_and_return( (
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = await dedup_layers_and_merge_and_return(
dialogue_nodes, dialogue_nodes,
chunk_nodes, chunk_nodes,
statement_nodes, statement_nodes,
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
llm_client=self.llm_client, llm_client=self.llm_client,
) )
# 解包返回值
(
_,
_,
_,
final_entity_nodes,
_,
final_statement_entity_edges,
final_entity_entity_edges,
dedup_details,
) = result_tuple
# 保存去重消歧的详细记录到实例变量 # 保存去重消歧的详细记录到实例变量
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes) self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
result_tuple = (
dialogue_nodes,
chunk_nodes,
statement_nodes,
final_entity_nodes,
statement_chunk_edges,
final_statement_entity_edges,
final_entity_entity_edges,
dialog_data_list,
dedup_details,
)
logger.info( logger.info(
f"去重后: {len(final_entity_nodes)} 个实体节点, " f"去重后: {len(final_entity_nodes)} 个实体节点, "
f"{len(final_statement_entity_edges)} 条陈述句-实体边, " f"{len(final_statement_entity_edges)} 条陈述句-实体边, "

View File

@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
{% if language == "zh" %} {% if language == "zh" %}
- 用户实体的 name 字段:使用 "用户" 或 "我" - 用户实体的 name 字段:使用 "用户" 或 "我"
- 用户的真实姓名:放入 aliases - 用户的真实姓名:放入 aliases
- **🚨 禁止将 "用户"、"我" 放入 aliases 中aliases 只能包含用户的真实姓名、昵称等**
- 示例: - 示例:
* "我叫李明" → name="用户", aliases=["李明"] * "我叫李明" → name="用户", aliases=["李明"]
* ❌ 错误aliases=["用户", "李明"]"用户"不是真实姓名,禁止放入 aliases
* ❌ 错误aliases=["我", "李明"]"我"不是真实姓名,禁止放入 aliases
{% else %} {% else %}
- User entity name field: use "User" or "I" - User entity name field: use "User" or "I"
- User's real name: put in aliases - User's real name: put in aliases
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
- Examples: - Examples:
* "I'm John" → name="User", aliases=["John"] * "I'm John" → name="User", aliases=["John"]
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
{% endif %} {% endif %}

View File

@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
access_key_id: str, access_key_id: str,
access_key_secret: str, access_key_secret: str,
bucket_name: str, bucket_name: str,
connect_timeout: int = 30,
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
): ):
""" """
Initialize the OSSStorage backend. Initialize the OSSStorage backend.
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
access_key_id: The Aliyun access key ID. access_key_id: The Aliyun access key ID.
access_key_secret: The Aliyun access key secret. access_key_secret: The Aliyun access key secret.
bucket_name: The name of the OSS bucket. bucket_name: The name of the OSS bucket.
connect_timeout: Connection timeout in seconds (default: 30).
multipart_threshold: File size threshold for multipart upload (default: 10MB).
Raises: Raises:
StorageConfigError: If any required configuration is missing. StorageConfigError: If any required configuration is missing.
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
self.endpoint = endpoint self.endpoint = endpoint
self.bucket_name = bucket_name self.bucket_name = bucket_name
self.multipart_threshold = multipart_threshold
try: try:
auth = oss2.Auth(access_key_id, access_key_secret) auth = oss2.Auth(access_key_id, access_key_secret)
self.bucket = oss2.Bucket(auth, endpoint, bucket_name) # 设置超时和重试
self.bucket = oss2.Bucket(
auth,
endpoint,
bucket_name,
connect_timeout=connect_timeout
)
logger.info( logger.info(
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}" f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
) )
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
self.bucket.put_object(file_key, content, headers=headers if headers else None) # 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB per part
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers if headers else None)
logger.info(f"File uploaded to OSS successfully: {file_key}") logger.info(f"File uploaded to OSS successfully: {file_key}")
return file_key return file_key
except OssError as e: except OssError as e:
logger.error(f"OSS error uploading file {file_key}: {e}") logger.error(f"OSS error uploading file {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to upload file to OSS: {e.message}", message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to upload file to OSS {file_key}: {e}") logger.error(f"Failed to upload file to OSS {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to upload file to OSS: {e}", message=f"Failed to upload file to OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
) -> int: ) -> int:
"""Upload from async stream to OSS. Returns total bytes written.""" """Upload from async stream to OSS. Returns total bytes written."""
buf = io.BytesIO() buf = io.BytesIO()
headers = {"Content-Type": content_type} if content_type else None
upload_id = None
try: try:
# 收集流数据
total_size = 0
async for chunk in stream: async for chunk in stream:
if not chunk:
continue
buf.write(chunk) buf.write(chunk)
total_size += len(chunk)
content = buf.getvalue() content = buf.getvalue()
headers = {"Content-Type": content_type} if content_type else None
self.bucket.put_object(file_key, content, headers=headers) if not content:
logger.info(f"File stream uploaded to OSS successfully: {file_key}") raise StorageUploadError(
return len(content) message="Empty stream content",
file_key=file_key,
)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
return total_size
except OssError as e: except OssError as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"OSS error stream uploading file {file_key}: {e}") logger.error(f"OSS error stream uploading file {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e.message}", message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}") logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {e}", message=f"Failed to stream upload file to OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
finally:
buf.close()
async def download(self, file_key: str) -> bytes: async def download(self, file_key: str) -> bytes:
""" """
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
except OssError as e: except OssError as e:
logger.error(f"OSS error downloading file {file_key}: {e}") logger.error(f"OSS error downloading file {file_key}: {e}")
raise StorageDownloadError( raise StorageDownloadError(
message=f"Failed to download file from OSS: {e.message}", message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to download file from OSS {file_key}: {e}") logger.error(f"Failed to download file from OSS {file_key}: {e}")
raise StorageDownloadError( raise StorageDownloadError(
message=f"Failed to download file from OSS: {e}", message=f"Failed to download file from OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
except OssError as e: except OssError as e:
logger.error(f"OSS error deleting file {file_key}: {e}") logger.error(f"OSS error deleting file {file_key}: {e}")
raise StorageDeleteError( raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e.message}", message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete file from OSS {file_key}: {e}") logger.error(f"Failed to delete file from OSS {file_key}: {e}")
raise StorageDeleteError( raise StorageDeleteError(
message=f"Failed to delete file from OSS: {e}", message=f"Failed to delete file from OSS: {str(e)}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )

View File

@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
def merge_activate_state(x, y): def merge_activate_state(x, y):
return { merged = dict(x)
k: x.get(k, False) or y.get(k, False) for k, v in y.items():
for k in set(x) | set(y) merged[k] = merged.get(k, False) or v
} return merged
def merge_looping_state(x, y): def merge_looping_state(x, y):

View File

@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
class LazyVariableDict:
def __init__(self, source, literal):
self._source: dict[str, VariableStruct[Any]] = source
self._literal: bool = literal
self._cache = {}
def keys(self):
return self._source.keys()
def _resolve(self, key):
if key in self._cache:
return self._cache[key]
var_struct = self._source.get(key)
if var_struct is None:
raise KeyError(key)
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
self._cache[key] = value
return value
def get(self, key, default=None):
try:
return self._resolve(key)
except KeyError:
return default
def __getitem__(self, key):
return self._resolve(key)
def __getattr__(self, key):
if key.startswith('_'):
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
return self._resolve(key)
def __contains__(self, key):
return key in self._source
def __iter__(self):
return iter(self._source)
def __len__(self):
return len(self._source)
class VariableSelector: class VariableSelector:
"""变量选择器 """变量选择器
@@ -117,8 +162,7 @@ class VariablePool:
@staticmethod @staticmethod
def transform_selector(selector): def transform_selector(selector):
pattern = r"\{\{\s*(.*?)\s*\}\}" variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
variable_literal = re.sub(pattern, r"\1", selector).strip()
selector = VariableSelector.from_string(variable_literal).path selector = VariableSelector.from_string(variable_literal).path
if len(selector) != 2: if len(selector) != 2:
raise ValueError(f"Selector not valid - {selector}") raise ValueError(f"Selector not valid - {selector}")
@@ -303,6 +347,16 @@ class VariablePool:
""" """
return self._get_variable_struct(selector) is not None return self._get_variable_struct(selector) is not None
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
return LazyVariableDict(self.variables.get(namespace, {}), literal)
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
return {
ns: LazyVariableDict(vars_dict, literal)
for ns, vars_dict in self.variables.items()
if ns not in ("sys", "conv")
}
def get_all_system_vars(self, literal=False) -> dict[str, Any]: def get_all_system_vars(self, literal=False) -> dict[str, Any]:
"""获取所有系统变量 """获取所有系统变量
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
var_type=var_type, var_type=var_type,
mut=False mut=False
) )

View File

@@ -552,9 +552,9 @@ class BaseNode(ABC):
return render_template( return render_template(
template=template, template=template,
conv_vars=variable_pool.get_all_conversation_vars(literal=True), conv_vars=variable_pool.lazy_namespace("conv", literal=True),
node_outputs=variable_pool.get_all_node_outputs(literal=True), node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
system_vars=variable_pool.get_all_system_vars(literal=True), system_vars=variable_pool.lazy_namespace("sys", literal=True),
strict=strict strict=strict
) )
@@ -579,9 +579,9 @@ class BaseNode(ABC):
return evaluate_condition( return evaluate_condition(
expression=expression, expression=expression,
conv_var=variable_pool.get_all_conversation_vars(), conv_var=variable_pool.lazy_namespace("conv"),
node_outputs=variable_pool.get_all_node_outputs(), node_outputs=variable_pool.lazy_all_node_outputs(),
system_vars=variable_pool.get_all_system_vars() system_vars=variable_pool.lazy_namespace("sys")
) )
@staticmethod @staticmethod

View File

@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
from app.core.workflow.utils.expression_evaluator import evaluate_expression
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -85,12 +84,7 @@ class LoopRuntime:
for variable in self.typed_config.cycle_vars: for variable in self.typed_config.cycle_vars:
if variable.input_type == ValueInputType.VARIABLE: if variable.input_type == ValueInputType.VARIABLE:
value = evaluate_expression( value = self.variable_pool.get_value(variable.value)
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
else: else:
value = TypeTransformer.transform(variable.value, variable.type) value = TypeTransformer.transform(variable.value, variable.type)
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
@@ -98,12 +92,7 @@ class LoopRuntime:
**self.state **self.state
) )
loopstate["node_outputs"][self.node_id] = { loopstate["node_outputs"][self.node_id] = {
variable.name: evaluate_expression( variable.name: self.variable_pool.get_value(variable.value)
expression=variable.value,
conv_var=self.variable_pool.get_all_conversation_vars(),
node_outputs=self.variable_pool.get_all_node_outputs(),
system_vars=self.variable_pool.get_all_system_vars(),
)
if variable.input_type == ValueInputType.VARIABLE if variable.input_type == ValueInputType.VARIABLE
else TypeTransformer.transform(variable.value, variable.type) else TypeTransformer.transform(variable.value, variable.type)
for variable in self.typed_config.cycle_vars for variable in self.typed_config.cycle_vars

View File

@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
# Reuse cached bytes if already fetched # Reuse cached bytes if already fetched
if f.get_content(): if f.get_content():
file_input.set_content(f.get_content()) file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input) text = await svc.extract_document_text(file_input)
chunks.append(text) chunks.append(text)
except Exception as e: except Exception as e:
logger.error( logger.error(

View File

@@ -1,19 +1,23 @@
import asyncio
import logging import logging
import uuid import uuid
from typing import Any from typing import Any
from langchain_core.documents import Document
from app.core.error_codes import BizCode from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector from app.core.rag.models.chunk import DocumentChunk
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.db import get_db_read from app.db import get_db_read
from app.models import knowledge_model, knowledgeshare_model, ModelType from app.models import knowledge_model, ModelType
from app.repositories import knowledge_repository, knowledgeshare_repository from app.repositories import knowledge_repository
from app.schemas.chunk_schema import RetrieveType from app.schemas.chunk_schema import RetrieveType
from app.services.model_service import ModelConfigService from app.services.model_service import ModelConfigService
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config, down_stream_nodes)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return { return {
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
unique.append(doc) unique.append(doc)
return unique return unique
def _get_existing_kb_ids(self, db, kb_ids): def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
""" """
Resolve all accessible and valid knowledge base IDs for retrieval. Reorder the list of document blocks and return the top_k results most relevant to the query
This includes:
- Private knowledge bases owned by the user
- Shared knowledge bases
- Source knowledge bases mapped via knowledge sharing relationships
Args: Args:
db: Database session. query: query string
kb_ids (list[UUID]): Knowledge base IDs from node configuration. docs: List of document chunk to be rearranged
top_k: The number of top-level documents returned
Returns: Returns:
list[UUID]: Final list of valid knowledge base IDs. Rearranged document chunk list (sorted in descending order of relevance)
Raises:
ValueError: If the input document list is empty or top_k is invalid
""" """
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) reranker = self.get_reranker_model()
# parameter validation
existing_ids = knowledge_repository.get_chunked_knowledgeids( if not docs:
db=db, raise ValueError("retrieval chunks be empty")
filters=filters if top_k <= 0:
) raise ValueError("top_k must be a positive integer")
try:
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) # Convert to LangChain Document object
documents = [
share_ids = knowledge_repository.get_chunked_knowledgeids( Document(
db=db, page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
filters=filters metadata=doc.metadata or {} # Deal with possible None metadata
) )
for doc in docs
if share_ids:
filters = [
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
] ]
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
db=db, # Perform reordering (compress_documents will automatically handle relevance scores and indexing)
filters=filters reranked_docs = list(reranker.compress_documents(documents, query))
# Sort in descending order based on relevance score
reranked_docs.sort(
key=lambda x: x.metadata.get("relevance_score", 0),
reverse=True
) )
existing_ids.extend(items) # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
return existing_ids result = []
for item in reranked_docs[:top_k]:
for doc in docs:
if doc.page_content == item.page_content:
doc.metadata["score"] = item.metadata["relevance_score"]
result.append(doc)
return result
except Exception as e:
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
def get_reranker_model(self) -> RedBearRerank: def get_reranker_model(self) -> RedBearRerank:
""" """
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
) )
return reranker return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
rs = []
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
tasks = []
for child in children: for child in children:
if not (child and child.chunk_num > 0 and child.status == 1): if not (child and child.chunk_num > 0 and child.status == 1):
continue continue
kb_config.kb_id = child.id child_kb_config = kb_config.model_copy()
self.knowledge_retrieval(db, query, rs, child, kb_config) child_kb_config.kb_id = child.id
return tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
return rs
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower() indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type: match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE: case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.similarity_threshold)) vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
)
case RetrieveType.SEMANTIC: case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs.extend(
indices=indices, await asyncio.to_thread(
score_threshold=kb_config.vector_similarity_weight)) vector_service.search_by_vector, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.vector_similarity_weight
}
)
)
case RetrieveType.HYBRID: case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, rs1_task = asyncio.to_thread(
indices=indices, vector_service.search_by_vector, **{
score_threshold=kb_config.vector_similarity_weight) "query": query,
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, "top_k": kb_config.top_k,
indices=indices, "indices": indices,
score_threshold=kb_config.similarity_threshold) "score_threshold": kb_config.vector_similarity_weight
}
)
rs2_task = asyncio.to_thread(
vector_service.search_by_full_text, **{
"query": query,
"top_k": kb_config.top_k,
"indices": indices,
"score_threshold": kb_config.similarity_threshold
}
)
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
# Deduplicate hybrid retrieval results # Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2) unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs: if not unique_rs:
return return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() rs.extend(
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) await asyncio.to_thread(
self.rerank,
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
)
)
else: else:
rs.extend(sorted( rs.extend(sorted(
unique_rs, unique_rs,
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
)[:kb_config.top_k]) )[:kb_config.top_k])
case _: case _:
raise RuntimeError("Unknown retrieval type") raise RuntimeError("Unknown retrieval type")
return rs
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
""" """
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
knowledge_bases = self.typed_config.knowledge_bases knowledge_bases = self.typed_config.knowledge_bases
rs = [] rs = []
tasks = []
for kb_config in knowledge_bases: for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge: if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.") raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
if tasks:
result = await asyncio.gather(*tasks)
for _ in result:
rs.extend(_)
if not rs: if not rs:
return [] return []
if self.typed_config.reranker_id: if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model() final_rs = await asyncio.to_thread(
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) self.rerank,
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
)
else: else:
final_rs = sorted( final_rs = sorted(
rs, rs,

View File

@@ -4,32 +4,33 @@ from typing import Any
from simpleeval import simple_eval, NameNotDefined, InvalidExpression from simpleeval import simple_eval, NameNotDefined, InvalidExpression
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class ExpressionEvaluator: class ExpressionEvaluator:
"""Safe expression evaluator for workflow variables and node outputs.""" """Safe expression evaluator for workflow variables and node outputs."""
# Reserved namespaces # Reserved namespaces
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
@classmethod @classmethod
def normalize_template(cls, template: str) -> str: def normalize_template(cls, template: str) -> str:
pattern = re.compile( return _NORMALIZE_PATTERN.sub(
r"\{\{\s*(\d+)\.(\w+)\s*}}"
)
return pattern.sub(
r'{{ node["\1"].\2 }}', r'{{ node["\1"].\2 }}',
template template
) )
@classmethod @classmethod
def evaluate( def evaluate(
cls, cls,
expression: str, expression: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | None = None
) -> Any: ) -> Any:
""" """
Safely evaluate an expression using workflow variables. Safely evaluate an expression using workflow variables.
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
# Remove Jinja2-style brackets if present # Remove Jinja2-style brackets if present
expression = expression.strip() expression = expression.strip()
expression = cls.normalize_template(expression) expression = cls.normalize_template(expression)
pattern = r"\{\{\s*(.*?)\s*\}\}" expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
expression = re.sub(pattern, r"\1", expression).strip()
# Build context for evaluation # Build context for evaluation
context = { context = {
"conv": conv_vars, # conversation variables "conv": conv_vars, # conversation variables
"node": node_outputs, # node outputs "node": node_outputs, # node outputs
"sys": system_vars or {}, # system variables "sys": system_vars or {}, # system variables
} }
context.update(conv_vars) # context.update(conv_vars)
context["nodes"] = node_outputs # context["nodes"] = node_outputs
context.update(node_outputs) context.update(node_outputs)
try: try:
# simpleeval supports safe operations: # simpleeval supports safe operations:
# arithmetic, comparisons, logical ops, attribute/dict/list access # arithmetic, comparisons, logical ops, attribute/dict/list access
result = simple_eval(expression, names=context) result = simple_eval(expression, names=context)
return result return result
except NameNotDefined as e: except NameNotDefined as e:
logger.error(f"Undefined variable in expression: {expression}, error: {e}") logger.error(f"Undefined variable in expression: {expression}, error: {e}")
raise ValueError(f"Undefined variable: {e}") raise ValueError(f"Undefined variable: {e}")
except InvalidExpression as e: except InvalidExpression as e:
logger.error(f"Invalid expression syntax: {expression}, error: {e}") logger.error(f"Invalid expression syntax: {expression}, error: {e}")
raise ValueError(f"Invalid expression syntax: {e}") raise ValueError(f"Invalid expression syntax: {e}")
except SyntaxError as e: except SyntaxError as e:
logger.error(f"Syntax error in expression: {expression}, error: {e}") logger.error(f"Syntax error in expression: {expression}, error: {e}")
raise ValueError(f"Syntax error: {e}") raise ValueError(f"Syntax error: {e}")
except Exception as e: except Exception as e:
logger.error(f"Expression evaluation failed: {expression}, error: {e}") logger.error(f"Expression evaluation failed: {expression}, error: {e}")
raise ValueError(f"Expression evaluation failed: {e}") raise ValueError(f"Expression evaluation failed: {e}")
@staticmethod @staticmethod
def evaluate_bool( def evaluate_bool(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any],
node_outputs: dict[str, Any], node_outputs: dict[str, Any],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | None = None
) -> bool: ) -> bool:
""" """
Evaluate a boolean expression (for conditions). Evaluate a boolean expression (for conditions).
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
expression, conv_var, node_outputs, system_vars expression, conv_var, node_outputs, system_vars
) )
return bool(result) return bool(result)
@staticmethod @staticmethod
def validate_variable_names(variables: list[dict]) -> list[str]: def validate_variable_names(variables: list[dict]) -> list[str]:
""" """
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
list[str]: List of error messages. Empty if all names are valid. list[str]: List of error messages. Empty if all names are valid.
""" """
errors = [] errors = []
for var in variables: for var in variables:
var_name = var.get("name", "") var_name = var.get("name", "")
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
errors.append( errors.append(
f"Variable name '{var_name}' is not a valid Python identifier" f"Variable name '{var_name}' is not a valid Python identifier"
) )
return errors return errors
# 便捷函数 # 便捷函数
def evaluate_expression( def evaluate_expression(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] system_vars: dict[str, Any] | LazyVariableDict
) -> Any: ) -> Any:
"""Evaluate an expression (convenience function).""" """Evaluate an expression (convenience function)."""
return ExpressionEvaluator.evaluate( return ExpressionEvaluator.evaluate(
@@ -152,11 +152,11 @@ def evaluate_expression(
def evaluate_condition( def evaluate_condition(
expression: str, expression: str,
conv_var: dict[str, Any], conv_var: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | LazyVariableDict
) -> bool: ) -> Any:
"""Evaluate a boolean condition expression (convenience function).""" """Evaluate a boolean condition expression (convenience function)."""
return ExpressionEvaluator.evaluate_bool( return ExpressionEvaluator.evaluate_bool(
expression, conv_var, node_outputs, system_vars expression, conv_var, node_outputs, system_vars

View File

@@ -1,7 +1,8 @@
""" """
模板渲染器 Template Renderer
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 Provides safe template rendering using Jinja2, supporting variable references
and expressions.
""" """
import logging import logging
@@ -10,11 +11,15 @@ from typing import Any
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
from app.core.workflow.engine.variable_pool import LazyVariableDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
class SafeUndefined(Undefined): class SafeUndefined(Undefined):
"""访问未定义属性不会报错,返回空字符串""" """Return empty string instead of raising error when accessing undefined variables"""
__slots__ = () __slots__ = ()
def _fail_with_undefined_error(self, *args, **kwargs): def _fail_with_undefined_error(self, *args, **kwargs):
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
class TemplateRenderer: class TemplateRenderer:
"""模板渲染器"""
def __init__(self, strict: bool = True): def __init__(self, strict: bool = True):
"""初始化渲染器 """Initialize renderer
Args: Args:
strict: 是否使用严格模式(未定义变量会抛出异常) strict: Whether to enable strict mode (raise error on undefined variables)
""" """
self.strict = strict self.strict = strict
self.env = Environment( self.env = Environment(
undefined=StrictUndefined if strict else SafeUndefined, undefined=StrictUndefined if strict else SafeUndefined,
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
) )
@staticmethod @staticmethod
def normalize_template(template: str) -> str: def normalize_template(template: str) -> str:
pattern = re.compile( """Normalize template syntax (convert numeric node reference to dict access)"""
r"\{\{\s*(\d+)\.(\w+)\s*}}" return _NORMALIZE_PATTERN.sub(
)
return pattern.sub(
r'{{ node["\1"].\2 }}', r'{{ node["\1"].\2 }}',
template template
) )
@@ -53,24 +54,24 @@ class TemplateRenderer:
def render( def render(
self, self,
template: str, template: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any] | None = None system_vars: dict[str, Any] | LazyVariableDict | None = None
) -> str: ) -> str:
"""渲染模板 """Render template
Args: Args:
template: 模板字符串 template: Template string
conv_vars: 会话变量 conv_vars: Conversation variables
node_outputs: 节点输出结果 node_outputs: Node outputs
system_vars: 系统变量 system_vars: System variables
Returns: Returns:
渲染后的字符串 Rendered string
Raises: Raises:
ValueError: 模板语法错误或变量未定义 ValueError: If template syntax is invalid or variables are undefined
Examples: Examples:
>>> renderer = TemplateRenderer() >>> renderer = TemplateRenderer()
>>> renderer.render( >>> renderer.render(
@@ -80,122 +81,119 @@ class TemplateRenderer:
... {} ... {}
... ) ... )
'Hello World!' 'Hello World!'
>>> renderer.render( >>> renderer.render(
... "分析结果: {{node.analyze.output}}", ... "Analysis result: {{node.analyze.output}}",
... {}, ... {},
... {"analyze": {"output": "正面情绪"}}, ... {"analyze": {"output": "positive sentiment"}},
... {} ... {}
... ) ... )
'分析结果: 正面情绪' 'Analysis result: positive sentiment'
""" """
# 构建命名空间上下文 # Build namespace context
context = { context = {
"conv": conv_vars, # 会话变量:{{conv.user_name}} "conv": conv_vars, # Conversation variables: {{conv.user_name}}
"node": node_outputs, # 节点输出:{{node.node_1.output}} "node": node_outputs, # Node outputs: {{node.node_1.output}}
"sys": system_vars, # 系统变量:{{sys.execution_id}} "sys": system_vars, # System variables: {{sys.execution_id}}
} }
# 支持直接通过节点ID访问节点输出{{llm_qa.output}} # Allow direct access to node outputs by node ID: {{llm_qa.output}}
# 将所有节点输出添加到顶层上下文
if node_outputs: if node_outputs:
context.update(node_outputs) context.update(node_outputs)
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
if conv_vars: # if conv_vars:
context.update(conv_vars) # context.update(conv_vars)
#
context["nodes"] = node_outputs or {} # 旧语法兼容 # context["nodes"] = node_outputs or {} # 旧语法兼容
template = self.normalize_template(template) template = self.normalize_template(template)
try: try:
tmpl = self.env.from_string(template) tmpl = self.env.from_string(template)
return tmpl.render(**context) return tmpl.render(**context)
except TemplateSyntaxError as e: except TemplateSyntaxError as e:
logger.error(f"模板语法错误: {template}, 错误: {e}") logger.error(f"Template syntax error: {template}, error: {e}")
raise ValueError(f"模板语法错误: {e}") raise ValueError(f"Template syntax error: {e}")
except UndefinedError as e: except UndefinedError as e:
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") logger.error(f"Undefined variable in template: {template}, error: {e}")
raise ValueError(f"未定义的变量: {e}") raise ValueError(f"Undefined variable: {e}")
except Exception as e: except Exception as e:
logger.error(f"模板渲染异常: {template}, 错误: {e}") logger.error(f"Template rendering error: {template}, error: {e}")
raise ValueError(f"模板渲染失败: {e}") raise ValueError(f"Template rendering failed: {e}")
def validate(self, template: str) -> list[str]: def validate(self, template: str) -> list[str]:
"""验证模板语法 """Validate template syntax
Args: Args:
template: 模板字符串 template: Template string
Returns: Returns:
错误列表,如果为空则验证通过 List of errors (empty if valid)
Examples: Examples:
>>> renderer = TemplateRenderer() >>> renderer = TemplateRenderer()
>>> renderer.validate("Hello {{var.name}}!") >>> renderer.validate("Hello {{var.name}}!")
[] []
>>> renderer.validate("Hello {{var.name") # 缺少结束标记 >>> renderer.validate("Hello {{var.name") # Missing closing tag
['模板语法错误: ...'] ['Template syntax error: ...']
""" """
errors = [] errors = []
try: try:
self.env.from_string(template) self.env.from_string(template)
except TemplateSyntaxError as e: except TemplateSyntaxError as e:
errors.append(f"模板语法错误: {e}") errors.append(f"Template syntax error: {e}")
except Exception as e: except Exception as e:
errors.append(f"模板验证失败: {e}") errors.append(f"Template validation failed: {e}")
return errors return errors
# 全局渲染器实例(严格模式) # Global renderer instances (strict / lenient)
_strict_renderer = TemplateRenderer(strict=True) _strict_renderer = TemplateRenderer(strict=True)
_lenient_renderer = TemplateRenderer(strict=False) _lenient_renderer = TemplateRenderer(strict=False)
def render_template( def render_template(
template: str, template: str,
conv_vars: dict[str, Any], conv_vars: dict[str, Any] | LazyVariableDict,
node_outputs: dict[str, Any], node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
system_vars: dict[str, Any], system_vars: dict[str, Any] | LazyVariableDict,
strict: bool = True strict: bool = True
) -> str: ) -> str:
"""渲染模板(便捷函数) """Render template (convenience function)
Args: Args:
strict: 严格模式 strict: Whether to use strict mode
template: 模板字符串 template: Template string
conv_vars: 会话变量 conv_vars: Conversation variables
node_outputs: 节点输出 node_outputs: Node outputs
system_vars: 系统变量 system_vars: System variables
Returns: Returns:
渲染后的字符串 Rendered string
Examples: Examples:
>>> render_template( >>> render_template(
... "请分析: {{var.text}}", ... "Analyze: {{var.text}}",
... {"text": "这是一段文本"}, ... {"text": "This is a text"},
... {}, ... {},
... {} ... {}
... ) ... )
'请分析: 这是一段文本' 'Analyze: This is a text'
""" """
renderer = _strict_renderer if strict else _lenient_renderer renderer = _strict_renderer if strict else _lenient_renderer
return renderer.render(template, conv_vars, node_outputs, system_vars) return renderer.render(template, conv_vars, node_outputs, system_vars)
def validate_template(template: str) -> list[str]: def validate_template(template: str) -> list[str]:
"""验证模板语法(便捷函数) """Validate template syntax (convenience function)
Args: Args:
template: 模板字符串 template: Template string
Returns: Returns:
错误列表 List of errors
""" """
return _strict_renderer.validate(template) return _strict_renderer.validate(template)

View File

@@ -1,5 +1,6 @@
import os import os
import subprocess import subprocess
from app.repositories.neo4j.create_indexes import create_all_indexes
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from fastapi import FastAPI, APIRouter from fastapi import FastAPI, APIRouter
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
logger.warning(f"加载预定义模型时出错: {str(e)}") logger.warning(f"加载预定义模型时出错: {str(e)}")
else: else:
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)") logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
await create_all_indexes()
logger.info("应用程序启动完成") logger.info("应用程序启动完成")
yield yield
# 应用关闭事件 # 应用关闭事件
logger.info("应用程序正在关闭") logger.info("应用程序正在关闭")

View File

@@ -19,9 +19,12 @@ class User(Base):
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空 last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
# SSO 外部关联字段 # SSO 外部关联字段
external_id = Column(String(100), nullable=True) # 外部用户ID external_id = Column(String(100), nullable=True) # 外部用户 ID
external_source = Column(String(50), nullable=True) # 来源系统 external_source = Column(String(50), nullable=True) # 来源系统
# 用户联系方式
phone = Column(String(50), nullable=True) # 用户电话
# 用户语言偏好 # 用户语言偏好
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文 preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文

View File

@@ -199,6 +199,96 @@ class ConversationRepository:
) )
return conversations, total return conversations, total
def list_app_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
is_draft: Optional[bool] = None,
page: int = 1,
pagesize: int = 20
) -> tuple[list[Conversation], int]:
"""
查询应用日志会话列表(带分页和过滤)
Args:
app_id: 应用 ID
workspace_id: 工作空间 ID
is_draft: 是否草稿会话None 表示不过滤)
page: 页码(从 1 开始)
pagesize: 每页数量
Returns:
Tuple[List[Conversation], int]: (会话列表,总数)
"""
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True)
)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# Calculate total number of records
total = int(self.db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
# Apply pagination
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
logger.info(
"Listed app conversations successfully",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"returned": len(conversations),
"total": total
}
)
return conversations, total
def get_conversation_for_app_log(
self,
conversation_id: uuid.UUID,
app_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Conversation:
"""
查询应用日志的会话详情
Args:
conversation_id: 会话 ID
app_id: 应用 ID
workspace_id: 工作空间 ID
Returns:
Conversation: 会话对象
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(f"Fetching conversation for app log: {conversation_id}")
stmt = select(Conversation).where(
Conversation.id == conversation_id,
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True)
)
conversation = self.db.scalars(stmt).first()
if not conversation:
logger.warning(f"Conversation not found: {conversation_id}")
raise ResourceNotFoundException("会话", str(conversation_id))
logger.info(f"Conversation fetched successfully: {conversation_id}")
return conversation
def soft_delete_conversation_by_conversation_id( def soft_delete_conversation_by_conversation_id(
self, self,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
@@ -290,6 +380,34 @@ class MessageRepository:
self.db.add(message) self.db.add(message)
return message return message
def get_messages_by_conversation(
self,
conversation_id: uuid.UUID
) -> list[Message]:
"""
查询会话的所有消息(按时间正序)
Args:
conversation_id: 会话 ID
Returns:
List[Message]: 消息列表
"""
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
messages = list(self.db.scalars(stmt).all())
logger.info(
"Fetched messages for conversation",
extra={
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return messages
def get_message_by_conversation_id( def get_message_by_conversation_id(
self, self,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,

View File

@@ -132,6 +132,82 @@ class EndUserRepository:
db_logger.error(f"获取或创建终端用户时出错: {str(e)}") db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
raise raise
def get_or_create_end_user_with_config(
self,
app_id: Optional[uuid.UUID],
workspace_id: uuid.UUID,
other_id: str,
memory_config_id: Optional[uuid.UUID] = None,
other_name: Optional[str] = None
) -> EndUser:
"""获取或创建终端用户,并在单次事务中关联记忆配置。
与 get_or_create_end_user 类似,但额外支持在创建/获取时
一并设置 memory_config_id避免多次提交。
Args:
app_id: 应用ID可为 None
workspace_id: 工作空间ID
other_id: 第三方ID
memory_config_id: 记忆配置ID可选仅在用户尚无配置时设置
other_name: 用户名称(用于创建 EndUserInfo
Returns:
EndUser: 终端用户对象(已关联记忆配置)
"""
try:
end_user = (
self.db.query(EndUser)
.filter(
EndUser.workspace_id == workspace_id,
EndUser.other_id == other_id
)
.order_by(EndUser.created_at.asc())
.first()
)
if end_user:
db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}")
if app_id is not None:
end_user.app_id = app_id
if memory_config_id and not end_user.memory_config_id:
end_user.memory_config_id = memory_config_id
self.db.commit()
self.db.refresh(end_user)
return end_user
# 创建新用户
end_user = EndUser(
app_id=app_id,
workspace_id=workspace_id,
other_id=other_id,
memory_config_id=memory_config_id,
)
self.db.add(end_user)
self.db.flush()
end_user_info = EndUserInfo(
end_user_id=end_user.id,
other_name=other_name or "",
aliases=[],
meta_data={}
)
self.db.add(end_user_info)
self.db.commit()
self.db.refresh(end_user)
db_logger.info(
f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, "
f"memory_config_id={memory_config_id}"
)
return end_user
except Exception as e:
self.db.rollback()
db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}")
raise
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]: def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
"""根据ID获取终端用户用于缓存操作 """根据ID获取终端用户用于缓存操作
@@ -515,6 +591,51 @@ class EndUserRepository:
) )
raise raise
def batch_update_memory_config_id_by_app(
self,
app_id: uuid.UUID,
memory_config_id: uuid.UUID
) -> int:
"""批量更新应用下所有终端用户的 memory_config_id
Args:
app_id: 应用ID
memory_config_id: 新的记忆配置ID
Returns:
int: 更新的终端用户数量
Raises:
Exception: 数据库操作失败时抛出
"""
try:
from sqlalchemy import update
stmt = (
update(EndUser)
.where(EndUser.app_id == app_id)
.values(memory_config_id=memory_config_id)
)
result = self.db.execute(stmt)
self.db.commit()
updated_count = result.rowcount
db_logger.info(
f"批量更新终端用户记忆配置: app_id={app_id}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
)
return updated_count
except Exception as e:
self.db.rollback()
db_logger.error(
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
f"memory_config_id={memory_config_id}, error={str(e)}"
)
raise
def count_by_memory_config_id( def count_by_memory_config_id(
self, self,
memory_config_id: uuid.UUID memory_config_id: uuid.UUID

View File

@@ -78,6 +78,15 @@ class MemoryConfigRepository:
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
""" """
# 批量查询多个用户的记忆数量简化版本只返回total
SEARCH_FOR_ALL_BATCH = """
MATCH (n) WHERE n.end_user_id IN $end_user_ids
RETURN
n.end_user_id as user_id,
count(n) as total
ORDER BY user_id
"""
# Extracted entity details within group/app/user # Extracted entity details within group/app/user
SEARCH_FOR_DETIALS = """ SEARCH_FOR_DETIALS = """
MATCH (n:ExtractedEntity) MATCH (n:ExtractedEntity)

View File

@@ -1,62 +1,47 @@
import asyncio
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def create_fulltext_indexes(): async def create_fulltext_indexes():
"""Create full-text indexes for keyword search with BM25 scoring.""" """Create full-text indexes for keyword search with BM25 scoring."""
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
print("\n" + "=" * 70)
print("Creating Full-Text Indexes (for keyword search)")
print("=" * 70)
# 创建 Statements 索引 # 创建 Statements 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement] CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
print("✓ Created: statementsFulltext")
# # 创建 Dialogues 索引 # # 创建 Dialogues 索引
# await connector.execute_query(""" # await connector.execute_query("""
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content] # CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } # OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
# """) # """)
# 创建 Entities 索引 # 创建 Entities 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name] CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
print("✓ Created: entitiesFulltext")
# 创建 Chunks 索引 # 创建 Chunks 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content] CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
print("✓ Created: chunksFulltext")
# 创建 MemorySummary 索引 # 创建 MemorySummary 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content] CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
print("✓ Created: summariesFulltext")
# 创建 Community 索引 # 创建 Community 索引
await connector.execute_query(""" await connector.execute_query("""
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary] CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } } OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
""") """)
print("✓ Created: communitiesFulltext")
print("\nFull-text indexes created successfully with BM25 support.")
except Exception as e:
print(f"✗ Error creating full-text indexes: {e}")
finally: finally:
await connector.close() await connector.close()
async def create_vector_indexes(): async def create_vector_indexes():
"""Create vector indexes for fast embedding similarity search. """Create vector indexes for fast embedding similarity search.
@@ -65,12 +50,7 @@ async def create_vector_indexes():
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
print("\n" + "=" * 70)
print("Creating Vector Indexes (for embedding search)")
print("=" * 70)
print("Note: Adjust vector.dimensions if using different embedding model")
print(" Current setting: 1024 dimensions (for bge-m3)")
print()
# Statement embedding index # Statement embedding index
await connector.execute_query(""" await connector.execute_query("""
@@ -82,7 +62,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: statement_embedding_index")
# Chunk embedding index # Chunk embedding index
await connector.execute_query(""" await connector.execute_query("""
@@ -94,7 +74,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: chunk_embedding_index")
# Entity name embedding index # Entity name embedding index
await connector.execute_query(""" await connector.execute_query("""
@@ -106,7 +86,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: entity_embedding_index")
# Memory summary embedding index # Memory summary embedding index
await connector.execute_query(""" await connector.execute_query("""
@@ -118,8 +98,7 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: summary_embedding_index")
# Community summary embedding index # Community summary embedding index
await connector.execute_query(""" await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
@@ -129,8 +108,7 @@ async def create_vector_indexes():
`vector.dimensions`: 1024, `vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: community_summary_embedding_index")
# Dialogue embedding index (optional) # Dialogue embedding index (optional)
await connector.execute_query(""" await connector.execute_query("""
@@ -142,91 +120,15 @@ async def create_vector_indexes():
`vector.similarity_function`: 'cosine' `vector.similarity_function`: 'cosine'
}} }}
""") """)
print("✓ Created: dialogue_embedding_index")
# Community summary embedding index
await connector.execute_query("""
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
FOR (c:Community)
ON c.summary_embedding
OPTIONS {indexConfig: {
`vector.dimensions`: 1024,
`vector.similarity_function`: 'cosine'
}}
""")
print("✓ Created: community_summary_embedding_index")
print("\nVector indexes created successfully!")
print("\nExpected performance improvement:")
print(" Before: ~1.4s for embedding search")
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
except Exception as e:
print(f"✗ Error creating vector indexes: {e}")
finally: finally:
await connector.close() await connector.close()
async def create_config_id_indexes():
"""Create indexes on config_id fields for improved query performance.
These indexes enable fast filtering of nodes by configuration ID,
which is essential for configuration isolation and multi-tenant scenarios.
"""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Creating Config ID Indexes")
print("=" * 70)
# Dialogue.config_id index
await connector.execute_query("""
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
FOR (d:Dialogue) ON (d.config_id)
""")
print("✓ Created: dialogue_config_id_index")
# Statement.config_id index
await connector.execute_query("""
CREATE INDEX statement_config_id_index IF NOT EXISTS
FOR (s:Statement) ON (s.config_id)
""")
print("✓ Created: statement_config_id_index")
# ExtractedEntity.config_id index
await connector.execute_query("""
CREATE INDEX entity_config_id_index IF NOT EXISTS
FOR (e:ExtractedEntity) ON (e.config_id)
""")
print("✓ Created: entity_config_id_index")
# MemorySummary.config_id index
await connector.execute_query("""
CREATE INDEX summary_config_id_index IF NOT EXISTS
FOR (m:MemorySummary) ON (m.config_id)
""")
print("✓ Created: summary_config_id_index")
print("\nConfig ID indexes created successfully!")
print("These indexes enable fast filtering by configuration ID.")
except Exception as e:
print(f"✗ Error creating config_id indexes: {e}")
finally:
await connector.close()
async def create_unique_constraints(): async def create_unique_constraints():
"""Create uniqueness constraints for core node identifiers. """Create uniqueness constraints for core node identifiers.
Ensures concurrent MERGE operations remain safe and prevents duplicates. Ensures concurrent MERGE operations remain safe and prevents duplicates.
""" """
connector = Neo4jConnector() connector = Neo4jConnector()
try: try:
print("\n" + "=" * 70)
print("Creating Unique Constraints")
print("=" * 70)
# Dialogue.id unique # Dialogue.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -234,8 +136,7 @@ async def create_unique_constraints():
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
""" """
) )
print("✓ Created: dialog_id_unique")
# Statement.id unique # Statement.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -243,8 +144,7 @@ async def create_unique_constraints():
FOR (s:Statement) REQUIRE s.id IS UNIQUE FOR (s:Statement) REQUIRE s.id IS UNIQUE
""" """
) )
print("✓ Created: statement_id_unique")
# Chunk.id unique # Chunk.id unique
await connector.execute_query( await connector.execute_query(
""" """
@@ -252,112 +152,13 @@ async def create_unique_constraints():
FOR (c:Chunk) REQUIRE c.id IS UNIQUE FOR (c:Chunk) REQUIRE c.id IS UNIQUE
""" """
) )
print("✓ Created: chunk_id_unique")
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
except Exception as e:
print(f"✗ Error creating unique constraints: {e}")
finally: finally:
await connector.close() await connector.close()
async def create_all_indexes(): async def create_all_indexes():
"""Create all indexes and constraints in one go.""" """Create all indexes and constraints in one go."""
print("\n" + "=" * 70)
print("Neo4j Index & Constraint Setup")
print("=" * 70)
print("This will create:")
print(" 1. Full-text indexes (for keyword/BM25 search)")
print(" 2. Vector indexes (for embedding similarity search)")
print(" 3. Config ID indexes (for configuration isolation)")
print(" 4. Unique constraints (for data integrity)")
print("=" * 70)
await create_fulltext_indexes() await create_fulltext_indexes()
await create_vector_indexes() await create_vector_indexes()
await create_config_id_indexes()
await create_unique_constraints() await create_unique_constraints()
print("\n" + "=" * 70)
print("✓ All indexes and constraints created successfully!") print("✓ All indexes and constraints created successfully!")
print("=" * 70)
print("\nTo verify, run in Neo4j Browser:")
print(" SHOW INDEXES")
print(" SHOW CONSTRAINTS")
print()
async def check_indexes():
"""Check what indexes currently exist."""
connector = Neo4jConnector()
try:
print("\n" + "=" * 70)
print("Checking Existing Indexes")
print("=" * 70)
query = "SHOW INDEXES"
result = await connector.execute_query(query)
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
for idx in fulltext_indexes:
print(f"{idx.get('name')}")
print(f"\nVector indexes: {len(vector_indexes)}")
for idx in vector_indexes:
print(f"{idx.get('name')}")
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
for idx in range_indexes:
print(f"{idx.get('name')}")
if not vector_indexes:
print("\n⚠️ WARNING: No vector indexes found!")
print(" Embedding search will be VERY SLOW (~1.4s)")
print(" Run: python create_indexes.py")
# Check for config_id indexes
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
if len(config_id_indexes) < 4:
print("\n⚠️ WARNING: Not all config_id indexes found!")
print(f" Expected 4, found {len(config_id_indexes)}")
print(" Run: python create_indexes.py config_id")
print("=" * 70)
finally:
await connector.close()
if __name__ == "__main__":
import asyncio
import sys
if len(sys.argv) > 1:
command = sys.argv[1]
if command == "check":
asyncio.run(check_indexes())
elif command == "fulltext":
asyncio.run(create_fulltext_indexes())
elif command == "vector":
asyncio.run(create_vector_indexes())
elif command == "config_id":
asyncio.run(create_config_id_indexes())
elif command == "constraints":
asyncio.run(create_unique_constraints())
else:
print(f"Unknown command: {command}")
print("\nUsage:")
print(" python create_indexes.py # Create all indexes")
print(" python create_indexes.py check # Check existing indexes")
print(" python create_indexes.py fulltext # Create only full-text indexes")
print(" python create_indexes.py vector # Create only vector indexes")
print(" python create_indexes.py config_id # Create only config_id indexes")
print(" python create_indexes.py constraints # Create only constraints")
else:
asyncio.run(create_all_indexes())

View File

@@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
WITH e, score WITH e, score
UNION WITH collect({entity: e, score: score}) AS fulltextResults
MATCH (e:ExtractedEntity)
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id) OPTIONAL MATCH (ae:ExtractedEntity)
AND e.aliases IS NOT NULL WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q)) AND ae.aliases IS NOT NULL
WITH e, AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
WITH fulltextResults, collect(ae) AS aliasEntities
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
CASE CASE
WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0 WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9 WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
ELSE 0.8 ELSE 0.8
END AS score END
}]) AS row
WITH row.entity AS e, row.score AS score
WITH DISTINCT e, MAX(score) AS score WITH DISTINCT e, MAX(score) AS score
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e) OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s) OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)

View File

@@ -158,22 +158,26 @@ class UserRepository:
raise raise
def get_users_by_tenant( def get_users_by_tenant(
self, self,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: Optional[bool] = None, is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None search: Optional[str] = None
) -> List[User]: ) -> List[User]:
"""获取租户下的用户列表""" """获取租户下的用户列表"""
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}") db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
try: try:
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id) query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
if is_active is not None: if is_active is not None:
query = query.filter(User.is_active == is_active) query = query.filter(User.is_active == is_active)
if is_superuser is not None:
query = query.filter(User.is_superuser == is_superuser)
if search: if search:
query = query.filter( query = query.filter(
or_( or_(
@@ -181,7 +185,7 @@ class UserRepository:
User.email.ilike(f"%{search}%") User.email.ilike(f"%{search}%")
) )
) )
users = query.offset(skip).limit(limit).all() users = query.offset(skip).limit(limit).all()
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}") db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
return users return users
@@ -190,18 +194,22 @@ class UserRepository:
raise raise
def count_users_by_tenant( def count_users_by_tenant(
self, self,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
is_active: Optional[bool] = None, is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None search: Optional[str] = None
) -> int: ) -> int:
"""统计租户下的用户数量""" """统计租户下的用户数量"""
try: try:
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id) query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
if is_active is not None: if is_active is not None:
query = query.filter(User.is_active == is_active) query = query.filter(User.is_active == is_active)
if is_superuser is not None:
query = query.filter(User.is_superuser == is_superuser)
if search: if search:
query = query.filter( query = query.filter(
or_( or_(
@@ -209,7 +217,7 @@ class UserRepository:
User.email.ilike(f"%{search}%") User.email.ilike(f"%{search}%")
) )
) )
return query.scalar() return query.scalar()
except Exception as e: except Exception as e:
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}") db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")

View File

@@ -3,9 +3,9 @@
""" """
import uuid import uuid
from typing import Any, Annotated from typing import Any, Annotated, Literal
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy import desc from sqlalchemy import desc, select
from fastapi import Depends from fastapi import Depends
from app.models.workflow_model import ( from app.models.workflow_model import (
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.app_id == app_id WorkflowExecution.app_id == app_id
).order_by( ).order_by(
desc(WorkflowExecution.started_at) desc(WorkflowExecution.started_at)
).limit(limit).offset(offset).all() ).limit(limit).offset(offset)
return list(self.db.execute(stmt).scalars())
def get_by_conversation_id( def get_by_conversation_id(
self, self,
conversation_id: uuid.UUID conversation_id: uuid.UUID,
status: Literal["running", "completed", "failed"] = None,
limit_count: int = 50
) -> list[WorkflowExecution]: ) -> list[WorkflowExecution]:
"""根据会话 ID 获取执行记录列表 """根据会话 ID 获取执行记录列表
Args: Args:
limit_count:
conversation_id: 会话 ID conversation_id: 会话 ID
status: 状态(可选)
Returns: Returns:
执行记录列表 执行记录列表
""" """
return self.db.query(WorkflowExecution).filter( stmt = select(WorkflowExecution).filter(
WorkflowExecution.conversation_id == conversation_id WorkflowExecution.conversation_id == conversation_id
).order_by( )
desc(WorkflowExecution.started_at) if status:
).all() stmt = stmt.filter(WorkflowExecution.status == status)
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
return list(self.db.execute(stmt).scalars())
def count_by_app_id(self, app_id: uuid.UUID) -> int: def count_by_app_id(self, app_id: uuid.UUID) -> int:
"""统计应用的执行次数 """统计应用的执行次数
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表(按执行顺序排序) 节点执行记录列表(按执行顺序排序)
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id WorkflowNodeExecution.execution_id == execution_id
).order_by( ).order_by(
WorkflowNodeExecution.execution_order WorkflowNodeExecution.execution_order
).all() )
return list(self.db.execute(stmt).scalars())
def get_by_node_id( def get_by_node_id(
self, self,
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
Returns: Returns:
节点执行记录列表 节点执行记录列表
""" """
return self.db.query(WorkflowNodeExecution).filter( stmt = select(WorkflowNodeExecution).filter(
WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.execution_id == execution_id,
WorkflowNodeExecution.node_id == node_id WorkflowNodeExecution.node_id == node_id
).order_by( ).order_by(
WorkflowNodeExecution.retry_count WorkflowNodeExecution.retry_count
).all() )
return list(self.db.execute(stmt).scalars())
# ==================== 依赖注入函数 ==================== # ==================== 依赖注入函数 ====================

View File

@@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel):
# 记忆配置 # 记忆配置
memory: MemoryConfig = Field( memory: MemoryConfig = Field(
default_factory=lambda: MemoryConfig(enabled=True), default_factory=lambda: MemoryConfig(enabled=False),
description="对话历史记忆配置" description="对话历史记忆配置"
) )

View File

@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
end_user_id: str end_user_id: str
config_id: Optional[str] = None config_id: Optional[str] = None
class AgentMemory_Long_Term(ABC): class AgentMemory_Long_Term(ABC):
"""长期记忆配置常量""" """长期记忆配置常量"""
STORAGE_NEO4J = "neo4j" STORAGE_NEO4J = "neo4j"
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
STRATEGY_CHUNK = "chunk" STRATEGY_CHUNK = "chunk"
STRATEGY_TIME = "time" STRATEGY_TIME = "time"
DEFAULT_SCOPE = 6 DEFAULT_SCOPE = 6
TIME_SCOPE=5 TIME_SCOPE = 5
class AgentMemoryDataset(ABC):
PRONOUN=['','本人','在下','自己','','鄙人','','']
NAME='用户'
class AgentMemoryDataset(ABC):
PRONOUN = ['', '本人', '在下', '自己', '', '鄙人', '', '']
NAME = '用户'

View File

@@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel):
"""Request schema for creating an end user. """Request schema for creating an end user.
Attributes: Attributes:
workspace_id: Workspace ID (required)
other_id: External user identifier (required) other_id: External user identifier (required)
other_name: Display name for the end user other_name: Display name for the end user
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
""" """
workspace_id: str = Field(..., description="Workspace ID (required)")
other_id: str = Field(..., description="External user identifier (required)") other_id: str = Field(..., description="External user identifier (required)")
other_name: Optional[str] = Field("", description="Display name") other_name: Optional[str] = Field("", description="Display name")
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
@field_validator("workspace_id")
@classmethod
def validate_workspace_id(cls, v: str) -> str:
"""Validate that workspace_id is not empty."""
if not v or not v.strip():
raise ValueError("workspace_id is required and cannot be empty")
return v.strip()
@field_validator("other_id") @field_validator("other_id")
@classmethod @classmethod
@@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel):
other_id: External user identifier other_id: External user identifier
other_name: Display name other_name: Display name
workspace_id: Workspace the user belongs to workspace_id: Workspace the user belongs to
memory_config_id: Connected memory config ID
""" """
id: str = Field(..., description="End user UUID") id: str = Field(..., description="End user UUID")
other_id: str = Field(..., description="External user identifier") other_id: str = Field(..., description="External user identifier")
other_name: str = Field("", description="Display name") other_name: str = Field("", description="Display name")
workspace_id: str = Field(..., description="Workspace ID") workspace_id: str = Field(..., description="Workspace ID")
memory_config_id: Optional[str] = Field(None, description="Connected memory config ID")
class MemoryConfigItem(BaseModel): class MemoryConfigItem(BaseModel):

View File

@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
last_access_time: int = Field(..., description="最后访问时间Unix时间戳") last_access_time: int = Field(..., description="最后访问时间Unix时间戳")
class PageInfo(BaseModel):
"""分页信息模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
page: int = Field(..., description="当前页码从1开始")
pagesize: int = Field(..., description="每页数量")
total: int = Field(..., description="总记录数")
hasnext: bool = Field(..., description="是否有下一页")
class PendingNodesResponse(BaseModel):
"""待遗忘节点列表响应模型(独立分页接口)"""
model_config = ConfigDict(populate_by_name=True, extra="forbid")
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
page: PageInfo = Field(..., description="分页信息")
class ForgettingStatsResponse(BaseModel): class ForgettingStatsResponse(BaseModel):
"""遗忘引擎统计信息响应模型""" """遗忘引擎统计信息响应模型"""
model_config = ConfigDict(populate_by_name=True, extra="forbid") model_config = ConfigDict(populate_by_name=True, extra="forbid")
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
node_distribution: Dict[str, int] = Field(..., description="节点类型分布") node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
recent_trends: List[ForgettingCycleHistoryPoint] = Field(..., recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
description="最近7个日期的遗忘趋势数据每天取最后一次执行") description="最近7个日期的遗忘趋势数据每天取最后一次执行")
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表前20个满足遗忘条件的节点")
timestamp: int = Field(..., description="统计时间(时间戳)") timestamp: int = Field(..., description="统计时间(时间戳)")

View File

@@ -1,6 +1,6 @@
from dataclasses import field from dataclasses import field
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
from typing import Optional from typing import Optional, List
import datetime import datetime
import uuid import uuid
@@ -20,6 +20,7 @@ class UserCreate(UserBase):
class UserUpdate(BaseModel): class UserUpdate(BaseModel):
username: Optional[str] = None username: Optional[str] = None
email: Optional[EmailStr] = None email: Optional[EmailStr] = None
phone: Optional[str] = None
is_active: Optional[bool] = None is_active: Optional[bool] = None
is_superuser: Optional[bool] = None is_superuser: Optional[bool] = None
@@ -85,6 +86,8 @@ class User(UserBase):
current_workspace_name: Optional[str] = None current_workspace_name: Optional[str] = None
role: Optional[WorkspaceRole] = None role: Optional[WorkspaceRole] = None
preferred_language: Optional[str] = "zh" # 用户语言偏好 preferred_language: Optional[str] = "zh" # 用户语言偏好
phone: Optional[str] = None # 用户电话
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
# 将 datetime 转换为毫秒时间戳 # 将 datetime 转换为毫秒时间戳
@validator("created_at", pre=True) @validator("created_at", pre=True)

View File

@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db from app.db import get_db
from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import MultiAgentConfig, AgentConfig, ModelType
from app.models import WorkflowConfig from app.models import WorkflowConfig
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import AgentRunService from app.services.draft_run_service import AgentRunService
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -43,18 +44,17 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
user_id: Optional[str] = None, files: list[FileInput],
user_id: str,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""聊天(非流式)""" """聊天(非流式)"""
start_time = time.time() start_time = time.time()
config_id = None
# 应用 features 配置 # 应用 features 配置
features_config: dict = config.features or {} features_config: dict = config.features or {}
@@ -93,7 +93,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
memory_flag = False memory_flag = False
if memory: if memory:
@@ -140,13 +141,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0 is_new_conversation = len(history) == 0
if is_new_conversation: if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables) opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening: if opening:
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=opening, content=opening,
meta_data={} meta_data={"suggested_questions": suggested_questions}
) )
# 重新加载历史(包含刚写入的开场白) # 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history( history = await self.conversation_service.get_conversation_history(
@@ -168,11 +169,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -229,6 +225,21 @@ class AppChatService:
# 保存消息 # 保存消息
if audio_url: if audio_url:
assistant_meta["audio_url"] = audio_url assistant_meta["audio_url"] = audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": result["content"]}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -264,20 +275,19 @@ class AppChatService:
message: str, message: str,
conversation_id: uuid.UUID, conversation_id: uuid.UUID,
config: AgentConfig, config: AgentConfig,
files: list[FileInput],
user_id: Optional[str] = None, user_id: Optional[str] = None,
variables: Optional[Dict[str, Any]] = None, variables: Optional[Dict[str, Any]] = None,
web_search: bool = False, web_search: bool = False,
memory: bool = True, memory: bool = True,
storage_type: Optional[str] = None, storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None, user_rag_memory_id: Optional[str] = None,
workspace_id: Optional[str] = None, workspace_id: Optional[str] = None
files: Optional[List[FileInput]] = None
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
"""聊天(流式)""" """聊天(流式)"""
try: try:
start_time = time.time() start_time = time.time()
config_id = None
message_id = uuid.uuid4() message_id = uuid.uuid4()
# 应用 features 配置 # 应用 features 配置
@@ -319,7 +329,8 @@ class AppChatService:
tools.extend(skill_tools) tools.extend(skill_tools)
if skill_prompts: if skill_prompts:
system_prompt = f"{system_prompt}\n\n{skill_prompts}" system_prompt = f"{system_prompt}\n\n{skill_prompts}"
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
config.knowledge_retrieval, user_id)
tools.extend(kb_tools) tools.extend(kb_tools)
# 添加长期记忆工具 # 添加长期记忆工具
memory_flag = False memory_flag = False
@@ -367,13 +378,13 @@ class AppChatService:
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库 # 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
is_new_conversation = len(history) == 0 is_new_conversation = len(history) == 0
if is_new_conversation: if is_new_conversation:
opening = self.agent_service._get_opening_statement(features_config, True, variables) opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
if opening: if opening:
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="assistant", role="assistant",
content=opening, content=opening,
meta_data={} meta_data={"suggested_questions": suggested_questions}
) )
# 重新加载历史(包含刚写入的开场白) # 重新加载历史(包含刚写入的开场白)
history = await self.conversation_service.get_conversation_history( history = await self.conversation_service.get_conversation_history(
@@ -411,11 +422,6 @@ class AppChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
@@ -459,7 +465,7 @@ class AppChatService:
# 保存消息 # 保存消息
human_meta = { human_meta = {
"files":[], "files": [],
"history_files": {} "history_files": {}
} }
assistant_meta = { assistant_meta = {
@@ -484,6 +490,22 @@ class AppChatService:
if stream_audio_url: if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url assistant_meta["audio_url"] = stream_audio_url
if memory_flag:
connected_config = get_end_user_connected_config(user_id, self.db)
memory_config_id: str = connected_config.get("memory_config_id")
messages = [
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
{"role": "assistant", "content": full_content}
]
if memory_config_id:
await write_long_term(
storage_type,
user_id,
messages,
user_rag_memory_id,
memory_config_id
)
self.conversation_service.add_message( self.conversation_service.add_message(
conversation_id=conversation_id, conversation_id=conversation_id,
role="user", role="user",
@@ -618,7 +640,6 @@ class AppChatService:
# 2. 创建编排器 # 2. 创建编排器
orchestrator = MultiAgentOrchestrator(self.db, config) orchestrator = MultiAgentOrchestrator(self.db, config)
# 3. 流式执行任务 # 3. 流式执行任务
async for event in orchestrator.execute_stream( async for event in orchestrator.execute_stream(
message=message, message=message,

View File

@@ -0,0 +1,128 @@
"""应用日志服务层"""
import uuid
from typing import Optional, Tuple
from datetime import datetime
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.models.conversation_model import Conversation, Message
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
logger = get_business_logger()
class AppLogService:
"""应用日志服务"""
def __init__(self, db: Session):
self.db = db
self.conversation_repository = ConversationRepository(db)
self.message_repository = MessageRepository(db)
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
is_draft: Optional[bool] = None,
) -> Tuple[list[Conversation], int]:
"""
查询应用日志会话列表
Args:
app_id: 应用 ID
workspace_id: 工作空间 ID
page: 页码(从 1 开始)
pagesize: 每页数量
is_draft: 是否草稿会话None 表示不过滤)
Returns:
Tuple[list[Conversation], int]: (会话列表,总数)
"""
logger.info(
"查询应用日志会话列表",
extra={
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"page": page,
"pagesize": pagesize,
"is_draft": is_draft
}
)
# 使用 Repository 查询
conversations, total = self.conversation_repository.list_app_conversations(
app_id=app_id,
workspace_id=workspace_id,
is_draft=is_draft,
page=page,
pagesize=pagesize
)
logger.info(
"查询应用日志会话列表成功",
extra={
"app_id": str(app_id),
"total": total,
"returned": len(conversations)
}
)
return conversations, total
def get_conversation_detail(
self,
app_id: uuid.UUID,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID
) -> Conversation:
"""
查询会话详情(包含消息)
Args:
app_id: 应用 ID
conversation_id: 会话 ID
workspace_id: 工作空间 ID
Returns:
Conversation: 包含消息的会话对象
Raises:
ResourceNotFoundException: 当会话不存在时
"""
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)
}
)
# 查询会话
conversation = self.conversation_repository.get_conversation_for_app_log(
conversation_id=conversation_id,
app_id=app_id,
workspace_id=workspace_id
)
# 查询消息(按时间正序)
messages = self.message_repository.get_messages_by_conversation(
conversation_id=conversation_id
)
# 将消息附加到会话对象
conversation.messages = messages
logger.info(
"查询应用日志会话详情成功",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return conversation

View File

@@ -1084,7 +1084,6 @@ class AppService:
if not exists: if not exists:
cleaned["memory_config_id"] = None cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None) cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned return cleaned
exists = self.db.query( exists = self.db.query(
@@ -1096,7 +1095,6 @@ class AppService:
if not exists: if not exists:
cleaned["memory_config_id"] = None cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None) cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned return cleaned
@@ -1684,15 +1682,15 @@ class AppService:
return config.config_id return config.config_id
def _update_endusers_memory_config_by_workspace( def _update_endusers_memory_config_by_app(
self, self,
workspace_id: uuid.UUID, app_id: uuid.UUID,
memory_config_id: uuid.UUID memory_config_id: uuid.UUID
) -> int: ) -> int:
"""批量更新应用下所有终端用户的 memory_config_id """批量更新应用下所有终端用户的 memory_config_id
Args: Args:
workspace_id: 工作空间ID app_id: 应用ID
memory_config_id: 新的记忆配置ID memory_config_id: 新的记忆配置ID
Returns: Returns:
@@ -1701,8 +1699,8 @@ class AppService:
from app.repositories.end_user_repository import EndUserRepository from app.repositories.end_user_repository import EndUserRepository
repo = EndUserRepository(self.db) repo = EndUserRepository(self.db)
updated_count = repo.batch_update_memory_config_id_by_workspace( updated_count = repo.batch_update_memory_config_id_by_app(
workspace_id=workspace_id, app_id=app_id,
memory_config_id=memory_config_id memory_config_id=memory_config_id
) )
@@ -1753,12 +1751,16 @@ class AppService:
miss_params = [] miss_params = []
if agent_cfg.default_model_config_id is None: if agent_cfg.default_model_config_id is None:
miss_params.append("model config") miss_params.append("模型配置")
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"): if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
miss_params.append("memory config") miss_params.append("记忆配置")
if miss_params: if miss_params:
raise BusinessException(f"{', '.join(miss_params)} is required") raise BusinessException(
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
BizCode.CONFIG_MISSING,
context={"missing_params": miss_params},
)
config = { config = {
"system_prompt": agent_cfg.system_prompt, "system_prompt": agent_cfg.system_prompt,
@@ -1877,8 +1879,8 @@ class AppService:
if memory_config_id: if memory_config_id:
app = self.db.query(App).filter(App.id == app_id).first() app = self.db.query(App).filter(App.id == app_id).first()
if app: if app:
updated_count = self._update_endusers_memory_config_by_workspace( updated_count = self._update_endusers_memory_config_by_app(
app.workspace_id, memory_config_id app_id, memory_config_id
) )
logger.info( logger.info(
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, " f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
@@ -2014,7 +2016,7 @@ class AppService:
if memory_config_id: if memory_config_id:
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id) updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
logger.info( logger.info(
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, " f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
f"memory_config_id={memory_config_id}, updated_count={updated_count}" f"memory_config_id={memory_config_id}, updated_count={updated_count}"

View File

@@ -214,7 +214,7 @@ class ConversationService:
conversation.message_count += 1 conversation.message_count += 1
if conversation.message_count == 1 and role == "user": if conversation.message_count <= 2 and role == "user":
conversation.title = ( conversation.title = (
content[:50] + ("..." if len(content) > 50 else "") content[:50] + ("..." if len(content) > 50 else "")
) )

View File

@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger from app.core.logging_config import get_business_logger
from app.core.rag.nlp.search import knowledge_retrieval from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context from app.db import get_db_context
from app.models import AgentConfig, ModelConfig, ModelType from app.models import AgentConfig, ModelConfig
from app.repositories.tool_repository import ToolRepository from app.repositories.tool_repository import ToolRepository
from app.schemas.app_schema import FileInput, Citation from app.schemas.app_schema import FileInput, Citation
from app.schemas.model_schema import ModelInfo from app.schemas.model_schema import ModelInfo
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.multimodal_service import MultimodalService from app.services.multimodal_service import MultimodalService
from app.services.tool_service import ToolService from app.services.tool_service import ToolService
from app.schemas import FileType
logger = get_business_logger() logger = get_business_logger()
@@ -449,15 +448,16 @@ class AgentRunService:
features_config: Dict[str, Any], features_config: Dict[str, Any],
is_new_conversation: bool, is_new_conversation: bool,
variables: Optional[Dict[str, Any]] = None variables: Optional[Dict[str, Any]] = None
) -> Optional[str]: ) -> tuple[Any, Any]:
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None""" """首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
if not is_new_conversation: if not is_new_conversation:
return None return None, None
opening = features_config.get("opening_statement", {}) opening = features_config.get("opening_statement", {})
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")): if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
return None return None, None
statement = opening["statement"] statement = opening["statement"]
suggested_questions = opening["suggested_questions"]
# 如果有变量,进行替换(仅支持 {{var_name}} 格式) # 如果有变量,进行替换(仅支持 {{var_name}} 格式)
if variables: if variables:
@@ -465,7 +465,7 @@ class AgentRunService:
placeholder = f"{{{{{var_name}}}}}" placeholder = f"{{{{{var_name}}}}}"
statement = statement.replace(placeholder, str(var_value)) statement = statement.replace(placeholder, str(var_value))
return statement return statement, suggested_questions
@staticmethod @staticmethod
def _filter_citations( def _filter_citations(
@@ -599,13 +599,16 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白 # 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables) opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation( conversation_id = await self._ensure_conversation(
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=agent_config.app_id, app_id=agent_config.app_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
opening_statement=opening opening_statement=opening,
suggested_questions=suggested_questions
) )
model_info = ModelInfo( model_info = ModelInfo(
@@ -657,11 +660,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files # 传递处理后的文件 files=processed_files # 传递处理后的文件
) )
@@ -845,14 +843,17 @@ class AgentRunService:
# 5. 处理会话ID创建或验证新会话时写入开场白 # 5. 处理会话ID创建或验证新会话时写入开场白
is_new_conversation = not conversation_id is_new_conversation = not conversation_id
opening = self._get_opening_statement(features_config, is_new_conversation, variables) opening, suggested_questions = None, None
if not sub_agent:
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
conversation_id = await self._ensure_conversation( conversation_id = await self._ensure_conversation(
conversation_id=conversation_id, conversation_id=conversation_id,
app_id=agent_config.app_id, app_id=agent_config.app_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
sub_agent=sub_agent, sub_agent=sub_agent,
opening_statement=opening opening_statement=opening,
suggested_questions=suggested_questions
) )
model_info = ModelInfo( model_info = ModelInfo(
@@ -911,11 +912,6 @@ class AgentRunService:
message=message, message=message,
history=history, history=history,
context=context, context=context,
end_user_id=user_id,
config_id=config_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_flag=memory_flag,
files=processed_files files=processed_files
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
@@ -1061,7 +1057,8 @@ class AgentRunService:
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
user_id: Optional[str], user_id: Optional[str],
sub_agent: bool = False, sub_agent: bool = False,
opening_statement: Optional[str] = None opening_statement: Optional[str] = None,
suggested_questions: Optional[List[str]] = None
) -> str: ) -> str:
"""确保会话存在(创建或验证) """确保会话存在(创建或验证)
@@ -1072,6 +1069,7 @@ class AgentRunService:
user_id: 用户ID user_id: 用户ID
sub_agent: 是否为子代理 sub_agent: 是否为子代理
opening_statement: 开场白(新会话时作为第一条消息写入) opening_statement: 开场白(新会话时作为第一条消息写入)
suggested_questions: 预设问题列表
Returns: Returns:
str: 会话ID str: 会话ID
@@ -1115,7 +1113,7 @@ class AgentRunService:
conversation_id=uuid.UUID(new_conv_id), conversation_id=uuid.UUID(new_conv_id),
role="assistant", role="assistant",
content=opening_statement, content=opening_statement,
meta_data={} meta_data={"suggested_questions": suggested_questions}
) )
logger.debug(f"已保存开场白到会话 {new_conv_id}") logger.debug(f"已保存开场白到会话 {new_conv_id}")

View File

@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.agent.utils.write_tools import write as write_neo4j
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.audit_logger import audit_logger
from app.db import get_db_context from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType from app.models.knowledge_model import Knowledge, KnowledgeType
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
) )
from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_perceptual_service import MemoryPerceptualService
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
logger = get_logger(__name__) logger = get_logger(__name__)
config_logger = get_config_logger() config_logger = get_config_logger()
@@ -68,24 +65,22 @@ class MemoryAgentService:
if str(messages) == 'success': if str(messages) == 'success':
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
# 记录成功的操作 # 记录成功的操作
if audit_logger: audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True,
success=True, duration=duration, details={"message_length": len(message)})
duration=duration, details={"message_length": len(message)})
return context return context
else: else:
logger.warning(f"Write operation failed for group {end_user_id}") logger.warning(f"Write operation failed for group {end_user_id}")
# 记录失败的操作 # 记录失败的操作
if audit_logger: audit_logger.log_operation(
audit_logger.log_operation( operation="WRITE",
operation="WRITE", config_id=config_id,
config_id=config_id, end_user_id=end_user_id,
end_user_id=end_user_id, success=False,
success=False, duration=duration,
duration=duration, error=f"写入失败: {messages[:100]}"
error=f"写入失败: {messages[:100]}" )
)
raise ValueError(f"写入失败: {messages}") raise ValueError(f"写入失败: {messages}")
@@ -338,10 +333,9 @@ class MemoryAgentService:
logger.error(error_msg) logger.error(error_msg)
# Log failed operation # Log failed operation
if audit_logger: duration = time.time() - start_time
duration = time.time() - start_time audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg)
success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -401,10 +395,10 @@ class MemoryAgentService:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Write operation failed: {str(e)}" error_msg = f"Write operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
success=False, duration=duration, error=error_msg) success=False, duration=duration, error=error_msg)
raise ValueError(error_msg) raise ValueError(error_msg)
async def read_memory( async def read_memory(
@@ -469,10 +463,9 @@ class MemoryAgentService:
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
# 导入审计日志记录器 # 导入审计日志记录器
try:
from app.core.memory.utils.log.audit_logger import audit_logger
except ImportError:
audit_logger = None
config_load_start = time.time() config_load_start = time.time()
try: try:
@@ -492,16 +485,15 @@ class MemoryAgentService:
logger.error(error_msg) logger.error(error_msg)
# Log failed operation # Log failed operation
if audit_logger: duration = time.time() - start_time
duration = time.time() - start_time audit_logger.log_operation(
audit_logger.log_operation( operation="READ",
operation="READ", config_id=config_id,
config_id=config_id, end_user_id=end_user_id,
end_user_id=end_user_id, success=False,
success=False, duration=duration,
duration=duration, error=error_msg
error=error_msg )
)
raise ValueError(error_msg) raise ValueError(error_msg)
@@ -633,15 +625,15 @@ class MemoryAgentService:
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info( logger.info(
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
end_user_id=end_user_id, end_user_id=end_user_id,
success=True, success=True,
duration=duration duration=duration
) )
return { return {
"answer": summary, "answer": summary,
@@ -651,16 +643,16 @@ class MemoryAgentService:
# Ensure proper error handling and logging # Ensure proper error handling and logging
error_msg = f"Read operation failed: {str(e)}" error_msg = f"Read operation failed: {str(e)}"
logger.error(error_msg) logger.error(error_msg)
if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
operation="READ", operation="READ",
config_id=config_id, config_id=config_id,
end_user_id=end_user_id, end_user_id=end_user_id,
success=False, success=False,
duration=duration, duration=duration,
error=error_msg error=error_msg
) )
raise ValueError(error_msg) raise ValueError(error_msg)
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:

View File

@@ -1,11 +1,12 @@
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from typing import List, Optional from sqlalchemy import desc, nullslast, or_, and_, cast, String
from typing import List, Optional, Dict, Any
import uuid import uuid
from fastapi import HTTPException from fastapi import HTTPException
from app.models.user_model import User from app.models.user_model import User
from app.models.app_model import App from app.models.app_model import App
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser, EndUser as EndUserModel
from app.models.memory_increment_model import MemoryIncrement from app.models.memory_increment_model import MemoryIncrement
from app.repositories import ( from app.repositories import (
@@ -49,44 +50,40 @@ def get_current_workspace_type(
def get_workspace_end_users( def get_workspace_end_users(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
current_user: User current_user: User
) -> List[EndUser]: ) -> List[EndUser]:
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数) """获取工作空间的所有宿主(优化版本:减少数据库查询次数)
返回结果按 created_at 从新到旧排序NULL 值排在最后) 返回结果按 created_at 从新到旧排序NULL 值排在最后)
""" """
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try: try:
# 查询应用ORM # 查询应用ORM
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
if not apps_orm: if not apps_orm:
business_logger.info("工作空间下没有应用") business_logger.info("工作空间下没有应用")
return [] return []
# 提取所有 app_id # 提取所有 app_id
# app_ids = [app.id for app in apps_orm] # app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询 # 批量查询所有 end_users一次查询而非循环查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性 # 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
from app.models.end_user_model import EndUser as EndUserModel
from sqlalchemy import desc, nullslast
end_users_orm = db.query(EndUserModel).filter( end_users_orm = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id EndUserModel.workspace_id == workspace_id
).order_by( ).order_by(
nullslast(desc(EndUserModel.created_at)), nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id) desc(EndUserModel.id)
).all() ).all()
# 转换为 Pydantic 模型(只在需要时转换) # 转换为 Pydantic 模型(只在需要时转换)
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm] end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users return end_users
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -94,6 +91,85 @@ def get_workspace_end_users(
raise raise
def get_workspace_end_users_paginated(
db: Session,
workspace_id: uuid.UUID,
current_user: User,
page: int,
pagesize: int,
keyword: Optional[str] = None
) -> Dict[str, Any]:
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
返回结果按 created_at 从新到旧排序NULL 值排在最后)
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
Args:
db: 数据库会话
workspace_id: 工作空间ID
current_user: 当前用户
page: 页码从1开始
pagesize: 每页数量
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id
Returns:
dict: 包含 items宿主列表和 total总记录数的字典
"""
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
try:
# 构建基础查询
base_query = db.query(EndUserModel).filter(
EndUserModel.workspace_id == workspace_id
)
# 构建搜索条件过滤空字符串和None
keyword = keyword.strip() if keyword else None
if keyword:
keyword_pattern = f"%{keyword}%"
# other_name 匹配始终生效id 匹配仅对 other_name 为空的记录生效
base_query = base_query.filter(
or_(
EndUserModel.other_name.ilike(keyword_pattern),
and_(
or_(
EndUserModel.other_name.is_(None),
EndUserModel.other_name == "",
),
cast(EndUserModel.id, String).ilike(keyword_pattern),
),
)
)
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_nameother_name 为空时匹配 id")
# 获取总记录数
total = base_query.count()
if total == 0:
business_logger.info("工作空间下没有宿主")
return {"items": [], "total": 0}
# 分页查询
# 按 created_at 降序排序NULL 值排在最后id 作为次级排序键保证确定性
end_users_orm = base_query.order_by(
nullslast(desc(EndUserModel.created_at)),
desc(EndUserModel.id)
).offset((page - 1) * pagesize).limit(pagesize).all()
# 转换为 Pydantic 模型
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total}")
return {"items": end_users, "total": total}
except HTTPException:
raise
except Exception as e:
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_workspace_memory_increment( def get_workspace_memory_increment(
db: Session, db: Session,
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
@@ -638,7 +714,24 @@ def get_rag_content(
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}") business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
continue continue
# 4. 返回结果 # 4. 将所有 page_content 拼接后按角色分割为对话列表
merged_text = "\n".join(page_contents)
conversations = []
if merged_text.strip():
import re
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
parts = re.split(r'(user|assistant):', merged_text)
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
i = 1
while i < len(parts) - 1:
role = parts[i].strip()
content = parts[i + 1].strip()
# 将 content 中的 \n 还原为真实换行
content = content.replace("\\n", "\n")
if role in ("user", "assistant") and content:
conversations.append({"role": role, "content": content})
i += 2
result = { result = {
"page": { "page": {
"page": page, "page": page,
@@ -646,10 +739,10 @@ def get_rag_content(
"total": global_total, "total": global_total,
"hasnext": offset_end < global_total, "hasnext": offset_end < global_total,
}, },
"items": page_contents "items": conversations
} }
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)}") business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)}对话")
return result return result
except Exception as e: except Exception as e:

View File

@@ -204,30 +204,35 @@ class MemoryForgetService:
end_user_id: str, end_user_id: str,
forgetting_threshold: float, forgetting_threshold: float,
min_days_since_access: int, min_days_since_access: int,
limit: int = 20 page: Optional[int] = None,
) -> list[Dict[str, Any]]: pagesize: Optional[int] = None
) -> Dict[str, Any]:
""" """
获取待遗忘节点列表 获取待遗忘节点列表
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数) 查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
Args: Args:
connector: Neo4j 连接器 connector: Neo4j 连接器
end_user_id: 组ID end_user_id: 组ID
forgetting_threshold: 遗忘阈值 forgetting_threshold: 遗忘阈值
min_days_since_access: 最小未访问天数 min_days_since_access: 最小未访问天数
limit: 返回节点数量限制 page: 页码可选从1开始
pagesize: 每页数量(可选)
Returns: Returns:
list: 待遗忘节点列表 dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息(分页时)
""" """
from datetime import timedelta from datetime import timedelta
# 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区) # 计算最小访问时间ISO 8601 格式字符串,使用 UTC 时区)
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access) min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ') min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
query = """ # 基础查询(用于获取总数)
count_query = """
MATCH (n) MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary) WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id AND n.end_user_id = $end_user_id
@@ -235,10 +240,22 @@ class MemoryForgetService:
AND n.activation_value < $threshold AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str) AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN RETURN count(n) as total
"""
# 数据查询
data_query = """
MATCH (n)
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
AND n.end_user_id = $end_user_id
AND n.activation_value IS NOT NULL
AND n.activation_value < $threshold
AND n.last_access_time IS NOT NULL
AND datetime(n.last_access_time) < datetime($min_access_time_str)
RETURN
elementId(n) as node_id, elementId(n) as node_id,
labels(n)[0] as node_type, labels(n)[0] as node_type,
CASE CASE
WHEN n:Statement THEN n.statement WHEN n:Statement THEN n.statement
WHEN n:ExtractedEntity THEN n.name WHEN n:ExtractedEntity THEN n.name
WHEN n:MemorySummary THEN n.content WHEN n:MemorySummary THEN n.content
@@ -247,18 +264,32 @@ class MemoryForgetService:
n.activation_value as activation_value, n.activation_value as activation_value,
n.last_access_time as last_access_time n.last_access_time as last_access_time
ORDER BY n.activation_value ASC ORDER BY n.activation_value ASC
LIMIT $limit
""" """
# 如果启用分页,添加 SKIP 和 LIMIT
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
data_query += " SKIP $skip LIMIT $limit"
params = { params = {
'end_user_id': end_user_id, 'end_user_id': end_user_id,
'threshold': forgetting_threshold, 'threshold': forgetting_threshold,
'min_access_time_str': min_access_time_str, 'min_access_time_str': min_access_time_str
'limit': limit
} }
results = await connector.execute_query(query, **params) # 获取总数(分页时需要)
total = 0
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
count_results = await connector.execute_query(count_query, **params)
if count_results:
total = count_results[0]['total']
# 添加分页参数
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
params['skip'] = (page - 1) * pagesize
params['limit'] = pagesize
results = await connector.execute_query(data_query, **params)
pending_nodes = [] pending_nodes = []
for result in results: for result in results:
# 将节点类型标签转换为小写 # 将节点类型标签转换为小写
@@ -267,7 +298,7 @@ class MemoryForgetService:
node_type_label = 'entity' node_type_label = 'entity'
elif node_type_label == 'memorysummary': elif node_type_label == 'memorysummary':
node_type_label = 'summary' node_type_label = 'summary'
# 将 Neo4j DateTime 对象转换为时间戳(毫秒) # 将 Neo4j DateTime 对象转换为时间戳(毫秒)
last_access_time = result['last_access_time'] last_access_time = result['last_access_time']
last_access_dt = convert_neo4j_datetime_to_python(last_access_time) last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
@@ -278,7 +309,7 @@ class MemoryForgetService:
last_access_timestamp = int(last_access_dt.timestamp() * 1000) last_access_timestamp = int(last_access_dt.timestamp() * 1000)
else: else:
last_access_timestamp = 0 last_access_timestamp = 0
pending_nodes.append({ pending_nodes.append({
'node_id': str(result['node_id']), 'node_id': str(result['node_id']),
'node_type': node_type_label, 'node_type': node_type_label,
@@ -286,8 +317,20 @@ class MemoryForgetService:
'activation_value': result['activation_value'], 'activation_value': result['activation_value'],
'last_access_time': last_access_timestamp 'last_access_time': last_access_timestamp
}) })
return pending_nodes # 构建返回结果
result: Dict[str, Any] = {'items': pending_nodes}
# 如果启用分页,添加分页信息
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
result['page'] = {
'page': page,
'pagesize': pagesize,
'total': total,
'hasnext': (page * pagesize) < total
}
return result
async def trigger_forgetting_cycle( async def trigger_forgetting_cycle(
self, self,
@@ -636,7 +679,7 @@ class MemoryForgetService:
api_logger.error(f"获取历史趋势数据失败: {str(e)}") api_logger.error(f"获取历史趋势数据失败: {str(e)}")
# 失败时返回空列表,不影响主流程 # 失败时返回空列表,不影响主流程
# 获取待遗忘节点列表前20个满足遗忘条件的节点 # 获取待遗忘节点列表
pending_nodes = [] pending_nodes = []
try: try:
if end_user_id: if end_user_id:
@@ -652,8 +695,7 @@ class MemoryForgetService:
connector=connector, connector=connector,
end_user_id=end_user_id, end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold, forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days), min_days_since_access=int(min_days)
limit=20
) )
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点") api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
@@ -661,24 +703,79 @@ class MemoryForgetService:
except Exception as e: except Exception as e:
api_logger.error(f"获取待遗忘节点失败: {str(e)}") api_logger.error(f"获取待遗忘节点失败: {str(e)}")
# 失败时返回空列表,不影响主流程 # 失败时返回空列表,不影响主流程
# 构建统计信息 # 构建统计信息(不包含 pending_nodes已分离到独立接口
stats = { stats = {
'activation_metrics': activation_metrics, 'activation_metrics': activation_metrics,
'node_distribution': node_distribution, 'node_distribution': node_distribution,
'recent_trends': recent_trends, 'recent_trends': recent_trends,
'pending_nodes': pending_nodes,
'timestamp': int(datetime.now().timestamp() * 1000) 'timestamp': int(datetime.now().timestamp() * 1000)
} }
api_logger.info( api_logger.info(
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, " f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, " f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}" f"trend_days={len(recent_trends)}"
) )
return stats return stats
async def get_pending_nodes(
self,
db: Session,
end_user_id: str,
config_id: Optional[UUID] = None,
page: int = 1,
pagesize: int = 10
) -> Dict[str, Any]:
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
Args:
db: 数据库会话
end_user_id: 组ID必填
config_id: 配置ID可选用于获取遗忘阈值
page: 页码从1开始默认1
pagesize: 每页数量默认10
Returns:
dict: 包含待遗忘节点列表和分页信息的字典
- items: 待遗忘节点列表
- page: 分页信息
"""
# 获取遗忘引擎组件
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
connector = forgetting_scheduler.connector
forgetting_threshold = config['forgetting_threshold']
# 验证 min_days_since_access 配置值
min_days = config.get('min_days_since_access')
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
api_logger.warning(
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
)
min_days = 7
# 调用内部方法获取分页数据
pending_nodes_result = await self._get_pending_forgetting_nodes(
connector=connector,
end_user_id=end_user_id,
forgetting_threshold=forgetting_threshold,
min_days_since_access=int(min_days),
page=page,
pagesize=pagesize
)
api_logger.info(
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
)
return pending_nodes_result
async def get_forgetting_curve( async def get_forgetting_curve(
self, self,
db: Session, db: Session,

View File

@@ -243,28 +243,9 @@ class MemoryPerceptualService:
memory_config: MemoryConfig, memory_config: MemoryConfig,
file: FileInput file: FileInput
): ):
memories = self.repository.get_by_url(file.url)
if memories:
business_logger.info(f"Perceptual memory already exists: {file.url}")
if end_user_id not in [memory.end_user_id for memory in memories]:
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
memory_cache = memories[0]
memory = self.repository.create_perceptual_memory(
end_user_id=uuid.UUID(end_user_id),
perceptual_type=PerceptualType(memory_cache.perceptual_type),
file_path=memory_cache.file_path,
file_name=memory_cache.file_name,
file_ext=memory_cache.file_ext,
summary=memory_cache.summary,
meta_data=memory_cache.meta_data
)
self.db.commit()
return memory
else:
for memory in memories:
if memory.end_user_id == uuid.UUID(end_user_id):
return memory
llm, model_config = self._get_mutlimodal_client(file.type, memory_config) llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
if model_config is None or llm is None:
return None
multimodel_service = MultimodalService(self.db, ModelInfo( multimodel_service = MultimodalService(self.db, ModelInfo(
model_name=model_config.model_name, model_name=model_config.model_name,
provider=model_config.provider, provider=model_config.provider,
@@ -286,15 +267,20 @@ class MemoryPerceptualService:
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
opt_system_prompt = f.read() opt_system_prompt = f.read()
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
except FileNotFoundError: except FileNotFoundError as e:
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
messages = [ messages = [
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
{"role": RoleType.USER.value, "content": [ {"role": RoleType.USER.value, "content": [
{"type": "text", "text": "Summarize the following file"}, file_message {"type": "text", "text": "Summarize the following file"}, file_message
]} ]}
] ]
result = await llm.ainvoke(messages) try:
result = await llm.ainvoke(messages)
except Exception as e:
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
return None
content = result.content content = result.content
final_output = "" final_output = ""
if isinstance(content, list): if isinstance(content, list):

View File

@@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
return result return result
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
"""批量查询多个用户的记忆数量简化版本只返回total
Args:
end_user_ids: 用户ID列表
Returns:
Dict[str, int]: 以user_id为key的记忆数量字典
格式: {"user_id": total_count}
"""
if not end_user_ids:
return {}
result = await _neo4j_connector.execute_query(
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
end_user_ids=end_user_ids,
)
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
data = {}
for row in result:
data[row["user_id"]] = row["total"]
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
for user_id in end_user_ids:
if user_id not in data:
data[user_id] = 0
return data
async def analytics_hot_memory_tags( async def analytics_hot_memory_tags(
db: Session, db: Session,
current_user: User, current_user: User,

View File

@@ -69,7 +69,8 @@ class ModelConfigService:
return items return items
@staticmethod @staticmethod
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: def get_model_by_name(db: Session, name: str, provider: str | None = None,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""根据名称获取模型配置""" """根据名称获取模型配置"""
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
if not model: if not model:
@@ -77,21 +78,22 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
ModelConfig]:
"""按名称模糊匹配获取模型配置列表""" """按名称模糊匹配获取模型配置列表"""
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
@staticmethod @staticmethod
async def validate_model_config( async def validate_model_config(
db: Session, db: Session,
*, *,
model_name: str, model_name: str,
provider: str, provider: str,
api_key: str, api_key: str,
api_base: Optional[str] = None, api_base: Optional[str] = None,
model_type: str = "llm", model_type: str = "llm",
test_message: str = "Hello", test_message: str = "Hello",
is_omni: bool = False is_omni: bool = False
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""验证模型配置是否有效 """验证模型配置是否有效
@@ -158,13 +160,13 @@ class ModelConfigService:
# 统一使用 RedBearEmbeddings自动支持火山引擎多模态 # 统一使用 RedBearEmbeddings自动支持火山引擎多模态
embedding = RedBearEmbeddings(model_config) embedding = RedBearEmbeddings(model_config)
test_texts = [test_message, "测试文本"] test_texts = [test_message, "测试文本"]
# 火山引擎使用 embed_batch其他使用 embed_documents # 火山引擎使用 embed_batch其他使用 embed_documents
if provider.lower() == "volcano": if provider.lower() == "volcano":
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
else: else:
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
return { return {
@@ -200,11 +202,11 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "image": elif model_type_lower == "image":
# 图片生成模型验证 # 图片生成模型验证
from app.core.models.generation import RedBearImageGenerator from app.core.models.generation import RedBearImageGenerator
generator = RedBearImageGenerator(model_config) generator = RedBearImageGenerator(model_config)
result = await generator.agenerate( result = await generator.agenerate(
prompt="a cute panda", prompt="a cute panda",
@@ -212,7 +214,7 @@ class ModelConfigService:
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
logger.info(f"成功生成图片,结果: {result}") logger.info(f"成功生成图片,结果: {result}")
return { return {
"valid": True, "valid": True,
"message": "图片生成模型配置验证成功", "message": "图片生成模型配置验证成功",
@@ -224,21 +226,21 @@ class ModelConfigService:
}, },
"error": None "error": None
} }
elif model_type_lower == "video": elif model_type_lower == "video":
# 视频生成模型验证 # 视频生成模型验证
from app.core.models.generation import RedBearVideoGenerator from app.core.models.generation import RedBearVideoGenerator
generator = RedBearVideoGenerator(model_config) generator = RedBearVideoGenerator(model_config)
result = await generator.agenerate( result = await generator.agenerate(
prompt="a cute panda playing in bamboo forest", prompt="a cute panda playing in bamboo forest",
duration=5 duration=5
) )
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
# 视频生成是异步任务返回任务ID # 视频生成是异步任务返回任务ID
task_id = result.get("task_id") if isinstance(result, dict) else None task_id = result.get("task_id") if isinstance(result, dict) else None
return { return {
"valid": True, "valid": True,
"message": "视频生成模型配置验证成功", "message": "视频生成模型配置验证成功",
@@ -265,7 +267,6 @@ class ModelConfigService:
# 提取详细的错误信息 # 提取详细的错误信息
error_message = str(e) error_message = str(e)
error_type = type(e).__name__ error_type = type(e).__name__
print("=========error_message:",error_message.lower())
# 特殊处理常见的错误类型 # 特殊处理常见的错误类型
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
# 区域/国家限制(适用于所有提供商) # 区域/国家限制(适用于所有提供商)
@@ -354,14 +355,16 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
tenant_id: uuid.UUID | None = None) -> ModelConfig:
"""更新模型配置""" """更新模型配置"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
@@ -370,25 +373,27 @@ class ModelConfigService:
return model return model
@staticmethod @staticmethod
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""创建组合模型""" """创建组合模型"""
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
# 验证所有 API Key 存在且类型匹配 # 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key: if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
# 检查 API Key 关联的模型配置类型 # 检查 API Key 关联的模型配置类型
for model_config in api_key.model_configs: for model_config in api_key.model_configs:
# chat 和 llm 类型可以兼容 # chat 和 llm 类型可以兼容
compatible_types = {ModelType.LLM, ModelType.CHAT} compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type config_type = model_config.type
request_type = model_data.type request_type = model_data.type
if not (config_type == request_type or if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)): (config_type in compatible_types and request_type in compatible_types)):
raise BusinessException( raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
@@ -399,7 +404,7 @@ class ModelConfigService:
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
# BizCode.INVALID_PARAMETER # BizCode.INVALID_PARAMETER
# ) # )
# 创建组合模型 # 创建组合模型
model_config_data = { model_config_data = {
"tenant_id": tenant_id, "tenant_id": tenant_id,
@@ -418,49 +423,51 @@ class ModelConfigService:
model = ModelConfigRepository.create(db, model_config_data) model = ModelConfigRepository.create(db, model_config_data)
db.flush() db.flush()
# 关联 API Keys # 关联 API Keys
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key: if api_key:
model.api_keys.append(api_key) model.api_keys.append(api_key)
db.commit() db.commit()
db.refresh(model) db.refresh(model)
return model return model
@staticmethod @staticmethod
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
tenant_id: uuid.UUID) -> ModelConfig:
"""更新组合模型""" """更新组合模型"""
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
if not existing_model: if not existing_model:
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
if model_data.name and model_data.name != existing_model.name: if model_data.name and model_data.name != existing_model.name:
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
tenant_id=tenant_id):
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
if not existing_model.is_composite: if not existing_model.is_composite:
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
# 验证所有 API Key 存在且类型匹配 # 验证所有 API Key 存在且类型匹配
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if not api_key: if not api_key:
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
for model_config in api_key.model_configs: for model_config in api_key.model_configs:
compatible_types = {ModelType.LLM, ModelType.CHAT} compatible_types = {ModelType.LLM, ModelType.CHAT}
config_type = model_config.type config_type = model_config.type
request_type = existing_model.type request_type = existing_model.type
if not (config_type == request_type or if not (config_type == request_type or
(config_type in compatible_types and request_type in compatible_types)): (config_type in compatible_types and request_type in compatible_types)):
raise BusinessException( raise BusinessException(
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
BizCode.INVALID_PARAMETER BizCode.INVALID_PARAMETER
) )
# 更新基本信息 # 更新基本信息
existing_model.name = model_data.name existing_model.name = model_data.name
# existing_model.type = model_data.type # existing_model.type = model_data.type
@@ -471,14 +478,14 @@ class ModelConfigService:
existing_model.is_public = model_data.is_public existing_model.is_public = model_data.is_public
if "load_balance_strategy" in model_data.model_fields_set: if "load_balance_strategy" in model_data.model_fields_set:
existing_model.load_balance_strategy = model_data.load_balance_strategy existing_model.load_balance_strategy = model_data.load_balance_strategy
# 更新 API Keys 关联 # 更新 API Keys 关联
existing_model.api_keys.clear() existing_model.api_keys.clear()
for api_key_id in model_data.api_key_ids: for api_key_id in model_data.api_key_ids:
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
if api_key: if api_key:
existing_model.api_keys.append(api_key) existing_model.api_keys.append(api_key)
db.commit() db.commit()
db.refresh(existing_model) db.refresh(existing_model)
return existing_model return existing_model
@@ -532,7 +539,7 @@ class ModelApiKeyService:
"""根据provider为多个ModelConfig创建API Key""" """根据provider为多个ModelConfig创建API Key"""
created_keys = [] created_keys = []
failed_models = [] # 记录验证失败的模型 failed_models = [] # 记录验证失败的模型
for model_config_id in data.model_config_ids: for model_config_id in data.model_config_ids:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
@@ -540,10 +547,10 @@ class ModelApiKeyService:
data.is_omni = model_config.is_omni data.is_omni = model_config.is_omni
data.capability = model_config.capability data.capability = model_config.capability
# 从ModelBase获取model_name # 从ModelBase获取model_name
model_name = model_config.model_base.name if model_config.model_base else model_config.name model_name = model_config.model_base.name if model_config.model_base else model_config.name
# 检查是否存在API Key包括软删除需要考虑tenant_id # 检查是否存在API Key包括软删除需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs ModelApiKey.model_configs
@@ -553,7 +560,7 @@ class ModelApiKeyService:
ModelApiKey.model_name == model_name, ModelApiKey.model_name == model_name,
ModelConfig.tenant_id == model_config.tenant_id ModelConfig.tenant_id == model_config.tenant_id
).first() ).first()
if existing_key: if existing_key:
# 如果已存在,重新激活并更新 # 如果已存在,重新激活并更新
if existing_key.is_active: if existing_key.is_active:
@@ -566,14 +573,14 @@ class ModelApiKeyService:
existing_key.model_name = model_name existing_key.model_name = model_name
existing_key.capability = data.capability existing_key.capability = data.capability
existing_key.is_omni = data.is_omni existing_key.is_omni = data.is_omni
# 检查是否已关联该模型配置 # 检查是否已关联该模型配置
if model_config not in existing_key.model_configs: if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config) existing_key.model_configs.append(model_config)
created_keys.append(existing_key) created_keys.append(existing_key)
continue continue
# 验证配置 # 验证配置
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
@@ -589,7 +596,7 @@ class ModelApiKeyService:
# 记录验证失败的模型,但不抛出异常 # 记录验证失败的模型,但不抛出异常
failed_models.append(model_name) failed_models.append(model_name)
continue continue
# 创建API Key # 创建API Key
api_key_data = ModelApiKeyCreate( api_key_data = ModelApiKeyCreate(
model_config_ids=[model_config_id], model_config_ids=[model_config_id],
@@ -606,12 +613,12 @@ class ModelApiKeyService:
) )
api_key_obj = ModelApiKeyRepository.create(db, api_key_data) api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
created_keys.append(api_key_obj) created_keys.append(api_key_obj)
if created_keys: if created_keys:
db.commit() db.commit()
for key in created_keys: for key in created_keys:
db.refresh(key) db.refresh(key)
return created_keys, failed_models return created_keys, failed_models
@staticmethod @staticmethod
@@ -626,7 +633,7 @@ class ModelApiKeyService:
api_key_data.is_omni = model_config.is_omni api_key_data.is_omni = model_config.is_omni
if api_key_data.capability is None: if api_key_data.capability is None:
api_key_data.capability = model_config.capability api_key_data.capability = model_config.capability
# 检查API Key是否已存在(包括软删除)需要考虑tenant_id # 检查API Key是否已存在(包括软删除)需要考虑tenant_id
existing_key = db.query(ModelApiKey).join( existing_key = db.query(ModelApiKey).join(
ModelApiKey.model_configs ModelApiKey.model_configs
@@ -650,15 +657,15 @@ class ModelApiKeyService:
existing_key.model_name = api_key_data.model_name existing_key.model_name = api_key_data.model_name
existing_key.capability = api_key_data.capability existing_key.capability = api_key_data.capability
existing_key.is_omni = api_key_data.is_omni existing_key.is_omni = api_key_data.is_omni
# 检查是否已关联该模型配置 # 检查是否已关联该模型配置
if model_config not in existing_key.model_configs: if model_config not in existing_key.model_configs:
existing_key.model_configs.append(model_config) existing_key.model_configs.append(model_config)
db.commit() db.commit()
db.refresh(existing_key) db.refresh(existing_key)
return existing_key return existing_key
# 验证配置 # 验证配置
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
@@ -691,7 +698,7 @@ class ModelApiKeyService:
# 获取关联的模型配置以获取模型类型 # 获取关联的模型配置以获取模型类型
if existing_api_key.model_configs: if existing_api_key.model_configs:
model_config = existing_api_key.model_configs[0] model_config = existing_api_key.model_configs[0]
validation_result = await ModelConfigService.validate_model_config( validation_result = await ModelConfigService.validate_model_config(
db=db, db=db,
model_name=api_key_data.model_name or existing_api_key.model_name, model_name=api_key_data.model_name or existing_api_key.model_name,
@@ -729,15 +736,15 @@ class ModelApiKeyService:
model_config = ModelConfigRepository.get_by_id(db, model_config_id) model_config = ModelConfigRepository.get_by_id(db, model_config_id)
if not model_config: if not model_config:
return None return None
api_keys = [key for key in model_config.api_keys if key.is_active] api_keys = [key for key in model_config.api_keys if key.is_active]
if not api_keys: if not api_keys:
return None return None
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
# 否则返回第一个 # 否则返回第一个
return api_keys[0] return api_keys[0]
@@ -760,20 +767,19 @@ class ModelApiKeyService:
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
class ModelBaseService: class ModelBaseService:
"""基础模型服务""" """基础模型服务"""
@staticmethod @staticmethod
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
models = ModelBaseRepository.get_list(db, query) models = ModelBaseRepository.get_list(db, query)
provider_groups = {} provider_groups = {}
for m in models: for m in models:
model_dict = model_schema.ModelBase.model_validate(m).model_dump() model_dict = model_schema.ModelBase.model_validate(m).model_dump()
if tenant_id: if tenant_id:
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
provider = m.provider provider = m.provider
if provider not in provider_groups: if provider not in provider_groups:
provider_groups[provider] = { provider_groups[provider] = {
@@ -781,7 +787,7 @@ class ModelBaseService:
"models": [] "models": []
} }
provider_groups[provider]["models"].append(model_dict) provider_groups[provider]["models"].append(model_dict)
return list(provider_groups.values()) return list(provider_groups.values())
@staticmethod @staticmethod
@@ -823,10 +829,10 @@ class ModelBaseService:
model_base = ModelBaseRepository.get_by_id(db, model_base_id) model_base = ModelBaseRepository.get_by_id(db, model_base_id)
if not model_base: if not model_base:
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
model_config_data = { model_config_data = {
"model_id": model_base_id, "model_id": model_base_id,
"tenant_id": tenant_id, "tenant_id": tenant_id,

View File

@@ -12,6 +12,9 @@ import base64
import csv import csv
import io import io
import json import json
import re
import olefile
import struct
import zipfile import zipfile
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
@@ -438,13 +441,13 @@ class MultimodalService:
if file.transfer_method == TransferMethod.REMOTE_URL: if file.transfer_method == TransferMethod.REMOTE_URL:
return True, { return True, {
"type": "text", "type": "text",
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>" "text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
} }
else: else:
# 本地文件,提取文本内容 # 本地文件,提取文本内容
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
text = await self._extract_document_text(file) text = await self.extract_document_text(file)
file_metadata = self.db.query(FileMetadata).filter( file_metadata = self.db.query(FileMetadata).filter(
FileMetadata.id == file.upload_file_id FileMetadata.id == file.upload_file_id
).first() ).first()
@@ -542,7 +545,7 @@ class MultimodalService:
server_url = settings.FILE_LOCAL_SERVER_URL server_url = settings.FILE_LOCAL_SERVER_URL
return f"{server_url}/storage/permanent/{file_id}" return f"{server_url}/storage/permanent/{file_id}"
async def _extract_document_text(self, file: FileInput) -> str: async def extract_document_text(self, file: FileInput) -> str:
""" """
提取文档文本内容 提取文档文本内容
@@ -602,31 +605,75 @@ class MultimodalService:
try: try:
word_file = io.BytesIO(file_content) word_file = io.BytesIO(file_content)
doc = Document(word_file) doc = Document(word_file)
return '\n'.join(p.text for p in doc.paragraphs) text_lines = []
for p in doc.paragraphs:
text = p.text.strip()
if text:
text_lines.append(text)
for table in doc.tables:
for row in table.rows:
for cell in row.cells:
text = cell.text.strip()
if text:
text_lines.append(text)
full_text = "\n".join(text_lines)
return full_text.strip() or "[docx 文件无文本内容]"
except Exception as e: except Exception as e:
logger.error(f"提取 docx 文本失败: {e}") logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
return f"[docx 提取失败: {str(e)}]" return f"[docx 提取失败: {str(e)}]"
# 旧版 .docOLE2 格式) # 旧版 .docOLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
try: try:
import olefile
ole = olefile.OleFileIO(io.BytesIO(file_content)) ole = olefile.OleFileIO(io.BytesIO(file_content))
if not ole.exists('WordDocument'): word_stream = ole.openstream('WordDocument').read()
return "[doc 提取失败: 未找到 WordDocument 流]"
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本 # FIB offset 0xA bit9 决定使用 0Table 还是 1Table
stream = ole.openstream('WordDocument').read() fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储 table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
# 简单提取:过滤出可打印字符段 table_stream = ole.openstream(table_name).read()
try:
text = stream.decode('utf-16-le', errors='ignore') # 从 FIB 读取 fcClx/lcbClx 定位 piece table
except Exception: fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
text = stream.decode('latin-1', errors='ignore') clx = table_stream[fc_clx: fc_clx + lcb_clx]
# 过滤控制字符,保留可打印内容
import re # 解析 CLX找到 PlcPcdpiece table
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text) i, plc_pcd = 0, None
text = re.sub(r' +', ' ', text).strip() while i < len(clx):
clxt = clx[i]
if clxt == 0x01:
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
elif clxt == 0x02:
cb = struct.unpack_from('<I', clx, i + 1)[0]
plc_pcd = clx[i + 5: i + 5 + cb]
break
else:
break
if plc_pcd is None:
raise ValueError("PlcPcd not found")
# PlcPcd: (n+1) 个 CP4字节+ n 个 PCD8字节
n_pieces = (len(plc_pcd) - 4) // 12
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
parts = []
for k in range(n_pieces):
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
is_ansi = bool(fc_value & 0x40000000)
fc = fc_value & 0x3FFFFFFF
char_count = cp_array[k + 1] - cp_array[k]
if is_ansi:
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
else:
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
ole.close() ole.close()
return text result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
return result.strip()
except Exception as e: except Exception as e:
logger.error(f"提取 doc 文本失败: {e}") logger.error(f"提取 doc 文本失败: {e}")
return f"[doc 提取失败: {str(e)}]" return f"[doc 提取失败: {str(e)}]"

View File

@@ -1,26 +1,24 @@
"""基于分享链接的聊天服务""" """基于分享链接的聊天服务"""
import uuid
import time
import asyncio import asyncio
import json
import time
import uuid
from typing import Optional, Dict, Any, AsyncGenerator from typing import Optional, Dict, Any, AsyncGenerator
from deprecated import deprecated
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.repositories.model_repository import ModelApiKeyRepository from app.core.error_codes import BizCode
from app.services.memory_konwledges_server import write_rag from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.logging_config import get_business_logger
from app.models import MultiAgentConfig
from app.models import ReleaseShare, AppRelease, Conversation from app.models import ReleaseShare, AppRelease, Conversation
from app.repositories import knowledge_repository
from app.services.conversation_service import ConversationService from app.services.conversation_service import ConversationService
from app.services.draft_run_service import create_web_search_tool from app.services.draft_run_service import create_web_search_tool
from app.services.model_service import ModelApiKeyService from app.services.model_service import ModelApiKeyService
from app.services.release_share_service import ReleaseShareService
from app.core.exceptions import BusinessException, ResourceNotFoundException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.services.multi_agent_service import MultiAgentService from app.services.multi_agent_service import MultiAgentService
from app.models import MultiAgentConfig from app.services.release_share_service import ReleaseShareService
from app.repositories import knowledge_repository
import json
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
logger = get_business_logger() logger = get_business_logger()
@@ -118,6 +116,7 @@ class SharedChatService:
return conversation return conversation
@deprecated("Use the chat method under app_chat_service instead.")
async def chat( async def chat(
self, self,
share_token: str, share_token: str,
@@ -136,10 +135,7 @@ class SharedChatService:
config_id = actual_config_id config_id = actual_config_id
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.services.model_parameter_merger import ModelParameterMerger
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
start_time = time.time() start_time = time.time()
actual_config_id = None actual_config_id = None
@@ -273,11 +269,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
) )
# 保存消息 # 保存消息
@@ -324,6 +315,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def chat_stream( async def chat_stream(
self, self,
share_token: str, share_token: str,
@@ -341,8 +333,6 @@ class SharedChatService:
from app.core.agent.langchain_agent import LangChainAgent from app.core.agent.langchain_agent import LangChainAgent
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
from sqlalchemy import select
from app.models import ModelApiKey
import json import json
start_time = time.time() start_time = time.time()
@@ -486,11 +476,6 @@ class SharedChatService:
message=message, message=message,
history=history, history=history,
context=None, context=None,
end_user_id=user_id,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
config_id=config_id,
memory_flag=memory_flag
): ):
if isinstance(chunk, int): if isinstance(chunk, int):
total_tokens = chunk total_tokens = chunk
@@ -585,6 +570,7 @@ class SharedChatService:
return conversations, total return conversations, total
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat( async def multi_agent_chat(
self, self,
share_token: str, share_token: str,
@@ -680,6 +666,7 @@ class SharedChatService:
"elapsed_time": elapsed_time "elapsed_time": elapsed_time
} }
@deprecated("Use the chat method under app_chat_service instead.")
async def multi_agent_chat_stream( async def multi_agent_chat_stream(
self, self,
share_token: str, share_token: str,

View File

@@ -138,7 +138,7 @@ class TenantService:
except Exception as e: except Exception as e:
business_logger.error(f"删除租户失败: {str(e)}") business_logger.error(f"删除租户失败: {str(e)}")
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR) raise BusinessException(f"删除租户失败{str(e)}", code=BizCode.DB_ERROR)
# 租户用户管理 # 租户用户管理
def get_tenant_users( def get_tenant_users(
@@ -147,6 +147,7 @@ class TenantService:
skip: int = 0, skip: int = 0,
limit: int = 100, limit: int = 100,
is_active: Optional[bool] = None, is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None search: Optional[str] = None
) -> List[UserModel]: ) -> List[UserModel]:
"""获取租户下的用户列表""" """获取租户下的用户列表"""
@@ -155,6 +156,7 @@ class TenantService:
skip=skip, skip=skip,
limit=limit, limit=limit,
is_active=is_active, is_active=is_active,
is_superuser=is_superuser,
search=search search=search
) )
@@ -162,12 +164,14 @@ class TenantService:
self, self,
tenant_id: uuid.UUID, tenant_id: uuid.UUID,
is_active: Optional[bool] = None, is_active: Optional[bool] = None,
is_superuser: Optional[bool] = None,
search: Optional[str] = None search: Optional[str] = None
) -> int: ) -> int:
"""统计租户下的用户数量""" """统计租户下的用户数量"""
return self.user_repo.count_users_by_tenant( return self.user_repo.count_users_by_tenant(
tenant_id=tenant_id, tenant_id=tenant_id,
is_active=is_active, is_active=is_active,
is_superuser=is_superuser,
search=search search=search
) )

View File

@@ -472,6 +472,21 @@ class UserMemoryService:
# 定义允许更新的字段白名单 # 定义允许更新的字段白名单
allowed_fields = {'other_name', 'aliases', 'meta_data'} allowed_fields = {'other_name', 'aliases', 'meta_data'}
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
_user_placeholder_names = {'用户', '', 'User', 'I'}
# 过滤 other_name不允许设置为占位名称
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
del update_data['other_name']
# 过滤 aliases移除占位名称和非字符串值
if 'aliases' in update_data and update_data['aliases']:
update_data['aliases'] = [
a for a in update_data['aliases']
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
]
# 检查是否更新了 aliases 字段 # 检查是否更新了 aliases 字段
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases

View File

@@ -561,6 +561,24 @@ class WorkflowService:
storage_type = 'neo4j' storage_type = 'neo4j'
return storage_type, user_rag_memory_id return storage_type, user_rag_memory_id
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
executions = self.execution_repo.get_by_conversation_id(
conversation_id=conversation_id,
status="completed",
limit_count=1
)
if executions:
last_state = executions[0].output_data
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
# input_data["conv"] = conv_vars
# input_data["conv_messages"] = last_state.get("messages") or []
conv_messages = last_state.get("messages") or []
return conv_vars, conv_messages
return None
# ==================== 工作流执行 ==================== # ==================== 工作流执行 ====================
async def run( async def run(
@@ -634,18 +652,11 @@ class WorkflowService:
# 更新状态为运行中 # 更新状态为运行中
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
result = await execute_workflow( result = await execute_workflow(
@@ -807,17 +818,11 @@ class WorkflowService:
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
input_data["files"] = files input_data["files"] = files
self.update_execution_status(execution.execution_id, "running") self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) history = self._get_history_info(conversation_id_uuid)
if history:
for exec_res in executions: conv_vars, conv_messages = history
if exec_res.status == "completed": input_data["conv"] = conv_vars
last_state = exec_res.output_data input_data["conv_messages"] = conv_messages
if isinstance(last_state, dict):
variables = last_state.get("variables", {})
conv_vars = variables.get("conv", {})
input_data["conv"] = conv_vars
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", [])) init_message_length = len(input_data.get("conv_messages", []))
message_id = uuid.uuid4() message_id = uuid.uuid4()
async for event in execute_workflow_stream( async for event in execute_workflow_stream(

View File

@@ -1,5 +1,4 @@
import asyncio import asyncio
import hashlib
import os import os
import re import re
import shutil import shutil
@@ -38,12 +37,10 @@ from app.db import get_db, get_db_context
from app.models import Document, File, Knowledge from app.models import Document, File, Knowledge
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.schemas import document_schema, file_schema from app.schemas import document_schema, file_schema
from app.schemas.model_schema import ModelInfo
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
from app.services.memory_perceptual_service import MemoryPerceptualService
from app.utils.config_utils import resolve_config_id from app.utils.config_utils import resolve_config_id
from app.utils.redis_lock import RedisLock from app.utils.redis_lock import RedisFairLock
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
def set_asyncio_event_loop(): def set_asyncio_event_loop():
"""Set the asyncio event loop for the current thread.""" """Ensure an open asyncio event loop exists for the current thread.
Reuses the existing event loop if one is available and still open.
Creates and installs a new event loop only when the current one is
closed or missing (e.g. after ``_shutdown_loop_gracefully``).
"""
try: try:
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if loop.is_closed(): if loop.is_closed():
@@ -116,6 +118,30 @@ def set_asyncio_event_loop():
return loop return loop
def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop):
"""Gracefully shutdown pending async generators and tasks on the event loop.
This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
by giving pending aclose() coroutines a chance to run before the loop is discarded.
Note: This only tears down the given loop. Callers that need a fresh event
loop afterwards should use ``set_asyncio_event_loop()`` explicitly.
"""
try:
# Cancel and collect all remaining tasks
all_tasks = asyncio.all_tasks(loop)
if all_tasks:
for task in all_tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True))
# Shutdown async generators (triggers __aclose__ on httpx clients etc.)
loop.run_until_complete(loop.shutdown_asyncgens())
except Exception:
pass
finally:
loop.close()
@celery_app.task(name="tasks.process_item") @celery_app.task(name="tasks.process_item")
def process_item(item: dict): def process_item(item: dict):
""" """
@@ -1148,8 +1174,28 @@ def write_message_task(
logger.info(f"[CELERY WRITE] Write completed successfully: {result}") logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
return result return result
redis_client = get_sync_redis_client()
lock = None
if redis_client is not None:
lock = RedisFairLock(
key=f"memory_write:{end_user_id}",
redis_client=redis_client,
expire=600,
timeout=3600,
auto_renewal=True,
)
if not lock.acquire():
logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}")
return {
"status": "SKIPPED",
"error": "acquire lock timeout",
"end_user_id": end_user_id,
"config_id": str(config_id),
"elapsed_time": time.time() - start_time,
"task_id": self.request.id,
}
try: try:
# 尝试获取现有事件循环,如果不存在则创建新的
loop = set_asyncio_event_loop() loop = set_asyncio_event_loop()
result = loop.run_until_complete(_run()) result = loop.run_until_complete(_run())
@@ -1158,7 +1204,6 @@ def write_message_task(
logger.info(f"[CELERY WRITE] Task completed successfully " logger.info(f"[CELERY WRITE] Task completed successfully "
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
try: try:
_r = get_sync_redis_client() _r = get_sync_redis_client()
if _r is not None: if _r is not None:
@@ -1199,6 +1244,15 @@ def write_message_task(
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"task_id": self.request.id "task_id": self.request.id
} }
finally:
if lock is not None:
try:
lock.release()
except Exception as e:
logger.warning(f"[CELERY WRITE] 释放锁失败: {e}")
# Gracefully shutdown the event loop to prevent
# 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
_shutdown_loop_gracefully(loop)
# unused task # unused task
@@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
"elapsed_time": time.time() - start_time, "elapsed_time": time.time() - start_time,
"task_id": self.request.id, "task_id": self.request.id,
} }
# unused task

View File

@@ -0,0 +1,37 @@
"""
性能监控工具模块
提供代码块执行时间统计功能,用于接口性能分析。
如需再次启用性能监控,只需在 controller 中导入 from app.utils.performance_timer import timer 并添加 with timer(...) 包裹需要监控的代码块即可
"""
import time
from contextlib import contextmanager
from app.core.logging_config import get_api_logger
# 获取API专用日志器
api_logger = get_api_logger()
@contextmanager
def timer(label: str, user_count: int = 0):
"""上下文管理器:用于测量代码块执行时间
Args:
label: 统计标签,用于标识被测量的代码块
user_count: 用户数,可选参数,用于记录处理的用户数量
Usage:
with timer("获取用户列表"):
users = get_users()
with timer("批量处理", user_count=len(user_ids)):
process_users(user_ids)
"""
start = time.perf_counter()
try:
yield
finally:
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")

View File

@@ -1,6 +1,7 @@
import redis import redis
import uuid import uuid
import time import time
import threading
UNLOCK_SCRIPT = """ UNLOCK_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then if redis.call("get", KEYS[1]) == ARGV[1] then
@@ -10,45 +11,136 @@ else
end end
""" """
RENEW_SCRIPT = """
if redis.call("get", KEYS[1]) == ARGV[1] then
return redis.call("expire", KEYS[1], ARGV[2])
else
return 0
end
"""
class RedisLock: CLEANUP_DEAD_HEAD_SCRIPT = """
local queue_key = KEYS[1]
local lock_key = KEYS[2]
local first = redis.call("lindex", queue_key, 0)
if not first then
return 0
end
if redis.call("exists", lock_key) == 1 then
return 0
end
redis.call("lpop", queue_key)
return 1
"""
SAFE_RELEASE_QUEUE_SCRIPT = """
local queue_key = KEYS[1]
local value = ARGV[1]
local first = redis.call("lindex", queue_key, 0)
if first == value then
redis.call("lpop", queue_key)
return 1
end
return 0
"""
def _ensure_str(val):
"""统一将 Redis 返回值转为 str兼容 decode_responses=True/False"""
if val is None:
return None
if isinstance(val, bytes):
return val.decode("utf-8")
return str(val)
class RedisFairLock:
def __init__( def __init__(
self, self,
key: str, key: str,
redis_client: redis.StrictRedis, redis_client: redis.StrictRedis,
expire: int = 60, expire: int = 30,
retry_interval: float = 0.1, retry_interval: float = 0.05,
timeout: float = 30 timeout: float = 600,
auto_renewal: bool = True
): ):
self.key = key self.key = key
self.expire = expire self.queue_key = f"{key}:queue"
self.value = str(uuid.uuid4()) self.value = str(uuid.uuid4())
self._locked = False self.expire = expire
self.retry_interval = retry_interval self.retry_interval = retry_interval
self.timeout = timeout self.timeout = timeout
self.redis_client = redis_client self.redis = redis_client
self._locked = False
self.auto_renewal = auto_renewal
self._renew_thread = None
self._stop_renew = threading.Event()
def acquire(self) -> bool: def acquire(self):
start = time.time() start = time.time()
self.redis.rpush(self.queue_key, self.value)
while True: while True:
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) first = _ensure_str(self.redis.lindex(self.queue_key, 0))
if ok:
self._locked = True if first == self.value:
return True ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
if time.time() - start >= self.timeout: if ok:
self._locked = True
if self.auto_renewal:
self._start_renewal()
return True
if first:
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
if time.time() - start > self.timeout:
self.redis.lrem(self.queue_key, 0, self.value)
return False return False
time.sleep(self.retry_interval) time.sleep(self.retry_interval)
def _renewal_loop(self):
while not self._stop_renew.is_set():
time.sleep(self.expire / 3)
if self._stop_renew.is_set():
break
self.redis.eval(
RENEW_SCRIPT,
1,
self.key,
self.value,
str(self.expire)
)
def _start_renewal(self):
self._stop_renew = threading.Event()
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
self._renew_thread.start()
def _stop_renewal(self):
self._stop_renew.set()
if self._renew_thread:
self._renew_thread.join(timeout=1)
def release(self): def release(self):
if not self._locked: if not self._locked:
return return
self.redis_client.eval(
UNLOCK_SCRIPT, if self.auto_renewal:
1, self._stop_renewal()
self.key,
self.value self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
)
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
self._locked = False self._locked = False
def __enter__(self): def __enter__(self):
@@ -59,3 +151,4 @@ class RedisLock:
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.release() self.release()

View File

@@ -0,0 +1,30 @@
"""202603271515
Revision ID: 4e89970f9e7c
Revises: 6b8a461148ff
Create Date: 2026-03-27 15:12:27.518344
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '4e89970f9e7c'
down_revision: Union[str, None] = '6b8a461148ff'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'phone')
# ### end Alembic commands ###

View File

@@ -68,8 +68,8 @@ export const getModelTypeList = async () => {
return response as any[]; return response as any[];
}; };
// 获取模型列表 // 获取模型列表
export const getModelList = async (pageInfo: PageRequest) => { export const getModelList = async (types: string[], pageInfo: PageRequest) => {
const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, is_active: true }); const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, type: types?.join(','), is_active: true });
return response as any; return response as any;
}; };
//获取模型提供者 //获取模型提供者

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing * @Author: ZhaoYing
* @Date: 2026-02-03 14:00:06 * @Date: 2026-02-03 14:00:06
* @Last Modified by: ZhaoYing * @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-24 17:48:01 * @Last Modified time: 2026-03-31 12:25:53
*/ */
import { request } from '@/utils/request' import { request } from '@/utils/request'
import type { AxiosRequestConfig } from 'axios' import type { AxiosRequestConfig } from 'axios'
@@ -63,8 +63,8 @@ export const getDashboardData = () => {
/****************** User Memory APIs *******************************/ /****************** User Memory APIs *******************************/
export const userMemoryListUrl = '/dashboard/end_users' export const userMemoryListUrl = '/dashboard/end_users'
export const getUserMemoryList = () => { export const getUserMemoryList = (query?: { keyword?: string }) => {
return request.get(userMemoryListUrl) return request.get(userMemoryListUrl, query)
} }
// User Memory - Total end users // User Memory - Total end users
export const getTotalEndUsers = () => { export const getTotalEndUsers = () => {
@@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => {
export const getForgetStats = (end_user_id: string) => { export const getForgetStats = (end_user_id: string) => {
return request.get(`/memory/forget-memory/stats`, { end_user_id }) return request.get(`/memory/forget-memory/stats`, { end_user_id })
} }
// 获取带遗忘节点列表
export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes'
// Implicit Memory - Preferences // Implicit Memory - Preferences
export const getImplicitPreferences = (end_user_id: string) => { export const getImplicitPreferences = (end_user_id: string) => {
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`) return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编组 33</title>
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -187)" stroke="#FF5D34">
<g id="编组-16" transform="translate(1020, 126)">
<g id="编组-33" transform="translate(12, 61)">
<g id="编组-32" transform="translate(2.5, 3)">
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-976, -320)">
<g id="编组-13备份-6" transform="translate(648, 122)">
<g id="编辑" transform="translate(328, 198)">
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
<g id="编组-10" transform="translate(4.3, 4.3)" stroke="#5B6167">
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linejoin="round">
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -135)" stroke="#171719">
<g id="编组-16" transform="translate(1020, 126)">
<g id="编辑" transform="translate(12, 9)">
<g id="编组-10" transform="translate(3, 3)">
<path d="M10,4.30130765 L10,8 C10,9.1045695 9.1045695,10 8,10 L2,10 C0.8954305,10 0,9.1045695 0,8 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.73360016,0 L5.73360016,0" id="路径"></path>
<line x1="3.98149359" y1="6.04872089" x2="10" y2="1.12483439e-14" id="路径-2"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>link-outlined</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="API-Key-管理" transform="translate(-1034, -161)" stroke="#171719">
<g id="编组-16" transform="translate(1022, 126)">
<g id="link-outlined" transform="translate(12, 35)">
<g id="编组-14" transform="translate(2.5, 4.15)">
<path d="M5.50029186,2.425 C5.92116821,2.425 6.30372933,2.58703648 6.58052274,2.85122056 C6.84862719,3.1071115 7.01717088,3.4597026 7.01717088,3.85 C7.01717088,4.24027468 6.84857286,4.5929387 6.58042003,4.84890344 C6.3036418,5.11310155 5.92112667,5.27520782 5.50029186,5.27520782 C5.07946329,5.27520782 4.69695201,5.11310096 4.42017607,4.84890315 C4.15202505,4.59293829 3.98342734,4.24027444 3.98342734,3.85 C3.98342734,3.45970284 4.15197072,3.10711191 4.42007337,2.85122085 C4.69686448,2.58703707 5.07942175,2.425 5.50029186,2.425 Z" id="路径" fill-rule="nonzero"></path>
<path d="M5.5,7.7 C8.53756612,7.7 11,5.39612383 11,3.85 C11,2.30387617 8.53756612,0 5.5,0 C2.46243388,0 0,2.26850164 0,3.85 C0,5.43149836 2.46243388,7.7 5.5,7.7 Z" id="椭圆形"></path>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-950, -320)">
<g id="编组-13备份-6" transform="translate(648, 122)">
<g id="编辑" transform="translate(302, 198)">
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
<g id="编组-16" transform="translate(2.5, 4.7)" stroke="#5B6167">
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1011 B

View File

@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>link-outlined</title>
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -161)" fill="#171719" fill-rule="nonzero">
<g id="编组-16" transform="translate(1020, 126)">
<g id="link-outlined" transform="translate(12, 35)">
<path d="M8.8887561,10.1978746 C8.84435493,10.1534734 8.77130783,10.1534734 8.72690666,10.1978746 L7.06257876,11.8622025 C6.29200354,12.6327777 4.99147881,12.7144186 4.14069502,11.8622025 C3.28847892,11.0099864 3.37011979,9.71089398 4.14069502,8.94031876 L5.80502291,7.27599086 C5.84942409,7.23158968 5.84942409,7.15854259 5.80502291,7.11414142 L5.23496912,6.54408763 C5.19056795,6.49968645 5.11752086,6.49968645 5.07311968,6.54408763 L3.40879178,8.20841552 C2.19706941,9.4201379 2.19706941,11.3809511 3.40879178,12.5912411 C4.62051416,13.8015312 6.58132732,13.8029635 7.7916174,12.5912411 L9.4559453,10.9269132 C9.50034647,10.8825121 9.50034647,10.809465 9.4559453,10.7650638 L8.8887561,10.1978746 Z M12.5926734,3.40879178 C11.3809511,2.19706941 9.4201379,2.19706941 8.20984782,3.40879178 L6.54408763,5.07311968 C6.49968645,5.11752086 6.49968645,5.19056795 6.54408763,5.23496912 L7.11270912,5.80359062 C7.15711029,5.84799179 7.23015739,5.84799179 7.27455856,5.80359062 L8.93888646,4.13926272 C9.70946168,3.3686875 11.0099864,3.28704663 11.8607702,4.13926272 C12.7129863,4.99147881 12.6313454,6.29057124 11.8607702,7.06114647 L10.1964423,8.72547436 C10.1520411,8.76987554 10.1520411,8.84292263 10.1964423,8.88732381 L10.7664961,9.4573776 C10.8108973,9.50177877 10.8839444,9.50177877 10.9283455,9.4573776 L12.5926734,7.7930497 C13.8029635,6.58132732 13.8029635,4.62051416 12.5926734,3.40879178 L12.5926734,3.40879178 Z M9.40581494,5.99981516 C9.36141377,5.95541399 9.28836667,5.95541399 9.2439655,5.99981516 L5.99981516,9.2425332 C5.95541399,9.28693438 5.95541399,9.35998147 5.99981516,9.40438265 L6.56700436,9.97157184 C6.61140554,10.015973 6.68445263,10.015973 6.7288538,9.97157184 L9.97157184,6.7288538 C10.015973,6.68445263 10.015973,6.61140554 9.97157184,6.56700436 L9.40581494,5.99981516 Z" id="形状"></path>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

View File

@@ -1,12 +1,11 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"> <svg width="22px" height="22px" viewBox="0 0 22 22" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>更多</title> <title>卡片1@3x</title>
<g id="空间层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd"> <g id="空间层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="记忆库-个人记忆-感知记忆-听觉" transform="translate(-440, -187)" fill-rule="nonzero"> <g id="平台管理-工具管理--MCP服务" transform="translate(-602, -128)" fill="#5B6167" fill-rule="nonzero">
<g id="编组-14" transform="translate(28, 168)"> <g id="卡片1" transform="translate(252, 112)">
<g id="更多" transform="translate(412, 19)"> <g id="更多" transform="translate(350, 16)">
<rect id="矩形" fill="#000000" opacity="0" x="0" y="0" width="24" height="24"></rect> <path d="M5.4,12.4 C6.17319865,12.4 6.8,11.7731986 6.8,11 C6.8,10.2268014 6.17319865,9.6 5.4,9.6 C4.62680135,9.6 4,10.2268014 4,11 C4,11.7731986 4.62680135,12.4 5.4,12.4 Z M11,12.4 C11.7731986,12.4 12.4,11.7731986 12.4,11 C12.4,10.2268014 11.7731986,9.6 11,9.6 C10.2268014,9.6 9.6,10.2268014 9.6,11 C9.6,11.7731986 10.2268014,12.4 11,12.4 Z M16.6,12.4 C17.3731986,12.4 18,11.7731986 18,11 C18,10.2268014 17.3731986,9.6 16.6,9.6 C15.8268014,9.6 15.2,10.2268014 15.2,11 C15.2,11.7731986 15.8268014,12.4 16.6,12.4 Z" id="形状"></path>
<path d="M5.25,12 C5.25,12.8284271 5.92157288,13.5 6.75,13.5 C7.57842712,13.5 8.25,12.8284271 8.25,12 C8.25,11.1715729 7.57842712,10.5 6.75,10.5 C5.92157288,10.5 5.25,11.1715729 5.25,12 Z M10.5,12 C10.5,12.8284271 11.1715729,13.5 12,13.5 C12.8284271,13.5 13.5,12.8284271 13.5,12 C13.5,11.1715729 12.8284271,10.5 12,10.5 C11.1715729,10.5 10.5,11.1715729 10.5,12 Z M15.75,12 C15.75,12.8284271 16.4215729,13.5 17.25,13.5 C18.0784271,13.5 18.75,12.8284271 18.75,12 C18.75,11.1715729 18.0784271,10.5 17.25,10.5 C16.4215729,10.5 15.75,11.1715729 15.75,12 Z" id="形状" fill="#171719"></path>
</g> </g>
</g> </g>
</g> </g>

Before

Width:  |  Height:  |  Size: 1.3 KiB

After

Width:  |  Height:  |  Size: 1.2 KiB

View File

@@ -1,14 +1,23 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"> <svg width="22px" height="22px" viewBox="0 0 22 22" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>更多</title> <title>更多@3x</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd"> <defs>
<g id="记忆库-个人记忆-感知记忆-文本" transform="translate(-440, -187)" fill-rule="nonzero"> <filter x="-3.4%" y="-6.2%" width="106.8%" height="114.8%" filterUnits="objectBoundingBox" id="filter-1">
<g id="编组-8" transform="translate(12, 76)"> <feOffset dx="0" dy="2" in="SourceAlpha" result="shadowOffsetOuter1"></feOffset>
<g id="编组-14" transform="translate(16, 92)"> <feGaussianBlur stdDeviation="4" in="shadowOffsetOuter1" result="shadowBlurOuter1"></feGaussianBlur>
<g id="更多" transform="translate(412, 19)"> <feColorMatrix values="0 0 0 0 0.0901960784 0 0 0 0 0.0901960784 0 0 0 0 0.0980392157 0 0 0 0.16 0" type="matrix" in="shadowBlurOuter1" result="shadowMatrixOuter1"></feColorMatrix>
<rect id="矩形" fill="#E4E4E4" x="0" y="0" width="24" height="24" rx="8"></rect> <feMerge>
<path d="M5.25,12 C5.25,12.8284271 5.92157288,13.5 6.75,13.5 C7.57842712,13.5 8.25,12.8284271 8.25,12 C8.25,11.1715729 7.57842712,10.5 6.75,10.5 C5.92157288,10.5 5.25,11.1715729 5.25,12 Z M10.5,12 C10.5,12.8284271 11.1715729,13.5 12,13.5 C12.8284271,13.5 13.5,12.8284271 13.5,12 C13.5,11.1715729 12.8284271,10.5 12,10.5 C11.1715729,10.5 10.5,11.1715729 10.5,12 Z M15.75,12 C15.75,12.8284271 16.4215729,13.5 17.25,13.5 C18.0784271,13.5 18.75,12.8284271 18.75,12 C18.75,11.1715729 18.0784271,10.5 17.25,10.5 C16.4215729,10.5 15.75,11.1715729 15.75,12 Z" id="形状" fill="#171719"></path> <feMergeNode in="shadowMatrixOuter1"></feMergeNode>
</g> <feMergeNode in="SourceGraphic"></feMergeNode>
</feMerge>
</filter>
</defs>
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="平台管理-工具管理--MCP服务" transform="translate(-998, -128)" fill-rule="nonzero">
<g id="卡片1备份" filter="url(#filter-1)" transform="translate(648, 112)">
<g id="更多" transform="translate(350, 16)">
<rect id="矩形" fill="#F6F6F6" x="0" y="0" width="22" height="22" rx="8"></rect>
<path d="M5.4,12.4 C6.17319865,12.4 6.8,11.7731986 6.8,11 C6.8,10.2268014 6.17319865,9.6 5.4,9.6 C4.62680135,9.6 4,10.2268014 4,11 C4,11.7731986 4.62680135,12.4 5.4,12.4 Z M11,12.4 C11.7731986,12.4 12.4,11.7731986 12.4,11 C12.4,10.2268014 11.7731986,9.6 11,9.6 C10.2268014,9.6 9.6,10.2268014 9.6,11 C9.6,11.7731986 10.2268014,12.4 11,12.4 Z M16.6,12.4 C17.3731986,12.4 18,11.7731986 18,11 C18,10.2268014 17.3731986,9.6 16.6,9.6 C15.8268014,9.6 15.2,10.2268014 15.2,11 C15.2,11.7731986 15.8268014,12.4 16.6,12.4 Z" id="形状" fill="#5B6167"></path>
</g> </g>
</g> </g>
</g> </g>

Before

Width:  |  Height:  |  Size: 1.4 KiB

After

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 189 KiB

After

Width:  |  Height:  |  Size: 158 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编组 51</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
<g id="工作台-记忆看板-3" transform="translate(-1400, -140)" stroke="#A8A9AA" stroke-width="1.2">
<g id="编组-22" transform="translate(1180, 57)">
<g id="编组-51" transform="translate(220, 83)">
<g id="编组-49" transform="translate(4.5, 4.5)">
<polyline id="路径" points="0 0 7 0 7 7"></polyline>
<line x1="7" y1="0" x2="9.71445147e-17" y2="7" id="路径-51"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 938 B

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>退出</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
<g id="工作台-记忆看板-3" transform="translate(-1196, -229)" stroke="#FF5D34" stroke-width="1.2">
<g id="编组-22" transform="translate(1180, 57)">
<g id="退出" transform="translate(16, 172)">
<g id="编组-7" transform="translate(2.5, 2)">
<path d="M4,12 L2,12 C0.8954305,12 0,11.1045695 0,10 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L4,0 L4,0" id="路径"></path>
<line x1="11" y1="6" x2="4.5" y2="6" id="路径-6"></line>
<polyline id="路径" points="8 3 11 6 8 9"></polyline>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>设置-界面设置</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-记忆看板-3" transform="translate(-1196, -176)">
<g id="编组-22" transform="translate(1180, 57)">
<g id="设置-界面设置" transform="translate(16, 119)">
<g id="编组-2" transform="translate(2, 2)">
<path d="M9.43360402,6.02003217 C9.70219942,5.98544974 9.97412053,5.98544974 10.2427159,6.02003217 C10.3615753,6.03694616 10.4495746,6.13920492 10.4485827,6.25925761 L10.4485827,6.97778158 C10.8355212,7.08416506 11.1878633,7.2900028 11.4705612,7.57481981 L12.0861381,7.21535276 C12.1886462,7.15532215 12.3197841,7.18070349 12.3924911,7.27464643 C12.5566377,7.49314267 12.6927553,7.73134853 12.7976622,7.98369599 C12.8417451,8.09509234 12.7980417,8.22199461 12.6947152,8.28262525 L12.0795484,8.64127201 C12.1812591,9.03280439 12.1812591,9.44382977 12.0795484,9.83536215 L12.6951253,10.1940089 C12.7984653,10.255883 12.8408681,10.3841328 12.7947775,10.4954127 C12.6902734,10.7470629 12.5545635,10.9845849 12.3908368,11.2023979 C12.3178817,11.29706 12.1858902,11.3226521 12.0828433,11.2621154 L11.4681003,10.9034823 C11.1853554,11.1882385 10.8330304,11.3940662 10.4461218,11.5005205 L10.4461218,12.2190308 C10.4472928,12.3391587 10.3591904,12.4415237 10.2402277,12.4582563 C9.97161477,12.49242 9.69975606,12.49242 9.43114314,12.4582563 C9.31227289,12.4413547 9.22425896,12.3390925 9.22524902,12.2190308 L9.22524902,11.5005069 C8.83966141,11.3929856 8.48863308,11.187108 8.20657902,10.9030585 L7.58893771,11.263756 C7.48646614,11.3234892 7.35562246,11.2983389 7.28259844,11.2048724 C7.11845115,10.9862292 6.98233473,10.7478877 6.87742727,10.4954127 C6.83256205,10.383776 6.87616574,10.2561165 6.9799505,10.195253 L7.59757815,9.83618245 C7.54895974,9.65227947 7.52310022,9.46310373 7.52057977,9.27289936 C7.51726668,9.06031128 7.54274952,8.8482532 7.59633403,8.64250245 L6.9807708,8.28344555 C6.87737248,8.2230699 6.83375833,8.09609457 6.87823389,7.98492643 C6.98279132,7.7324117 7.11893329,7.49416785 7.28340506,7.27589055 C7.35624204,7.18108729 7.48825303,7.15531561 7.59139859,7.2157629 L8.2061552,7.57481981 C8.48870496,7.29001094 8.84091669,7.08416484 9.22773725,6.97776791 L9.22773725,6.25925761 C9.22638248,6.13907231 9.31455943,6.03660714 9.43360402,6.02003217 Z M9.85567328,8.1381029 L9.83960916,8.1381029 C9.44388649,8.13394685 9.07614467,8.3416968 8.87546122,8.68278309 C8.67477776,9.02386937 8.67178651,9.44622602 8.86761874,9.79012055 C9.06345096,10.1340151 9.42821335,10.3469528 9.82395519,10.3484022 L9.85608342,10.3484022 C10.4510916,10.3389892 10.929189,9.8552216 10.9315862,9.26014377 C10.9437669,8.6534016 10.4623788,8.15138092 9.85567328,8.1381029 Z" id="形状结合" fill="#171719" fill-rule="nonzero"></path>
<path d="M5.99499898,12 L2,12 C0.8954305,12 4.4408921e-16,11.1045695 4.4408921e-16,10 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L10,0 C11.1045695,0 12,0.8954305 12,2 L12,4.98927875 L12,4.98927875" id="路径" stroke="#171719" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"></path>
<line x1="0" y1="3.99424593" x2="12" y2="3.99424593" id="路径-2" stroke="#171719" stroke-width="1.2"></line>
<circle id="椭圆形" fill="#171719" cx="2.2" cy="2" r="1"></circle>
<circle id="椭圆形" fill="#171719" cx="4.4" cy="2" r="1"></circle>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 3.7 KiB

View File

@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>账户</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-记忆看板-3" transform="translate(-1196, -140)" fill="#171719" fill-rule="nonzero" stroke="#171719" stroke-width="0.4">
<g id="编组-22" transform="translate(1180, 57)">
<g id="账户" transform="translate(16, 83)">
<path d="M8,1 C4.13400675,1 1,4.13400675 1,8 C1,11.8659932 4.13400675,15 8,15 C11.8659932,15 15,11.8659932 15,8 C14.9947583,4.13618002 11.86382,1.0052417 8,1 L8,1 Z M12.4014141,12.3271621 L12.2224629,12.517543 L12.1748711,12.2662539 C12.0428879,11.5485013 11.7289104,10.8766306 11.2629844,10.3149355 C11.1145387,10.1521744 10.8641988,10.135306 10.6952598,10.2766811 C10.5263208,10.4180561 10.4988,10.6674498 10.6328477,10.8422598 C11.1442401,11.4572317 11.4244051,12.2317168 11.4248047,13.0315371 C11.4257399,13.0442165 11.4257399,13.0569476 11.4248047,13.069627 L11.4248047,13.1419512 L11.3638828,13.1819414 C9.31720678,14.5169283 6.67517799,14.5169283 4.62850195,13.1819414 L4.56758008,13.1419648 L4.56758008,13.0334512 C4.56758008,11.1409267 6.10177432,9.60673242 7.99429883,9.60673242 C9.43633406,9.61053035 10.6698828,8.57151634 10.9112248,7.14981524 C11.1525669,5.72811413 10.3309981,4.34023189 8.96849203,3.8679425 C7.60598597,3.39565311 6.10170668,3.97731994 5.4113796,5.24338809 C4.72105252,6.50945624 5.046918,8.08901446 6.18194141,8.97850977 L6.35137695,9.10986914 L6.16099609,9.20315234 C4.9414843,9.79213158 4.07443441,10.9255913 3.82512891,12.2567383 L3.77753711,12.5080273 L3.59858594,12.3176602 C2.45876944,11.1700068 1.82010269,9.61749329 1.8223981,8 C1.81896848,5.20888419 3.68748546,2.76216731 6.38105315,2.03070823 C9.07462084,1.29924914 11.9240679,2.46476648 13.3328918,4.87423832 C14.7417157,7.28371017 14.3599179,10.338544 12.4014141,12.3271621 L12.4014141,12.3271621 Z M5.87924609,6.65787305 C5.88029543,5.48791052 6.82940261,4.5402209 7.9993654,4.54091991 C9.16932819,4.5416197 10.1173016,5.49044339 10.1169523,6.66040633 C10.1166027,7.83036928 9.16806261,8.77862695 7.99809961,8.77862695 C6.82758179,8.77757806 5.87924609,7.82839134 5.87924609,6.65787305 L5.87924609,6.65787305 Z" id="形状"></path>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.2 KiB

View File

@@ -1,15 +0,0 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_16762_59518)">
<path d="M12.6667 0H3.33333C1.49238 0 0 1.49238 0 3.33333V12.6667C0 14.5076 1.49238 16 3.33333 16H12.6667C14.5076 16 16 14.5076 16 12.6667V3.33333C16 1.49238 14.5076 0 12.6667 0Z" fill="url(#paint0_linear_16762_59518)"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.99984 12.093L6.3825 12.6323L5.75184 12.2116L6.4385 11.9823L6.22784 11.3503L5.04917 11.743L4.6665 11.4883V9.66631C4.6665 9.54031 4.59517 9.42497 4.4825 9.3683L3.33317 8.79364V7.20564L4.33317 6.70564L5.33317 7.20564V8.33297C5.33317 8.45964 5.4045 8.57497 5.51717 8.63164L6.8505 9.29831L7.14917 8.70164L5.99984 8.12697V7.20564L7.14917 6.63164C7.26184 6.57497 7.33317 6.45964 7.33317 6.33297V5.33297H6.6665V6.12697L5.6665 6.62697L4.6665 6.12697V4.51164L5.33317 4.06697V5.33297H5.99984V3.62297L6.3825 3.36764L7.99984 3.90697V12.093ZM11.6665 11.333C11.8498 11.333 11.9998 11.4823 11.9998 11.6663C11.9998 11.8503 11.8498 11.9996 11.6665 11.9996C11.4832 11.9996 11.3332 11.8503 11.3332 11.6663C11.3332 11.4823 11.4832 11.333 11.6665 11.333ZM10.9998 3.99964C11.1832 3.99964 11.3332 4.14897 11.3332 4.33297C11.3332 4.51697 11.1832 4.6663 10.9998 4.6663C10.8165 4.6663 10.6665 4.51697 10.6665 4.33297C10.6665 4.14897 10.8165 3.99964 10.9998 3.99964ZM12.3332 7.99964C12.5165 7.99964 12.6665 8.14897 12.6665 8.33297C12.6665 8.51697 12.5165 8.66631 12.3332 8.66631C12.1498 8.66631 11.9998 8.51697 11.9998 8.33297C11.9998 8.14897 12.1498 7.99964 12.3332 7.99964ZM11.3945 8.66631C11.5325 9.05364 11.8992 9.33297 12.3332 9.33297C12.8845 9.33297 13.3332 8.88497 13.3332 8.33297C13.3332 7.78164 12.8845 7.33297 12.3332 7.33297C11.8992 7.33297 11.5325 7.61297 11.3945 7.99964H8.6665V6.66631H10.9998C11.1838 6.66631 11.3332 6.51764 11.3332 6.33297V5.27164C11.7205 5.13364 11.9998 4.76697 11.9998 4.33297C11.9998 3.78164 11.5512 3.33297 10.9998 3.33297C10.4485 3.33297 9.99984 3.78164 9.99984 4.33297C9.99984 4.76697 10.2792 5.13364 10.6665 5.27164V5.99964H8.6665V3.6663C8.6665 3.52297 8.5745 3.39564 8.4385 3.3503L6.4385 2.68364C6.3405 2.65097 6.23384 2.66564 6.1485 2.7223L4.1485 4.05564C4.05584 4.11764 3.99984 4.22164 3.99984 4.33297V6.12697L2.8505 6.70164C2.73784 6.75831 2.6665 6.87364 2.6665 6.99964V8.99964C2.6665 9.12631 2.73784 9.24164 2.8505 9.29831L3.99984 9.87231V11.6663C3.99984 11.7776 4.05584 11.8823 4.1485 11.9436L6.1485 13.277C6.20384 13.3143 6.26784 13.333 6.33317 13.333C6.3685 13.333 6.40384 13.3276 6.4385 13.3156L8.4385 12.649C8.5745 12.6043 8.6665 12.477 8.6665 12.333V10.6663H10.1952L10.7638 11.2356L10.7725 11.227C10.7072 11.3603 10.6665 11.5083 10.6665 11.6663C10.6665 12.2176 11.1152 12.6663 11.6665 12.6663C12.2178 12.6663 12.6665 12.2176 12.6665 11.6663C12.6665 11.115 12.2178 10.6663 11.6665 10.6663C11.5078 10.6663 11.3598 10.707 11.2272 10.773L11.2358 10.7643L10.5692 10.0976C10.5065 10.035 10.4218 9.99964 10.3332 9.99964H8.6665V8.66631H11.3945Z" fill="white"/>
</g>
<defs>
<linearGradient id="paint0_linear_16762_59518" x1="0" y1="1600" x2="1600" y2="0" gradientUnits="userSpaceOnUse">
<stop stop-color="#055F4E"/>
<stop offset="1" stop-color="#56C0A7"/>
</linearGradient>
<clipPath id="clip0_16762_59518">
<rect width="16" height="16" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 57 KiB

After

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.2 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 7.7 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 3.3 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 6.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.1 KiB

After

Width:  |  Height:  |  Size: 2.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

View File

@@ -1,24 +0,0 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Xorbits Square" clip-path="url(#clip0_9850_26870)">
<path id="Vector" d="M8.00391 12.3124C8.69334 13.0754 9.47526 13.7494 10.3316 14.3188C11.0667 14.8105 11.8509 15.2245 12.6716 15.5541C14.1617 14.1465 15.3959 12.4907 16.3192 10.6606L21.7051 0L12.3133 7.38353C10.5832 8.74456 9.12178 10.416 8.00391 12.3124Z" fill="url(#paint0_linear_9850_26870)"/>
<path id="Vector_2" d="M7.23504 18.9512C6.56092 18.5012 5.92386 18.0265 5.3221 17.5394L2.06445 24L7.91975 19.3959C7.69034 19.2494 7.46092 19.103 7.23504 18.9512Z" fill="url(#paint1_linear_9850_26870)"/>
<path id="Vector_3" d="M19.3161 8.57474C21.0808 10.9147 21.5961 13.5159 20.3996 15.3053C18.6526 17.9189 13.9161 17.8183 9.82024 15.0812C5.72435 12.3441 3.82024 8.0065 5.56729 5.39297C6.76377 3.60356 9.36318 3.0865 12.2008 3.81886C7.29318 1.73474 2.62376 1.94121 0.813177 4.64474C-1.45976 8.04709 1.64435 14.1177 7.74494 18.1889C13.8455 22.26 20.6361 22.8124 22.9091 19.4118C24.7179 16.703 23.1173 12.3106 19.3161 8.57474Z" fill="url(#paint2_linear_9850_26870)"/>
</g>
<defs>
<linearGradient id="paint0_linear_9850_26870" x1="2.15214" y1="24.3018" x2="21.2921" y2="0.0988218" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint1_linear_9850_26870" x1="2.06269" y1="24.2294" x2="21.2027" y2="0.028252" gradientUnits="userSpaceOnUse">
<stop stop-color="#E9A85E"/>
<stop offset="1" stop-color="#F52B76"/>
</linearGradient>
<linearGradient id="paint2_linear_9850_26870" x1="-0.613606" y1="3.843" x2="21.4449" y2="18.7258" gradientUnits="userSpaceOnUse">
<stop stop-color="#6A0CF5"/>
<stop offset="1" stop-color="#AB66F3"/>
</linearGradient>
<clipPath id="clip0_9850_26870">
<rect width="24" height="24" fill="white"/>
</clipPath>
</defs>
</svg>

Before

Width:  |  Height:  |  Size: 1.8 KiB

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编组 33</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-1002, -560)" stroke="#5B6167">
<g id="编组-13备份-4" transform="translate(648, 362)">
<g id="编组-33" transform="translate(354, 198)">
<g id="编组-32" transform="translate(3.5, 4)">
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.4 KiB

View File

@@ -0,0 +1,20 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编组 33</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-1002, -320)">
<g id="编组-13备份-6" transform="translate(648, 122)">
<g id="编组-33" transform="translate(354, 198)">
<rect id="矩形" fill-opacity="0.08" fill="#FF5D34" x="0" y="0" width="18" height="18" rx="6"></rect>
<g id="编组-32" transform="translate(3.5, 4)" stroke="#FF5D34">
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-976, -560)" stroke="#5B6167">
<g id="编组-13备份-4" transform="translate(648, 362)">
<g id="编辑" transform="translate(328, 198)">
<g id="编组-10" transform="translate(4.3, 4.3)">
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.0 KiB

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-976, -320)">
<g id="编组-13备份-6" transform="translate(648, 122)">
<g id="编辑" transform="translate(328, 198)">
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
<g id="编组-10" transform="translate(4.3, 4.3)" stroke="#5B6167">
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,16 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-950, -560)" stroke="#5B6167">
<g id="编组-13备份-4" transform="translate(648, 362)">
<g id="编辑" transform="translate(302, 198)">
<g id="编组-16" transform="translate(2.5, 4.7)">
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 888 B

View File

@@ -0,0 +1,17 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>编辑</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="工作台-提示词-我的历史" transform="translate(-950, -320)">
<g id="编组-13备份-6" transform="translate(648, 122)">
<g id="编辑" transform="translate(302, 198)">
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
<g id="编组-16" transform="translate(2.5, 4.7)" stroke="#5B6167">
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
</g>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 1011 B

View File

@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
<title>clear-outlined</title>
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
<g id="应用管理-工作流-配置-开始" transform="translate(-1249, -24)" fill="#171719" fill-rule="nonzero">
<g id="编组-11" transform="translate(1242, 17)">
<g id="clear-outlined" transform="translate(7, 7)">
<path d="M14.4933021,14.4598985 L13.6042691,9.03045685 L13.9045274,9.03045685 C14.146076,9.03045685 14.3406568,8.82436548 14.3406568,8.56852792 L14.3406568,5.15736041 C14.3406568,4.90152284 14.146076,4.69543147 13.9045274,4.69543147 L9.77807205,4.69543147 L9.77807205,1.46192893 C9.77807205,1.20609137 9.58349122,1 9.34194262,1 L6.65806921,1 C6.4165206,1 6.22193978,1.20609137 6.22193978,1.46192893 L6.22193978,4.69543147 L2.09548441,4.69543147 C1.85393581,4.69543147 1.65935498,4.90152284 1.65935498,5.15736041 L1.65935498,8.56852792 C1.65935498,8.82436548 1.85393581,9.03045685 2.09548441,9.03045685 L2.39574275,9.03045685 L1.50670968,14.4598985 C1.50167742,14.4865482 1.5,14.513198 1.5,14.5380711 C1.5,14.7939086 1.69458082,15 1.93612943,15 L14.0638824,15 C14.0890437,15 14.114205,14.9982234 14.1376889,14.9928934 C14.3758827,14.9502538 14.5352377,14.7104061 14.4933021,14.4598985 Z M2.8335496,5.93908629 L7.3961344,5.93908629 L7.3961344,2.24365482 L8.60387743,2.24365482 L8.60387743,5.93908629 L13.1664622,5.93908629 L13.1664622,7.78680203 L2.8335496,7.78680203 L2.8335496,5.93908629 Z M10.6838793,13.7563452 L10.6838793,10.9847716 C10.6838793,10.906599 10.6234922,10.8426396 10.5496857,10.8426396 L9.74452363,10.8426396 C9.67071711,10.8426396 9.61032996,10.906599 9.61032996,10.9847716 L9.61032996,13.7563452 L6.38968187,13.7563452 L6.38968187,10.9847716 C6.38968187,10.906599 6.32929472,10.8426396 6.2554882,10.8426396 L5.45032617,10.8426396 C5.37651966,10.8426396 5.3161325,10.906599 5.3161325,10.9847716 L5.3161325,13.7563452 L2.81342055,13.7563452 L3.56993737,9.13705584 L12.428397,9.13705584 L13.1849139,13.7563452 L10.6838793,13.7563452 Z" id="形状"></path>
</g>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.3 KiB

Some files were not shown because too many files have changed in this diff Show More