Merge branch 'refs/heads/develop' into feature/agent-tool_xjn
# Conflicts: # api/app/core/agent/langchain_agent.py # api/app/core/tools/mcp/client.py
@@ -1,6 +1,8 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
@@ -21,6 +23,50 @@ pool = ConnectionPool.from_url(
|
|||||||
)
|
)
|
||||||
aio_redis = redis.StrictRedis(connection_pool=pool)
|
aio_redis = redis.StrictRedis(connection_pool=pool)
|
||||||
|
|
||||||
|
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
|
||||||
|
|
||||||
|
# Thread-local storage for connection pools.
|
||||||
|
# Each thread (and each forked process) gets its own pool to avoid
|
||||||
|
# "Future attached to a different loop" errors in Celery --pool=threads
|
||||||
|
# and stale connections after fork in --pool=prefork.
|
||||||
|
_thread_local = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def get_thread_safe_redis() -> redis.StrictRedis:
|
||||||
|
"""Return a Redis client whose connection pool is bound to the current
|
||||||
|
thread, process **and** event loop.
|
||||||
|
|
||||||
|
The pool is recreated when:
|
||||||
|
- The PID changes (fork, Celery --pool=prefork)
|
||||||
|
- The thread has no pool yet (Celery --pool=threads)
|
||||||
|
- The previously-cached event loop has been closed (Celery tasks call
|
||||||
|
``_shutdown_loop_gracefully`` which closes the loop after each run)
|
||||||
|
"""
|
||||||
|
current_pid = os.getpid()
|
||||||
|
cached_loop = getattr(_thread_local, "loop", None)
|
||||||
|
loop_stale = cached_loop is not None and cached_loop.is_closed()
|
||||||
|
|
||||||
|
if not hasattr(_thread_local, "pool") \
|
||||||
|
or getattr(_thread_local, "pid", None) != current_pid \
|
||||||
|
or loop_stale:
|
||||||
|
_thread_local.pid = current_pid
|
||||||
|
# Python 3.10+: get_event_loop() raises RuntimeError in threads
|
||||||
|
# where no loop has been set yet (e.g. Celery --pool=threads).
|
||||||
|
try:
|
||||||
|
_thread_local.loop = asyncio.get_event_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
_thread_local.loop = None
|
||||||
|
_thread_local.pool = ConnectionPool.from_url(
|
||||||
|
_REDIS_URL,
|
||||||
|
db=settings.REDIS_DB,
|
||||||
|
password=settings.REDIS_PASSWORD,
|
||||||
|
decode_responses=True,
|
||||||
|
max_connections=5,
|
||||||
|
health_check_interval=30,
|
||||||
|
)
|
||||||
|
|
||||||
|
return redis.StrictRedis(connection_pool=_thread_local.pool)
|
||||||
|
|
||||||
|
|
||||||
async def get_redis_connection():
|
async def get_redis_connection():
|
||||||
"""获取Redis连接"""
|
"""获取Redis连接"""
|
||||||
@@ -44,10 +90,8 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
|
|||||||
val = json.dumps(val, ensure_ascii=False)
|
val = json.dumps(val, ensure_ascii=False)
|
||||||
|
|
||||||
if expire is not None:
|
if expire is not None:
|
||||||
# 设置带过期时间的键值
|
|
||||||
await aio_redis.set(key, val, ex=expire)
|
await aio_redis.set(key, val, ex=expire)
|
||||||
else:
|
else:
|
||||||
# 设置永久键值
|
|
||||||
await aio_redis.set(key, val)
|
await aio_redis.set(key, val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Redis set错误: {str(e)}")
|
logger.error(f"Redis set错误: {str(e)}")
|
||||||
|
|||||||
8
api/app/cache/memory/activity_stats_cache.py
vendored
@@ -10,7 +10,7 @@ import logging
|
|||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class ActivityStatsCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -90,7 +90,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中活动统计缓存: {key}")
|
logger.info(f"命中活动统计缓存: {key}")
|
||||||
@@ -116,7 +116,7 @@ class ActivityStatsCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(workspace_id)
|
key = cls._get_key(workspace_id)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
8
api/app/cache/memory/interest_memory.py
vendored
@@ -9,7 +9,7 @@ import logging
|
|||||||
from typing import Optional, List, Dict, Any
|
from typing import Optional, List, Dict, Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from app.aioRedis import aio_redis
|
from app.aioRedis import get_thread_safe_redis
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -62,7 +62,7 @@ class InterestMemoryCache:
|
|||||||
"cached": True,
|
"cached": True,
|
||||||
}
|
}
|
||||||
value = json.dumps(payload, ensure_ascii=False)
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
await aio_redis.set(key, value, ex=expire)
|
await get_thread_safe_redis().set(key, value, ex=expire)
|
||||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -86,7 +86,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
value = await aio_redis.get(key)
|
value = await get_thread_safe_redis().get(key)
|
||||||
if value:
|
if value:
|
||||||
payload = json.loads(value)
|
payload = json.loads(value)
|
||||||
logger.info(f"命中兴趣分布缓存: {key}")
|
logger.info(f"命中兴趣分布缓存: {key}")
|
||||||
@@ -114,7 +114,7 @@ class InterestMemoryCache:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
key = cls._get_key(end_user_id, language)
|
key = cls._get_key(end_user_id, language)
|
||||||
result = await aio_redis.delete(key)
|
result = await get_thread_safe_redis().delete(key)
|
||||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||||
return result > 0
|
return result > 0
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -57,7 +57,6 @@ def list_apps(
|
|||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 10,
|
pagesize: int = 10,
|
||||||
ids: Optional[str] = None,
|
ids: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -66,7 +65,7 @@ def list_apps(
|
|||||||
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
- 默认包含本工作空间的应用和分享给本工作空间的应用
|
||||||
- 设置 include_shared=false 可以只查看本工作空间的应用
|
- 设置 include_shared=false 可以只查看本工作空间的应用
|
||||||
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
|
||||||
- 当提供 api_key 参数时,查找该 API Key 关联的应用
|
- search 参数支持:应用名称模糊搜索、API Key 精确搜索
|
||||||
"""
|
"""
|
||||||
from sqlalchemy import select as sa_select
|
from sqlalchemy import select as sa_select
|
||||||
from app.models.api_key_model import ApiKey
|
from app.models.api_key_model import ApiKey
|
||||||
@@ -74,23 +73,34 @@ def list_apps(
|
|||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
service = app_service.AppService(db)
|
service = app_service.AppService(db)
|
||||||
|
|
||||||
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
|
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索
|
||||||
if api_key:
|
if search:
|
||||||
matched_id = db.execute(
|
search = search.strip()
|
||||||
sa_select(ApiKey.resource_id).where(
|
# 尝试作为 API Key 精确匹配(API Key 通常较长)
|
||||||
ApiKey.workspace_id == workspace_id,
|
if len(search) >= 10:
|
||||||
ApiKey.api_key == api_key,
|
matched_id = db.execute(
|
||||||
ApiKey.resource_id.isnot(None),
|
sa_select(ApiKey.resource_id).where(
|
||||||
)
|
ApiKey.workspace_id == workspace_id,
|
||||||
).scalar_one_or_none()
|
ApiKey.api_key == search,
|
||||||
ids = str(matched_id) if matched_id else ""
|
ApiKey.resource_id.isnot(None),
|
||||||
|
)
|
||||||
|
).scalar_one_or_none()
|
||||||
|
if matched_id:
|
||||||
|
# 找到 API Key,直接返回关联的应用
|
||||||
|
ids = str(matched_id)
|
||||||
|
|
||||||
# 当 ids 存在且不为 None 时,根据 ids 获取应用
|
# 当 ids 存在时,根据 ids 获取应用(不分页)
|
||||||
if ids is not None:
|
if ids is not None:
|
||||||
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
|
||||||
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
if app_ids:
|
||||||
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
|
||||||
return success(data=items)
|
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
|
||||||
|
# 返回标准分页格式
|
||||||
|
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
# ids 为空时,返回空列表
|
||||||
|
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
|
||||||
|
return success(data=PageData(page=meta, items=[]))
|
||||||
|
|
||||||
# 正常分页查询
|
# 正常分页查询
|
||||||
items_orm, total = app_service.list_apps(
|
items_orm, total = app_service.list_apps(
|
||||||
|
|||||||
@@ -3,17 +3,16 @@ import uuid
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Query
|
from fastapi import APIRouter, Depends, Query
|
||||||
from sqlalchemy import select, desc, func
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||||
from app.models.conversation_model import Conversation, Message
|
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
|
||||||
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
|
|
||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
|
from app.services.app_log_service import AppLogService
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
router = APIRouter(prefix="/apps", tags=["App Logs"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -25,52 +24,35 @@ def list_app_logs(
|
|||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
page: int = Query(1, ge=1),
|
page: int = Query(1, ge=1),
|
||||||
pagesize: int = Query(20, ge=1, le=100),
|
pagesize: int = Query(20, ge=1, le=100),
|
||||||
user_id: Optional[str] = None,
|
|
||||||
is_draft: Optional[bool] = None,
|
is_draft: Optional[bool] = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""查看应用下所有会话记录(分页)
|
"""查看应用下所有会话记录(分页)
|
||||||
|
|
||||||
- 支持按 user_id 筛选
|
|
||||||
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
|
||||||
- 按最新更新时间倒序排列
|
- 按最新更新时间倒序排列
|
||||||
|
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
service = AppService(db)
|
app_service = AppService(db)
|
||||||
service.get_app(app_id, workspace_id)
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
stmt = select(Conversation).where(
|
# 使用 Service 层查询
|
||||||
Conversation.app_id == app_id,
|
log_service = AppLogService(db)
|
||||||
Conversation.workspace_id == workspace_id,
|
conversations, total = log_service.list_conversations(
|
||||||
Conversation.is_active.is_(True),
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
is_draft=is_draft
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_id:
|
|
||||||
stmt = stmt.where(Conversation.user_id == user_id)
|
|
||||||
|
|
||||||
if is_draft is not None:
|
|
||||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
|
||||||
|
|
||||||
total = int(db.execute(
|
|
||||||
select(func.count()).select_from(stmt.subquery())
|
|
||||||
).scalar_one())
|
|
||||||
|
|
||||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
|
||||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
|
||||||
|
|
||||||
conversations = list(db.scalars(stmt).all())
|
|
||||||
|
|
||||||
items = [AppLogConversation.model_validate(c) for c in conversations]
|
items = [AppLogConversation.model_validate(c) for c in conversations]
|
||||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"查询应用日志会话列表",
|
|
||||||
extra={"app_id": str(app_id), "total": total, "page": page}
|
|
||||||
)
|
|
||||||
|
|
||||||
return success(data=PageData(page=meta, items=items))
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
@@ -86,44 +68,22 @@ def get_app_log_detail(
|
|||||||
|
|
||||||
- 返回会话基本信息 + 所有消息(按时间正序)
|
- 返回会话基本信息 + 所有消息(按时间正序)
|
||||||
- 消息 meta_data 包含模型名、token 用量等信息
|
- 消息 meta_data 包含模型名、token 用量等信息
|
||||||
|
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
# 验证应用访问权限
|
# 验证应用访问权限
|
||||||
service = AppService(db)
|
app_service = AppService(db)
|
||||||
service.get_app(app_id, workspace_id)
|
app_service.get_app(app_id, workspace_id)
|
||||||
|
|
||||||
# 查询会话(确保属于该应用和工作空间)
|
# 使用 Service 层查询
|
||||||
conversation = db.scalars(
|
log_service = AppLogService(db)
|
||||||
select(Conversation).where(
|
conversation = log_service.get_conversation_detail(
|
||||||
Conversation.id == conversation_id,
|
app_id=app_id,
|
||||||
Conversation.app_id == app_id,
|
conversation_id=conversation_id,
|
||||||
Conversation.workspace_id == workspace_id,
|
workspace_id=workspace_id
|
||||||
Conversation.is_active.is_(True),
|
|
||||||
)
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not conversation:
|
|
||||||
from app.core.exceptions import ResourceNotFoundException
|
|
||||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
|
||||||
|
|
||||||
# 查询消息(按时间正序)
|
|
||||||
messages = list(db.scalars(
|
|
||||||
select(Message)
|
|
||||||
.where(Message.conversation_id == conversation_id)
|
|
||||||
.order_by(Message.created_at)
|
|
||||||
).all())
|
|
||||||
|
|
||||||
detail = AppLogConversationDetail.model_validate(conversation)
|
|
||||||
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"查询应用日志会话详情",
|
|
||||||
extra={
|
|
||||||
"app_id": str(app_id),
|
|
||||||
"conversation_id": str(conversation_id),
|
|
||||||
"message_count": len(messages)
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
detail = AppLogConversationDetail.model_validate(conversation)
|
||||||
|
|
||||||
return success(data=detail)
|
return success(data=detail)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import asyncio
|
||||||
|
import uuid
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@@ -47,64 +49,64 @@ def get_workspace_total_end_users(
|
|||||||
|
|
||||||
@router.get("/end_users", response_model=ApiResponse)
|
@router.get("/end_users", response_model=ApiResponse)
|
||||||
async def get_workspace_end_users(
|
async def get_workspace_end_users(
|
||||||
|
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||||
|
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||||
|
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||||
|
pagesize: int = Query(10, ge=1, description="每页数量"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取工作空间的宿主列表(高性能优化版本 v2)
|
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||||
|
|
||||||
优化策略:
|
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||||
1. 批量查询 end_users(一次查询而非循环)
|
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||||
2. 并发查询所有用户的记忆数量(Neo4j)
|
|
||||||
3. RAG 模式使用批量查询(一次 SQL)
|
Args:
|
||||||
4. 只返回必要字段减少数据传输
|
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||||
5. 添加短期缓存减少重复查询
|
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||||
6. 并发执行配置查询和记忆数量查询
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
返回格式:
|
db: 数据库会话
|
||||||
{
|
current_user: 当前用户
|
||||||
"end_user": {"id": "uuid", "other_name": "名称"},
|
|
||||||
"memory_num": {"total": 数量},
|
Returns:
|
||||||
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
|
ApiResponse: 包含宿主列表和分页信息
|
||||||
}
|
|
||||||
"""
|
"""
|
||||||
import asyncio
|
# 如果未提供 workspace_id,使用当前用户的工作空间
|
||||||
import json
|
if workspace_id is None:
|
||||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
workspace_id = current_user.current_workspace_id
|
|
||||||
|
|
||||||
# 尝试从缓存获取(30秒缓存)
|
|
||||||
cache_key = f"end_users:workspace:{workspace_id}"
|
|
||||||
try:
|
|
||||||
cached_data = await aio_redis_get(cache_key)
|
|
||||||
if cached_data:
|
|
||||||
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
|
|
||||||
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
|
|
||||||
|
|
||||||
# 获取当前空间类型
|
# 获取当前空间类型
|
||||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||||
|
|
||||||
# 获取 end_users(已优化为批量查询)
|
# 获取分页的 end_users
|
||||||
end_users = memory_dashboard_service.get_workspace_end_users(
|
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||||
db=db,
|
db=db,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
current_user=current_user
|
current_user=current_user,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize,
|
||||||
|
keyword=keyword
|
||||||
)
|
)
|
||||||
|
|
||||||
|
end_users = end_users_result.get("items", [])
|
||||||
|
total = end_users_result.get("total", 0)
|
||||||
|
|
||||||
if not end_users:
|
if not end_users:
|
||||||
api_logger.info("工作空间下没有宿主")
|
api_logger.info(f"工作空间下没有宿主或当前页无数据: total={total}, page={page}")
|
||||||
# 缓存空结果,避免重复查询
|
return success(data={
|
||||||
try:
|
"items": [],
|
||||||
await aio_redis_set(cache_key, json.dumps([]), expire=30)
|
"page": {
|
||||||
except Exception as e:
|
"page": page,
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
"pagesize": pagesize,
|
||||||
return success(data=[], msg="宿主列表获取成功")
|
"total": total,
|
||||||
|
"hasnext": (page * pagesize) < total
|
||||||
|
}
|
||||||
|
}, msg="宿主列表获取成功")
|
||||||
|
|
||||||
end_user_ids = [str(user.id) for user in end_users]
|
end_user_ids = [str(user.id) for user in end_users]
|
||||||
|
|
||||||
# 并发执行两个独立的查询任务
|
# 并发执行两个独立的查询任务
|
||||||
async def get_memory_configs():
|
async def get_memory_configs():
|
||||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||||
@@ -116,7 +118,7 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
async def get_memory_nums():
|
async def get_memory_nums():
|
||||||
"""获取记忆数量"""
|
"""获取记忆数量"""
|
||||||
if current_workspace_type == "rag":
|
if current_workspace_type == "rag":
|
||||||
@@ -130,26 +132,18 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
elif current_workspace_type == "neo4j":
|
elif current_workspace_type == "neo4j":
|
||||||
# Neo4j 模式:并发查询(带并发限制)
|
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||||
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
|
try:
|
||||||
MAX_CONCURRENT_QUERIES = 10
|
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
|
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||||
|
except Exception as e:
|
||||||
async def get_neo4j_memory_num(end_user_id: str):
|
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||||
async with semaphore:
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
try:
|
|
||||||
return await memory_storage_service.search_all(end_user_id)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
|
|
||||||
return {"total": 0}
|
|
||||||
|
|
||||||
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
|
|
||||||
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
|
|
||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||||
try:
|
try:
|
||||||
from app.celery_app import celery_app as _celery_app
|
from app.celery_app import celery_app as _celery_app
|
||||||
@@ -170,13 +164,13 @@ async def get_workspace_end_users(
|
|||||||
get_memory_configs(),
|
get_memory_configs(),
|
||||||
get_memory_nums()
|
get_memory_nums()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建结果(优化:使用列表推导式)
|
# 构建结果列表
|
||||||
result = []
|
items = []
|
||||||
for end_user in end_users:
|
for end_user in end_users:
|
||||||
user_id = str(end_user.id)
|
user_id = str(end_user.id)
|
||||||
config_info = memory_configs_map.get(user_id, {})
|
config_info = memory_configs_map.get(user_id, {})
|
||||||
result.append({
|
items.append({
|
||||||
'end_user': {
|
'end_user': {
|
||||||
'id': user_id,
|
'id': user_id,
|
||||||
'other_name': end_user.other_name
|
'other_name': end_user.other_name
|
||||||
@@ -187,12 +181,6 @@ async def get_workspace_end_users(
|
|||||||
"memory_config_name": config_info.get("memory_config_name")
|
"memory_config_name": config_info.get("memory_config_name")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
# 写入缓存(30秒过期)
|
|
||||||
try:
|
|
||||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
|
||||||
except Exception as e:
|
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
|
||||||
|
|
||||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
try:
|
try:
|
||||||
@@ -202,7 +190,18 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
# 构建分页响应
|
||||||
|
result = {
|
||||||
|
"items": items,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * pagesize) < total
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
ForgettingCurveRequest,
|
ForgettingCurveRequest,
|
||||||
ForgettingCurveResponse,
|
ForgettingCurveResponse,
|
||||||
ForgettingCurvePoint,
|
ForgettingCurvePoint,
|
||||||
|
PendingNodesResponse,
|
||||||
)
|
)
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
@@ -308,6 +309,100 @@ async def get_forgetting_stats(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/pending-nodes", response_model=ApiResponse)
|
||||||
|
async def get_pending_nodes(
|
||||||
|
end_user_id: str,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
此接口独立分页,与 /stats 接口分离。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 组ID(即 end_user_id,必填)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
current_user: 当前用户
|
||||||
|
db: 数据库会话
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ApiResponse: 包含待遗忘节点列表和分页信息的响应
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
- 第1页,每页10条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
|
||||||
|
- 第2页,每页20条:GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- page 从1开始,pagesize 必须大于0
|
||||||
|
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# 检查用户是否已选择工作空间
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
# 验证 end_user_id 必填
|
||||||
|
if not end_user_id:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
|
||||||
|
|
||||||
|
# 通过 end_user_id 获取关联的 config_id
|
||||||
|
try:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
config_id = resolve_config_id(config_id, db)
|
||||||
|
|
||||||
|
if config_id is None:
|
||||||
|
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
|
||||||
|
|
||||||
|
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
|
||||||
|
except ValueError as e:
|
||||||
|
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
|
||||||
|
|
||||||
|
# 验证分页参数
|
||||||
|
if page < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
|
||||||
|
if pagesize < 1:
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
|
||||||
|
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 调用服务层获取待遗忘节点列表
|
||||||
|
result = await forget_service.get_pending_nodes(
|
||||||
|
db=db,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
config_id=config_id,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建响应
|
||||||
|
response_data = PendingNodesResponse(**result)
|
||||||
|
|
||||||
|
return success(data=response_data.model_dump(), msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forgetting_curve", response_model=ApiResponse)
|
@router.post("/forgetting_curve", response_model=ApiResponse)
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
request: ForgettingCurveRequest,
|
request: ForgettingCurveRequest,
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from app.services.conversation_service import ConversationService
|
|||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.services.shared_chat_service import SharedChatService
|
from app.services.shared_chat_service import SharedChatService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||||
|
|
||||||
@@ -259,8 +260,41 @@ def get_conversation(
|
|||||||
conv_service = ConversationService(db)
|
conv_service = ConversationService(db)
|
||||||
messages = conv_service.get_messages(conversation_id)
|
messages = conv_service.get_messages(conversation_id)
|
||||||
|
|
||||||
# 构建响应
|
file_ids = []
|
||||||
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
|
message_file_id_map = {}
|
||||||
|
|
||||||
|
# 第一次遍历:解析 audio_url,收集所有有效的 file_id
|
||||||
|
for idx, m in enumerate(messages):
|
||||||
|
if m.role == "assistant" and m.meta_data:
|
||||||
|
audio_url = m.meta_data.get("audio_url")
|
||||||
|
if not audio_url:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
# audio_url 无法解析为 UUID,标记为 unknown
|
||||||
|
m.meta_data["audio_status"] = "unknown"
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_ids.append(file_id)
|
||||||
|
message_file_id_map[idx] = file_id
|
||||||
|
|
||||||
|
# 批量查询所有相关的 FileMetadata
|
||||||
|
file_status_map = {}
|
||||||
|
if file_ids:
|
||||||
|
file_metas = (
|
||||||
|
db.query(FileMetadata)
|
||||||
|
.filter(FileMetadata.id.in_(set(file_ids)))
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
file_status_map = {fm.id: fm.status for fm in file_metas}
|
||||||
|
|
||||||
|
# 第二次遍历:将查询结果映射回消息
|
||||||
|
for idx, file_id in message_file_id_map.items():
|
||||||
|
m = messages[idx]
|
||||||
|
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
|
||||||
|
|
||||||
|
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
|
||||||
conv_dict["messages"] = [
|
conv_dict["messages"] = [
|
||||||
conversation_schema.Message.model_validate(m) for m in messages
|
conversation_schema.Message.model_validate(m) for m in messages
|
||||||
]
|
]
|
||||||
@@ -320,6 +354,16 @@ async def chat(
|
|||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id
|
original_user_id=user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Only extract and set memory_config_id when the end user doesn't have one yet
|
||||||
|
if not new_end_user.memory_config_id:
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
memory_config_service = MemoryConfigService(db)
|
||||||
|
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
|
||||||
|
if memory_config_id:
|
||||||
|
new_end_user.memory_config_id = memory_config_id
|
||||||
|
db.commit()
|
||||||
|
db.refresh(new_end_user)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# appid = share.app_id
|
# appid = share.app_id
|
||||||
@@ -410,30 +454,6 @@ async def chat(
|
|||||||
agent_config = agent_config_4_app_release(release)
|
agent_config = agent_config_4_app_release(release)
|
||||||
|
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
# async def event_generator():
|
|
||||||
# async for event in service.chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
async for event in app_chat_service.agnet_chat_stream(
|
async for event in app_chat_service.agnet_chat_stream(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
@@ -459,20 +479,6 @@ async def chat(
|
|||||||
"X-Accel-Buffering": "no"
|
"X-Accel-Buffering": "no"
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
# 非流式返回
|
|
||||||
# result = await service.chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
result = await app_chat_service.agnet_chat(
|
result = await app_chat_service.agnet_chat(
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
@@ -531,48 +537,6 @@ async def chat(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||||
# 多 Agent 流式返回
|
|
||||||
# if payload.stream:
|
|
||||||
# async def event_generator():
|
|
||||||
# async for event in service.multi_agent_chat_stream(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# ):
|
|
||||||
# yield event
|
|
||||||
|
|
||||||
# return StreamingResponse(
|
|
||||||
# event_generator(),
|
|
||||||
# media_type="text/event-stream",
|
|
||||||
# headers={
|
|
||||||
# "Cache-Control": "no-cache",
|
|
||||||
# "Connection": "keep-alive",
|
|
||||||
# "X-Accel-Buffering": "no"
|
|
||||||
# }
|
|
||||||
# )
|
|
||||||
|
|
||||||
# # 多 Agent 非流式返回
|
|
||||||
# result = await service.multi_agent_chat(
|
|
||||||
# share_token=share_token,
|
|
||||||
# message=payload.message,
|
|
||||||
# conversation_id=conversation.id, # 使用已创建的会话 ID
|
|
||||||
# user_id=str(new_end_user.id), # 转换为字符串
|
|
||||||
# variables=payload.variables,
|
|
||||||
# password=password,
|
|
||||||
# web_search=payload.web_search,
|
|
||||||
# memory=payload.memory,
|
|
||||||
# storage_type=storage_type,
|
|
||||||
# user_rag_memory_id=user_rag_memory_id
|
|
||||||
# )
|
|
||||||
|
|
||||||
# return success(data=conversation_schema.ChatResponse(**result))
|
|
||||||
elif app_type == AppType.WORKFLOW:
|
elif app_type == AppType.WORKFLOW:
|
||||||
config = workflow_config_4_app_release(release)
|
config = workflow_config_4_app_release(release)
|
||||||
if not config.id:
|
if not config.id:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
认证方式: API Key
|
认证方式: API Key
|
||||||
"""
|
"""
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller
|
from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller
|
||||||
|
|
||||||
# 创建 V1 API 路由器
|
# 创建 V1 API 路由器
|
||||||
service_router = APIRouter()
|
service_router = APIRouter()
|
||||||
@@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router)
|
|||||||
service_router.include_router(rag_api_file_controller.router)
|
service_router.include_router(rag_api_file_controller.router)
|
||||||
service_router.include_router(rag_api_chunk_controller.router)
|
service_router.include_router(rag_api_chunk_controller.router)
|
||||||
service_router.include_router(memory_api_controller.router)
|
service_router.include_router(memory_api_controller.router)
|
||||||
|
service_router.include_router(end_user_api_controller.router)
|
||||||
|
|
||||||
__all__ = ["service_router"]
|
__all__ = ["service_router"]
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ async def chat(
|
|||||||
|
|
||||||
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
|
||||||
other_id = payload.user_id
|
other_id = payload.user_id
|
||||||
workspace_id = app.workspace_id
|
workspace_id = api_key_auth.workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
|||||||
92
api/app/controllers/service/end_user_api_controller.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""End User 服务接口 - 基于 API Key 认证"""
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, Request
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.api_key_auth import require_api_key
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
|
from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/end_user", tags=["V1 - End User API"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/create")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def create_end_user(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
message: str = Body(..., description="Request body"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create or retrieve an end user for the workspace.
|
||||||
|
|
||||||
|
Creates a new end user and connects it to a memory configuration.
|
||||||
|
If an end user with the same other_id already exists in the workspace,
|
||||||
|
returns the existing one.
|
||||||
|
|
||||||
|
Optionally accepts a memory_config_id to connect the end user to a specific
|
||||||
|
memory configuration. If not provided, falls back to the workspace default config.
|
||||||
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = CreateEndUserRequest(**body)
|
||||||
|
workspace_id = api_key_auth.workspace_id
|
||||||
|
|
||||||
|
logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}")
|
||||||
|
|
||||||
|
# Resolve memory_config_id: explicit > workspace default
|
||||||
|
memory_config_id = None
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
|
||||||
|
if payload.memory_config_id:
|
||||||
|
try:
|
||||||
|
memory_config_id = uuid.UUID(payload.memory_config_id)
|
||||||
|
except ValueError:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Invalid memory_config_id format: {payload.memory_config_id}",
|
||||||
|
BizCode.INVALID_PARAMETER
|
||||||
|
)
|
||||||
|
config = config_service.get_config_with_fallback(memory_config_id, workspace_id)
|
||||||
|
if not config:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Memory config not found: {payload.memory_config_id}",
|
||||||
|
BizCode.MEMORY_CONFIG_NOT_FOUND
|
||||||
|
)
|
||||||
|
memory_config_id = config.config_id
|
||||||
|
else:
|
||||||
|
default_config = config_service.get_workspace_default_config(workspace_id)
|
||||||
|
if default_config:
|
||||||
|
memory_config_id = default_config.config_id
|
||||||
|
logger.info(f"Using workspace default memory config: {memory_config_id}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"No default memory config found for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_or_create_end_user_with_config(
|
||||||
|
app_id=api_key_auth.resource_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
other_id=payload.other_id,
|
||||||
|
memory_config_id=memory_config_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"End user ready: {end_user.id}")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": str(end_user.id),
|
||||||
|
"other_id": end_user.other_id or "",
|
||||||
|
"other_name": end_user.other_name or "",
|
||||||
|
"workspace_id": str(end_user.workspace_id),
|
||||||
|
"memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||||
@@ -111,6 +111,18 @@ def get_current_user_info(
|
|||||||
break
|
break
|
||||||
|
|
||||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
|
|
||||||
|
# 设置权限:如果用户来自 SSO Source,则使用该 Source 的 permissions;否则返回 "all" 表示拥有所有权限
|
||||||
|
if current_user.external_source:
|
||||||
|
from premium.sso.models import SSOSource
|
||||||
|
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
|
||||||
|
if source and source.permissions:
|
||||||
|
result_schema.permissions = source.permissions
|
||||||
|
else:
|
||||||
|
result_schema.permissions = []
|
||||||
|
else:
|
||||||
|
result_schema.permissions = ["all"]
|
||||||
|
|
||||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@@ -135,7 +147,6 @@ def get_tenant_superusers(
|
|||||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{user_id}", response_model=ApiResponse)
|
@router.get("/{user_id}", response_model=ApiResponse)
|
||||||
def get_user_info_by_id(
|
def get_user_info_by_id(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
|
|||||||
@@ -11,18 +11,14 @@ LangChain Agent 封装
|
|||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||||
|
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
|
||||||
from app.db import get_db
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
|
||||||
from app.models.models_model import ModelType, ModelProvider
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
from langchain_core.tools import BaseTool
|
from langchain_core.tools import BaseTool
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.models.models_model import ModelType
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
@@ -226,10 +222,9 @@ class LangChainAgent:
|
|||||||
Returns:
|
Returns:
|
||||||
List[BaseMessage]: 消息列表
|
List[BaseMessage]: 消息列表
|
||||||
"""
|
"""
|
||||||
messages = []
|
messages:list = [SystemMessage(content=self.system_prompt)]
|
||||||
|
|
||||||
# 添加系统提示词
|
# 添加系统提示词
|
||||||
messages.append(SystemMessage(content=self.system_prompt))
|
|
||||||
|
|
||||||
# 添加历史消息
|
# 添加历史消息
|
||||||
if history:
|
if history:
|
||||||
@@ -320,12 +315,7 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None, # 添加这个参数
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""执行对话
|
"""执行对话
|
||||||
|
|
||||||
@@ -333,32 +323,12 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
history: 历史消息列表 [{"role": "user/assistant", "content": "..."}]
|
||||||
context: 上下文信息(如知识库检索结果)
|
context: 上下文信息(如知识库检索结果)
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dict: 包含 content 和元数据的字典
|
Dict: 包含 content 和元数据的字典
|
||||||
"""
|
"""
|
||||||
message_chat = message
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
from app.services.memory_agent_service import (
|
|
||||||
get_end_user_connected_config,
|
|
||||||
)
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
|
||||||
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -445,9 +415,6 @@ class LangChainAgent:
|
|||||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
|
|
||||||
actual_config_id)
|
|
||||||
response = {
|
response = {
|
||||||
"content": content,
|
"content": content,
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
@@ -478,12 +445,7 @@ class LangChainAgent:
|
|||||||
message: str,
|
message: str,
|
||||||
history: Optional[List[Dict[str, str]]] = None,
|
history: Optional[List[Dict[str, str]]] = None,
|
||||||
context: Optional[str] = None,
|
context: Optional[str] = None,
|
||||||
end_user_id: Optional[str] = None,
|
files: Optional[List[Dict[str, Any]]] = None
|
||||||
config_id: Optional[str] = None,
|
|
||||||
storage_type: Optional[str] = None,
|
|
||||||
user_rag_memory_id: Optional[str] = None,
|
|
||||||
memory_flag: Optional[bool] = True,
|
|
||||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
|
||||||
) -> AsyncGenerator[str | int, None]:
|
) -> AsyncGenerator[str | int, None]:
|
||||||
"""执行流式对话
|
"""执行流式对话
|
||||||
|
|
||||||
@@ -491,6 +453,7 @@ class LangChainAgent:
|
|||||||
message: 用户消息
|
message: 用户消息
|
||||||
history: 历史消息列表
|
history: 历史消息列表
|
||||||
context: 上下文信息
|
context: 上下文信息
|
||||||
|
files: 多模态文件
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
str: 消息内容块
|
str: 消息内容块
|
||||||
@@ -501,23 +464,6 @@ class LangChainAgent:
|
|||||||
logger.info(f" Has tools: {bool(self.tools)}")
|
logger.info(f" Has tools: {bool(self.tools)}")
|
||||||
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
logger.info(f" Tool count: {len(self.tools) if self.tools else 0}")
|
||||||
logger.info("=" * 80)
|
logger.info("=" * 80)
|
||||||
message_chat = message
|
|
||||||
actual_config_id = config_id
|
|
||||||
# If config_id is None, try to get from end_user's connected config
|
|
||||||
if actual_config_id is None and end_user_id:
|
|
||||||
try:
|
|
||||||
db = next(get_db())
|
|
||||||
try:
|
|
||||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
|
||||||
actual_config_id = connected_config.get("memory_config_id")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}")
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to get db session: {e}")
|
|
||||||
|
|
||||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
|
||||||
try:
|
try:
|
||||||
# 准备消息列表(支持多模态)
|
# 准备消息列表(支持多模态)
|
||||||
messages = self._prepare_messages(message, history, context, files)
|
messages = self._prepare_messages(message, history, context, files)
|
||||||
@@ -527,17 +473,18 @@ class LangChainAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
yielded_content = False
|
|
||||||
|
|
||||||
# 统一使用 agent 的 astream_events 实现流式输出
|
# 统一使用 agent 的 astream_events 实现流式输出
|
||||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||||
full_content = ''
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
|
last_event = {}
|
||||||
async for event in self.agent.astream_events(
|
async for event in self.agent.astream_events(
|
||||||
{"messages": messages},
|
{"messages": messages},
|
||||||
version="v2",
|
version="v2",
|
||||||
config={"recursion_limit": self.max_iterations}
|
config={"recursion_limit": self.max_iterations}
|
||||||
):
|
):
|
||||||
|
last_event = event
|
||||||
chunk_count += 1
|
chunk_count += 1
|
||||||
kind = event.get("event")
|
kind = event.get("event")
|
||||||
|
|
||||||
@@ -551,7 +498,6 @@ class LangChainAgent:
|
|||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -562,18 +508,15 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
elif kind == "on_llm_stream":
|
elif kind == "on_llm_stream":
|
||||||
# 另一种 LLM 流式事件
|
# 另一种 LLM 流式事件
|
||||||
@@ -584,7 +527,6 @@ class LangChainAgent:
|
|||||||
if isinstance(chunk_content, str) and chunk_content:
|
if isinstance(chunk_content, str) and chunk_content:
|
||||||
full_content += chunk_content
|
full_content += chunk_content
|
||||||
yield chunk_content
|
yield chunk_content
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk_content, list):
|
elif isinstance(chunk_content, list):
|
||||||
# 多模态响应:提取文本部分
|
# 多模态响应:提取文本部分
|
||||||
for item in chunk_content:
|
for item in chunk_content:
|
||||||
@@ -595,22 +537,18 @@ class LangChainAgent:
|
|||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||||
elif item.get("type") == "text":
|
elif item.get("type") == "text":
|
||||||
text = item.get("text", "")
|
text = item.get("text", "")
|
||||||
if text:
|
if text:
|
||||||
full_content += text
|
full_content += text
|
||||||
yield text
|
yield text
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
full_content += item
|
full_content += item
|
||||||
yield item
|
yield item
|
||||||
yielded_content = True
|
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, str):
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
yield chunk
|
yield chunk
|
||||||
yielded_content = True
|
|
||||||
|
|
||||||
# 记录工具调用(可选)
|
# 记录工具调用(可选)
|
||||||
elif kind == "on_tool_start":
|
elif kind == "on_tool_start":
|
||||||
@@ -620,17 +558,14 @@ class LangChainAgent:
|
|||||||
|
|
||||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||||
# 统计token消耗
|
# 统计token消耗
|
||||||
# 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages
|
output_messages = last_event.get("data", {}).get("output", {}).get("messages", [])
|
||||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
|
||||||
for msg in reversed(output_messages):
|
for msg in reversed(output_messages):
|
||||||
if isinstance(msg, AIMessage):
|
if isinstance(msg, AIMessage):
|
||||||
stream_total_tokens = self._extract_tokens_from_message(msg)
|
stream_total_tokens = self._extract_tokens_from_message(msg)
|
||||||
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}")
|
||||||
yield stream_total_tokens
|
yield stream_total_tokens
|
||||||
break
|
break
|
||||||
if memory_flag:
|
|
||||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
|
|
||||||
actual_config_id)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_konwledges_server import write_rag
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
from app.services.task_service import get_task_memory_write_result
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
@@ -21,25 +20,6 @@ logger = get_agent_logger(__name__)
|
|||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
|
||||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
|
||||||
"""
|
|
||||||
Write messages to RAG storage system
|
|
||||||
|
|
||||||
Combines user and AI messages into a single string format and stores them
|
|
||||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
end_user_id: User identifier for the conversation
|
|
||||||
user_message: User's input message content
|
|
||||||
ai_message: AI's response message content
|
|
||||||
user_rag_memory_id: RAG memory identifier for storage location
|
|
||||||
"""
|
|
||||||
# RAG mode: combine messages into string format (maintain original logic)
|
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
|
||||||
|
|
||||||
|
|
||||||
async def write(
|
async def write(
|
||||||
storage_type,
|
storage_type,
|
||||||
end_user_id,
|
end_user_id,
|
||||||
@@ -118,7 +98,7 @@ async def write(
|
|||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
|
|
||||||
|
|
||||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
async def term_memory_save(end_user_id, strategy_type, scope):
|
||||||
"""
|
"""
|
||||||
Save long-term memory data to database
|
Save long-term memory data to database
|
||||||
|
|
||||||
@@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
to long-term memory storage.
|
to long-term memory storage.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
long_term_messages: Long-term message data to be saved
|
|
||||||
actual_config_id: Configuration identifier for memory settings
|
|
||||||
end_user_id: User identifier for memory association
|
end_user_id: User identifier for memory association
|
||||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||||
scope: Scope/window size for memory processing
|
scope: Scope/window size for memory processing
|
||||||
"""
|
"""
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
@@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if not result:
|
||||||
|
logger.warning(f"No write data found for user {end_user_id}")
|
||||||
|
return
|
||||||
|
if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]:
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
if len(chunk_data) == scope:
|
if len(chunk_data) == scope:
|
||||||
@@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty
|
|||||||
logger.info(f'写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
"""Window-based dialogue processing"""
|
|
||||||
|
|
||||||
|
|
||||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||||
"""
|
"""
|
||||||
Process dialogue based on window size and write to Neo4j
|
Process dialogue based on window size and write to Neo4j
|
||||||
@@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
|||||||
langchain_messages: Original message data list
|
langchain_messages: Original message data list
|
||||||
scope: Window size determining when to trigger long-term storage
|
scope: Window size determining when to trigger long-term storage
|
||||||
"""
|
"""
|
||||||
scope = scope
|
is_end_user_has_history = count_store.get_sessions_count(end_user_id)
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
if is_end_user_has_history:
|
||||||
if is_end_user_id is not False:
|
end_user_visit_count, redis_messages = is_end_user_has_history
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
else:
|
||||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
return
|
||||||
is_end_user_id += 1
|
end_user_visit_count += 1
|
||||||
langchain_messages += redis_messages
|
if end_user_visit_count < scope:
|
||||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
redis_messages.extend(langchain_messages)
|
||||||
elif int(is_end_user_id) == int(scope):
|
count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages)
|
||||||
|
else:
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = redis_messages
|
redis_messages.extend(langchain_messages)
|
||||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
await write(
|
write_message_task.delay(
|
||||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
end_user_id, # end_user_id: User ID
|
||||||
end_user_id,
|
redis_messages, # message: JSON string format message list
|
||||||
"",
|
config_id, # config_id: Configuration ID string
|
||||||
"",
|
AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j"
|
||||||
None,
|
"" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
end_user_id,
|
|
||||||
config_id,
|
|
||||||
formatted_messages
|
|
||||||
)
|
)
|
||||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.update_sessions_count(end_user_id, 0, [])
|
||||||
else:
|
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
|
||||||
|
|
||||||
|
|
||||||
"""Time-based memory processing"""
|
|
||||||
|
|
||||||
|
|
||||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||||
@@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
|||||||
return result_dict
|
return result_dict
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"is_same_event": False,
|
"is_same_event": False,
|
||||||
|
|||||||
@@ -1,49 +1,25 @@
|
|||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from langgraph.constants import END, START
|
|
||||||
from langgraph.graph import StateGraph
|
|
||||||
|
|
||||||
from app.db import get_db, get_db_context
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
aggregate_judgment
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
|
from app.db import get_db_context
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
from app.services.memory_konwledges_server import write_rag
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
if sys.platform.startswith("win"):
|
|
||||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
|
||||||
|
|
||||||
|
async def long_term_storage(
|
||||||
@asynccontextmanager
|
long_term_type: str,
|
||||||
async def make_write_graph():
|
langchain_messages: list,
|
||||||
"""
|
memory_config_id: str,
|
||||||
Create a write graph workflow for memory operations.
|
end_user_id: str,
|
||||||
|
scope: int = 6
|
||||||
Args:
|
):
|
||||||
user_id: User identifier
|
|
||||||
tools: MCP tools loaded from session
|
|
||||||
apply_id: Application identifier
|
|
||||||
end_user_id: Group identifier
|
|
||||||
memory_config: MemoryConfig object containing all configuration
|
|
||||||
"""
|
|
||||||
workflow = StateGraph(WriteState)
|
|
||||||
workflow.add_node("save_neo4j", write_node)
|
|
||||||
workflow.add_edge(START, "save_neo4j")
|
|
||||||
workflow.add_edge("save_neo4j", END)
|
|
||||||
|
|
||||||
graph = workflow.compile()
|
|
||||||
|
|
||||||
yield graph
|
|
||||||
|
|
||||||
|
|
||||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
|
||||||
end_user_id: str = '', scope: int = 6):
|
|
||||||
"""
|
"""
|
||||||
Handle long-term memory storage with different strategies
|
Handle long-term memory storage with different strategies
|
||||||
|
|
||||||
@@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l
|
|||||||
Args:
|
Args:
|
||||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||||
langchain_messages: List of messages to store
|
langchain_messages: List of messages to store
|
||||||
memory_config: Memory configuration identifier
|
memory_config_id: Memory configuration identifier
|
||||||
end_user_id: User group identifier
|
end_user_id: User group identifier
|
||||||
scope: Scope parameter for chunk-based storage (default: 6)
|
scope: Scope parameter for chunk-based storage (default: 6)
|
||||||
"""
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
if langchain_messages is None:
|
||||||
aggregate_judgment
|
langchain_messages = []
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
write_store.save_session_write(end_user_id, langchain_messages)
|
write_store.save_session_write(end_user_id, langchain_messages)
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
memory_config = config_service.load_memory_config(
|
memory_config = config_service.load_memory_config(
|
||||||
config_id=memory_config, # 改为整数
|
config_id=memory_config_id, # 改为整数
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
# Dialogue window with 6 rounds of conversation
|
||||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||||
"""Time-based strategy"""
|
# Time-based strategy
|
||||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
"""Strategy 3: Aggregate judgment"""
|
# Aggregate judgment
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
|
||||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
async def write_long_term(
|
||||||
|
storage_type: str,
|
||||||
|
end_user_id: str,
|
||||||
|
messages: list[dict],
|
||||||
|
user_rag_memory_id: str,
|
||||||
|
actual_config_id: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Write long-term memory with different storage types
|
Write long-term memory with different storage types
|
||||||
|
|
||||||
@@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u
|
|||||||
Args:
|
Args:
|
||||||
storage_type: Type of storage (RAG or traditional)
|
storage_type: Type of storage (RAG or traditional)
|
||||||
end_user_id: User group identifier
|
end_user_id: User group identifier
|
||||||
message_chat: User message content
|
messages: message list
|
||||||
aimessages: AI response messages
|
|
||||||
user_rag_memory_id: RAG memory identifier
|
user_rag_memory_id: RAG memory identifier
|
||||||
actual_config_id: Actual configuration ID
|
actual_config_id: Actual configuration ID
|
||||||
"""
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
message_content = []
|
||||||
|
for message in messages:
|
||||||
|
message_content.append(f'{message.get("role")}:{message.get("content")}')
|
||||||
|
messages_string = "\n".join(message_content)
|
||||||
|
await write_rag(end_user_id, messages_string, user_rag_memory_id)
|
||||||
else:
|
else:
|
||||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
await long_term_storage(long_term_type=CHUNK,
|
||||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
langchain_messages=messages,
|
||||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
memory_config_id=actual_config_id,
|
||||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
end_user_id=end_user_id,
|
||||||
|
scope=SCOPE)
|
||||||
# async def main():
|
await term_memory_save(end_user_id, CHUNK, scope=SCOPE)
|
||||||
# """主函数 - 运行工作流"""
|
|
||||||
# langchain_messages = [
|
|
||||||
# {
|
|
||||||
# "role": "user",
|
|
||||||
# "content": "今天周五去爬山"
|
|
||||||
# },
|
|
||||||
# {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "好耶"
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# ]
|
|
||||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
|
||||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
|
||||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
|
||||||
#
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# import asyncio
|
|
||||||
# asyncio.run(main())
|
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ import uuid
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from typing import List, Dict, Any, Optional, Union
|
from typing import List, Dict, Any, Optional, Union
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.core.memory.agent.utils.redis_base import (
|
from app.core.memory.agent.utils.redis_base import (
|
||||||
serialize_messages,
|
serialize_messages,
|
||||||
deserialize_messages,
|
deserialize_messages,
|
||||||
fix_encoding,
|
fix_encoding,
|
||||||
format_session_data,
|
format_session_data,
|
||||||
@@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import (
|
|||||||
get_current_timestamp
|
get_current_timestamp
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RedisWriteStore:
|
class RedisWriteStore:
|
||||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
"""
|
"""
|
||||||
初始化 Redis 连接
|
初始化 Redis 连接
|
||||||
@@ -66,10 +67,10 @@ class RedisWriteStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session_write] 保存会话失败: {e}")
|
logger.error(f"[save_session_write] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||||
@@ -99,7 +100,7 @@ class RedisWriteStore:
|
|||||||
for key, data in zip(keys, all_data):
|
for key, data in zip(keys, all_data):
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
if data.get('sessionid') == userid:
|
if data.get('sessionid') == userid:
|
||||||
# 从 key 中提取 session_id: session:write:{session_id}
|
# 从 key 中提取 session_id: session:write:{session_id}
|
||||||
@@ -108,16 +109,16 @@ class RedisWriteStore:
|
|||||||
"sessionid": session_id,
|
"sessionid": session_id,
|
||||||
"messages": fix_encoding(data.get('messages', ''))
|
"messages": fix_encoding(data.get('messages', ''))
|
||||||
})
|
})
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
logger.error(f"[get_session_by_userid] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 获取所有 write 类型的会话数据
|
通过 end_user_id 获取所有 write 类型的会话数据
|
||||||
@@ -144,7 +145,7 @@ class RedisWriteStore:
|
|||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -158,12 +159,12 @@ class RedisWriteStore:
|
|||||||
for key, data in zip(keys, all_data):
|
for key, data in zip(keys, all_data):
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
if data.get('sessionid') == end_user_id:
|
if data.get('sessionid') == end_user_id:
|
||||||
# 从 key 中提取 session_id: session:write:{session_id}
|
# 从 key 中提取 session_id: session:write:{session_id}
|
||||||
session_id = key.split(':')[-1]
|
session_id = key.split(':')[-1]
|
||||||
|
|
||||||
# 构建完整的会话信息
|
# 构建完整的会话信息
|
||||||
session_info = {
|
session_info = {
|
||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
@@ -173,23 +174,21 @@ class RedisWriteStore:
|
|||||||
"starttime": data.get('starttime', '')
|
"starttime": data.get('starttime', '')
|
||||||
}
|
}
|
||||||
results.append(session_info)
|
results.append(session_info)
|
||||||
|
|
||||||
if not results:
|
if not results:
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 按时间排序(最新的在前)
|
# 按时间排序(最新的在前)
|
||||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||||
|
|
||||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True)
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def find_user_recent_sessions(self, userid: str,
|
def find_user_recent_sessions(self, userid: str,
|
||||||
minutes: int = 5) -> List[Dict[str, str]]:
|
minutes: int = 5) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||||
@@ -203,11 +202,11 @@ class RedisWriteStore:
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 只查询 write 类型的 key
|
# 只查询 write 类型的 key
|
||||||
keys = self.r.keys('session:write:*')
|
keys = self.r.keys('session:write:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -221,7 +220,7 @@ class RedisWriteStore:
|
|||||||
for data in all_data:
|
for data in all_data:
|
||||||
if not data:
|
if not data:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 从 write 类型读取,匹配 sessionid 字段
|
# 从 write 类型读取,匹配 sessionid 字段
|
||||||
if data.get('sessionid') == userid and data.get('starttime'):
|
if data.get('sessionid') == userid and data.get('starttime'):
|
||||||
# write 类型没有 aimessages,所以 Answer 为空
|
# write 类型没有 aimessages,所以 Answer 为空
|
||||||
@@ -230,15 +229,14 @@ class RedisWriteStore:
|
|||||||
"Answer": "",
|
"Answer": "",
|
||||||
"starttime": data.get('starttime', '')
|
"starttime": data.get('starttime', '')
|
||||||
})
|
})
|
||||||
|
|
||||||
# 根据时间范围过滤
|
# 根据时间范围过滤
|
||||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||||
# 排序并移除时间字段
|
# 排序并移除时间字段
|
||||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
result_items = sort_and_limit_results(filtered_items)
|
||||||
print(result_items)
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
@@ -258,7 +256,7 @@ class RedisWriteStore:
|
|||||||
|
|
||||||
class RedisCountStore:
|
class RedisCountStore:
|
||||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
"""
|
"""
|
||||||
初始化 Redis 连接
|
初始化 Redis 连接
|
||||||
@@ -278,7 +276,7 @@ class RedisCountStore:
|
|||||||
decode_responses=True,
|
decode_responses=True,
|
||||||
encoding='utf-8'
|
encoding='utf-8'
|
||||||
)
|
)
|
||||||
self.uudi = session_id
|
self.uuid = session_id
|
||||||
|
|
||||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -295,26 +293,26 @@ class RedisCountStore:
|
|||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
key = generate_session_key(session_id, key_type="count")
|
key = generate_session_key(session_id, key_type="count")
|
||||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, mapping={
|
pipe.hset(key, mapping={
|
||||||
"id": self.uudi,
|
"id": self.uuid,
|
||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"count": int(count),
|
"count": int(count),
|
||||||
"messages": serialize_messages(messages),
|
"messages": serialize_messages(messages),
|
||||||
"starttime": get_current_timestamp()
|
"starttime": get_current_timestamp()
|
||||||
})
|
})
|
||||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||||
|
|
||||||
# 创建索引:end_user_id -> session_id 映射
|
# 创建索引:end_user_id -> session_id 映射
|
||||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||||
|
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
|
|
||||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 查询访问次数统计
|
通过 end_user_id 查询访问次数统计
|
||||||
|
|
||||||
@@ -327,7 +325,7 @@ class RedisCountStore:
|
|||||||
try:
|
try:
|
||||||
# 使用索引键快速查找
|
# 使用索引键快速查找
|
||||||
index_key = f'session:count:index:{end_user_id}'
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
try:
|
try:
|
||||||
key_type = self.r.type(index_key)
|
key_type = self.r.type(index_key)
|
||||||
@@ -335,35 +333,40 @@ class RedisCountStore:
|
|||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 直接获取数据
|
# 直接获取数据
|
||||||
key = generate_session_key(session_id, key_type="count")
|
key = generate_session_key(session_id, key_type="count")
|
||||||
data = self.r.hgetall(key)
|
data = self.r.hgetall(key)
|
||||||
|
|
||||||
if not data:
|
if not data:
|
||||||
# 索引存在但数据不存在,清理索引
|
# 索引存在但数据不存在,清理索引
|
||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
count = data.get('count')
|
count = data.get('count')
|
||||||
messages_str = data.get('messages')
|
messages_str = data.get('messages')
|
||||||
|
|
||||||
if count is not None:
|
if count is not None:
|
||||||
messages = deserialize_messages(messages_str)
|
messages: list[dict] = deserialize_messages(messages_str)
|
||||||
return [int(count), messages]
|
return int(count), messages
|
||||||
|
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[get_sessions_count] 查询失败: {e}")
|
logger.error(f"[get_sessions_count] 查询失败: {e}")
|
||||||
return False
|
return False
|
||||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
|
||||||
messages: Any) -> bool:
|
def update_sessions_count(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_count: int,
|
||||||
|
messages: Any
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||||
|
|
||||||
@@ -378,39 +381,39 @@ class RedisCountStore:
|
|||||||
try:
|
try:
|
||||||
# 使用索引键快速查找
|
# 使用索引键快速查找
|
||||||
index_key = f'session:count:index:{end_user_id}'
|
index_key = f'session:count:index:{end_user_id}'
|
||||||
|
|
||||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||||
try:
|
try:
|
||||||
key_type = self.r.type(index_key)
|
key_type = self.r.type(index_key)
|
||||||
if key_type != 'string' and key_type != 'none':
|
if key_type != 'string' and key_type != 'none':
|
||||||
# 索引键类型错误,删除并返回 False
|
# 索引键类型错误,删除并返回 False
|
||||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||||
self.r.delete(index_key)
|
self.r.delete(index_key)
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
except Exception as type_error:
|
except Exception as type_error:
|
||||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||||
|
|
||||||
session_id = self.r.get(index_key)
|
session_id = self.r.get(index_key)
|
||||||
|
|
||||||
if not session_id:
|
if not session_id:
|
||||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 直接更新数据
|
# 直接更新数据
|
||||||
key = generate_session_key(session_id, key_type="count")
|
key = generate_session_key(session_id, key_type="count")
|
||||||
messages_str = serialize_messages(messages)
|
messages_str = serialize_messages(messages)
|
||||||
|
|
||||||
pipe = self.r.pipeline()
|
pipe = self.r.pipeline()
|
||||||
pipe.hset(key, 'count', int(new_count))
|
pipe.hset(key, 'count', str(new_count))
|
||||||
pipe.hset(key, 'messages', messages_str)
|
pipe.hset(key, 'messages', messages_str)
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[update_sessions_count] 更新失败: {e}")
|
logger.debug(f"[update_sessions_count] 更新失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_all_count_sessions(self) -> int:
|
def delete_all_count_sessions(self) -> int:
|
||||||
@@ -428,7 +431,7 @@ class RedisCountStore:
|
|||||||
|
|
||||||
class RedisSessionStore:
|
class RedisSessionStore:
|
||||||
"""Redis 会话存储类,用于管理会话数据"""
|
"""Redis 会话存储类,用于管理会话数据"""
|
||||||
|
|
||||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||||
"""
|
"""
|
||||||
初始化 Redis 连接
|
初始化 Redis 连接
|
||||||
@@ -451,9 +454,9 @@ class RedisSessionStore:
|
|||||||
self.uudi = session_id
|
self.uudi = session_id
|
||||||
|
|
||||||
# ==================== 写入操作 ====================
|
# ==================== 写入操作 ====================
|
||||||
|
|
||||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||||
apply_id: str, end_user_id: str) -> str:
|
apply_id: str, end_user_id: str) -> str:
|
||||||
"""
|
"""
|
||||||
写入一条会话数据,返回 session_id
|
写入一条会话数据,返回 session_id
|
||||||
|
|
||||||
@@ -483,14 +486,14 @@ class RedisSessionStore:
|
|||||||
})
|
})
|
||||||
result = pipe.execute()
|
result = pipe.execute()
|
||||||
|
|
||||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||||
return session_id
|
return session_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"[save_session] 保存会话失败: {e}")
|
logger.error(f"[save_session] 保存会话失败: {e}")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
# ==================== 读取操作 ====================
|
# ==================== 读取操作 ====================
|
||||||
|
|
||||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
读取一条会话数据
|
读取一条会话数据
|
||||||
@@ -520,8 +523,8 @@ class RedisSessionStore:
|
|||||||
sessions[sid] = self.get_session(sid)
|
sessions[sid] = self.get_session(sid)
|
||||||
return sessions
|
return sessions
|
||||||
|
|
||||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||||
end_user_id: str) -> List[Dict[str, str]]:
|
end_user_id: str) -> List[Dict[str, str]]:
|
||||||
"""
|
"""
|
||||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||||
|
|
||||||
@@ -535,10 +538,10 @@ class RedisSessionStore:
|
|||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 批量获取数据
|
# 批量获取数据
|
||||||
@@ -556,21 +559,21 @@ class RedisSessionStore:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if (data.get('apply_id') == apply_id and
|
if (data.get('apply_id') == apply_id and
|
||||||
data.get('end_user_id') == end_user_id):
|
data.get('end_user_id') == end_user_id):
|
||||||
# 支持模糊匹配或完全匹配 sessionid
|
# 支持模糊匹配或完全匹配 sessionid
|
||||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||||
matched_items.append(format_session_data(data, include_time=True))
|
matched_items.append(format_session_data(data, include_time=True))
|
||||||
|
|
||||||
# 排序、限制数量并移除时间字段
|
# 排序、限制数量并移除时间字段
|
||||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||||
|
|
||||||
return result_items
|
return result_items
|
||||||
|
|
||||||
# ==================== 更新操作 ====================
|
# ==================== 更新操作 ====================
|
||||||
|
|
||||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
更新单个字段
|
更新单个字段
|
||||||
@@ -591,7 +594,7 @@ class RedisSessionStore:
|
|||||||
return bool(results[0])
|
return bool(results[0])
|
||||||
|
|
||||||
# ==================== 删除操作 ====================
|
# ==================== 删除操作 ====================
|
||||||
|
|
||||||
def delete_session(self, session_id: str) -> int:
|
def delete_session(self, session_id: str) -> int:
|
||||||
"""
|
"""
|
||||||
删除单条会话
|
删除单条会话
|
||||||
@@ -632,7 +635,7 @@ class RedisSessionStore:
|
|||||||
|
|
||||||
keys = self.r.keys('session:*')
|
keys = self.r.keys('session:*')
|
||||||
if not keys:
|
if not keys:
|
||||||
print("[delete_duplicate_sessions] 没有会话数据")
|
logger.debug("[delete_duplicate_sessions] 没有会话数据")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 批量获取所有数据
|
# 批量获取所有数据
|
||||||
@@ -678,7 +681,7 @@ class RedisSessionStore:
|
|||||||
deleted_count += len(batch)
|
deleted_count += len(batch)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||||
return deleted_count
|
return deleted_count
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -151,11 +151,6 @@ async def write(
|
|||||||
|
|
||||||
# Step 3: Save all data to Neo4j database
|
# Step 3: Save all data to Neo4j database
|
||||||
step_start = time.time()
|
step_start = time.time()
|
||||||
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
|
|
||||||
try:
|
|
||||||
await create_fulltext_indexes()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# 添加死锁重试机制
|
# 添加死锁重试机制
|
||||||
max_retries = 3
|
max_retries = 3
|
||||||
@@ -279,5 +274,21 @@ async def write(
|
|||||||
except Exception as cache_err:
|
except Exception as cache_err:
|
||||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||||
|
|
||||||
|
# Close LLM/Embedder underlying httpx clients to prevent
|
||||||
|
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||||
|
for client_obj in (llm_client, embedder_client):
|
||||||
|
try:
|
||||||
|
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
|
||||||
|
if underlying is None:
|
||||||
|
continue
|
||||||
|
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
|
||||||
|
inner = getattr(underlying, '_model', underlying)
|
||||||
|
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
|
||||||
|
http_client = getattr(inner, 'async_client', None)
|
||||||
|
if http_client is not None and hasattr(http_client, 'aclose'):
|
||||||
|
await http_client.aclose()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
logger.info("=== Pipeline Complete ===")
|
logger.info("=== Pipeline Complete ===")
|
||||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class LLMClient(ABC):
|
|||||||
self.max_retries = self.config.max_retries
|
self.max_retries = self.config.max_retries
|
||||||
self.timeout = self.config.timeout
|
self.timeout = self.config.timeout
|
||||||
|
|
||||||
logger.info(
|
logger.debug(
|
||||||
f"初始化 LLM 客户端: provider={self.provider}, "
|
f"初始化 LLM 客户端: provider={self.provider}, "
|
||||||
f"model={self.model_name}, max_retries={self.max_retries}"
|
f"model={self.model_name}, max_retries={self.max_retries}"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
|
|||||||
type=type_
|
type=type_
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"OpenAI 客户端初始化完成: type={type_}")
|
logger.debug(f"OpenAI 客户端初始化完成: type={type_}")
|
||||||
|
|
||||||
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -30,6 +30,18 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def message_has_files(message: "ConversationMessage") -> bool:
|
||||||
|
"""检查消息是否包含文件。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 待检查的消息对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 如果消息包含文件则返回 True,否则返回 False
|
||||||
|
"""
|
||||||
|
return message.files and len(message.files) > 0
|
||||||
|
|
||||||
|
|
||||||
class DialogExtractionResponse(BaseModel):
|
class DialogExtractionResponse(BaseModel):
|
||||||
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
"""对话级一次性抽取的结构化返回,用于加速剪枝。
|
||||||
|
|
||||||
@@ -128,7 +140,7 @@ class SemanticPruner:
|
|||||||
1. 空消息
|
1. 空消息
|
||||||
2. 场景特定填充词库精确匹配
|
2. 场景特定填充词库精确匹配
|
||||||
3. 常见寒暄精确匹配
|
3. 常见寒暄精确匹配
|
||||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||||
5. 纯表情/标点
|
5. 纯表情/标点
|
||||||
"""
|
"""
|
||||||
t = message.msg.strip()
|
t = message.msg.strip()
|
||||||
@@ -482,6 +494,11 @@ class SemanticPruner:
|
|||||||
"""
|
"""
|
||||||
to_delete_ids: set = set()
|
to_delete_ids: set = set()
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||||
if self._is_filler_message(m):
|
if self._is_filler_message(m):
|
||||||
to_delete_ids.add(id(m))
|
to_delete_ids.add(id(m))
|
||||||
@@ -549,6 +566,11 @@ class SemanticPruner:
|
|||||||
to_delete_ids: set = set()
|
to_delete_ids: set = set()
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
msg_text = m.msg.strip()
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
|
||||||
|
continue
|
||||||
|
|
||||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||||
if self._is_filler_message(m):
|
if self._is_filler_message(m):
|
||||||
@@ -801,6 +823,12 @@ class SemanticPruner:
|
|||||||
|
|
||||||
for idx, m in enumerate(msgs):
|
for idx, m in enumerate(msgs):
|
||||||
msg_text = m.msg.strip()
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
|
# 最高优先级保护:带有文件的消息一律保留,不参与分类
|
||||||
|
if message_has_files(m):
|
||||||
|
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
|
||||||
|
llm_protected_msgs.append((idx, m)) # 放入保护列表
|
||||||
|
continue
|
||||||
|
|
||||||
if self._msg_matches_tokens(m, preserve_tokens):
|
if self._msg_matches_tokens(m, preserve_tokens):
|
||||||
llm_protected_msgs.append((idx, m))
|
llm_protected_msgs.append((idx, m))
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ class ExtractionOrchestrator:
|
|||||||
list[StatementEntityEdge],
|
list[StatementEntityEdge],
|
||||||
list[EntityEntityEdge],
|
list[EntityEntityEdge],
|
||||||
list[PerceptualEdge],
|
list[PerceptualEdge],
|
||||||
dict
|
list[DialogData]
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
运行完整的知识提取流水线(优化版:并行执行)
|
运行完整的知识提取流水线(优化版:并行执行)
|
||||||
@@ -295,6 +295,7 @@ class ExtractionOrchestrator:
|
|||||||
statement_entity_edges,
|
statement_entity_edges,
|
||||||
entity_entity_edges,
|
entity_entity_edges,
|
||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
) = await self._run_dedup_and_write_summary(
|
) = await self._run_dedup_and_write_summary(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
@@ -306,6 +307,11 @@ class ExtractionOrchestrator:
|
|||||||
dialog_data_list,
|
dialog_data_list,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 步骤 7: 同步用户别名到数据库表(仅正式模式)
|
||||||
|
if not is_pilot_run:
|
||||||
|
logger.info("步骤 7: 同步用户别名到 end_user 和 end_user_info 表")
|
||||||
|
await self._update_end_user_other_name(entity_nodes, dialog_data_list)
|
||||||
|
|
||||||
logger.info(f"知识提取流水线运行完成({mode_str})")
|
logger.info(f"知识提取流水线运行完成({mode_str})")
|
||||||
return (
|
return (
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
@@ -1399,7 +1405,8 @@ class ExtractionOrchestrator:
|
|||||||
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
logger.info(f"同步 Neo4j aliases 到 end_user_info: {neo4j_aliases}")
|
||||||
else:
|
else:
|
||||||
first_alias = current_aliases[0].strip() if current_aliases else ""
|
first_alias = current_aliases[0].strip() if current_aliases else ""
|
||||||
if first_alias:
|
# 确保 first_alias 不是占位名称
|
||||||
|
if first_alias and first_alias not in self.USER_PLACEHOLDER_NAMES:
|
||||||
db.add(EndUserInfo(
|
db.add(EndUserInfo(
|
||||||
end_user_id=end_user_uuid,
|
end_user_id=end_user_uuid,
|
||||||
other_name=first_alias,
|
other_name=first_alias,
|
||||||
@@ -1415,29 +1422,33 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 用户实体占位名称,不允许作为 other_name 或出现在 aliases 中
|
||||||
|
USER_PLACEHOLDER_NAMES = {'用户', '我', 'User', 'I'}
|
||||||
|
|
||||||
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
def _extract_current_aliases(self, entity_nodes: List[ExtractedEntityNode]) -> List[str]:
|
||||||
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
"""从实体节点提取用户别名(保持 LLM 提取的原始顺序,不进行任何排序)
|
||||||
|
|
||||||
这个方法直接返回 LLM 提取的别名列表,不做任何修改。
|
这个方法直接返回 LLM 提取的别名列表,并过滤掉占位名称("用户"、"我"、"User"、"I")。
|
||||||
第一个别名将被用作 other_name。
|
第一个别名将被用作 other_name。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity_nodes: 实体节点列表
|
entity_nodes: 实体节点列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
别名列表(保持 LLM 提取的原始顺序)
|
别名列表(保持 LLM 提取的原始顺序,已过滤占位名称)
|
||||||
"""
|
"""
|
||||||
USER_NAMES = {'用户', '我', 'User', 'I'}
|
|
||||||
for entity in entity_nodes:
|
for entity in entity_nodes:
|
||||||
if getattr(entity, 'name', '').strip() in USER_NAMES:
|
if getattr(entity, 'name', '').strip() in self.USER_PLACEHOLDER_NAMES:
|
||||||
aliases = getattr(entity, 'aliases', []) or []
|
aliases = getattr(entity, 'aliases', []) or []
|
||||||
logger.debug(f"提取到用户别名(原始顺序): {aliases}")
|
# 过滤掉占位名称,防止 "用户"/"我"/"User"/"I" 被存入 aliases 和 other_name
|
||||||
return aliases
|
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
logger.debug(f"提取到用户别名(原始顺序,已过滤占位名称): {filtered}")
|
||||||
|
return filtered
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
async def _fetch_neo4j_user_aliases(self, end_user_id: str) -> List[str]:
|
||||||
"""从 Neo4j 查询用户实体的完整 aliases 列表"""
|
"""从 Neo4j 查询用户实体的完整 aliases 列表(已过滤占位名称)"""
|
||||||
cypher = """
|
cypher = """
|
||||||
MATCH (e:ExtractedEntity)
|
MATCH (e:ExtractedEntity)
|
||||||
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
WHERE e.end_user_id = $end_user_id AND e.name IN ['用户', '我', 'User', 'I']
|
||||||
@@ -1451,7 +1462,10 @@ class ExtractionOrchestrator:
|
|||||||
aliases = result[0].get('aliases') or []
|
aliases = result[0].get('aliases') or []
|
||||||
if not aliases:
|
if not aliases:
|
||||||
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
logger.debug(f"Neo4j 用户实体 aliases 为空: end_user_id={end_user_id}")
|
||||||
return aliases
|
return []
|
||||||
|
# 过滤掉占位名称,防止历史脏数据传播
|
||||||
|
filtered = [a for a in aliases if a.strip() not in self.USER_PLACEHOLDER_NAMES]
|
||||||
|
return filtered
|
||||||
|
|
||||||
def _resolve_other_name(
|
def _resolve_other_name(
|
||||||
self,
|
self,
|
||||||
@@ -1463,14 +1477,25 @@ class ExtractionOrchestrator:
|
|||||||
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
决定 other_name 是否需要更新,返回新值;无需更新返回 None。
|
||||||
|
|
||||||
决策规则:
|
决策规则:
|
||||||
- 为空 → 用本次对话第一个别名
|
- 为空或为占位名称 → 用本次对话第一个别名
|
||||||
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
- 不在 Neo4j aliases 中 → 用 Neo4j 第一个别名(说明已被删除)
|
||||||
- 否则 → 保持不变(返回 None)
|
- 否则 → 保持不变(返回 None)
|
||||||
|
|
||||||
|
注意:返回值不允许是占位名称("用户"、"我"、"User"、"I")
|
||||||
"""
|
"""
|
||||||
if not current or not current.strip():
|
# 当前值为空或为占位名称时,需要更新
|
||||||
return current_aliases[0].strip() if current_aliases else None
|
if not current or not current.strip() or current.strip() in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
candidate = current_aliases[0].strip() if current_aliases else None
|
||||||
|
# 确保候选值不是占位名称
|
||||||
|
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
return None
|
||||||
|
return candidate
|
||||||
if current not in neo4j_aliases:
|
if current not in neo4j_aliases:
|
||||||
return neo4j_aliases[0].strip() if neo4j_aliases else None
|
candidate = neo4j_aliases[0].strip() if neo4j_aliases else None
|
||||||
|
# 确保候选值不是占位名称
|
||||||
|
if candidate and candidate in self.USER_PLACEHOLDER_NAMES:
|
||||||
|
return None
|
||||||
|
return candidate
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -1492,6 +1517,7 @@ class ExtractionOrchestrator:
|
|||||||
list[StatementChunkEdge],
|
list[StatementChunkEdge],
|
||||||
list[StatementEntityEdge],
|
list[StatementEntityEdge],
|
||||||
list[EntityEntityEdge],
|
list[EntityEntityEdge],
|
||||||
|
list[DialogData],
|
||||||
dict
|
dict
|
||||||
]:
|
]:
|
||||||
"""
|
"""
|
||||||
@@ -1555,6 +1581,8 @@ class ExtractionOrchestrator:
|
|||||||
statement_chunk_edges,
|
statement_chunk_edges,
|
||||||
dedup_statement_entity_edges,
|
dedup_statement_entity_edges,
|
||||||
dedup_entity_entity_edges,
|
dedup_entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
)
|
)
|
||||||
|
|
||||||
final_entity_nodes = dedup_entity_nodes
|
final_entity_nodes = dedup_entity_nodes
|
||||||
@@ -1562,7 +1590,16 @@ class ExtractionOrchestrator:
|
|||||||
final_entity_entity_edges = dedup_entity_entity_edges
|
final_entity_entity_edges = dedup_entity_entity_edges
|
||||||
else:
|
else:
|
||||||
# 正式模式:执行完整的两阶段去重
|
# 正式模式:执行完整的两阶段去重
|
||||||
result_tuple = await dedup_layers_and_merge_and_return(
|
(
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
final_entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
final_statement_entity_edges,
|
||||||
|
final_entity_entity_edges,
|
||||||
|
dedup_details,
|
||||||
|
) = await dedup_layers_and_merge_and_return(
|
||||||
dialogue_nodes,
|
dialogue_nodes,
|
||||||
chunk_nodes,
|
chunk_nodes,
|
||||||
statement_nodes,
|
statement_nodes,
|
||||||
@@ -1576,21 +1613,21 @@ class ExtractionOrchestrator:
|
|||||||
llm_client=self.llm_client,
|
llm_client=self.llm_client,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 解包返回值
|
|
||||||
(
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
_,
|
|
||||||
final_entity_nodes,
|
|
||||||
_,
|
|
||||||
final_statement_entity_edges,
|
|
||||||
final_entity_entity_edges,
|
|
||||||
dedup_details,
|
|
||||||
) = result_tuple
|
|
||||||
|
|
||||||
# 保存去重消歧的详细记录到实例变量
|
# 保存去重消歧的详细记录到实例变量
|
||||||
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
self._save_dedup_details(dedup_details, entity_nodes, final_entity_nodes)
|
||||||
|
|
||||||
|
result_tuple = (
|
||||||
|
dialogue_nodes,
|
||||||
|
chunk_nodes,
|
||||||
|
statement_nodes,
|
||||||
|
final_entity_nodes,
|
||||||
|
statement_chunk_edges,
|
||||||
|
final_statement_entity_edges,
|
||||||
|
final_entity_entity_edges,
|
||||||
|
dialog_data_list,
|
||||||
|
dedup_details,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
f"去重后: {len(final_entity_nodes)} 个实体节点, "
|
||||||
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
f"{len(final_statement_entity_edges)} 条陈述句-实体边, "
|
||||||
|
|||||||
@@ -105,13 +105,19 @@ Extract entities and knowledge triplets from the given statement.
|
|||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
- 用户实体的 name 字段:使用 "用户" 或 "我"
|
||||||
- 用户的真实姓名:放入 aliases
|
- 用户的真实姓名:放入 aliases
|
||||||
|
- **🚨 禁止将 "用户"、"我" 放入 aliases 中,aliases 只能包含用户的真实姓名、昵称等**
|
||||||
- 示例:
|
- 示例:
|
||||||
* "我叫李明" → name="用户", aliases=["李明"]
|
* "我叫李明" → name="用户", aliases=["李明"]
|
||||||
|
* ❌ 错误:aliases=["用户", "李明"]("用户"不是真实姓名,禁止放入 aliases)
|
||||||
|
* ❌ 错误:aliases=["我", "李明"]("我"不是真实姓名,禁止放入 aliases)
|
||||||
{% else %}
|
{% else %}
|
||||||
- User entity name field: use "User" or "I"
|
- User entity name field: use "User" or "I"
|
||||||
- User's real name: put in aliases
|
- User's real name: put in aliases
|
||||||
|
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
|
||||||
- Examples:
|
- Examples:
|
||||||
* "I'm John" → name="User", aliases=["John"]
|
* "I'm John" → name="User", aliases=["John"]
|
||||||
|
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
|
||||||
|
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,8 @@ class OSSStorage(StorageBackend):
|
|||||||
access_key_id: str,
|
access_key_id: str,
|
||||||
access_key_secret: str,
|
access_key_secret: str,
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
|
connect_timeout: int = 30,
|
||||||
|
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the OSSStorage backend.
|
Initialize the OSSStorage backend.
|
||||||
@@ -53,6 +55,8 @@ class OSSStorage(StorageBackend):
|
|||||||
access_key_id: The Aliyun access key ID.
|
access_key_id: The Aliyun access key ID.
|
||||||
access_key_secret: The Aliyun access key secret.
|
access_key_secret: The Aliyun access key secret.
|
||||||
bucket_name: The name of the OSS bucket.
|
bucket_name: The name of the OSS bucket.
|
||||||
|
connect_timeout: Connection timeout in seconds (default: 30).
|
||||||
|
multipart_threshold: File size threshold for multipart upload (default: 10MB).
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StorageConfigError: If any required configuration is missing.
|
StorageConfigError: If any required configuration is missing.
|
||||||
@@ -69,10 +73,17 @@ class OSSStorage(StorageBackend):
|
|||||||
|
|
||||||
self.endpoint = endpoint
|
self.endpoint = endpoint
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
|
self.multipart_threshold = multipart_threshold
|
||||||
|
|
||||||
try:
|
try:
|
||||||
auth = oss2.Auth(access_key_id, access_key_secret)
|
auth = oss2.Auth(access_key_id, access_key_secret)
|
||||||
self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
|
# 设置超时和重试
|
||||||
|
self.bucket = oss2.Bucket(
|
||||||
|
auth,
|
||||||
|
endpoint,
|
||||||
|
bucket_name,
|
||||||
|
connect_timeout=connect_timeout
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
|
||||||
)
|
)
|
||||||
@@ -108,21 +119,38 @@ class OSSStorage(StorageBackend):
|
|||||||
if content_type:
|
if content_type:
|
||||||
headers["Content-Type"] = content_type
|
headers["Content-Type"] = content_type
|
||||||
|
|
||||||
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
# 大文件使用分片上传
|
||||||
|
if len(content) > self.multipart_threshold:
|
||||||
|
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
|
||||||
|
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
|
||||||
|
parts = []
|
||||||
|
part_size = 5 * 1024 * 1024 # 5MB per part
|
||||||
|
part_num = 1
|
||||||
|
|
||||||
|
for offset in range(0, len(content), part_size):
|
||||||
|
chunk = content[offset:offset + part_size]
|
||||||
|
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||||
|
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||||
|
part_num += 1
|
||||||
|
|
||||||
|
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||||
|
else:
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers if headers else None)
|
||||||
|
|
||||||
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
logger.info(f"File uploaded to OSS successfully: {file_key}")
|
||||||
return file_key
|
return file_key
|
||||||
|
|
||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error uploading file {file_key}: {e}")
|
logger.error(f"OSS error uploading file {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to upload file to OSS: {e.message}",
|
message=f"Failed to upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
logger.error(f"Failed to upload file to OSS {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to upload file to OSS: {e}",
|
message=f"Failed to upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
@@ -135,28 +163,73 @@ class OSSStorage(StorageBackend):
|
|||||||
) -> int:
|
) -> int:
|
||||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||||
buf = io.BytesIO()
|
buf = io.BytesIO()
|
||||||
|
headers = {"Content-Type": content_type} if content_type else None
|
||||||
|
upload_id = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# 收集流数据
|
||||||
|
total_size = 0
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
if not chunk:
|
||||||
|
continue
|
||||||
buf.write(chunk)
|
buf.write(chunk)
|
||||||
|
total_size += len(chunk)
|
||||||
|
|
||||||
content = buf.getvalue()
|
content = buf.getvalue()
|
||||||
headers = {"Content-Type": content_type} if content_type else None
|
|
||||||
self.bucket.put_object(file_key, content, headers=headers)
|
if not content:
|
||||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
raise StorageUploadError(
|
||||||
return len(content)
|
message="Empty stream content",
|
||||||
|
file_key=file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 大文件使用分片上传
|
||||||
|
if len(content) > self.multipart_threshold:
|
||||||
|
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
|
||||||
|
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
|
||||||
|
parts = []
|
||||||
|
part_size = 5 * 1024 * 1024 # 5MB
|
||||||
|
part_num = 1
|
||||||
|
|
||||||
|
for offset in range(0, len(content), part_size):
|
||||||
|
chunk = content[offset:offset + part_size]
|
||||||
|
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
|
||||||
|
parts.append(oss2.models.PartInfo(part_num, result.etag))
|
||||||
|
part_num += 1
|
||||||
|
|
||||||
|
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
|
||||||
|
else:
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers)
|
||||||
|
|
||||||
|
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
|
||||||
|
return total_size
|
||||||
|
|
||||||
except OssError as e:
|
except OssError as e:
|
||||||
|
if upload_id:
|
||||||
|
try:
|
||||||
|
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
if upload_id:
|
||||||
|
try:
|
||||||
|
self.bucket.abort_multipart_upload(file_key, upload_id)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||||
raise StorageUploadError(
|
raise StorageUploadError(
|
||||||
message=f"Failed to stream upload file to OSS: {e}",
|
message=f"Failed to stream upload file to OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
finally:
|
||||||
|
buf.close()
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -182,14 +255,14 @@ class OSSStorage(StorageBackend):
|
|||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error downloading file {file_key}: {e}")
|
logger.error(f"OSS error downloading file {file_key}: {e}")
|
||||||
raise StorageDownloadError(
|
raise StorageDownloadError(
|
||||||
message=f"Failed to download file from OSS: {e.message}",
|
message=f"Failed to download file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
logger.error(f"Failed to download file from OSS {file_key}: {e}")
|
||||||
raise StorageDownloadError(
|
raise StorageDownloadError(
|
||||||
message=f"Failed to download file from OSS: {e}",
|
message=f"Failed to download file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
@@ -215,14 +288,14 @@ class OSSStorage(StorageBackend):
|
|||||||
except OssError as e:
|
except OssError as e:
|
||||||
logger.error(f"OSS error deleting file {file_key}: {e}")
|
logger.error(f"OSS error deleting file {file_key}: {e}")
|
||||||
raise StorageDeleteError(
|
raise StorageDeleteError(
|
||||||
message=f"Failed to delete file from OSS: {e.message}",
|
message=f"Failed to delete file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
logger.error(f"Failed to delete file from OSS {file_key}: {e}")
|
||||||
raise StorageDeleteError(
|
raise StorageDeleteError(
|
||||||
message=f"Failed to delete file from OSS: {e}",
|
message=f"Failed to delete file from OSS: {str(e)}",
|
||||||
file_key=file_key,
|
file_key=file_key,
|
||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType
|
|||||||
|
|
||||||
|
|
||||||
def merge_activate_state(x, y):
|
def merge_activate_state(x, y):
|
||||||
return {
|
merged = dict(x)
|
||||||
k: x.get(k, False) or y.get(k, False)
|
for k, v in y.items():
|
||||||
for k in set(x) | set(y)
|
merged[k] = merged.get(k, False) or v
|
||||||
}
|
return merged
|
||||||
|
|
||||||
|
|
||||||
def merge_looping_state(x, y):
|
def merge_looping_state(x, y):
|
||||||
|
|||||||
@@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}")
|
||||||
|
|
||||||
|
|
||||||
|
class LazyVariableDict:
|
||||||
|
def __init__(self, source, literal):
|
||||||
|
self._source: dict[str, VariableStruct[Any]] = source
|
||||||
|
self._literal: bool = literal
|
||||||
|
self._cache = {}
|
||||||
|
|
||||||
|
def keys(self):
|
||||||
|
return self._source.keys()
|
||||||
|
|
||||||
|
def _resolve(self, key):
|
||||||
|
if key in self._cache:
|
||||||
|
return self._cache[key]
|
||||||
|
var_struct = self._source.get(key)
|
||||||
|
if var_struct is None:
|
||||||
|
raise KeyError(key)
|
||||||
|
value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value()
|
||||||
|
self._cache[key] = value
|
||||||
|
return value
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
try:
|
||||||
|
return self._resolve(key)
|
||||||
|
except KeyError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
return self._resolve(key)
|
||||||
|
|
||||||
|
def __getattr__(self, key):
|
||||||
|
if key.startswith('_'):
|
||||||
|
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'")
|
||||||
|
return self._resolve(key)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
return key in self._source
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._source)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._source)
|
||||||
|
|
||||||
|
|
||||||
class VariableSelector:
|
class VariableSelector:
|
||||||
"""变量选择器
|
"""变量选择器
|
||||||
@@ -117,8 +162,7 @@ class VariablePool:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def transform_selector(selector):
|
def transform_selector(selector):
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip()
|
||||||
variable_literal = re.sub(pattern, r"\1", selector).strip()
|
|
||||||
selector = VariableSelector.from_string(variable_literal).path
|
selector = VariableSelector.from_string(variable_literal).path
|
||||||
if len(selector) != 2:
|
if len(selector) != 2:
|
||||||
raise ValueError(f"Selector not valid - {selector}")
|
raise ValueError(f"Selector not valid - {selector}")
|
||||||
@@ -303,6 +347,16 @@ class VariablePool:
|
|||||||
"""
|
"""
|
||||||
return self._get_variable_struct(selector) is not None
|
return self._get_variable_struct(selector) is not None
|
||||||
|
|
||||||
|
def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict:
|
||||||
|
return LazyVariableDict(self.variables.get(namespace, {}), literal)
|
||||||
|
|
||||||
|
def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]:
|
||||||
|
return {
|
||||||
|
ns: LazyVariableDict(vars_dict, literal)
|
||||||
|
for ns, vars_dict in self.variables.items()
|
||||||
|
if ns not in ("sys", "conv")
|
||||||
|
}
|
||||||
|
|
||||||
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
def get_all_system_vars(self, literal=False) -> dict[str, Any]:
|
||||||
"""获取所有系统变量
|
"""获取所有系统变量
|
||||||
|
|
||||||
@@ -479,5 +533,3 @@ class VariablePoolInitializer:
|
|||||||
var_type=var_type,
|
var_type=var_type,
|
||||||
mut=False
|
mut=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -552,9 +552,9 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
template=template,
|
template=template,
|
||||||
conv_vars=variable_pool.get_all_conversation_vars(literal=True),
|
conv_vars=variable_pool.lazy_namespace("conv", literal=True),
|
||||||
node_outputs=variable_pool.get_all_node_outputs(literal=True),
|
node_outputs=variable_pool.lazy_all_node_outputs(literal=True),
|
||||||
system_vars=variable_pool.get_all_system_vars(literal=True),
|
system_vars=variable_pool.lazy_namespace("sys", literal=True),
|
||||||
strict=strict
|
strict=strict
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -579,9 +579,9 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
return evaluate_condition(
|
return evaluate_condition(
|
||||||
expression=expression,
|
expression=expression,
|
||||||
conv_var=variable_pool.get_all_conversation_vars(),
|
conv_var=variable_pool.lazy_namespace("conv"),
|
||||||
node_outputs=variable_pool.get_all_node_outputs(),
|
node_outputs=variable_pool.lazy_all_node_outputs(),
|
||||||
system_vars=variable_pool.get_all_system_vars()
|
system_vars=variable_pool.lazy_namespace("sys")
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig
|
||||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType
|
||||||
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_expression
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -85,12 +84,7 @@ class LoopRuntime:
|
|||||||
|
|
||||||
for variable in self.typed_config.cycle_vars:
|
for variable in self.typed_config.cycle_vars:
|
||||||
if variable.input_type == ValueInputType.VARIABLE:
|
if variable.input_type == ValueInputType.VARIABLE:
|
||||||
value = evaluate_expression(
|
value = self.variable_pool.get_value(variable.value)
|
||||||
expression=variable.value,
|
|
||||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
|
||||||
system_vars=self.variable_pool.get_all_system_vars(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
value = TypeTransformer.transform(variable.value, variable.type)
|
value = TypeTransformer.transform(variable.value, variable.type)
|
||||||
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True)
|
||||||
@@ -98,12 +92,7 @@ class LoopRuntime:
|
|||||||
**self.state
|
**self.state
|
||||||
)
|
)
|
||||||
loopstate["node_outputs"][self.node_id] = {
|
loopstate["node_outputs"][self.node_id] = {
|
||||||
variable.name: evaluate_expression(
|
variable.name: self.variable_pool.get_value(variable.value)
|
||||||
expression=variable.value,
|
|
||||||
conv_var=self.variable_pool.get_all_conversation_vars(),
|
|
||||||
node_outputs=self.variable_pool.get_all_node_outputs(),
|
|
||||||
system_vars=self.variable_pool.get_all_system_vars(),
|
|
||||||
)
|
|
||||||
if variable.input_type == ValueInputType.VARIABLE
|
if variable.input_type == ValueInputType.VARIABLE
|
||||||
else TypeTransformer.transform(variable.value, variable.type)
|
else TypeTransformer.transform(variable.value, variable.type)
|
||||||
for variable in self.typed_config.cycle_vars
|
for variable in self.typed_config.cycle_vars
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode):
|
|||||||
# Reuse cached bytes if already fetched
|
# Reuse cached bytes if already fetched
|
||||||
if f.get_content():
|
if f.get_content():
|
||||||
file_input.set_content(f.get_content())
|
file_input.set_content(f.get_content())
|
||||||
text = await svc._extract_document_text(file_input)
|
text = await svc.extract_document_text(file_input)
|
||||||
chunks.append(text)
|
chunks.append(text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
|
|||||||
@@ -1,19 +1,23 @@
|
|||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
from app.core.rag.models.chunk import DocumentChunk
|
||||||
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
from app.models import knowledge_model, knowledgeshare_model, ModelType
|
from app.models import knowledge_model, ModelType
|
||||||
from app.repositories import knowledge_repository, knowledgeshare_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.schemas.chunk_schema import RetrieveType
|
from app.schemas.chunk_schema import RetrieveType
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
@@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]):
|
||||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
self.vector_service: ElasticSearchVector | None = None
|
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
@@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
unique.append(doc)
|
unique.append(doc)
|
||||||
return unique
|
return unique
|
||||||
|
|
||||||
def _get_existing_kb_ids(self, db, kb_ids):
|
def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
Resolve all accessible and valid knowledge base IDs for retrieval.
|
Reorder the list of document blocks and return the top_k results most relevant to the query
|
||||||
|
|
||||||
This includes:
|
|
||||||
- Private knowledge bases owned by the user
|
|
||||||
- Shared knowledge bases
|
|
||||||
- Source knowledge bases mapped via knowledge sharing relationships
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: Database session.
|
query: query string
|
||||||
kb_ids (list[UUID]): Knowledge base IDs from node configuration.
|
docs: List of document chunk to be rearranged
|
||||||
|
top_k: The number of top-level documents returned
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list[UUID]: Final list of valid knowledge base IDs.
|
Rearranged document chunk list (sorted in descending order of relevance)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the input document list is empty or top_k is invalid
|
||||||
"""
|
"""
|
||||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private)
|
reranker = self.get_reranker_model()
|
||||||
|
# parameter validation
|
||||||
existing_ids = knowledge_repository.get_chunked_knowledgeids(
|
if not docs:
|
||||||
db=db,
|
raise ValueError("retrieval chunks be empty")
|
||||||
filters=filters
|
if top_k <= 0:
|
||||||
)
|
raise ValueError("top_k must be a positive integer")
|
||||||
|
try:
|
||||||
filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share)
|
# Convert to LangChain Document object
|
||||||
|
documents = [
|
||||||
share_ids = knowledge_repository.get_chunked_knowledgeids(
|
Document(
|
||||||
db=db,
|
page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute
|
||||||
filters=filters
|
metadata=doc.metadata or {} # Deal with possible None metadata
|
||||||
)
|
)
|
||||||
|
for doc in docs
|
||||||
if share_ids:
|
|
||||||
filters = [
|
|
||||||
knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids)
|
|
||||||
]
|
]
|
||||||
items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id(
|
|
||||||
db=db,
|
# Perform reordering (compress_documents will automatically handle relevance scores and indexing)
|
||||||
filters=filters
|
reranked_docs = list(reranker.compress_documents(documents, query))
|
||||||
|
|
||||||
|
# Sort in descending order based on relevance score
|
||||||
|
reranked_docs.sort(
|
||||||
|
key=lambda x: x.metadata.get("relevance_score", 0),
|
||||||
|
reverse=True
|
||||||
)
|
)
|
||||||
existing_ids.extend(items)
|
# Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"]
|
||||||
return existing_ids
|
result = []
|
||||||
|
for item in reranked_docs[:top_k]:
|
||||||
|
for doc in docs:
|
||||||
|
if doc.page_content == item.page_content:
|
||||||
|
doc.metadata["score"] = item.metadata["relevance_score"]
|
||||||
|
result.append(doc)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e
|
||||||
|
|
||||||
def get_reranker_model(self) -> RedBearRerank:
|
def get_reranker_model(self) -> RedBearRerank:
|
||||||
"""
|
"""
|
||||||
@@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
return reranker
|
return reranker
|
||||||
|
|
||||||
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
async def knowledge_retrieval(self, db, query, db_knowledge, kb_config):
|
||||||
|
rs = []
|
||||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
tasks = []
|
||||||
for child in children:
|
for child in children:
|
||||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
continue
|
continue
|
||||||
kb_config.kb_id = child.id
|
child_kb_config = kb_config.model_copy()
|
||||||
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
child_kb_config.kb_id = child.id
|
||||||
return
|
tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config))
|
||||||
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
if tasks:
|
||||||
|
result = await asyncio.gather(*tasks)
|
||||||
|
for _ in result:
|
||||||
|
rs.extend(_)
|
||||||
|
return rs
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||||
match kb_config.retrieve_type:
|
match kb_config.retrieve_type:
|
||||||
case RetrieveType.PARTICIPLE:
|
case RetrieveType.PARTICIPLE:
|
||||||
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
rs.extend(
|
||||||
indices=indices,
|
await asyncio.to_thread(
|
||||||
score_threshold=kb_config.similarity_threshold))
|
vector_service.search_by_full_text, **{
|
||||||
|
"query": query,
|
||||||
|
"top_k": kb_config.top_k,
|
||||||
|
"indices": indices,
|
||||||
|
"score_threshold": kb_config.similarity_threshold
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
case RetrieveType.SEMANTIC:
|
case RetrieveType.SEMANTIC:
|
||||||
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
rs.extend(
|
||||||
indices=indices,
|
await asyncio.to_thread(
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
vector_service.search_by_vector, **{
|
||||||
|
"query": query,
|
||||||
|
"top_k": kb_config.top_k,
|
||||||
|
"indices": indices,
|
||||||
|
"score_threshold": kb_config.vector_similarity_weight
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
case RetrieveType.HYBRID:
|
case RetrieveType.HYBRID:
|
||||||
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
rs1_task = asyncio.to_thread(
|
||||||
indices=indices,
|
vector_service.search_by_vector, **{
|
||||||
score_threshold=kb_config.vector_similarity_weight)
|
"query": query,
|
||||||
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
"top_k": kb_config.top_k,
|
||||||
indices=indices,
|
"indices": indices,
|
||||||
score_threshold=kb_config.similarity_threshold)
|
"score_threshold": kb_config.vector_similarity_weight
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rs2_task = asyncio.to_thread(
|
||||||
|
vector_service.search_by_full_text, **{
|
||||||
|
"query": query,
|
||||||
|
"top_k": kb_config.top_k,
|
||||||
|
"indices": indices,
|
||||||
|
"score_threshold": kb_config.similarity_threshold
|
||||||
|
}
|
||||||
|
)
|
||||||
|
rs1, rs2 = await asyncio.gather(rs1_task, rs2_task)
|
||||||
|
|
||||||
# Deduplicate hybrid retrieval results
|
# Deduplicate hybrid retrieval results
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
if not unique_rs:
|
if not unique_rs:
|
||||||
return
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
self.vector_service.reranker = self.get_reranker_model()
|
rs.extend(
|
||||||
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
await asyncio.to_thread(
|
||||||
|
self.rerank,
|
||||||
|
**{"query": query, "docs": unique_rs, "top_k": kb_config.top_k}
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
rs.extend(sorted(
|
rs.extend(sorted(
|
||||||
unique_rs,
|
unique_rs,
|
||||||
@@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)[:kb_config.top_k])
|
)[:kb_config.top_k])
|
||||||
case _:
|
case _:
|
||||||
raise RuntimeError("Unknown retrieval type")
|
raise RuntimeError("Unknown retrieval type")
|
||||||
|
return rs
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
@@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
|
|
||||||
rs = []
|
rs = []
|
||||||
|
tasks = []
|
||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config))
|
||||||
|
if tasks:
|
||||||
|
result = await asyncio.gather(*tasks)
|
||||||
|
for _ in result:
|
||||||
|
rs.extend(_)
|
||||||
|
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
self.vector_service.reranker = self.get_reranker_model()
|
final_rs = await asyncio.to_thread(
|
||||||
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
self.rerank,
|
||||||
|
**{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
final_rs = sorted(
|
final_rs = sorted(
|
||||||
rs,
|
rs,
|
||||||
|
|||||||
@@ -4,32 +4,33 @@ from typing import Any
|
|||||||
|
|
||||||
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
from simpleeval import simple_eval, NameNotDefined, InvalidExpression
|
||||||
|
|
||||||
|
from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||||
|
|
||||||
|
|
||||||
class ExpressionEvaluator:
|
class ExpressionEvaluator:
|
||||||
"""Safe expression evaluator for workflow variables and node outputs."""
|
"""Safe expression evaluator for workflow variables and node outputs."""
|
||||||
|
|
||||||
# Reserved namespaces
|
# Reserved namespaces
|
||||||
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def normalize_template(cls, template: str) -> str:
|
def normalize_template(cls, template: str) -> str:
|
||||||
pattern = re.compile(
|
return _NORMALIZE_PATTERN.sub(
|
||||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
|
||||||
)
|
|
||||||
return pattern.sub(
|
|
||||||
r'{{ node["\1"].\2 }}',
|
r'{{ node["\1"].\2 }}',
|
||||||
template
|
template
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def evaluate(
|
def evaluate(
|
||||||
cls,
|
cls,
|
||||||
expression: str,
|
expression: str,
|
||||||
conv_vars: dict[str, Any],
|
conv_vars: dict[str, Any],
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, Any],
|
||||||
system_vars: dict[str, Any] | None = None
|
system_vars: dict[str, Any] | None = None
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Safely evaluate an expression using workflow variables.
|
Safely evaluate an expression using workflow variables.
|
||||||
@@ -49,48 +50,47 @@ class ExpressionEvaluator:
|
|||||||
# Remove Jinja2-style brackets if present
|
# Remove Jinja2-style brackets if present
|
||||||
expression = expression.strip()
|
expression = expression.strip()
|
||||||
expression = cls.normalize_template(expression)
|
expression = cls.normalize_template(expression)
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
expression = VARIABLE_PATTERN.sub(r"\1", expression).strip()
|
||||||
expression = re.sub(pattern, r"\1", expression).strip()
|
|
||||||
|
|
||||||
# Build context for evaluation
|
# Build context for evaluation
|
||||||
context = {
|
context = {
|
||||||
"conv": conv_vars, # conversation variables
|
"conv": conv_vars, # conversation variables
|
||||||
"node": node_outputs, # node outputs
|
"node": node_outputs, # node outputs
|
||||||
"sys": system_vars or {}, # system variables
|
"sys": system_vars or {}, # system variables
|
||||||
}
|
}
|
||||||
|
|
||||||
context.update(conv_vars)
|
# context.update(conv_vars)
|
||||||
context["nodes"] = node_outputs
|
# context["nodes"] = node_outputs
|
||||||
context.update(node_outputs)
|
context.update(node_outputs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# simpleeval supports safe operations:
|
# simpleeval supports safe operations:
|
||||||
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
# arithmetic, comparisons, logical ops, attribute/dict/list access
|
||||||
result = simple_eval(expression, names=context)
|
result = simple_eval(expression, names=context)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except NameNotDefined as e:
|
except NameNotDefined as e:
|
||||||
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
logger.error(f"Undefined variable in expression: {expression}, error: {e}")
|
||||||
raise ValueError(f"Undefined variable: {e}")
|
raise ValueError(f"Undefined variable: {e}")
|
||||||
|
|
||||||
except InvalidExpression as e:
|
except InvalidExpression as e:
|
||||||
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
logger.error(f"Invalid expression syntax: {expression}, error: {e}")
|
||||||
raise ValueError(f"Invalid expression syntax: {e}")
|
raise ValueError(f"Invalid expression syntax: {e}")
|
||||||
|
|
||||||
except SyntaxError as e:
|
except SyntaxError as e:
|
||||||
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
logger.error(f"Syntax error in expression: {expression}, error: {e}")
|
||||||
raise ValueError(f"Syntax error: {e}")
|
raise ValueError(f"Syntax error: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
logger.error(f"Expression evaluation failed: {expression}, error: {e}")
|
||||||
raise ValueError(f"Expression evaluation failed: {e}")
|
raise ValueError(f"Expression evaluation failed: {e}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def evaluate_bool(
|
def evaluate_bool(
|
||||||
expression: str,
|
expression: str,
|
||||||
conv_var: dict[str, Any],
|
conv_var: dict[str, Any],
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, Any],
|
||||||
system_vars: dict[str, Any] | None = None
|
system_vars: dict[str, Any] | None = None
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Evaluate a boolean expression (for conditions).
|
Evaluate a boolean expression (for conditions).
|
||||||
@@ -108,7 +108,7 @@ class ExpressionEvaluator:
|
|||||||
expression, conv_var, node_outputs, system_vars
|
expression, conv_var, node_outputs, system_vars
|
||||||
)
|
)
|
||||||
return bool(result)
|
return bool(result)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_variable_names(variables: list[dict]) -> list[str]:
|
def validate_variable_names(variables: list[dict]) -> list[str]:
|
||||||
"""
|
"""
|
||||||
@@ -121,7 +121,7 @@ class ExpressionEvaluator:
|
|||||||
list[str]: List of error messages. Empty if all names are valid.
|
list[str]: List of error messages. Empty if all names are valid.
|
||||||
"""
|
"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
for var in variables:
|
for var in variables:
|
||||||
var_name = var.get("name", "")
|
var_name = var.get("name", "")
|
||||||
|
|
||||||
@@ -134,16 +134,16 @@ class ExpressionEvaluator:
|
|||||||
errors.append(
|
errors.append(
|
||||||
f"Variable name '{var_name}' is not a valid Python identifier"
|
f"Variable name '{var_name}' is not a valid Python identifier"
|
||||||
)
|
)
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
# 便捷函数
|
# 便捷函数
|
||||||
def evaluate_expression(
|
def evaluate_expression(
|
||||||
expression: str,
|
expression: str,
|
||||||
conv_var: dict[str, Any],
|
conv_var: dict[str, Any] | LazyVariableDict,
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||||
system_vars: dict[str, Any]
|
system_vars: dict[str, Any] | LazyVariableDict
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Evaluate an expression (convenience function)."""
|
"""Evaluate an expression (convenience function)."""
|
||||||
return ExpressionEvaluator.evaluate(
|
return ExpressionEvaluator.evaluate(
|
||||||
@@ -152,11 +152,11 @@ def evaluate_expression(
|
|||||||
|
|
||||||
|
|
||||||
def evaluate_condition(
|
def evaluate_condition(
|
||||||
expression: str,
|
expression: str,
|
||||||
conv_var: dict[str, Any],
|
conv_var: dict[str, Any] | LazyVariableDict,
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, dict[str, Any] | LazyVariableDict],
|
||||||
system_vars: dict[str, Any] | None = None
|
system_vars: dict[str, Any] | LazyVariableDict
|
||||||
) -> bool:
|
) -> Any:
|
||||||
"""Evaluate a boolean condition expression (convenience function)."""
|
"""Evaluate a boolean condition expression (convenience function)."""
|
||||||
return ExpressionEvaluator.evaluate_bool(
|
return ExpressionEvaluator.evaluate_bool(
|
||||||
expression, conv_var, node_outputs, system_vars
|
expression, conv_var, node_outputs, system_vars
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
模板渲染器
|
Template Renderer
|
||||||
|
|
||||||
使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。
|
Provides safe template rendering using Jinja2, supporting variable references
|
||||||
|
and expressions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -10,11 +11,15 @@ from typing import Any
|
|||||||
|
|
||||||
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined
|
||||||
|
|
||||||
|
from app.core.workflow.engine.variable_pool import LazyVariableDict
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}")
|
||||||
|
|
||||||
|
|
||||||
class SafeUndefined(Undefined):
|
class SafeUndefined(Undefined):
|
||||||
"""访问未定义属性不会报错,返回空字符串"""
|
"""Return empty string instead of raising error when accessing undefined variables"""
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def _fail_with_undefined_error(self, *args, **kwargs):
|
def _fail_with_undefined_error(self, *args, **kwargs):
|
||||||
@@ -26,26 +31,22 @@ class SafeUndefined(Undefined):
|
|||||||
|
|
||||||
|
|
||||||
class TemplateRenderer:
|
class TemplateRenderer:
|
||||||
"""模板渲染器"""
|
|
||||||
|
|
||||||
def __init__(self, strict: bool = True):
|
def __init__(self, strict: bool = True):
|
||||||
"""初始化渲染器
|
"""Initialize renderer
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strict: 是否使用严格模式(未定义变量会抛出异常)
|
strict: Whether to enable strict mode (raise error on undefined variables)
|
||||||
"""
|
"""
|
||||||
self.strict = strict
|
self.strict = strict
|
||||||
self.env = Environment(
|
self.env = Environment(
|
||||||
undefined=StrictUndefined if strict else SafeUndefined,
|
undefined=StrictUndefined if strict else SafeUndefined,
|
||||||
autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML
|
autoescape=False # Disable auto-escaping since we handle plain text instead of HTML
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def normalize_template(template: str) -> str:
|
def normalize_template(template: str) -> str:
|
||||||
pattern = re.compile(
|
"""Normalize template syntax (convert numeric node reference to dict access)"""
|
||||||
r"\{\{\s*(\d+)\.(\w+)\s*}}"
|
return _NORMALIZE_PATTERN.sub(
|
||||||
)
|
|
||||||
return pattern.sub(
|
|
||||||
r'{{ node["\1"].\2 }}',
|
r'{{ node["\1"].\2 }}',
|
||||||
template
|
template
|
||||||
)
|
)
|
||||||
@@ -53,24 +54,24 @@ class TemplateRenderer:
|
|||||||
def render(
|
def render(
|
||||||
self,
|
self,
|
||||||
template: str,
|
template: str,
|
||||||
conv_vars: dict[str, Any],
|
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||||
system_vars: dict[str, Any] | None = None
|
system_vars: dict[str, Any] | LazyVariableDict | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""渲染模板
|
"""Render template
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: 模板字符串
|
template: Template string
|
||||||
conv_vars: 会话变量
|
conv_vars: Conversation variables
|
||||||
node_outputs: 节点输出结果
|
node_outputs: Node outputs
|
||||||
system_vars: 系统变量
|
system_vars: System variables
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
渲染后的字符串
|
Rendered string
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: 模板语法错误或变量未定义
|
ValueError: If template syntax is invalid or variables are undefined
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> renderer = TemplateRenderer()
|
>>> renderer = TemplateRenderer()
|
||||||
>>> renderer.render(
|
>>> renderer.render(
|
||||||
@@ -80,122 +81,119 @@ class TemplateRenderer:
|
|||||||
... {}
|
... {}
|
||||||
... )
|
... )
|
||||||
'Hello World!'
|
'Hello World!'
|
||||||
|
|
||||||
>>> renderer.render(
|
>>> renderer.render(
|
||||||
... "分析结果: {{node.analyze.output}}",
|
... "Analysis result: {{node.analyze.output}}",
|
||||||
... {},
|
... {},
|
||||||
... {"analyze": {"output": "正面情绪"}},
|
... {"analyze": {"output": "positive sentiment"}},
|
||||||
... {}
|
... {}
|
||||||
... )
|
... )
|
||||||
'分析结果: 正面情绪'
|
'Analysis result: positive sentiment'
|
||||||
"""
|
"""
|
||||||
# 构建命名空间上下文
|
# Build namespace context
|
||||||
context = {
|
context = {
|
||||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
"conv": conv_vars, # Conversation variables: {{conv.user_name}}
|
||||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
"node": node_outputs, # Node outputs: {{node.node_1.output}}
|
||||||
"sys": system_vars, # 系统变量:{{sys.execution_id}}
|
"sys": system_vars, # System variables: {{sys.execution_id}}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
# Allow direct access to node outputs by node ID: {{llm_qa.output}}
|
||||||
# 将所有节点输出添加到顶层上下文
|
|
||||||
if node_outputs:
|
if node_outputs:
|
||||||
context.update(node_outputs)
|
context.update(node_outputs)
|
||||||
|
|
||||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
# # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||||
if conv_vars:
|
# if conv_vars:
|
||||||
context.update(conv_vars)
|
# context.update(conv_vars)
|
||||||
|
#
|
||||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
# context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||||
template = self.normalize_template(template)
|
template = self.normalize_template(template)
|
||||||
try:
|
try:
|
||||||
tmpl = self.env.from_string(template)
|
tmpl = self.env.from_string(template)
|
||||||
return tmpl.render(**context)
|
return tmpl.render(**context)
|
||||||
|
|
||||||
except TemplateSyntaxError as e:
|
except TemplateSyntaxError as e:
|
||||||
logger.error(f"模板语法错误: {template}, 错误: {e}")
|
logger.error(f"Template syntax error: {template}, error: {e}")
|
||||||
raise ValueError(f"模板语法错误: {e}")
|
raise ValueError(f"Template syntax error: {e}")
|
||||||
|
|
||||||
except UndefinedError as e:
|
except UndefinedError as e:
|
||||||
logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}")
|
logger.error(f"Undefined variable in template: {template}, error: {e}")
|
||||||
raise ValueError(f"未定义的变量: {e}")
|
raise ValueError(f"Undefined variable: {e}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"模板渲染异常: {template}, 错误: {e}")
|
logger.error(f"Template rendering error: {template}, error: {e}")
|
||||||
raise ValueError(f"模板渲染失败: {e}")
|
raise ValueError(f"Template rendering failed: {e}")
|
||||||
|
|
||||||
def validate(self, template: str) -> list[str]:
|
def validate(self, template: str) -> list[str]:
|
||||||
"""验证模板语法
|
"""Validate template syntax
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: 模板字符串
|
template: Template string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
错误列表,如果为空则验证通过
|
List of errors (empty if valid)
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> renderer = TemplateRenderer()
|
>>> renderer = TemplateRenderer()
|
||||||
>>> renderer.validate("Hello {{var.name}}!")
|
>>> renderer.validate("Hello {{var.name}}!")
|
||||||
[]
|
[]
|
||||||
|
|
||||||
>>> renderer.validate("Hello {{var.name") # 缺少结束标记
|
>>> renderer.validate("Hello {{var.name") # Missing closing tag
|
||||||
['模板语法错误: ...']
|
['Template syntax error: ...']
|
||||||
"""
|
"""
|
||||||
errors = []
|
errors = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.env.from_string(template)
|
self.env.from_string(template)
|
||||||
except TemplateSyntaxError as e:
|
except TemplateSyntaxError as e:
|
||||||
errors.append(f"模板语法错误: {e}")
|
errors.append(f"Template syntax error: {e}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
errors.append(f"模板验证失败: {e}")
|
errors.append(f"Template validation failed: {e}")
|
||||||
|
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
||||||
# 全局渲染器实例(严格模式)
|
# Global renderer instances (strict / lenient)
|
||||||
_strict_renderer = TemplateRenderer(strict=True)
|
_strict_renderer = TemplateRenderer(strict=True)
|
||||||
_lenient_renderer = TemplateRenderer(strict=False)
|
_lenient_renderer = TemplateRenderer(strict=False)
|
||||||
|
|
||||||
|
|
||||||
def render_template(
|
def render_template(
|
||||||
template: str,
|
template: str,
|
||||||
conv_vars: dict[str, Any],
|
conv_vars: dict[str, Any] | LazyVariableDict,
|
||||||
node_outputs: dict[str, Any],
|
node_outputs: dict[str, Any] | dict[str, LazyVariableDict],
|
||||||
system_vars: dict[str, Any],
|
system_vars: dict[str, Any] | LazyVariableDict,
|
||||||
strict: bool = True
|
strict: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""渲染模板(便捷函数)
|
"""Render template (convenience function)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
strict: 严格模式
|
strict: Whether to use strict mode
|
||||||
template: 模板字符串
|
template: Template string
|
||||||
conv_vars: 会话变量
|
conv_vars: Conversation variables
|
||||||
node_outputs: 节点输出
|
node_outputs: Node outputs
|
||||||
system_vars: 系统变量
|
system_vars: System variables
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
渲染后的字符串
|
Rendered string
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> render_template(
|
>>> render_template(
|
||||||
... "请分析: {{var.text}}",
|
... "Analyze: {{var.text}}",
|
||||||
... {"text": "这是一段文本"},
|
... {"text": "This is a text"},
|
||||||
... {},
|
... {},
|
||||||
... {}
|
... {}
|
||||||
... )
|
... )
|
||||||
'请分析: 这是一段文本'
|
'Analyze: This is a text'
|
||||||
"""
|
"""
|
||||||
renderer = _strict_renderer if strict else _lenient_renderer
|
renderer = _strict_renderer if strict else _lenient_renderer
|
||||||
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
return renderer.render(template, conv_vars, node_outputs, system_vars)
|
||||||
|
|
||||||
|
|
||||||
def validate_template(template: str) -> list[str]:
|
def validate_template(template: str) -> list[str]:
|
||||||
"""验证模板语法(便捷函数)
|
"""Validate template syntax (convenience function)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
template: 模板字符串
|
template: Template string
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
错误列表
|
List of errors
|
||||||
"""
|
"""
|
||||||
return _strict_renderer.validate(template)
|
return _strict_renderer.validate(template)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
|
from app.repositories.neo4j.create_indexes import create_all_indexes
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI, APIRouter
|
from fastapi import FastAPI, APIRouter
|
||||||
@@ -60,8 +61,10 @@ async def lifespan(app: FastAPI):
|
|||||||
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
logger.warning(f"加载预定义模型时出错: {str(e)}")
|
||||||
else:
|
else:
|
||||||
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
logger.info("预定义模型加载已禁用 (LOAD_MODEL=false)")
|
||||||
|
await create_all_indexes()
|
||||||
logger.info("应用程序启动完成")
|
logger.info("应用程序启动完成")
|
||||||
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
# 应用关闭事件
|
# 应用关闭事件
|
||||||
logger.info("应用程序正在关闭")
|
logger.info("应用程序正在关闭")
|
||||||
|
|||||||
@@ -19,9 +19,12 @@ class User(Base):
|
|||||||
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
|
last_login_at = Column(DateTime, nullable=True) # 最后登录时间,可为空
|
||||||
|
|
||||||
# SSO 外部关联字段
|
# SSO 外部关联字段
|
||||||
external_id = Column(String(100), nullable=True) # 外部用户ID
|
external_id = Column(String(100), nullable=True) # 外部用户 ID
|
||||||
external_source = Column(String(50), nullable=True) # 来源系统
|
external_source = Column(String(50), nullable=True) # 来源系统
|
||||||
|
|
||||||
|
# 用户联系方式
|
||||||
|
phone = Column(String(50), nullable=True) # 用户电话
|
||||||
|
|
||||||
# 用户语言偏好
|
# 用户语言偏好
|
||||||
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
preferred_language = Column(String(10), server_default=text("'zh'"), default='zh', nullable=False, index=True) # 用户偏好语言,默认中文
|
||||||
|
|
||||||
|
|||||||
@@ -199,6 +199,96 @@ class ConversationRepository:
|
|||||||
)
|
)
|
||||||
return conversations, total
|
return conversations, total
|
||||||
|
|
||||||
|
def list_app_conversations(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
is_draft: Optional[bool] = None,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20
|
||||||
|
) -> tuple[list[Conversation], int]:
|
||||||
|
"""
|
||||||
|
查询应用日志会话列表(带分页和过滤)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
is_draft: 是否草稿会话(None 表示不过滤)
|
||||||
|
page: 页码(从 1 开始)
|
||||||
|
pagesize: 每页数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[List[Conversation], int]: (会话列表,总数)
|
||||||
|
"""
|
||||||
|
stmt = select(Conversation).where(
|
||||||
|
Conversation.app_id == app_id,
|
||||||
|
Conversation.workspace_id == workspace_id,
|
||||||
|
Conversation.is_active.is_(True)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_draft is not None:
|
||||||
|
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||||
|
|
||||||
|
# Calculate total number of records
|
||||||
|
total = int(self.db.execute(
|
||||||
|
select(func.count()).select_from(stmt.subquery())
|
||||||
|
).scalar_one())
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||||
|
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||||
|
|
||||||
|
conversations = list(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Listed app conversations successfully",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"workspace_id": str(workspace_id),
|
||||||
|
"returned": len(conversations),
|
||||||
|
"total": total
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return conversations, total
|
||||||
|
|
||||||
|
def get_conversation_for_app_log(
|
||||||
|
self,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID
|
||||||
|
) -> Conversation:
|
||||||
|
"""
|
||||||
|
查询应用日志的会话详情
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
app_id: 应用 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Conversation: 会话对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: 当会话不存在时
|
||||||
|
"""
|
||||||
|
logger.info(f"Fetching conversation for app log: {conversation_id}")
|
||||||
|
|
||||||
|
stmt = select(Conversation).where(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.app_id == app_id,
|
||||||
|
Conversation.workspace_id == workspace_id,
|
||||||
|
Conversation.is_active.is_(True)
|
||||||
|
)
|
||||||
|
|
||||||
|
conversation = self.db.scalars(stmt).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
logger.warning(f"Conversation not found: {conversation_id}")
|
||||||
|
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||||
|
|
||||||
|
logger.info(f"Conversation fetched successfully: {conversation_id}")
|
||||||
|
return conversation
|
||||||
|
|
||||||
def soft_delete_conversation_by_conversation_id(
|
def soft_delete_conversation_by_conversation_id(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
@@ -290,6 +380,34 @@ class MessageRepository:
|
|||||||
self.db.add(message)
|
self.db.add(message)
|
||||||
return message
|
return message
|
||||||
|
|
||||||
|
def get_messages_by_conversation(
|
||||||
|
self,
|
||||||
|
conversation_id: uuid.UUID
|
||||||
|
) -> list[Message]:
|
||||||
|
"""
|
||||||
|
查询会话的所有消息(按时间正序)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Message]: 消息列表
|
||||||
|
"""
|
||||||
|
stmt = select(Message).where(
|
||||||
|
Message.conversation_id == conversation_id
|
||||||
|
).order_by(Message.created_at)
|
||||||
|
|
||||||
|
messages = list(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Fetched messages for conversation",
|
||||||
|
extra={
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"message_count": len(messages)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return messages
|
||||||
|
|
||||||
def get_message_by_conversation_id(
|
def get_message_by_conversation_id(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
|
|||||||
@@ -132,6 +132,82 @@ class EndUserRepository:
|
|||||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_or_create_end_user_with_config(
|
||||||
|
self,
|
||||||
|
app_id: Optional[uuid.UUID],
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
other_id: str,
|
||||||
|
memory_config_id: Optional[uuid.UUID] = None,
|
||||||
|
other_name: Optional[str] = None
|
||||||
|
) -> EndUser:
|
||||||
|
"""获取或创建终端用户,并在单次事务中关联记忆配置。
|
||||||
|
|
||||||
|
与 get_or_create_end_user 类似,但额外支持在创建/获取时
|
||||||
|
一并设置 memory_config_id,避免多次提交。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID(可为 None)
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
other_id: 第三方ID
|
||||||
|
memory_config_id: 记忆配置ID(可选,仅在用户尚无配置时设置)
|
||||||
|
other_name: 用户名称(用于创建 EndUserInfo)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EndUser: 终端用户对象(已关联记忆配置)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
end_user = (
|
||||||
|
self.db.query(EndUser)
|
||||||
|
.filter(
|
||||||
|
EndUser.workspace_id == workspace_id,
|
||||||
|
EndUser.other_id == other_id
|
||||||
|
)
|
||||||
|
.order_by(EndUser.created_at.asc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if end_user:
|
||||||
|
db_logger.debug(f"找到现有终端用户: workspace_id={workspace_id}, other_id={other_id}")
|
||||||
|
if app_id is not None:
|
||||||
|
end_user.app_id = app_id
|
||||||
|
if memory_config_id and not end_user.memory_config_id:
|
||||||
|
end_user.memory_config_id = memory_config_id
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(end_user)
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
# 创建新用户
|
||||||
|
end_user = EndUser(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
other_id=other_id,
|
||||||
|
memory_config_id=memory_config_id,
|
||||||
|
)
|
||||||
|
self.db.add(end_user)
|
||||||
|
self.db.flush()
|
||||||
|
|
||||||
|
end_user_info = EndUserInfo(
|
||||||
|
end_user_id=end_user.id,
|
||||||
|
other_name=other_name or "",
|
||||||
|
aliases=[],
|
||||||
|
meta_data={}
|
||||||
|
)
|
||||||
|
self.db.add(end_user_info)
|
||||||
|
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(end_user)
|
||||||
|
|
||||||
|
db_logger.info(
|
||||||
|
f"创建新终端用户及其信息: (other_id: {other_id}) for workspace {workspace_id}, "
|
||||||
|
f"memory_config_id={memory_config_id}"
|
||||||
|
)
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"获取或创建终端用户(含配置)时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||||
"""根据ID获取终端用户(用于缓存操作)
|
"""根据ID获取终端用户(用于缓存操作)
|
||||||
|
|
||||||
@@ -515,6 +591,51 @@ class EndUserRepository:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def batch_update_memory_config_id_by_app(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
memory_config_id: uuid.UUID
|
||||||
|
) -> int:
|
||||||
|
"""批量更新应用下所有终端用户的 memory_config_id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用ID
|
||||||
|
memory_config_id: 新的记忆配置ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 更新的终端用户数量
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: 数据库操作失败时抛出
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
update(EndUser)
|
||||||
|
.where(EndUser.app_id == app_id)
|
||||||
|
.values(memory_config_id=memory_config_id)
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self.db.execute(stmt)
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
updated_count = result.rowcount
|
||||||
|
|
||||||
|
db_logger.info(
|
||||||
|
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||||
|
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_count
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(
|
||||||
|
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||||
|
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
def count_by_memory_config_id(
|
def count_by_memory_config_id(
|
||||||
self,
|
self,
|
||||||
memory_config_id: uuid.UUID
|
memory_config_id: uuid.UUID
|
||||||
|
|||||||
@@ -78,6 +78,15 @@ class MemoryConfigRepository:
|
|||||||
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
OPTIONAL MATCH (n) WHERE n.end_user_id = $end_user_id RETURN 'ALL' AS Label, COUNT(n) AS Count
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||||
|
SEARCH_FOR_ALL_BATCH = """
|
||||||
|
MATCH (n) WHERE n.end_user_id IN $end_user_ids
|
||||||
|
RETURN
|
||||||
|
n.end_user_id as user_id,
|
||||||
|
count(n) as total
|
||||||
|
ORDER BY user_id
|
||||||
|
"""
|
||||||
|
|
||||||
# Extracted entity details within group/app/user
|
# Extracted entity details within group/app/user
|
||||||
SEARCH_FOR_DETIALS = """
|
SEARCH_FOR_DETIALS = """
|
||||||
MATCH (n:ExtractedEntity)
|
MATCH (n:ExtractedEntity)
|
||||||
|
|||||||
@@ -1,62 +1,47 @@
|
|||||||
|
import asyncio
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
|
||||||
async def create_fulltext_indexes():
|
async def create_fulltext_indexes():
|
||||||
"""Create full-text indexes for keyword search with BM25 scoring."""
|
"""Create full-text indexes for keyword search with BM25 scoring."""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Creating Full-Text Indexes (for keyword search)")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# 创建 Statements 索引
|
# 创建 Statements 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
CREATE FULLTEXT INDEX statementsFulltext IF NOT EXISTS FOR (s:Statement) ON EACH [s.statement]
|
||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: statementsFulltext")
|
|
||||||
|
|
||||||
# # 创建 Dialogues 索引
|
# # 创建 Dialogues 索引
|
||||||
# await connector.execute_query("""
|
# await connector.execute_query("""
|
||||||
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
# CREATE FULLTEXT INDEX dialoguesFulltext IF NOT EXISTS FOR (d:Dialogue) ON EACH [d.content]
|
||||||
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
# OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
# """)
|
# """)
|
||||||
|
|
||||||
# 创建 Entities 索引
|
# 创建 Entities 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
CREATE FULLTEXT INDEX entitiesFulltext IF NOT EXISTS FOR (e:ExtractedEntity) ON EACH [e.name]
|
||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: entitiesFulltext")
|
|
||||||
|
|
||||||
# 创建 Chunks 索引
|
# 创建 Chunks 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
CREATE FULLTEXT INDEX chunksFulltext IF NOT EXISTS FOR (c:Chunk) ON EACH [c.content]
|
||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: chunksFulltext")
|
|
||||||
|
|
||||||
# 创建 MemorySummary 索引
|
# 创建 MemorySummary 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
CREATE FULLTEXT INDEX summariesFulltext IF NOT EXISTS FOR (m:MemorySummary) ON EACH [m.content]
|
||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: summariesFulltext")
|
|
||||||
|
|
||||||
# 创建 Community 索引
|
# 创建 Community 索引
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||||
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
""")
|
""")
|
||||||
print("✓ Created: communitiesFulltext")
|
|
||||||
|
|
||||||
print("\nFull-text indexes created successfully with BM25 support.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error creating full-text indexes: {e}")
|
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_vector_indexes():
|
async def create_vector_indexes():
|
||||||
"""Create vector indexes for fast embedding similarity search.
|
"""Create vector indexes for fast embedding similarity search.
|
||||||
|
|
||||||
@@ -65,12 +50,7 @@ async def create_vector_indexes():
|
|||||||
"""
|
"""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Creating Vector Indexes (for embedding search)")
|
|
||||||
print("=" * 70)
|
|
||||||
print("Note: Adjust vector.dimensions if using different embedding model")
|
|
||||||
print(" Current setting: 1024 dimensions (for bge-m3)")
|
|
||||||
print()
|
|
||||||
|
|
||||||
# Statement embedding index
|
# Statement embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
@@ -82,7 +62,7 @@ async def create_vector_indexes():
|
|||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: statement_embedding_index")
|
|
||||||
|
|
||||||
# Chunk embedding index
|
# Chunk embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
@@ -94,7 +74,7 @@ async def create_vector_indexes():
|
|||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: chunk_embedding_index")
|
|
||||||
|
|
||||||
# Entity name embedding index
|
# Entity name embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
@@ -106,7 +86,7 @@ async def create_vector_indexes():
|
|||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: entity_embedding_index")
|
|
||||||
|
|
||||||
# Memory summary embedding index
|
# Memory summary embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
@@ -118,8 +98,7 @@ async def create_vector_indexes():
|
|||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: summary_embedding_index")
|
|
||||||
|
|
||||||
# Community summary embedding index
|
# Community summary embedding index
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||||
@@ -129,8 +108,7 @@ async def create_vector_indexes():
|
|||||||
`vector.dimensions`: 1024,
|
`vector.dimensions`: 1024,
|
||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: community_summary_embedding_index")
|
|
||||||
|
|
||||||
# Dialogue embedding index (optional)
|
# Dialogue embedding index (optional)
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
@@ -142,91 +120,15 @@ async def create_vector_indexes():
|
|||||||
`vector.similarity_function`: 'cosine'
|
`vector.similarity_function`: 'cosine'
|
||||||
}}
|
}}
|
||||||
""")
|
""")
|
||||||
print("✓ Created: dialogue_embedding_index")
|
|
||||||
|
|
||||||
# Community summary embedding index
|
|
||||||
await connector.execute_query("""
|
|
||||||
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
|
||||||
FOR (c:Community)
|
|
||||||
ON c.summary_embedding
|
|
||||||
OPTIONS {indexConfig: {
|
|
||||||
`vector.dimensions`: 1024,
|
|
||||||
`vector.similarity_function`: 'cosine'
|
|
||||||
}}
|
|
||||||
""")
|
|
||||||
print("✓ Created: community_summary_embedding_index")
|
|
||||||
|
|
||||||
print("\nVector indexes created successfully!")
|
|
||||||
print("\nExpected performance improvement:")
|
|
||||||
print(" Before: ~1.4s for embedding search")
|
|
||||||
print(" After: ~0.05-0.2s for embedding search (10-30x faster!)")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error creating vector indexes: {e}")
|
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_config_id_indexes():
|
|
||||||
"""Create indexes on config_id fields for improved query performance.
|
|
||||||
|
|
||||||
These indexes enable fast filtering of nodes by configuration ID,
|
|
||||||
which is essential for configuration isolation and multi-tenant scenarios.
|
|
||||||
"""
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
try:
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Creating Config ID Indexes")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# Dialogue.config_id index
|
|
||||||
await connector.execute_query("""
|
|
||||||
CREATE INDEX dialogue_config_id_index IF NOT EXISTS
|
|
||||||
FOR (d:Dialogue) ON (d.config_id)
|
|
||||||
""")
|
|
||||||
print("✓ Created: dialogue_config_id_index")
|
|
||||||
|
|
||||||
# Statement.config_id index
|
|
||||||
await connector.execute_query("""
|
|
||||||
CREATE INDEX statement_config_id_index IF NOT EXISTS
|
|
||||||
FOR (s:Statement) ON (s.config_id)
|
|
||||||
""")
|
|
||||||
print("✓ Created: statement_config_id_index")
|
|
||||||
|
|
||||||
# ExtractedEntity.config_id index
|
|
||||||
await connector.execute_query("""
|
|
||||||
CREATE INDEX entity_config_id_index IF NOT EXISTS
|
|
||||||
FOR (e:ExtractedEntity) ON (e.config_id)
|
|
||||||
""")
|
|
||||||
print("✓ Created: entity_config_id_index")
|
|
||||||
|
|
||||||
# MemorySummary.config_id index
|
|
||||||
await connector.execute_query("""
|
|
||||||
CREATE INDEX summary_config_id_index IF NOT EXISTS
|
|
||||||
FOR (m:MemorySummary) ON (m.config_id)
|
|
||||||
""")
|
|
||||||
print("✓ Created: summary_config_id_index")
|
|
||||||
|
|
||||||
print("\nConfig ID indexes created successfully!")
|
|
||||||
print("These indexes enable fast filtering by configuration ID.")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error creating config_id indexes: {e}")
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def create_unique_constraints():
|
async def create_unique_constraints():
|
||||||
"""Create uniqueness constraints for core node identifiers.
|
"""Create uniqueness constraints for core node identifiers.
|
||||||
|
|
||||||
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
Ensures concurrent MERGE operations remain safe and prevents duplicates.
|
||||||
"""
|
"""
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
try:
|
try:
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Creating Unique Constraints")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
# Dialogue.id unique
|
# Dialogue.id unique
|
||||||
await connector.execute_query(
|
await connector.execute_query(
|
||||||
"""
|
"""
|
||||||
@@ -234,8 +136,7 @@ async def create_unique_constraints():
|
|||||||
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
FOR (d:Dialogue) REQUIRE d.id IS UNIQUE
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
print("✓ Created: dialog_id_unique")
|
|
||||||
|
|
||||||
# Statement.id unique
|
# Statement.id unique
|
||||||
await connector.execute_query(
|
await connector.execute_query(
|
||||||
"""
|
"""
|
||||||
@@ -243,8 +144,7 @@ async def create_unique_constraints():
|
|||||||
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
FOR (s:Statement) REQUIRE s.id IS UNIQUE
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
print("✓ Created: statement_id_unique")
|
|
||||||
|
|
||||||
# Chunk.id unique
|
# Chunk.id unique
|
||||||
await connector.execute_query(
|
await connector.execute_query(
|
||||||
"""
|
"""
|
||||||
@@ -252,112 +152,13 @@ async def create_unique_constraints():
|
|||||||
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
FOR (c:Chunk) REQUIRE c.id IS UNIQUE
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
print("✓ Created: chunk_id_unique")
|
|
||||||
|
|
||||||
print("\nUnique constraints ensured for Dialogue, Statement, and Chunk.")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"✗ Error creating unique constraints: {e}")
|
|
||||||
finally:
|
finally:
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
|
||||||
async def create_all_indexes():
|
async def create_all_indexes():
|
||||||
"""Create all indexes and constraints in one go."""
|
"""Create all indexes and constraints in one go."""
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Neo4j Index & Constraint Setup")
|
|
||||||
print("=" * 70)
|
|
||||||
print("This will create:")
|
|
||||||
print(" 1. Full-text indexes (for keyword/BM25 search)")
|
|
||||||
print(" 2. Vector indexes (for embedding similarity search)")
|
|
||||||
print(" 3. Config ID indexes (for configuration isolation)")
|
|
||||||
print(" 4. Unique constraints (for data integrity)")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
await create_fulltext_indexes()
|
await create_fulltext_indexes()
|
||||||
await create_vector_indexes()
|
await create_vector_indexes()
|
||||||
await create_config_id_indexes()
|
|
||||||
await create_unique_constraints()
|
await create_unique_constraints()
|
||||||
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("✓ All indexes and constraints created successfully!")
|
print("✓ All indexes and constraints created successfully!")
|
||||||
print("=" * 70)
|
|
||||||
print("\nTo verify, run in Neo4j Browser:")
|
|
||||||
print(" SHOW INDEXES")
|
|
||||||
print(" SHOW CONSTRAINTS")
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
async def check_indexes():
|
|
||||||
"""Check what indexes currently exist."""
|
|
||||||
connector = Neo4jConnector()
|
|
||||||
|
|
||||||
try:
|
|
||||||
print("\n" + "=" * 70)
|
|
||||||
print("Checking Existing Indexes")
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
query = "SHOW INDEXES"
|
|
||||||
result = await connector.execute_query(query)
|
|
||||||
|
|
||||||
fulltext_indexes = [idx for idx in result if idx.get('type') == 'FULLTEXT']
|
|
||||||
vector_indexes = [idx for idx in result if idx.get('type') == 'VECTOR']
|
|
||||||
range_indexes = [idx for idx in result if idx.get('type') == 'RANGE']
|
|
||||||
|
|
||||||
print(f"\nFull-text indexes: {len(fulltext_indexes)}")
|
|
||||||
for idx in fulltext_indexes:
|
|
||||||
print(f" ✓ {idx.get('name')}")
|
|
||||||
|
|
||||||
print(f"\nVector indexes: {len(vector_indexes)}")
|
|
||||||
for idx in vector_indexes:
|
|
||||||
print(f" ✓ {idx.get('name')}")
|
|
||||||
|
|
||||||
print(f"\nRange indexes (including config_id): {len(range_indexes)}")
|
|
||||||
for idx in range_indexes:
|
|
||||||
print(f" ✓ {idx.get('name')}")
|
|
||||||
|
|
||||||
if not vector_indexes:
|
|
||||||
print("\n⚠️ WARNING: No vector indexes found!")
|
|
||||||
print(" Embedding search will be VERY SLOW (~1.4s)")
|
|
||||||
print(" Run: python create_indexes.py")
|
|
||||||
|
|
||||||
# Check for config_id indexes
|
|
||||||
config_id_indexes = [idx for idx in range_indexes if 'config_id' in idx.get('name', '')]
|
|
||||||
if len(config_id_indexes) < 4:
|
|
||||||
print("\n⚠️ WARNING: Not all config_id indexes found!")
|
|
||||||
print(f" Expected 4, found {len(config_id_indexes)}")
|
|
||||||
print(" Run: python create_indexes.py config_id")
|
|
||||||
|
|
||||||
print("=" * 70)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await connector.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
|
|
||||||
if len(sys.argv) > 1:
|
|
||||||
command = sys.argv[1]
|
|
||||||
if command == "check":
|
|
||||||
asyncio.run(check_indexes())
|
|
||||||
elif command == "fulltext":
|
|
||||||
asyncio.run(create_fulltext_indexes())
|
|
||||||
elif command == "vector":
|
|
||||||
asyncio.run(create_vector_indexes())
|
|
||||||
elif command == "config_id":
|
|
||||||
asyncio.run(create_config_id_indexes())
|
|
||||||
elif command == "constraints":
|
|
||||||
asyncio.run(create_unique_constraints())
|
|
||||||
else:
|
|
||||||
print(f"Unknown command: {command}")
|
|
||||||
print("\nUsage:")
|
|
||||||
print(" python create_indexes.py # Create all indexes")
|
|
||||||
print(" python create_indexes.py check # Check existing indexes")
|
|
||||||
print(" python create_indexes.py fulltext # Create only full-text indexes")
|
|
||||||
print(" python create_indexes.py vector # Create only vector indexes")
|
|
||||||
print(" python create_indexes.py config_id # Create only config_id indexes")
|
|
||||||
print(" python create_indexes.py constraints # Create only constraints")
|
|
||||||
else:
|
|
||||||
asyncio.run(create_all_indexes())
|
|
||||||
|
|
||||||
|
|||||||
@@ -340,17 +340,22 @@ SEARCH_ENTITIES_BY_NAME_OR_ALIAS = """
|
|||||||
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
CALL db.index.fulltext.queryNodes("entitiesFulltext", $q) YIELD node AS e, score
|
||||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
||||||
WITH e, score
|
WITH e, score
|
||||||
UNION
|
WITH collect({entity: e, score: score}) AS fulltextResults
|
||||||
MATCH (e:ExtractedEntity)
|
|
||||||
WHERE ($end_user_id IS NULL OR e.end_user_id = $end_user_id)
|
OPTIONAL MATCH (ae:ExtractedEntity)
|
||||||
AND e.aliases IS NOT NULL
|
WHERE ($end_user_id IS NULL OR ae.end_user_id = $end_user_id)
|
||||||
AND ANY(alias IN e.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
AND ae.aliases IS NOT NULL
|
||||||
WITH e,
|
AND ANY(alias IN ae.aliases WHERE toLower(alias) CONTAINS toLower($q))
|
||||||
|
WITH fulltextResults, collect(ae) AS aliasEntities
|
||||||
|
|
||||||
|
UNWIND (fulltextResults + [x IN aliasEntities | {entity: x, score:
|
||||||
CASE
|
CASE
|
||||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) = toLower($q)) THEN 1.0
|
||||||
WHEN ANY(alias IN e.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
WHEN ANY(alias IN x.aliases WHERE toLower(alias) STARTS WITH toLower($q)) THEN 0.9
|
||||||
ELSE 0.8
|
ELSE 0.8
|
||||||
END AS score
|
END
|
||||||
|
}]) AS row
|
||||||
|
WITH row.entity AS e, row.score AS score
|
||||||
WITH DISTINCT e, MAX(score) AS score
|
WITH DISTINCT e, MAX(score) AS score
|
||||||
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
OPTIONAL MATCH (c:Chunk)-[:CONTAINS]->(s)
|
||||||
|
|||||||
@@ -158,22 +158,26 @@ class UserRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def get_users_by_tenant(
|
def get_users_by_tenant(
|
||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: Optional[bool] = None,
|
is_active: Optional[bool] = None,
|
||||||
|
is_superuser: Optional[bool] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None
|
||||||
) -> List[User]:
|
) -> List[User]:
|
||||||
"""获取租户下的用户列表"""
|
"""获取租户下的用户列表"""
|
||||||
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
|
db_logger.debug(f"查询租户用户: tenant_id={tenant_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
|
query = self.db.query(User).options(joinedload(User.tenant)).filter(User.tenant_id == tenant_id)
|
||||||
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.filter(User.is_active == is_active)
|
query = query.filter(User.is_active == is_active)
|
||||||
|
|
||||||
|
if is_superuser is not None:
|
||||||
|
query = query.filter(User.is_superuser == is_superuser)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
@@ -181,7 +185,7 @@ class UserRepository:
|
|||||||
User.email.ilike(f"%{search}%")
|
User.email.ilike(f"%{search}%")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
users = query.offset(skip).limit(limit).all()
|
users = query.offset(skip).limit(limit).all()
|
||||||
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
db_logger.debug(f"租户用户查询成功: tenant_id={tenant_id}, count={len(users)}")
|
||||||
return users
|
return users
|
||||||
@@ -190,18 +194,22 @@ class UserRepository:
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
def count_users_by_tenant(
|
def count_users_by_tenant(
|
||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
is_active: Optional[bool] = None,
|
is_active: Optional[bool] = None,
|
||||||
|
is_superuser: Optional[bool] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""统计租户下的用户数量"""
|
"""统计租户下的用户数量"""
|
||||||
try:
|
try:
|
||||||
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
query = self.db.query(func.count(User.id)).filter(User.tenant_id == tenant_id)
|
||||||
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
query = query.filter(User.is_active == is_active)
|
query = query.filter(User.is_active == is_active)
|
||||||
|
|
||||||
|
if is_superuser is not None:
|
||||||
|
query = query.filter(User.is_superuser == is_superuser)
|
||||||
|
|
||||||
if search:
|
if search:
|
||||||
query = query.filter(
|
query = query.filter(
|
||||||
or_(
|
or_(
|
||||||
@@ -209,7 +217,7 @@ class UserRepository:
|
|||||||
User.email.ilike(f"%{search}%")
|
User.email.ilike(f"%{search}%")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return query.scalar()
|
return query.scalar()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
db_logger.error(f"统计租户用户失败: tenant_id={tenant_id} - {str(e)}")
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Annotated
|
from typing import Any, Annotated, Literal
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc, select
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
|
|
||||||
from app.models.workflow_model import (
|
from app.models.workflow_model import (
|
||||||
@@ -128,29 +128,36 @@ class WorkflowExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
执行记录列表
|
执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowExecution).filter(
|
stmt = select(WorkflowExecution).filter(
|
||||||
WorkflowExecution.app_id == app_id
|
WorkflowExecution.app_id == app_id
|
||||||
).order_by(
|
).order_by(
|
||||||
desc(WorkflowExecution.started_at)
|
desc(WorkflowExecution.started_at)
|
||||||
).limit(limit).offset(offset).all()
|
).limit(limit).offset(offset)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def get_by_conversation_id(
|
def get_by_conversation_id(
|
||||||
self,
|
self,
|
||||||
conversation_id: uuid.UUID
|
conversation_id: uuid.UUID,
|
||||||
|
status: Literal["running", "completed", "failed"] = None,
|
||||||
|
limit_count: int = 50
|
||||||
) -> list[WorkflowExecution]:
|
) -> list[WorkflowExecution]:
|
||||||
"""根据会话 ID 获取执行记录列表
|
"""根据会话 ID 获取执行记录列表
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
limit_count:
|
||||||
conversation_id: 会话 ID
|
conversation_id: 会话 ID
|
||||||
|
status: 状态(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
执行记录列表
|
执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowExecution).filter(
|
stmt = select(WorkflowExecution).filter(
|
||||||
WorkflowExecution.conversation_id == conversation_id
|
WorkflowExecution.conversation_id == conversation_id
|
||||||
).order_by(
|
)
|
||||||
desc(WorkflowExecution.started_at)
|
if status:
|
||||||
).all()
|
stmt = stmt.filter(WorkflowExecution.status == status)
|
||||||
|
stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
def count_by_app_id(self, app_id: uuid.UUID) -> int:
|
||||||
"""统计应用的执行次数
|
"""统计应用的执行次数
|
||||||
@@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
节点执行记录列表(按执行顺序排序)
|
节点执行记录列表(按执行顺序排序)
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowNodeExecution).filter(
|
stmt = select(WorkflowNodeExecution).filter(
|
||||||
WorkflowNodeExecution.execution_id == execution_id
|
WorkflowNodeExecution.execution_id == execution_id
|
||||||
).order_by(
|
).order_by(
|
||||||
WorkflowNodeExecution.execution_order
|
WorkflowNodeExecution.execution_order
|
||||||
).all()
|
)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
def get_by_node_id(
|
def get_by_node_id(
|
||||||
self,
|
self,
|
||||||
@@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository:
|
|||||||
Returns:
|
Returns:
|
||||||
节点执行记录列表
|
节点执行记录列表
|
||||||
"""
|
"""
|
||||||
return self.db.query(WorkflowNodeExecution).filter(
|
stmt = select(WorkflowNodeExecution).filter(
|
||||||
WorkflowNodeExecution.execution_id == execution_id,
|
WorkflowNodeExecution.execution_id == execution_id,
|
||||||
WorkflowNodeExecution.node_id == node_id
|
WorkflowNodeExecution.node_id == node_id
|
||||||
).order_by(
|
).order_by(
|
||||||
WorkflowNodeExecution.retry_count
|
WorkflowNodeExecution.retry_count
|
||||||
).all()
|
)
|
||||||
|
return list(self.db.execute(stmt).scalars())
|
||||||
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ class AgentConfigCreate(BaseModel):
|
|||||||
|
|
||||||
# 记忆配置
|
# 记忆配置
|
||||||
memory: MemoryConfig = Field(
|
memory: MemoryConfig = Field(
|
||||||
default_factory=lambda: MemoryConfig(enabled=True),
|
default_factory=lambda: MemoryConfig(enabled=False),
|
||||||
description="对话历史记忆配置"
|
description="对话历史记忆配置"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ class Write_UserInput(BaseModel):
|
|||||||
end_user_id: str
|
end_user_id: str
|
||||||
config_id: Optional[str] = None
|
config_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class AgentMemory_Long_Term(ABC):
|
class AgentMemory_Long_Term(ABC):
|
||||||
"""长期记忆配置常量"""
|
"""长期记忆配置常量"""
|
||||||
STORAGE_NEO4J = "neo4j"
|
STORAGE_NEO4J = "neo4j"
|
||||||
@@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC):
|
|||||||
STRATEGY_CHUNK = "chunk"
|
STRATEGY_CHUNK = "chunk"
|
||||||
STRATEGY_TIME = "time"
|
STRATEGY_TIME = "time"
|
||||||
DEFAULT_SCOPE = 6
|
DEFAULT_SCOPE = 6
|
||||||
TIME_SCOPE=5
|
TIME_SCOPE = 5
|
||||||
class AgentMemoryDataset(ABC):
|
|
||||||
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
|
||||||
NAME='用户'
|
|
||||||
|
|
||||||
|
|
||||||
|
class AgentMemoryDataset(ABC):
|
||||||
|
PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余']
|
||||||
|
NAME = '用户'
|
||||||
|
|||||||
@@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel):
|
|||||||
"""Request schema for creating an end user.
|
"""Request schema for creating an end user.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
workspace_id: Workspace ID (required)
|
|
||||||
other_id: External user identifier (required)
|
other_id: External user identifier (required)
|
||||||
other_name: Display name for the end user
|
other_name: Display name for the end user
|
||||||
|
memory_config_id: Optional memory config ID. If not provided, uses workspace default.
|
||||||
"""
|
"""
|
||||||
workspace_id: str = Field(..., description="Workspace ID (required)")
|
|
||||||
other_id: str = Field(..., description="External user identifier (required)")
|
other_id: str = Field(..., description="External user identifier (required)")
|
||||||
other_name: Optional[str] = Field("", description="Display name")
|
other_name: Optional[str] = Field("", description="Display name")
|
||||||
|
memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.")
|
||||||
@field_validator("workspace_id")
|
|
||||||
@classmethod
|
|
||||||
def validate_workspace_id(cls, v: str) -> str:
|
|
||||||
"""Validate that workspace_id is not empty."""
|
|
||||||
if not v or not v.strip():
|
|
||||||
raise ValueError("workspace_id is required and cannot be empty")
|
|
||||||
return v.strip()
|
|
||||||
|
|
||||||
@field_validator("other_id")
|
@field_validator("other_id")
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel):
|
|||||||
other_id: External user identifier
|
other_id: External user identifier
|
||||||
other_name: Display name
|
other_name: Display name
|
||||||
workspace_id: Workspace the user belongs to
|
workspace_id: Workspace the user belongs to
|
||||||
|
memory_config_id: Connected memory config ID
|
||||||
"""
|
"""
|
||||||
id: str = Field(..., description="End user UUID")
|
id: str = Field(..., description="End user UUID")
|
||||||
other_id: str = Field(..., description="External user identifier")
|
other_id: str = Field(..., description="External user identifier")
|
||||||
other_name: str = Field("", description="Display name")
|
other_name: str = Field("", description="Display name")
|
||||||
workspace_id: str = Field(..., description="Workspace ID")
|
workspace_id: str = Field(..., description="Workspace ID")
|
||||||
|
memory_config_id: Optional[str] = Field(None, description="Connected memory config ID")
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfigItem(BaseModel):
|
class MemoryConfigItem(BaseModel):
|
||||||
|
|||||||
@@ -478,6 +478,22 @@ class PendingForgettingNode(BaseModel):
|
|||||||
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
last_access_time: int = Field(..., description="最后访问时间(Unix时间戳,秒)")
|
||||||
|
|
||||||
|
|
||||||
|
class PageInfo(BaseModel):
|
||||||
|
"""分页信息模型"""
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
page: int = Field(..., description="当前页码(从1开始)")
|
||||||
|
pagesize: int = Field(..., description="每页数量")
|
||||||
|
total: int = Field(..., description="总记录数")
|
||||||
|
hasnext: bool = Field(..., description="是否有下一页")
|
||||||
|
|
||||||
|
|
||||||
|
class PendingNodesResponse(BaseModel):
|
||||||
|
"""待遗忘节点列表响应模型(独立分页接口)"""
|
||||||
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
|
items: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表")
|
||||||
|
page: PageInfo = Field(..., description="分页信息")
|
||||||
|
|
||||||
|
|
||||||
class ForgettingStatsResponse(BaseModel):
|
class ForgettingStatsResponse(BaseModel):
|
||||||
"""遗忘引擎统计信息响应模型"""
|
"""遗忘引擎统计信息响应模型"""
|
||||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||||
@@ -485,7 +501,6 @@ class ForgettingStatsResponse(BaseModel):
|
|||||||
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
node_distribution: Dict[str, int] = Field(..., description="节点类型分布")
|
||||||
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
recent_trends: List[ForgettingCycleHistoryPoint] = Field(...,
|
||||||
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
description="最近7个日期的遗忘趋势数据(每天取最后一次执行)")
|
||||||
pending_nodes: List[PendingForgettingNode] = Field(..., description="待遗忘节点列表(前20个满足遗忘条件的节点)")
|
|
||||||
timestamp: int = Field(..., description="统计时间(时间戳)")
|
timestamp: int = Field(..., description="统计时间(时间戳)")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import field
|
from dataclasses import field
|
||||||
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
|
from pydantic import BaseModel, EmailStr, Field, field_validator, validator, ConfigDict
|
||||||
from typing import Optional
|
from typing import Optional, List
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@@ -20,6 +20,7 @@ class UserCreate(UserBase):
|
|||||||
class UserUpdate(BaseModel):
|
class UserUpdate(BaseModel):
|
||||||
username: Optional[str] = None
|
username: Optional[str] = None
|
||||||
email: Optional[EmailStr] = None
|
email: Optional[EmailStr] = None
|
||||||
|
phone: Optional[str] = None
|
||||||
is_active: Optional[bool] = None
|
is_active: Optional[bool] = None
|
||||||
is_superuser: Optional[bool] = None
|
is_superuser: Optional[bool] = None
|
||||||
|
|
||||||
@@ -85,6 +86,8 @@ class User(UserBase):
|
|||||||
current_workspace_name: Optional[str] = None
|
current_workspace_name: Optional[str] = None
|
||||||
role: Optional[WorkspaceRole] = None
|
role: Optional[WorkspaceRole] = None
|
||||||
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
preferred_language: Optional[str] = "zh" # 用户语言偏好
|
||||||
|
phone: Optional[str] = None # 用户电话
|
||||||
|
permissions: Optional[List[str]] = None # 用户权限列表,由 external_source 的 permissions 控制
|
||||||
|
|
||||||
# 将 datetime 转换为毫秒时间戳
|
# 将 datetime 转换为毫秒时间戳
|
||||||
@validator("created_at", pre=True)
|
@validator("created_at", pre=True)
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
from app.models import MultiAgentConfig, AgentConfig, ModelType
|
||||||
from app.models import WorkflowConfig
|
from app.models import WorkflowConfig
|
||||||
@@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo
|
|||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
from app.schemas import FileType
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -43,18 +44,17 @@ class AppChatService:
|
|||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
user_id: Optional[str] = None,
|
files: list[FileInput],
|
||||||
|
user_id: str,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None
|
||||||
files: Optional[List[FileInput]] = None
|
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
|
||||||
|
|
||||||
# 应用 features 配置
|
# 应用 features 配置
|
||||||
features_config: dict = config.features or {}
|
features_config: dict = config.features or {}
|
||||||
@@ -93,7 +93,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval,
|
||||||
|
user_id)
|
||||||
tools.extend(kb_tools)
|
tools.extend(kb_tools)
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
if memory:
|
if memory:
|
||||||
@@ -140,13 +141,13 @@ class AppChatService:
|
|||||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||||
is_new_conversation = len(history) == 0
|
is_new_conversation = len(history) == 0
|
||||||
if is_new_conversation:
|
if is_new_conversation:
|
||||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||||
if opening:
|
if opening:
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=opening,
|
content=opening,
|
||||||
meta_data={}
|
meta_data={"suggested_questions": suggested_questions}
|
||||||
)
|
)
|
||||||
# 重新加载历史(包含刚写入的开场白)
|
# 重新加载历史(包含刚写入的开场白)
|
||||||
history = await self.conversation_service.get_conversation_history(
|
history = await self.conversation_service.get_conversation_history(
|
||||||
@@ -168,11 +169,6 @@ class AppChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -229,6 +225,21 @@ class AppChatService:
|
|||||||
# 保存消息
|
# 保存消息
|
||||||
if audio_url:
|
if audio_url:
|
||||||
assistant_meta["audio_url"] = audio_url
|
assistant_meta["audio_url"] = audio_url
|
||||||
|
if memory_flag:
|
||||||
|
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||||
|
memory_config_id: str = connected_config.get("memory_config_id")
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||||
|
{"role": "assistant", "content": result["content"]}
|
||||||
|
]
|
||||||
|
if memory_config_id:
|
||||||
|
await write_long_term(
|
||||||
|
storage_type,
|
||||||
|
user_id,
|
||||||
|
messages,
|
||||||
|
user_rag_memory_id,
|
||||||
|
memory_config_id
|
||||||
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
@@ -264,20 +275,19 @@ class AppChatService:
|
|||||||
message: str,
|
message: str,
|
||||||
conversation_id: uuid.UUID,
|
conversation_id: uuid.UUID,
|
||||||
config: AgentConfig,
|
config: AgentConfig,
|
||||||
|
files: list[FileInput],
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
memory: bool = True,
|
memory: bool = True,
|
||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None
|
||||||
files: Optional[List[FileInput]] = None
|
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
|
|
||||||
# 应用 features 配置
|
# 应用 features 配置
|
||||||
@@ -319,7 +329,8 @@ class AppChatService:
|
|||||||
tools.extend(skill_tools)
|
tools.extend(skill_tools)
|
||||||
if skill_prompts:
|
if skill_prompts:
|
||||||
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
system_prompt = f"{system_prompt}\n\n{skill_prompts}"
|
||||||
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)
|
kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(
|
||||||
|
config.knowledge_retrieval, user_id)
|
||||||
tools.extend(kb_tools)
|
tools.extend(kb_tools)
|
||||||
# 添加长期记忆工具
|
# 添加长期记忆工具
|
||||||
memory_flag = False
|
memory_flag = False
|
||||||
@@ -367,13 +378,13 @@ class AppChatService:
|
|||||||
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
# 如果是新会话且有开场白,作为第一条 assistant 消息写入数据库
|
||||||
is_new_conversation = len(history) == 0
|
is_new_conversation = len(history) == 0
|
||||||
if is_new_conversation:
|
if is_new_conversation:
|
||||||
opening = self.agent_service._get_opening_statement(features_config, True, variables)
|
opening, suggested_questions = self.agent_service._get_opening_statement(features_config, True, variables)
|
||||||
if opening:
|
if opening:
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=opening,
|
content=opening,
|
||||||
meta_data={}
|
meta_data={"suggested_questions": suggested_questions}
|
||||||
)
|
)
|
||||||
# 重新加载历史(包含刚写入的开场白)
|
# 重新加载历史(包含刚写入的开场白)
|
||||||
history = await self.conversation_service.get_conversation_history(
|
history = await self.conversation_service.get_conversation_history(
|
||||||
@@ -411,11 +422,6 @@ class AppChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
@@ -459,7 +465,7 @@ class AppChatService:
|
|||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
human_meta = {
|
human_meta = {
|
||||||
"files":[],
|
"files": [],
|
||||||
"history_files": {}
|
"history_files": {}
|
||||||
}
|
}
|
||||||
assistant_meta = {
|
assistant_meta = {
|
||||||
@@ -484,6 +490,22 @@ class AppChatService:
|
|||||||
|
|
||||||
if stream_audio_url:
|
if stream_audio_url:
|
||||||
assistant_meta["audio_url"] = stream_audio_url
|
assistant_meta["audio_url"] = stream_audio_url
|
||||||
|
|
||||||
|
if memory_flag:
|
||||||
|
connected_config = get_end_user_connected_config(user_id, self.db)
|
||||||
|
memory_config_id: str = connected_config.get("memory_config_id")
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": message, "files": [file.model_dump() for file in files]},
|
||||||
|
{"role": "assistant", "content": full_content}
|
||||||
|
]
|
||||||
|
if memory_config_id:
|
||||||
|
await write_long_term(
|
||||||
|
storage_type,
|
||||||
|
user_id,
|
||||||
|
messages,
|
||||||
|
user_rag_memory_id,
|
||||||
|
memory_config_id
|
||||||
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
@@ -618,7 +640,6 @@ class AppChatService:
|
|||||||
# 2. 创建编排器
|
# 2. 创建编排器
|
||||||
orchestrator = MultiAgentOrchestrator(self.db, config)
|
orchestrator = MultiAgentOrchestrator(self.db, config)
|
||||||
|
|
||||||
|
|
||||||
# 3. 流式执行任务
|
# 3. 流式执行任务
|
||||||
async for event in orchestrator.execute_stream(
|
async for event in orchestrator.execute_stream(
|
||||||
message=message,
|
message=message,
|
||||||
|
|||||||
128
api/app/services/app_log_service.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
"""应用日志服务层"""
|
||||||
|
import uuid
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models.conversation_model import Conversation, Message
|
||||||
|
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||||
|
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class AppLogService:
|
||||||
|
"""应用日志服务"""
|
||||||
|
|
||||||
|
def __init__(self, db: Session):
|
||||||
|
self.db = db
|
||||||
|
self.conversation_repository = ConversationRepository(db)
|
||||||
|
self.message_repository = MessageRepository(db)
|
||||||
|
|
||||||
|
def list_conversations(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20,
|
||||||
|
is_draft: Optional[bool] = None,
|
||||||
|
) -> Tuple[list[Conversation], int]:
|
||||||
|
"""
|
||||||
|
查询应用日志会话列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
page: 页码(从 1 开始)
|
||||||
|
pagesize: 每页数量
|
||||||
|
is_draft: 是否草稿会话(None 表示不过滤)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[list[Conversation], int]: (会话列表,总数)
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"查询应用日志会话列表",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"workspace_id": str(workspace_id),
|
||||||
|
"page": page,
|
||||||
|
"pagesize": pagesize,
|
||||||
|
"is_draft": is_draft
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 使用 Repository 查询
|
||||||
|
conversations, total = self.conversation_repository.list_app_conversations(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
is_draft=is_draft,
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"查询应用日志会话列表成功",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"total": total,
|
||||||
|
"returned": len(conversations)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversations, total
|
||||||
|
|
||||||
|
def get_conversation_detail(
|
||||||
|
self,
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
conversation_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID
|
||||||
|
) -> Conversation:
|
||||||
|
"""
|
||||||
|
查询会话详情(包含消息)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app_id: 应用 ID
|
||||||
|
conversation_id: 会话 ID
|
||||||
|
workspace_id: 工作空间 ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Conversation: 包含消息的会话对象
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ResourceNotFoundException: 当会话不存在时
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"查询应用日志会话详情",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"workspace_id": str(workspace_id)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询会话
|
||||||
|
conversation = self.conversation_repository.get_conversation_for_app_log(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 查询消息(按时间正序)
|
||||||
|
messages = self.message_repository.get_messages_by_conversation(
|
||||||
|
conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将消息附加到会话对象
|
||||||
|
conversation.messages = messages
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"查询应用日志会话详情成功",
|
||||||
|
extra={
|
||||||
|
"app_id": str(app_id),
|
||||||
|
"conversation_id": str(conversation_id),
|
||||||
|
"message_count": len(messages)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return conversation
|
||||||
@@ -1084,7 +1084,6 @@ class AppService:
|
|||||||
if not exists:
|
if not exists:
|
||||||
cleaned["memory_config_id"] = None
|
cleaned["memory_config_id"] = None
|
||||||
cleaned.pop("memory_content", None)
|
cleaned.pop("memory_content", None)
|
||||||
cleaned["enabled"] = False
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
exists = self.db.query(
|
exists = self.db.query(
|
||||||
@@ -1096,7 +1095,6 @@ class AppService:
|
|||||||
if not exists:
|
if not exists:
|
||||||
cleaned["memory_config_id"] = None
|
cleaned["memory_config_id"] = None
|
||||||
cleaned.pop("memory_content", None)
|
cleaned.pop("memory_content", None)
|
||||||
cleaned["enabled"] = False
|
|
||||||
|
|
||||||
return cleaned
|
return cleaned
|
||||||
|
|
||||||
@@ -1684,15 +1682,15 @@ class AppService:
|
|||||||
|
|
||||||
return config.config_id
|
return config.config_id
|
||||||
|
|
||||||
def _update_endusers_memory_config_by_workspace(
|
def _update_endusers_memory_config_by_app(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
memory_config_id: uuid.UUID
|
memory_config_id: uuid.UUID
|
||||||
) -> int:
|
) -> int:
|
||||||
"""批量更新应用下所有终端用户的 memory_config_id
|
"""批量更新应用下所有终端用户的 memory_config_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: 工作空间ID
|
app_id: 应用ID
|
||||||
memory_config_id: 新的记忆配置ID
|
memory_config_id: 新的记忆配置ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1701,8 +1699,8 @@ class AppService:
|
|||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
repo = EndUserRepository(self.db)
|
repo = EndUserRepository(self.db)
|
||||||
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
updated_count = repo.batch_update_memory_config_id_by_app(
|
||||||
workspace_id=workspace_id,
|
app_id=app_id,
|
||||||
memory_config_id=memory_config_id
|
memory_config_id=memory_config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1753,12 +1751,16 @@ class AppService:
|
|||||||
|
|
||||||
miss_params = []
|
miss_params = []
|
||||||
if agent_cfg.default_model_config_id is None:
|
if agent_cfg.default_model_config_id is None:
|
||||||
miss_params.append("model config")
|
miss_params.append("模型配置")
|
||||||
|
|
||||||
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
if agent_cfg.memory.get("enabled") and not agent_cfg.memory.get("memory_config_id"):
|
||||||
miss_params.append("memory config")
|
miss_params.append("记忆配置")
|
||||||
if miss_params:
|
if miss_params:
|
||||||
raise BusinessException(f"{', '.join(miss_params)} is required")
|
raise BusinessException(
|
||||||
|
f"应用发布失败:检测到以下必要配置尚未完成:{', '.join(miss_params)}。请返回应用编辑页面完成相关配置后再尝试发布。",
|
||||||
|
BizCode.CONFIG_MISSING,
|
||||||
|
context={"missing_params": miss_params},
|
||||||
|
)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
"system_prompt": agent_cfg.system_prompt,
|
"system_prompt": agent_cfg.system_prompt,
|
||||||
@@ -1877,8 +1879,8 @@ class AppService:
|
|||||||
if memory_config_id:
|
if memory_config_id:
|
||||||
app = self.db.query(App).filter(App.id == app_id).first()
|
app = self.db.query(App).filter(App.id == app_id).first()
|
||||||
if app:
|
if app:
|
||||||
updated_count = self._update_endusers_memory_config_by_workspace(
|
updated_count = self._update_endusers_memory_config_by_app(
|
||||||
app.workspace_id, memory_config_id
|
app_id, memory_config_id
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||||
@@ -2014,7 +2016,7 @@ class AppService:
|
|||||||
|
|
||||||
if memory_config_id:
|
if memory_config_id:
|
||||||
|
|
||||||
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
|
updated_count = self._update_endusers_memory_config_by_app(app_id, memory_config_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
|
|||||||
@@ -214,7 +214,7 @@ class ConversationService:
|
|||||||
|
|
||||||
conversation.message_count += 1
|
conversation.message_count += 1
|
||||||
|
|
||||||
if conversation.message_count == 1 and role == "user":
|
if conversation.message_count <= 2 and role == "user":
|
||||||
conversation.title = (
|
conversation.title = (
|
||||||
content[:50] + ("..." if len(content) > 50 else "")
|
content[:50] + ("..." if len(content) > 50 else "")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.rag.nlp.search import knowledge_retrieval
|
from app.core.rag.nlp.search import knowledge_retrieval
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import AgentConfig, ModelConfig, ModelType
|
from app.models import AgentConfig, ModelConfig
|
||||||
from app.repositories.tool_repository import ToolRepository
|
from app.repositories.tool_repository import ToolRepository
|
||||||
from app.schemas.app_schema import FileInput, Citation
|
from app.schemas.app_schema import FileInput, Citation
|
||||||
from app.schemas.model_schema import ModelInfo
|
from app.schemas.model_schema import ModelInfo
|
||||||
@@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
|||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.schemas import FileType
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -449,15 +448,16 @@ class AgentRunService:
|
|||||||
features_config: Dict[str, Any],
|
features_config: Dict[str, Any],
|
||||||
is_new_conversation: bool,
|
is_new_conversation: bool,
|
||||||
variables: Optional[Dict[str, Any]] = None
|
variables: Optional[Dict[str, Any]] = None
|
||||||
) -> Optional[str]:
|
) -> tuple[Any, Any]:
|
||||||
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
"""首轮对话时返回开场白文本(支持变量替换),否则返回 None"""
|
||||||
if not is_new_conversation:
|
if not is_new_conversation:
|
||||||
return None
|
return None, None
|
||||||
opening = features_config.get("opening_statement", {})
|
opening = features_config.get("opening_statement", {})
|
||||||
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||||
return None
|
return None, None
|
||||||
|
|
||||||
statement = opening["statement"]
|
statement = opening["statement"]
|
||||||
|
suggested_questions = opening["suggested_questions"]
|
||||||
|
|
||||||
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
# 如果有变量,进行替换(仅支持 {{var_name}} 格式)
|
||||||
if variables:
|
if variables:
|
||||||
@@ -465,7 +465,7 @@ class AgentRunService:
|
|||||||
placeholder = f"{{{{{var_name}}}}}"
|
placeholder = f"{{{{{var_name}}}}}"
|
||||||
statement = statement.replace(placeholder, str(var_value))
|
statement = statement.replace(placeholder, str(var_value))
|
||||||
|
|
||||||
return statement
|
return statement, suggested_questions
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_citations(
|
def _filter_citations(
|
||||||
@@ -599,13 +599,16 @@ class AgentRunService:
|
|||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
opening, suggested_questions = None, None
|
||||||
|
if not sub_agent:
|
||||||
|
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||||
conversation_id = await self._ensure_conversation(
|
conversation_id = await self._ensure_conversation(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
opening_statement=opening
|
opening_statement=opening,
|
||||||
|
suggested_questions=suggested_questions
|
||||||
)
|
)
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
@@ -657,11 +660,6 @@ class AgentRunService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=context,
|
context=context,
|
||||||
end_user_id=user_id,
|
|
||||||
config_id=config_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -845,14 +843,17 @@ class AgentRunService:
|
|||||||
|
|
||||||
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
# 5. 处理会话ID(创建或验证),新会话时写入开场白
|
||||||
is_new_conversation = not conversation_id
|
is_new_conversation = not conversation_id
|
||||||
opening = self._get_opening_statement(features_config, is_new_conversation, variables)
|
opening, suggested_questions = None, None
|
||||||
|
if not sub_agent:
|
||||||
|
opening, suggested_questions = self._get_opening_statement(features_config, is_new_conversation, variables)
|
||||||
conversation_id = await self._ensure_conversation(
|
conversation_id = await self._ensure_conversation(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
app_id=agent_config.app_id,
|
app_id=agent_config.app_id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
sub_agent=sub_agent,
|
sub_agent=sub_agent,
|
||||||
opening_statement=opening
|
opening_statement=opening,
|
||||||
|
suggested_questions=suggested_questions
|
||||||
)
|
)
|
||||||
|
|
||||||
model_info = ModelInfo(
|
model_info = ModelInfo(
|
||||||
@@ -911,11 +912,6 @@ class AgentRunService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=context,
|
context=context,
|
||||||
end_user_id=user_id,
|
|
||||||
config_id=config_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
memory_flag=memory_flag,
|
|
||||||
files=processed_files
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
@@ -1061,7 +1057,8 @@ class AgentRunService:
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
user_id: Optional[str],
|
user_id: Optional[str],
|
||||||
sub_agent: bool = False,
|
sub_agent: bool = False,
|
||||||
opening_statement: Optional[str] = None
|
opening_statement: Optional[str] = None,
|
||||||
|
suggested_questions: Optional[List[str]] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""确保会话存在(创建或验证)
|
"""确保会话存在(创建或验证)
|
||||||
|
|
||||||
@@ -1072,6 +1069,7 @@ class AgentRunService:
|
|||||||
user_id: 用户ID
|
user_id: 用户ID
|
||||||
sub_agent: 是否为子代理
|
sub_agent: 是否为子代理
|
||||||
opening_statement: 开场白(新会话时作为第一条消息写入)
|
opening_statement: 开场白(新会话时作为第一条消息写入)
|
||||||
|
suggested_questions: 预设问题列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 会话ID
|
str: 会话ID
|
||||||
@@ -1115,7 +1113,7 @@ class AgentRunService:
|
|||||||
conversation_id=uuid.UUID(new_conv_id),
|
conversation_id=uuid.UUID(new_conv_id),
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=opening_statement,
|
content=opening_statement,
|
||||||
meta_data={}
|
meta_data={"suggested_questions": suggested_questions}
|
||||||
)
|
)
|
||||||
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
logger.debug(f"已保存开场白到会话 {new_conv_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle
|
|||||||
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
from app.core.memory.agent.utils.write_tools import write as write_neo4j
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
from app.core.memory.analytics.hot_memory_tags import get_interest_distribution
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
from app.core.memory.utils.log.audit_logger import audit_logger
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import (
|
|||||||
)
|
)
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
from app.services.memory_perceptual_service import MemoryPerceptualService
|
||||||
|
|
||||||
try:
|
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
|
||||||
except ImportError:
|
|
||||||
audit_logger = None
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
config_logger = get_config_logger()
|
config_logger = get_config_logger()
|
||||||
|
|
||||||
@@ -68,24 +65,22 @@ class MemoryAgentService:
|
|||||||
if str(messages) == 'success':
|
if str(messages) == 'success':
|
||||||
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}")
|
||||||
# 记录成功的操作
|
# 记录成功的操作
|
||||||
if audit_logger:
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
success=True,
|
||||||
success=True,
|
duration=duration, details={"message_length": len(message)})
|
||||||
duration=duration, details={"message_length": len(message)})
|
|
||||||
return context
|
return context
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Write operation failed for group {end_user_id}")
|
logger.warning(f"Write operation failed for group {end_user_id}")
|
||||||
|
|
||||||
# 记录失败的操作
|
# 记录失败的操作
|
||||||
if audit_logger:
|
audit_logger.log_operation(
|
||||||
audit_logger.log_operation(
|
operation="WRITE",
|
||||||
operation="WRITE",
|
config_id=config_id,
|
||||||
config_id=config_id,
|
end_user_id=end_user_id,
|
||||||
end_user_id=end_user_id,
|
success=False,
|
||||||
success=False,
|
duration=duration,
|
||||||
duration=duration,
|
error=f"写入失败: {messages[:100]}"
|
||||||
error=f"写入失败: {messages[:100]}"
|
)
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(f"写入失败: {messages}")
|
raise ValueError(f"写入失败: {messages}")
|
||||||
|
|
||||||
@@ -338,10 +333,9 @@ class MemoryAgentService:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|
||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
duration = time.time() - start_time
|
||||||
duration = time.time() - start_time
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
success=False, duration=duration, error=error_msg)
|
||||||
success=False, duration=duration, error=error_msg)
|
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
@@ -401,10 +395,10 @@ class MemoryAgentService:
|
|||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Write operation failed: {str(e)}"
|
error_msg = f"Write operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id,
|
||||||
success=False, duration=duration, error=error_msg)
|
success=False, duration=duration, error=error_msg)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
async def read_memory(
|
async def read_memory(
|
||||||
@@ -469,10 +463,9 @@ class MemoryAgentService:
|
|||||||
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
logger.info(f"Read operation for group {end_user_id} with config_id {config_id}")
|
||||||
|
|
||||||
# 导入审计日志记录器
|
# 导入审计日志记录器
|
||||||
try:
|
|
||||||
from app.core.memory.utils.log.audit_logger import audit_logger
|
|
||||||
except ImportError:
|
|
||||||
audit_logger = None
|
|
||||||
|
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
try:
|
try:
|
||||||
@@ -492,16 +485,15 @@ class MemoryAgentService:
|
|||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
|
|
||||||
# Log failed operation
|
# Log failed operation
|
||||||
if audit_logger:
|
duration = time.time() - start_time
|
||||||
duration = time.time() - start_time
|
audit_logger.log_operation(
|
||||||
audit_logger.log_operation(
|
operation="READ",
|
||||||
operation="READ",
|
config_id=config_id,
|
||||||
config_id=config_id,
|
end_user_id=end_user_id,
|
||||||
end_user_id=end_user_id,
|
success=False,
|
||||||
success=False,
|
duration=duration,
|
||||||
duration=duration,
|
error=error_msg
|
||||||
error=error_msg
|
)
|
||||||
)
|
|
||||||
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
@@ -633,15 +625,15 @@ class MemoryAgentService:
|
|||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
|
||||||
if audit_logger:
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
success=True,
|
success=True,
|
||||||
duration=duration
|
duration=duration
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"answer": summary,
|
"answer": summary,
|
||||||
@@ -651,16 +643,16 @@ class MemoryAgentService:
|
|||||||
# Ensure proper error handling and logging
|
# Ensure proper error handling and logging
|
||||||
error_msg = f"Read operation failed: {str(e)}"
|
error_msg = f"Read operation failed: {str(e)}"
|
||||||
logger.error(error_msg)
|
logger.error(error_msg)
|
||||||
if audit_logger:
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
audit_logger.log_operation(
|
audit_logger.log_operation(
|
||||||
operation="READ",
|
operation="READ",
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
success=False,
|
success=False,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
error=error_msg
|
error=error_msg
|
||||||
)
|
)
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
def get_messages_list(self, user_input: Write_UserInput) -> list[dict]:
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import List, Optional
|
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
import uuid
|
import uuid
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.app_model import App
|
from app.models.app_model import App
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser, EndUser as EndUserModel
|
||||||
from app.models.memory_increment_model import MemoryIncrement
|
from app.models.memory_increment_model import MemoryIncrement
|
||||||
|
|
||||||
from app.repositories import (
|
from app.repositories import (
|
||||||
@@ -49,44 +50,40 @@ def get_current_workspace_type(
|
|||||||
|
|
||||||
|
|
||||||
def get_workspace_end_users(
|
def get_workspace_end_users(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
current_user: User
|
current_user: User
|
||||||
) -> List[EndUser]:
|
) -> List[EndUser]:
|
||||||
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
"""获取工作空间的所有宿主(优化版本:减少数据库查询次数)
|
||||||
|
|
||||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||||
"""
|
"""
|
||||||
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 查询应用(ORM)
|
# 查询应用(ORM)
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||||
|
|
||||||
if not apps_orm:
|
if not apps_orm:
|
||||||
business_logger.info("工作空间下没有应用")
|
business_logger.info("工作空间下没有应用")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 提取所有 app_id
|
# 提取所有 app_id
|
||||||
# app_ids = [app.id for app in apps_orm]
|
# app_ids = [app.id for app in apps_orm]
|
||||||
|
|
||||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||||
from app.models.end_user_model import EndUser as EndUserModel
|
|
||||||
from sqlalchemy import desc, nullslast
|
|
||||||
end_users_orm = db.query(EndUserModel).filter(
|
end_users_orm = db.query(EndUserModel).filter(
|
||||||
EndUserModel.workspace_id == workspace_id
|
EndUserModel.workspace_id == workspace_id
|
||||||
).order_by(
|
).order_by(
|
||||||
nullslast(desc(EndUserModel.created_at)),
|
nullslast(desc(EndUserModel.created_at)),
|
||||||
desc(EndUserModel.id)
|
desc(EndUserModel.id)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# 转换为 Pydantic 模型(只在需要时转换)
|
# 转换为 Pydantic 模型(只在需要时转换)
|
||||||
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||||
|
|
||||||
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
return end_users
|
return end_users
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -94,6 +91,85 @@ def get_workspace_end_users(
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_workspace_end_users_paginated(
|
||||||
|
db: Session,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
current_user: User,
|
||||||
|
page: int,
|
||||||
|
pagesize: int,
|
||||||
|
keyword: Optional[str] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||||
|
|
||||||
|
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||||
|
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
workspace_id: 工作空间ID
|
||||||
|
current_user: 当前用户
|
||||||
|
page: 页码(从1开始)
|
||||||
|
pagesize: 每页数量
|
||||||
|
keyword: 搜索关键词(可选,同时模糊匹配 other_name 和 id)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含 items(宿主列表)和 total(总记录数)的字典
|
||||||
|
"""
|
||||||
|
business_logger.info(f"获取工作空间宿主列表(分页): workspace_id={workspace_id}, keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建基础查询
|
||||||
|
base_query = db.query(EndUserModel).filter(
|
||||||
|
EndUserModel.workspace_id == workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建搜索条件(过滤空字符串和None)
|
||||||
|
keyword = keyword.strip() if keyword else None
|
||||||
|
|
||||||
|
if keyword:
|
||||||
|
keyword_pattern = f"%{keyword}%"
|
||||||
|
# other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效
|
||||||
|
base_query = base_query.filter(
|
||||||
|
or_(
|
||||||
|
EndUserModel.other_name.ilike(keyword_pattern),
|
||||||
|
and_(
|
||||||
|
or_(
|
||||||
|
EndUserModel.other_name.is_(None),
|
||||||
|
EndUserModel.other_name == "",
|
||||||
|
),
|
||||||
|
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)")
|
||||||
|
|
||||||
|
# 获取总记录数
|
||||||
|
total = base_query.count()
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
business_logger.info("工作空间下没有宿主")
|
||||||
|
return {"items": [], "total": 0}
|
||||||
|
|
||||||
|
# 分页查询
|
||||||
|
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||||
|
end_users_orm = base_query.order_by(
|
||||||
|
nullslast(desc(EndUserModel.created_at)),
|
||||||
|
desc(EndUserModel.id)
|
||||||
|
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||||
|
|
||||||
|
# 转换为 Pydantic 模型
|
||||||
|
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
|
||||||
|
|
||||||
|
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录,总计 {total} 条")
|
||||||
|
return {"items": end_users, "total": total}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_workspace_memory_increment(
|
def get_workspace_memory_increment(
|
||||||
db: Session,
|
db: Session,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
@@ -638,7 +714,24 @@ def get_rag_content(
|
|||||||
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
business_logger.error(f"获取文档 {document.id} 的chunks失败: {str(e)}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 4. 返回结果
|
# 4. 将所有 page_content 拼接后按角色分割为对话列表
|
||||||
|
merged_text = "\n".join(page_contents)
|
||||||
|
conversations = []
|
||||||
|
if merged_text.strip():
|
||||||
|
import re
|
||||||
|
# 在任意位置匹配 "user:" 或 "assistant:",不限于行首
|
||||||
|
parts = re.split(r'(user|assistant):', merged_text)
|
||||||
|
# parts 结构: ['', 'user', ' content...', 'assistant', ' content...', ...]
|
||||||
|
i = 1
|
||||||
|
while i < len(parts) - 1:
|
||||||
|
role = parts[i].strip()
|
||||||
|
content = parts[i + 1].strip()
|
||||||
|
# 将 content 中的 \n 还原为真实换行
|
||||||
|
content = content.replace("\\n", "\n")
|
||||||
|
if role in ("user", "assistant") and content:
|
||||||
|
conversations.append({"role": role, "content": content})
|
||||||
|
i += 2
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"page": {
|
"page": {
|
||||||
"page": page,
|
"page": page,
|
||||||
@@ -646,10 +739,10 @@ def get_rag_content(
|
|||||||
"total": global_total,
|
"total": global_total,
|
||||||
"hasnext": offset_end < global_total,
|
"hasnext": offset_end < global_total,
|
||||||
},
|
},
|
||||||
"items": page_contents
|
"items": conversations
|
||||||
}
|
}
|
||||||
|
|
||||||
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(page_contents)} 条")
|
business_logger.info(f"成功获取RAG内容: total={global_total}, page={page}, 返回={len(conversations)} 条对话")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -204,30 +204,35 @@ class MemoryForgetService:
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
forgetting_threshold: float,
|
forgetting_threshold: float,
|
||||||
min_days_since_access: int,
|
min_days_since_access: int,
|
||||||
limit: int = 20
|
page: Optional[int] = None,
|
||||||
) -> list[Dict[str, Any]]:
|
pagesize: Optional[int] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取待遗忘节点列表
|
获取待遗忘节点列表
|
||||||
|
|
||||||
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。支持分页查询。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
connector: Neo4j 连接器
|
connector: Neo4j 连接器
|
||||||
end_user_id: 组ID
|
end_user_id: 组ID
|
||||||
forgetting_threshold: 遗忘阈值
|
forgetting_threshold: 遗忘阈值
|
||||||
min_days_since_access: 最小未访问天数
|
min_days_since_access: 最小未访问天数
|
||||||
limit: 返回节点数量限制
|
page: 页码(可选,从1开始)
|
||||||
|
pagesize: 每页数量(可选)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
list: 待遗忘节点列表
|
dict: 包含待遗忘节点列表和分页信息的字典
|
||||||
|
- items: 待遗忘节点列表
|
||||||
|
- page: 分页信息(分页时)
|
||||||
"""
|
"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
|
||||||
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
# 计算最小访问时间(ISO 8601 格式字符串,使用 UTC 时区)
|
||||||
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
min_access_time = datetime.now(timezone.utc) - timedelta(days=min_days_since_access)
|
||||||
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
min_access_time_str = min_access_time.strftime('%Y-%m-%dT%H:%M:%S.%fZ')
|
||||||
|
|
||||||
query = """
|
# 基础查询(用于获取总数)
|
||||||
|
count_query = """
|
||||||
MATCH (n)
|
MATCH (n)
|
||||||
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
AND n.end_user_id = $end_user_id
|
AND n.end_user_id = $end_user_id
|
||||||
@@ -235,10 +240,22 @@ class MemoryForgetService:
|
|||||||
AND n.activation_value < $threshold
|
AND n.activation_value < $threshold
|
||||||
AND n.last_access_time IS NOT NULL
|
AND n.last_access_time IS NOT NULL
|
||||||
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||||
RETURN
|
RETURN count(n) as total
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 数据查询
|
||||||
|
data_query = """
|
||||||
|
MATCH (n)
|
||||||
|
WHERE (n:Statement OR n:ExtractedEntity OR n:MemorySummary)
|
||||||
|
AND n.end_user_id = $end_user_id
|
||||||
|
AND n.activation_value IS NOT NULL
|
||||||
|
AND n.activation_value < $threshold
|
||||||
|
AND n.last_access_time IS NOT NULL
|
||||||
|
AND datetime(n.last_access_time) < datetime($min_access_time_str)
|
||||||
|
RETURN
|
||||||
elementId(n) as node_id,
|
elementId(n) as node_id,
|
||||||
labels(n)[0] as node_type,
|
labels(n)[0] as node_type,
|
||||||
CASE
|
CASE
|
||||||
WHEN n:Statement THEN n.statement
|
WHEN n:Statement THEN n.statement
|
||||||
WHEN n:ExtractedEntity THEN n.name
|
WHEN n:ExtractedEntity THEN n.name
|
||||||
WHEN n:MemorySummary THEN n.content
|
WHEN n:MemorySummary THEN n.content
|
||||||
@@ -247,18 +264,32 @@ class MemoryForgetService:
|
|||||||
n.activation_value as activation_value,
|
n.activation_value as activation_value,
|
||||||
n.last_access_time as last_access_time
|
n.last_access_time as last_access_time
|
||||||
ORDER BY n.activation_value ASC
|
ORDER BY n.activation_value ASC
|
||||||
LIMIT $limit
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# 如果启用分页,添加 SKIP 和 LIMIT
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
data_query += " SKIP $skip LIMIT $limit"
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'end_user_id': end_user_id,
|
'end_user_id': end_user_id,
|
||||||
'threshold': forgetting_threshold,
|
'threshold': forgetting_threshold,
|
||||||
'min_access_time_str': min_access_time_str,
|
'min_access_time_str': min_access_time_str
|
||||||
'limit': limit
|
|
||||||
}
|
}
|
||||||
|
|
||||||
results = await connector.execute_query(query, **params)
|
# 获取总数(分页时需要)
|
||||||
|
total = 0
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
count_results = await connector.execute_query(count_query, **params)
|
||||||
|
if count_results:
|
||||||
|
total = count_results[0]['total']
|
||||||
|
|
||||||
|
# 添加分页参数
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
params['skip'] = (page - 1) * pagesize
|
||||||
|
params['limit'] = pagesize
|
||||||
|
|
||||||
|
results = await connector.execute_query(data_query, **params)
|
||||||
|
|
||||||
pending_nodes = []
|
pending_nodes = []
|
||||||
for result in results:
|
for result in results:
|
||||||
# 将节点类型标签转换为小写
|
# 将节点类型标签转换为小写
|
||||||
@@ -267,7 +298,7 @@ class MemoryForgetService:
|
|||||||
node_type_label = 'entity'
|
node_type_label = 'entity'
|
||||||
elif node_type_label == 'memorysummary':
|
elif node_type_label == 'memorysummary':
|
||||||
node_type_label = 'summary'
|
node_type_label = 'summary'
|
||||||
|
|
||||||
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
# 将 Neo4j DateTime 对象转换为时间戳(毫秒)
|
||||||
last_access_time = result['last_access_time']
|
last_access_time = result['last_access_time']
|
||||||
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
last_access_dt = convert_neo4j_datetime_to_python(last_access_time)
|
||||||
@@ -278,7 +309,7 @@ class MemoryForgetService:
|
|||||||
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
last_access_timestamp = int(last_access_dt.timestamp() * 1000)
|
||||||
else:
|
else:
|
||||||
last_access_timestamp = 0
|
last_access_timestamp = 0
|
||||||
|
|
||||||
pending_nodes.append({
|
pending_nodes.append({
|
||||||
'node_id': str(result['node_id']),
|
'node_id': str(result['node_id']),
|
||||||
'node_type': node_type_label,
|
'node_type': node_type_label,
|
||||||
@@ -286,8 +317,20 @@ class MemoryForgetService:
|
|||||||
'activation_value': result['activation_value'],
|
'activation_value': result['activation_value'],
|
||||||
'last_access_time': last_access_timestamp
|
'last_access_time': last_access_timestamp
|
||||||
})
|
})
|
||||||
|
|
||||||
return pending_nodes
|
# 构建返回结果
|
||||||
|
result: Dict[str, Any] = {'items': pending_nodes}
|
||||||
|
|
||||||
|
# 如果启用分页,添加分页信息
|
||||||
|
if page is not None and pagesize is not None and page > 0 and pagesize > 0:
|
||||||
|
result['page'] = {
|
||||||
|
'page': page,
|
||||||
|
'pagesize': pagesize,
|
||||||
|
'total': total,
|
||||||
|
'hasnext': (page * pagesize) < total
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
async def trigger_forgetting_cycle(
|
async def trigger_forgetting_cycle(
|
||||||
self,
|
self,
|
||||||
@@ -636,7 +679,7 @@ class MemoryForgetService:
|
|||||||
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
api_logger.error(f"获取历史趋势数据失败: {str(e)}")
|
||||||
# 失败时返回空列表,不影响主流程
|
# 失败时返回空列表,不影响主流程
|
||||||
|
|
||||||
# 获取待遗忘节点列表(前20个满足遗忘条件的节点)
|
# 获取待遗忘节点列表
|
||||||
pending_nodes = []
|
pending_nodes = []
|
||||||
try:
|
try:
|
||||||
if end_user_id:
|
if end_user_id:
|
||||||
@@ -652,8 +695,7 @@ class MemoryForgetService:
|
|||||||
connector=connector,
|
connector=connector,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
forgetting_threshold=forgetting_threshold,
|
forgetting_threshold=forgetting_threshold,
|
||||||
min_days_since_access=int(min_days),
|
min_days_since_access=int(min_days)
|
||||||
limit=20
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
api_logger.info(f"成功获取 {len(pending_nodes)} 个待遗忘节点")
|
||||||
@@ -661,24 +703,79 @@ class MemoryForgetService:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
api_logger.error(f"获取待遗忘节点失败: {str(e)}")
|
||||||
# 失败时返回空列表,不影响主流程
|
# 失败时返回空列表,不影响主流程
|
||||||
|
|
||||||
# 构建统计信息
|
# 构建统计信息(不包含 pending_nodes,已分离到独立接口)
|
||||||
stats = {
|
stats = {
|
||||||
'activation_metrics': activation_metrics,
|
'activation_metrics': activation_metrics,
|
||||||
'node_distribution': node_distribution,
|
'node_distribution': node_distribution,
|
||||||
'recent_trends': recent_trends,
|
'recent_trends': recent_trends,
|
||||||
'pending_nodes': pending_nodes,
|
|
||||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||||
}
|
}
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
f"成功获取遗忘引擎统计: total_nodes={stats['activation_metrics']['total_nodes']}, "
|
||||||
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
f"low_activation_nodes={stats['activation_metrics']['low_activation_nodes']}, "
|
||||||
f"trend_days={len(recent_trends)}, pending_nodes={len(pending_nodes)}"
|
f"trend_days={len(recent_trends)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
|
async def get_pending_nodes(
|
||||||
|
self,
|
||||||
|
db: Session,
|
||||||
|
end_user_id: str,
|
||||||
|
config_id: Optional[UUID] = None,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 10
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取待遗忘节点列表(独立分页接口)
|
||||||
|
|
||||||
|
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: 数据库会话
|
||||||
|
end_user_id: 组ID(必填)
|
||||||
|
config_id: 配置ID(可选,用于获取遗忘阈值)
|
||||||
|
page: 页码(从1开始,默认1)
|
||||||
|
pagesize: 每页数量(默认10)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: 包含待遗忘节点列表和分页信息的字典
|
||||||
|
- items: 待遗忘节点列表
|
||||||
|
- page: 分页信息
|
||||||
|
"""
|
||||||
|
# 获取遗忘引擎组件
|
||||||
|
_, _, forgetting_scheduler, config = await self._get_forgetting_components(db, config_id)
|
||||||
|
|
||||||
|
connector = forgetting_scheduler.connector
|
||||||
|
forgetting_threshold = config['forgetting_threshold']
|
||||||
|
|
||||||
|
# 验证 min_days_since_access 配置值
|
||||||
|
min_days = config.get('min_days_since_access')
|
||||||
|
if min_days is None or not isinstance(min_days, (int, float)) or min_days < 0:
|
||||||
|
api_logger.warning(
|
||||||
|
f"min_days_since_access 配置无效: {min_days}, 使用默认值 7"
|
||||||
|
)
|
||||||
|
min_days = 7
|
||||||
|
|
||||||
|
# 调用内部方法获取分页数据
|
||||||
|
pending_nodes_result = await self._get_pending_forgetting_nodes(
|
||||||
|
connector=connector,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
forgetting_threshold=forgetting_threshold,
|
||||||
|
min_days_since_access=int(min_days),
|
||||||
|
page=page,
|
||||||
|
pagesize=pagesize
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功获取待遗忘节点列表: end_user_id={end_user_id}, "
|
||||||
|
f"page={page}, pagesize={pagesize}, total={pending_nodes_result.get('page', {}).get('total', 0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return pending_nodes_result
|
||||||
|
|
||||||
async def get_forgetting_curve(
|
async def get_forgetting_curve(
|
||||||
self,
|
self,
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -243,28 +243,9 @@ class MemoryPerceptualService:
|
|||||||
memory_config: MemoryConfig,
|
memory_config: MemoryConfig,
|
||||||
file: FileInput
|
file: FileInput
|
||||||
):
|
):
|
||||||
memories = self.repository.get_by_url(file.url)
|
|
||||||
if memories:
|
|
||||||
business_logger.info(f"Perceptual memory already exists: {file.url}")
|
|
||||||
if end_user_id not in [memory.end_user_id for memory in memories]:
|
|
||||||
business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}")
|
|
||||||
memory_cache = memories[0]
|
|
||||||
memory = self.repository.create_perceptual_memory(
|
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
|
||||||
perceptual_type=PerceptualType(memory_cache.perceptual_type),
|
|
||||||
file_path=memory_cache.file_path,
|
|
||||||
file_name=memory_cache.file_name,
|
|
||||||
file_ext=memory_cache.file_ext,
|
|
||||||
summary=memory_cache.summary,
|
|
||||||
meta_data=memory_cache.meta_data
|
|
||||||
)
|
|
||||||
self.db.commit()
|
|
||||||
return memory
|
|
||||||
else:
|
|
||||||
for memory in memories:
|
|
||||||
if memory.end_user_id == uuid.UUID(end_user_id):
|
|
||||||
return memory
|
|
||||||
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
llm, model_config = self._get_mutlimodal_client(file.type, memory_config)
|
||||||
|
if model_config is None or llm is None:
|
||||||
|
return None
|
||||||
multimodel_service = MultimodalService(self.db, ModelInfo(
|
multimodel_service = MultimodalService(self.db, ModelInfo(
|
||||||
model_name=model_config.model_name,
|
model_name=model_config.model_name,
|
||||||
provider=model_config.provider,
|
provider=model_config.provider,
|
||||||
@@ -286,15 +267,20 @@ class MemoryPerceptualService:
|
|||||||
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||||||
opt_system_prompt = f.read()
|
opt_system_prompt = f.read()
|
||||||
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh')
|
||||||
except FileNotFoundError:
|
except FileNotFoundError as e:
|
||||||
raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND)
|
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||||
|
return None
|
||||||
messages = [
|
messages = [
|
||||||
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
{"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]},
|
||||||
{"role": RoleType.USER.value, "content": [
|
{"role": RoleType.USER.value, "content": [
|
||||||
{"type": "text", "text": "Summarize the following file"}, file_message
|
{"type": "text", "text": "Summarize the following file"}, file_message
|
||||||
]}
|
]}
|
||||||
]
|
]
|
||||||
result = await llm.ainvoke(messages)
|
try:
|
||||||
|
result = await llm.ainvoke(messages)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"Failed to generate perceptual memory: {str(e)}")
|
||||||
|
return None
|
||||||
content = result.content
|
content = result.content
|
||||||
final_output = ""
|
final_output = ""
|
||||||
if isinstance(content, list):
|
if isinstance(content, list):
|
||||||
|
|||||||
@@ -695,6 +695,37 @@ async def search_edges(end_user_id: Optional[str] = None) -> List[Dict[str, Any]
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def search_all_batch(end_user_ids: List[str]) -> Dict[str, int]:
|
||||||
|
"""批量查询多个用户的记忆数量(简化版本,只返回total)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_ids: 用户ID列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, int]: 以user_id为key的记忆数量字典
|
||||||
|
格式: {"user_id": total_count}
|
||||||
|
"""
|
||||||
|
if not end_user_ids:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
result = await _neo4j_connector.execute_query(
|
||||||
|
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||||
|
end_user_ids=end_user_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换结果为字典格式,字典格式在查询中无需遍历结果集,直接返回
|
||||||
|
data = {}
|
||||||
|
for row in result:
|
||||||
|
data[row["user_id"]] = row["total"]
|
||||||
|
|
||||||
|
# 为没有数据的用户填充默认值,转换字典格式还为无数据填充默认值
|
||||||
|
for user_id in end_user_ids:
|
||||||
|
if user_id not in data:
|
||||||
|
data[user_id] = 0
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
async def analytics_hot_memory_tags(
|
async def analytics_hot_memory_tags(
|
||||||
db: Session,
|
db: Session,
|
||||||
current_user: User,
|
current_user: User,
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ class ModelConfigService:
|
|||||||
return items
|
return items
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
def get_model_by_name(db: Session, name: str, provider: str | None = None,
|
||||||
|
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||||
"""根据名称获取模型配置"""
|
"""根据名称获取模型配置"""
|
||||||
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id)
|
||||||
if not model:
|
if not model:
|
||||||
@@ -77,21 +78,22 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]:
|
def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[
|
||||||
|
ModelConfig]:
|
||||||
"""按名称模糊匹配获取模型配置列表"""
|
"""按名称模糊匹配获取模型配置列表"""
|
||||||
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def validate_model_config(
|
async def validate_model_config(
|
||||||
db: Session,
|
db: Session,
|
||||||
*,
|
*,
|
||||||
model_name: str,
|
model_name: str,
|
||||||
provider: str,
|
provider: str,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
model_type: str = "llm",
|
model_type: str = "llm",
|
||||||
test_message: str = "Hello",
|
test_message: str = "Hello",
|
||||||
is_omni: bool = False
|
is_omni: bool = False
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""验证模型配置是否有效
|
"""验证模型配置是否有效
|
||||||
|
|
||||||
@@ -158,13 +160,13 @@ class ModelConfigService:
|
|||||||
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
# 统一使用 RedBearEmbeddings(自动支持火山引擎多模态)
|
||||||
embedding = RedBearEmbeddings(model_config)
|
embedding = RedBearEmbeddings(model_config)
|
||||||
test_texts = [test_message, "测试文本"]
|
test_texts = [test_message, "测试文本"]
|
||||||
|
|
||||||
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
# 火山引擎使用 embed_batch,其他使用 embed_documents
|
||||||
if provider.lower() == "volcano":
|
if provider.lower() == "volcano":
|
||||||
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
vectors = await asyncio.to_thread(embedding.embed_batch, test_texts)
|
||||||
else:
|
else:
|
||||||
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
vectors = await asyncio.to_thread(embedding.embed_documents, test_texts)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -200,11 +202,11 @@ class ModelConfigService:
|
|||||||
},
|
},
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
elif model_type_lower == "image":
|
elif model_type_lower == "image":
|
||||||
# 图片生成模型验证
|
# 图片生成模型验证
|
||||||
from app.core.models.generation import RedBearImageGenerator
|
from app.core.models.generation import RedBearImageGenerator
|
||||||
|
|
||||||
generator = RedBearImageGenerator(model_config)
|
generator = RedBearImageGenerator(model_config)
|
||||||
result = await generator.agenerate(
|
result = await generator.agenerate(
|
||||||
prompt="a cute panda",
|
prompt="a cute panda",
|
||||||
@@ -212,7 +214,7 @@ class ModelConfigService:
|
|||||||
)
|
)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.info(f"成功生成图片,结果: {result}")
|
logger.info(f"成功生成图片,结果: {result}")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": "图片生成模型配置验证成功",
|
"message": "图片生成模型配置验证成功",
|
||||||
@@ -224,21 +226,21 @@ class ModelConfigService:
|
|||||||
},
|
},
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
elif model_type_lower == "video":
|
elif model_type_lower == "video":
|
||||||
# 视频生成模型验证
|
# 视频生成模型验证
|
||||||
from app.core.models.generation import RedBearVideoGenerator
|
from app.core.models.generation import RedBearVideoGenerator
|
||||||
|
|
||||||
generator = RedBearVideoGenerator(model_config)
|
generator = RedBearVideoGenerator(model_config)
|
||||||
result = await generator.agenerate(
|
result = await generator.agenerate(
|
||||||
prompt="a cute panda playing in bamboo forest",
|
prompt="a cute panda playing in bamboo forest",
|
||||||
duration=5
|
duration=5
|
||||||
)
|
)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
# 视频生成是异步任务,返回任务ID
|
# 视频生成是异步任务,返回任务ID
|
||||||
task_id = result.get("task_id") if isinstance(result, dict) else None
|
task_id = result.get("task_id") if isinstance(result, dict) else None
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"valid": True,
|
"valid": True,
|
||||||
"message": "视频生成模型配置验证成功",
|
"message": "视频生成模型配置验证成功",
|
||||||
@@ -265,7 +267,6 @@ class ModelConfigService:
|
|||||||
# 提取详细的错误信息
|
# 提取详细的错误信息
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
error_type = type(e).__name__
|
error_type = type(e).__name__
|
||||||
print("=========error_message:",error_message.lower())
|
|
||||||
# 特殊处理常见的错误类型
|
# 特殊处理常见的错误类型
|
||||||
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower():
|
||||||
# 区域/国家限制(适用于所有提供商)
|
# 区域/国家限制(适用于所有提供商)
|
||||||
@@ -354,14 +355,16 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate,
|
||||||
|
tenant_id: uuid.UUID | None = None) -> ModelConfig:
|
||||||
"""更新模型配置"""
|
"""更新模型配置"""
|
||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if model_data.name and model_data.name != existing_model.name:
|
if model_data.name and model_data.name != existing_model.name:
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id)
|
||||||
@@ -370,25 +373,27 @@ class ModelConfigService:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate,
|
||||||
|
tenant_id: uuid.UUID) -> ModelConfig:
|
||||||
"""创建组合模型"""
|
"""创建组合模型"""
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
# 验证所有 API Key 存在且类型匹配
|
# 验证所有 API Key 存在且类型匹配
|
||||||
for api_key_id in model_data.api_key_ids:
|
for api_key_id in model_data.api_key_ids:
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
# 检查 API Key 关联的模型配置类型
|
# 检查 API Key 关联的模型配置类型
|
||||||
for model_config in api_key.model_configs:
|
for model_config in api_key.model_configs:
|
||||||
# chat 和 llm 类型可以兼容
|
# chat 和 llm 类型可以兼容
|
||||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||||
config_type = model_config.type
|
config_type = model_config.type
|
||||||
request_type = model_data.type
|
request_type = model_data.type
|
||||||
|
|
||||||
if not (config_type == request_type or
|
if not (config_type == request_type or
|
||||||
(config_type in compatible_types and request_type in compatible_types)):
|
(config_type in compatible_types and request_type in compatible_types)):
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||||
@@ -399,7 +404,7 @@ class ModelConfigService:
|
|||||||
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
# f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型",
|
||||||
# BizCode.INVALID_PARAMETER
|
# BizCode.INVALID_PARAMETER
|
||||||
# )
|
# )
|
||||||
|
|
||||||
# 创建组合模型
|
# 创建组合模型
|
||||||
model_config_data = {
|
model_config_data = {
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
@@ -418,49 +423,51 @@ class ModelConfigService:
|
|||||||
|
|
||||||
model = ModelConfigRepository.create(db, model_config_data)
|
model = ModelConfigRepository.create(db, model_config_data)
|
||||||
db.flush()
|
db.flush()
|
||||||
|
|
||||||
# 关联 API Keys
|
# 关联 API Keys
|
||||||
for api_key_id in model_data.api_key_ids:
|
for api_key_id in model_data.api_key_ids:
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
if api_key:
|
if api_key:
|
||||||
model.api_keys.append(api_key)
|
model.api_keys.append(api_key)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(model)
|
db.refresh(model)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig:
|
async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate,
|
||||||
|
tenant_id: uuid.UUID) -> ModelConfig:
|
||||||
"""更新组合模型"""
|
"""更新组合模型"""
|
||||||
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id)
|
||||||
if not existing_model:
|
if not existing_model:
|
||||||
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if model_data.name and model_data.name != existing_model.name:
|
if model_data.name and model_data.name != existing_model.name:
|
||||||
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id):
|
if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider,
|
||||||
|
tenant_id=tenant_id):
|
||||||
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
if not existing_model.is_composite:
|
if not existing_model.is_composite:
|
||||||
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER)
|
||||||
|
|
||||||
# 验证所有 API Key 存在且类型匹配
|
# 验证所有 API Key 存在且类型匹配
|
||||||
for api_key_id in model_data.api_key_ids:
|
for api_key_id in model_data.api_key_ids:
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
if not api_key:
|
if not api_key:
|
||||||
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND)
|
||||||
|
|
||||||
for model_config in api_key.model_configs:
|
for model_config in api_key.model_configs:
|
||||||
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
compatible_types = {ModelType.LLM, ModelType.CHAT}
|
||||||
config_type = model_config.type
|
config_type = model_config.type
|
||||||
request_type = existing_model.type
|
request_type = existing_model.type
|
||||||
|
|
||||||
if not (config_type == request_type or
|
if not (config_type == request_type or
|
||||||
(config_type in compatible_types and request_type in compatible_types)):
|
(config_type in compatible_types and request_type in compatible_types)):
|
||||||
raise BusinessException(
|
raise BusinessException(
|
||||||
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配",
|
||||||
BizCode.INVALID_PARAMETER
|
BizCode.INVALID_PARAMETER
|
||||||
)
|
)
|
||||||
|
|
||||||
# 更新基本信息
|
# 更新基本信息
|
||||||
existing_model.name = model_data.name
|
existing_model.name = model_data.name
|
||||||
# existing_model.type = model_data.type
|
# existing_model.type = model_data.type
|
||||||
@@ -471,14 +478,14 @@ class ModelConfigService:
|
|||||||
existing_model.is_public = model_data.is_public
|
existing_model.is_public = model_data.is_public
|
||||||
if "load_balance_strategy" in model_data.model_fields_set:
|
if "load_balance_strategy" in model_data.model_fields_set:
|
||||||
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
existing_model.load_balance_strategy = model_data.load_balance_strategy
|
||||||
|
|
||||||
# 更新 API Keys 关联
|
# 更新 API Keys 关联
|
||||||
existing_model.api_keys.clear()
|
existing_model.api_keys.clear()
|
||||||
for api_key_id in model_data.api_key_ids:
|
for api_key_id in model_data.api_key_ids:
|
||||||
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
api_key = ModelApiKeyRepository.get_by_id(db, api_key_id)
|
||||||
if api_key:
|
if api_key:
|
||||||
existing_model.api_keys.append(api_key)
|
existing_model.api_keys.append(api_key)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_model)
|
db.refresh(existing_model)
|
||||||
return existing_model
|
return existing_model
|
||||||
@@ -532,7 +539,7 @@ class ModelApiKeyService:
|
|||||||
"""根据provider为多个ModelConfig创建API Key"""
|
"""根据provider为多个ModelConfig创建API Key"""
|
||||||
created_keys = []
|
created_keys = []
|
||||||
failed_models = [] # 记录验证失败的模型
|
failed_models = [] # 记录验证失败的模型
|
||||||
|
|
||||||
for model_config_id in data.model_config_ids:
|
for model_config_id in data.model_config_ids:
|
||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
@@ -540,10 +547,10 @@ class ModelApiKeyService:
|
|||||||
|
|
||||||
data.is_omni = model_config.is_omni
|
data.is_omni = model_config.is_omni
|
||||||
data.capability = model_config.capability
|
data.capability = model_config.capability
|
||||||
|
|
||||||
# 从ModelBase获取model_name
|
# 从ModelBase获取model_name
|
||||||
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
model_name = model_config.model_base.name if model_config.model_base else model_config.name
|
||||||
|
|
||||||
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
# 检查是否存在API Key(包括软删除),需要考虑tenant_id
|
||||||
existing_key = db.query(ModelApiKey).join(
|
existing_key = db.query(ModelApiKey).join(
|
||||||
ModelApiKey.model_configs
|
ModelApiKey.model_configs
|
||||||
@@ -553,7 +560,7 @@ class ModelApiKeyService:
|
|||||||
ModelApiKey.model_name == model_name,
|
ModelApiKey.model_name == model_name,
|
||||||
ModelConfig.tenant_id == model_config.tenant_id
|
ModelConfig.tenant_id == model_config.tenant_id
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if existing_key:
|
if existing_key:
|
||||||
# 如果已存在,重新激活并更新
|
# 如果已存在,重新激活并更新
|
||||||
if existing_key.is_active:
|
if existing_key.is_active:
|
||||||
@@ -566,14 +573,14 @@ class ModelApiKeyService:
|
|||||||
existing_key.model_name = model_name
|
existing_key.model_name = model_name
|
||||||
existing_key.capability = data.capability
|
existing_key.capability = data.capability
|
||||||
existing_key.is_omni = data.is_omni
|
existing_key.is_omni = data.is_omni
|
||||||
|
|
||||||
# 检查是否已关联该模型配置
|
# 检查是否已关联该模型配置
|
||||||
if model_config not in existing_key.model_configs:
|
if model_config not in existing_key.model_configs:
|
||||||
existing_key.model_configs.append(model_config)
|
existing_key.model_configs.append(model_config)
|
||||||
|
|
||||||
created_keys.append(existing_key)
|
created_keys.append(existing_key)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 验证配置
|
# 验证配置
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -589,7 +596,7 @@ class ModelApiKeyService:
|
|||||||
# 记录验证失败的模型,但不抛出异常
|
# 记录验证失败的模型,但不抛出异常
|
||||||
failed_models.append(model_name)
|
failed_models.append(model_name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 创建API Key
|
# 创建API Key
|
||||||
api_key_data = ModelApiKeyCreate(
|
api_key_data = ModelApiKeyCreate(
|
||||||
model_config_ids=[model_config_id],
|
model_config_ids=[model_config_id],
|
||||||
@@ -606,12 +613,12 @@ class ModelApiKeyService:
|
|||||||
)
|
)
|
||||||
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
api_key_obj = ModelApiKeyRepository.create(db, api_key_data)
|
||||||
created_keys.append(api_key_obj)
|
created_keys.append(api_key_obj)
|
||||||
|
|
||||||
if created_keys:
|
if created_keys:
|
||||||
db.commit()
|
db.commit()
|
||||||
for key in created_keys:
|
for key in created_keys:
|
||||||
db.refresh(key)
|
db.refresh(key)
|
||||||
|
|
||||||
return created_keys, failed_models
|
return created_keys, failed_models
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -626,7 +633,7 @@ class ModelApiKeyService:
|
|||||||
api_key_data.is_omni = model_config.is_omni
|
api_key_data.is_omni = model_config.is_omni
|
||||||
if api_key_data.capability is None:
|
if api_key_data.capability is None:
|
||||||
api_key_data.capability = model_config.capability
|
api_key_data.capability = model_config.capability
|
||||||
|
|
||||||
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
# 检查API Key是否已存在(包括软删除),需要考虑tenant_id
|
||||||
existing_key = db.query(ModelApiKey).join(
|
existing_key = db.query(ModelApiKey).join(
|
||||||
ModelApiKey.model_configs
|
ModelApiKey.model_configs
|
||||||
@@ -650,15 +657,15 @@ class ModelApiKeyService:
|
|||||||
existing_key.model_name = api_key_data.model_name
|
existing_key.model_name = api_key_data.model_name
|
||||||
existing_key.capability = api_key_data.capability
|
existing_key.capability = api_key_data.capability
|
||||||
existing_key.is_omni = api_key_data.is_omni
|
existing_key.is_omni = api_key_data.is_omni
|
||||||
|
|
||||||
# 检查是否已关联该模型配置
|
# 检查是否已关联该模型配置
|
||||||
if model_config not in existing_key.model_configs:
|
if model_config not in existing_key.model_configs:
|
||||||
existing_key.model_configs.append(model_config)
|
existing_key.model_configs.append(model_config)
|
||||||
|
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(existing_key)
|
db.refresh(existing_key)
|
||||||
return existing_key
|
return existing_key
|
||||||
|
|
||||||
# 验证配置
|
# 验证配置
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
@@ -691,7 +698,7 @@ class ModelApiKeyService:
|
|||||||
# 获取关联的模型配置以获取模型类型
|
# 获取关联的模型配置以获取模型类型
|
||||||
if existing_api_key.model_configs:
|
if existing_api_key.model_configs:
|
||||||
model_config = existing_api_key.model_configs[0]
|
model_config = existing_api_key.model_configs[0]
|
||||||
|
|
||||||
validation_result = await ModelConfigService.validate_model_config(
|
validation_result = await ModelConfigService.validate_model_config(
|
||||||
db=db,
|
db=db,
|
||||||
model_name=api_key_data.model_name or existing_api_key.model_name,
|
model_name=api_key_data.model_name or existing_api_key.model_name,
|
||||||
@@ -729,15 +736,15 @@ class ModelApiKeyService:
|
|||||||
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
model_config = ModelConfigRepository.get_by_id(db, model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
api_keys = [key for key in model_config.api_keys if key.is_active]
|
api_keys = [key for key in model_config.api_keys if key.is_active]
|
||||||
if not api_keys:
|
if not api_keys:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
# 如果是轮询策略,按使用次数最少,次数相同则选最早使用的
|
||||||
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN:
|
||||||
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min))
|
||||||
|
|
||||||
# 否则返回第一个
|
# 否则返回第一个
|
||||||
return api_keys[0]
|
return api_keys[0]
|
||||||
|
|
||||||
@@ -760,20 +767,19 @@ class ModelApiKeyService:
|
|||||||
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelBaseService:
|
class ModelBaseService:
|
||||||
"""基础模型服务"""
|
"""基础模型服务"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List:
|
||||||
models = ModelBaseRepository.get_list(db, query)
|
models = ModelBaseRepository.get_list(db, query)
|
||||||
|
|
||||||
provider_groups = {}
|
provider_groups = {}
|
||||||
for m in models:
|
for m in models:
|
||||||
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
model_dict = model_schema.ModelBase.model_validate(m).model_dump()
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id)
|
||||||
|
|
||||||
provider = m.provider
|
provider = m.provider
|
||||||
if provider not in provider_groups:
|
if provider not in provider_groups:
|
||||||
provider_groups[provider] = {
|
provider_groups[provider] = {
|
||||||
@@ -781,7 +787,7 @@ class ModelBaseService:
|
|||||||
"models": []
|
"models": []
|
||||||
}
|
}
|
||||||
provider_groups[provider]["models"].append(model_dict)
|
provider_groups[provider]["models"].append(model_dict)
|
||||||
|
|
||||||
return list(provider_groups.values())
|
return list(provider_groups.values())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -823,10 +829,10 @@ class ModelBaseService:
|
|||||||
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
model_base = ModelBaseRepository.get_by_id(db, model_base_id)
|
||||||
if not model_base:
|
if not model_base:
|
||||||
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND)
|
||||||
|
|
||||||
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id):
|
||||||
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME)
|
||||||
|
|
||||||
model_config_data = {
|
model_config_data = {
|
||||||
"model_id": model_base_id,
|
"model_id": model_base_id,
|
||||||
"tenant_id": tenant_id,
|
"tenant_id": tenant_id,
|
||||||
|
|||||||
@@ -12,6 +12,9 @@ import base64
|
|||||||
import csv
|
import csv
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
import olefile
|
||||||
|
import struct
|
||||||
import zipfile
|
import zipfile
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
@@ -438,13 +441,13 @@ class MultimodalService:
|
|||||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
return True, {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
"text": f"<document url=\"{file.url}\">\n{await self.extract_document_text(file)}\n</document>"
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# 本地文件,提取文本内容
|
# 本地文件,提取文本内容
|
||||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
file.url = f"{server_url}/storage/permanent/{file.upload_file_id}"
|
||||||
text = await self._extract_document_text(file)
|
text = await self.extract_document_text(file)
|
||||||
file_metadata = self.db.query(FileMetadata).filter(
|
file_metadata = self.db.query(FileMetadata).filter(
|
||||||
FileMetadata.id == file.upload_file_id
|
FileMetadata.id == file.upload_file_id
|
||||||
).first()
|
).first()
|
||||||
@@ -542,7 +545,7 @@ class MultimodalService:
|
|||||||
server_url = settings.FILE_LOCAL_SERVER_URL
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
return f"{server_url}/storage/permanent/{file_id}"
|
return f"{server_url}/storage/permanent/{file_id}"
|
||||||
|
|
||||||
async def _extract_document_text(self, file: FileInput) -> str:
|
async def extract_document_text(self, file: FileInput) -> str:
|
||||||
"""
|
"""
|
||||||
提取文档文本内容
|
提取文档文本内容
|
||||||
|
|
||||||
@@ -602,31 +605,75 @@ class MultimodalService:
|
|||||||
try:
|
try:
|
||||||
word_file = io.BytesIO(file_content)
|
word_file = io.BytesIO(file_content)
|
||||||
doc = Document(word_file)
|
doc = Document(word_file)
|
||||||
return '\n'.join(p.text for p in doc.paragraphs)
|
text_lines = []
|
||||||
|
for p in doc.paragraphs:
|
||||||
|
text = p.text.strip()
|
||||||
|
if text:
|
||||||
|
text_lines.append(text)
|
||||||
|
|
||||||
|
for table in doc.tables:
|
||||||
|
for row in table.rows:
|
||||||
|
for cell in row.cells:
|
||||||
|
text = cell.text.strip()
|
||||||
|
if text:
|
||||||
|
text_lines.append(text)
|
||||||
|
|
||||||
|
full_text = "\n".join(text_lines)
|
||||||
|
return full_text.strip() or "[docx 文件无文本内容]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 docx 文本失败: {e}")
|
logger.error(f"提取 docx 文本失败: {str(e)}", exc_info=True)
|
||||||
return f"[docx 提取失败: {str(e)}]"
|
return f"[docx 提取失败: {str(e)}]"
|
||||||
|
|
||||||
# 旧版 .doc(OLE2 格式)
|
# 旧版 .doc(OLE2/CFB 格式),按 Word Binary Format 规范解析 piece table
|
||||||
try:
|
try:
|
||||||
import olefile
|
|
||||||
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
ole = olefile.OleFileIO(io.BytesIO(file_content))
|
||||||
if not ole.exists('WordDocument'):
|
word_stream = ole.openstream('WordDocument').read()
|
||||||
return "[doc 提取失败: 未找到 WordDocument 流]"
|
|
||||||
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
|
# FIB offset 0xA bit9 决定使用 0Table 还是 1Table
|
||||||
stream = ole.openstream('WordDocument').read()
|
fib_flags = struct.unpack_from('<H', word_stream, 0xA)[0]
|
||||||
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
|
table_name = '1Table' if (fib_flags & 0x0200) else '0Table'
|
||||||
# 简单提取:过滤出可打印字符段
|
table_stream = ole.openstream(table_name).read()
|
||||||
try:
|
|
||||||
text = stream.decode('utf-16-le', errors='ignore')
|
# 从 FIB 读取 fcClx/lcbClx 定位 piece table
|
||||||
except Exception:
|
fc_clx, lcb_clx = struct.unpack_from("<II", word_stream, 0x1A2)
|
||||||
text = stream.decode('latin-1', errors='ignore')
|
clx = table_stream[fc_clx: fc_clx + lcb_clx]
|
||||||
# 过滤控制字符,保留可打印内容
|
|
||||||
import re
|
# 解析 CLX,找到 PlcPcd(piece table)
|
||||||
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
|
i, plc_pcd = 0, None
|
||||||
text = re.sub(r' +', ' ', text).strip()
|
while i < len(clx):
|
||||||
|
clxt = clx[i]
|
||||||
|
if clxt == 0x01:
|
||||||
|
i += 3 + struct.unpack_from('<H', clx, i + 1)[0]
|
||||||
|
elif clxt == 0x02:
|
||||||
|
cb = struct.unpack_from('<I', clx, i + 1)[0]
|
||||||
|
plc_pcd = clx[i + 5: i + 5 + cb]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
if plc_pcd is None:
|
||||||
|
raise ValueError("PlcPcd not found")
|
||||||
|
|
||||||
|
# PlcPcd: (n+1) 个 CP(4字节)+ n 个 PCD(8字节)
|
||||||
|
n_pieces = (len(plc_pcd) - 4) // 12
|
||||||
|
cp_array = [struct.unpack_from('<I', plc_pcd, k * 4)[0] for k in range(n_pieces + 1)]
|
||||||
|
|
||||||
|
parts = []
|
||||||
|
for k in range(n_pieces):
|
||||||
|
fc_value = struct.unpack_from('<I', plc_pcd, (n_pieces + 1) * 4 + k * 8 + 2)[0]
|
||||||
|
is_ansi = bool(fc_value & 0x40000000)
|
||||||
|
fc = fc_value & 0x3FFFFFFF
|
||||||
|
char_count = cp_array[k + 1] - cp_array[k]
|
||||||
|
|
||||||
|
if is_ansi:
|
||||||
|
parts.append(word_stream[fc: fc + char_count].decode('cp1252', errors='replace'))
|
||||||
|
else:
|
||||||
|
parts.append(word_stream[fc: fc + char_count * 2].decode('utf-16-le', errors='replace'))
|
||||||
|
|
||||||
ole.close()
|
ole.close()
|
||||||
return text
|
result = re.sub(r'[\x00-\x1f\x7f]', '', ''.join(parts))
|
||||||
|
return result.strip()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取 doc 文本失败: {e}")
|
logger.error(f"提取 doc 文本失败: {e}")
|
||||||
return f"[doc 提取失败: {str(e)}]"
|
return f"[doc 提取失败: {str(e)}]"
|
||||||
|
|||||||
@@ -1,26 +1,24 @@
|
|||||||
"""基于分享链接的聊天服务"""
|
"""基于分享链接的聊天服务"""
|
||||||
import uuid
|
|
||||||
import time
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Optional, Dict, Any, AsyncGenerator
|
from typing import Optional, Dict, Any, AsyncGenerator
|
||||||
|
|
||||||
|
from deprecated import deprecated
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.repositories.model_repository import ModelApiKeyRepository
|
from app.core.error_codes import BizCode
|
||||||
from app.services.memory_konwledges_server import write_rag
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.models import MultiAgentConfig
|
||||||
from app.models import ReleaseShare, AppRelease, Conversation
|
from app.models import ReleaseShare, AppRelease, Conversation
|
||||||
|
from app.repositories import knowledge_repository
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.draft_run_service import create_web_search_tool
|
from app.services.draft_run_service import create_web_search_tool
|
||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
|
||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
|
||||||
from app.core.error_codes import BizCode
|
|
||||||
from app.core.logging_config import get_business_logger
|
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import MultiAgentConfig
|
from app.services.release_share_service import ReleaseShareService
|
||||||
from app.repositories import knowledge_repository
|
|
||||||
import json
|
|
||||||
from app.services.task_service import get_task_memory_write_result
|
|
||||||
from app.tasks import write_message_task
|
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -118,6 +116,7 @@ class SharedChatService:
|
|||||||
|
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -136,10 +135,7 @@ class SharedChatService:
|
|||||||
config_id = actual_config_id
|
config_id = actual_config_id
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
from app.services.model_parameter_merger import ModelParameterMerger
|
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from sqlalchemy import select
|
|
||||||
from app.models import ModelApiKey
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
actual_config_id = None
|
actual_config_id = None
|
||||||
@@ -273,11 +269,6 @@ class SharedChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
@@ -324,6 +315,7 @@ class SharedChatService:
|
|||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def chat_stream(
|
async def chat_stream(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -341,8 +333,6 @@ class SharedChatService:
|
|||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool
|
||||||
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole
|
||||||
from sqlalchemy import select
|
|
||||||
from app.models import ModelApiKey
|
|
||||||
import json
|
import json
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -486,11 +476,6 @@ class SharedChatService:
|
|||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
context=None,
|
context=None,
|
||||||
end_user_id=user_id,
|
|
||||||
storage_type=storage_type,
|
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
|
||||||
config_id=config_id,
|
|
||||||
memory_flag=memory_flag
|
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
@@ -585,6 +570,7 @@ class SharedChatService:
|
|||||||
|
|
||||||
return conversations, total
|
return conversations, total
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def multi_agent_chat(
|
async def multi_agent_chat(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
@@ -680,6 +666,7 @@ class SharedChatService:
|
|||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@deprecated("Use the chat method under app_chat_service instead.")
|
||||||
async def multi_agent_chat_stream(
|
async def multi_agent_chat_stream(
|
||||||
self,
|
self,
|
||||||
share_token: str,
|
share_token: str,
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ class TenantService:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
business_logger.error(f"删除租户失败: {str(e)}")
|
business_logger.error(f"删除租户失败: {str(e)}")
|
||||||
raise BusinessException(f"删除租户失败: {str(e)}", code=BizCode.DB_ERROR)
|
raise BusinessException(f"删除租户失败:{str(e)}", code=BizCode.DB_ERROR)
|
||||||
|
|
||||||
# 租户用户管理
|
# 租户用户管理
|
||||||
def get_tenant_users(
|
def get_tenant_users(
|
||||||
@@ -147,6 +147,7 @@ class TenantService:
|
|||||||
skip: int = 0,
|
skip: int = 0,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
is_active: Optional[bool] = None,
|
is_active: Optional[bool] = None,
|
||||||
|
is_superuser: Optional[bool] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None
|
||||||
) -> List[UserModel]:
|
) -> List[UserModel]:
|
||||||
"""获取租户下的用户列表"""
|
"""获取租户下的用户列表"""
|
||||||
@@ -155,6 +156,7 @@ class TenantService:
|
|||||||
skip=skip,
|
skip=skip,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
|
is_superuser=is_superuser,
|
||||||
search=search
|
search=search
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,12 +164,14 @@ class TenantService:
|
|||||||
self,
|
self,
|
||||||
tenant_id: uuid.UUID,
|
tenant_id: uuid.UUID,
|
||||||
is_active: Optional[bool] = None,
|
is_active: Optional[bool] = None,
|
||||||
|
is_superuser: Optional[bool] = None,
|
||||||
search: Optional[str] = None
|
search: Optional[str] = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""统计租户下的用户数量"""
|
"""统计租户下的用户数量"""
|
||||||
return self.user_repo.count_users_by_tenant(
|
return self.user_repo.count_users_by_tenant(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
is_active=is_active,
|
is_active=is_active,
|
||||||
|
is_superuser=is_superuser,
|
||||||
search=search
|
search=search
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -472,6 +472,21 @@ class UserMemoryService:
|
|||||||
# 定义允许更新的字段白名单
|
# 定义允许更新的字段白名单
|
||||||
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
allowed_fields = {'other_name', 'aliases', 'meta_data'}
|
||||||
|
|
||||||
|
# 用户占位名称黑名单,不允许作为 other_name 或出现在 aliases 中
|
||||||
|
_user_placeholder_names = {'用户', '我', 'User', 'I'}
|
||||||
|
|
||||||
|
# 过滤 other_name:不允许设置为占位名称
|
||||||
|
if 'other_name' in update_data and update_data['other_name'] and update_data['other_name'].strip() in _user_placeholder_names:
|
||||||
|
logger.warning(f"拒绝将占位名称 '{update_data['other_name']}' 设置为 other_name")
|
||||||
|
del update_data['other_name']
|
||||||
|
|
||||||
|
# 过滤 aliases:移除占位名称和非字符串值
|
||||||
|
if 'aliases' in update_data and update_data['aliases']:
|
||||||
|
update_data['aliases'] = [
|
||||||
|
a for a in update_data['aliases']
|
||||||
|
if isinstance(a, str) and a.strip() and a.strip() not in _user_placeholder_names
|
||||||
|
]
|
||||||
|
|
||||||
# 检查是否更新了 aliases 字段
|
# 检查是否更新了 aliases 字段
|
||||||
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
|
aliases_updated = 'aliases' in update_data and update_data['aliases'] != end_user_info_record.aliases
|
||||||
|
|
||||||
|
|||||||
@@ -561,6 +561,24 @@ class WorkflowService:
|
|||||||
storage_type = 'neo4j'
|
storage_type = 'neo4j'
|
||||||
return storage_type, user_rag_memory_id
|
return storage_type, user_rag_memory_id
|
||||||
|
|
||||||
|
def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None:
|
||||||
|
executions = self.execution_repo.get_by_conversation_id(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
status="completed",
|
||||||
|
limit_count=1
|
||||||
|
)
|
||||||
|
|
||||||
|
if executions:
|
||||||
|
last_state = executions[0].output_data
|
||||||
|
if isinstance(last_state, dict):
|
||||||
|
variables = last_state.get("variables", {})
|
||||||
|
conv_vars = variables.get("conv", {})
|
||||||
|
# input_data["conv"] = conv_vars
|
||||||
|
# input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
|
conv_messages = last_state.get("messages") or []
|
||||||
|
return conv_vars, conv_messages
|
||||||
|
return None
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -634,18 +652,11 @@ class WorkflowService:
|
|||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
|
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
history = self._get_history_info(conversation_id_uuid)
|
||||||
|
if history:
|
||||||
for exec_res in executions:
|
conv_vars, conv_messages = history
|
||||||
if exec_res.status == "completed":
|
input_data["conv"] = conv_vars
|
||||||
last_state = exec_res.output_data
|
input_data["conv_messages"] = conv_messages
|
||||||
if isinstance(last_state, dict):
|
|
||||||
variables = last_state.get("variables", {})
|
|
||||||
conv_vars = variables.get("conv", {})
|
|
||||||
input_data["conv"] = conv_vars
|
|
||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
|
||||||
break
|
|
||||||
|
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
|
|
||||||
result = await execute_workflow(
|
result = await execute_workflow(
|
||||||
@@ -807,17 +818,11 @@ class WorkflowService:
|
|||||||
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id)
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
history = self._get_history_info(conversation_id_uuid)
|
||||||
|
if history:
|
||||||
for exec_res in executions:
|
conv_vars, conv_messages = history
|
||||||
if exec_res.status == "completed":
|
input_data["conv"] = conv_vars
|
||||||
last_state = exec_res.output_data
|
input_data["conv_messages"] = conv_messages
|
||||||
if isinstance(last_state, dict):
|
|
||||||
variables = last_state.get("variables", {})
|
|
||||||
conv_vars = variables.get("conv", {})
|
|
||||||
input_data["conv"] = conv_vars
|
|
||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
|
||||||
break
|
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import hashlib
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -38,12 +37,10 @@ from app.db import get_db, get_db_context
|
|||||||
from app.models import Document, File, Knowledge
|
from app.models import Document, File, Knowledge
|
||||||
from app.models.end_user_model import EndUser
|
from app.models.end_user_model import EndUser
|
||||||
from app.schemas import document_schema, file_schema
|
from app.schemas import document_schema, file_schema
|
||||||
from app.schemas.model_schema import ModelInfo
|
|
||||||
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config
|
||||||
from app.services.memory_forget_service import MemoryForgetService
|
from app.services.memory_forget_service import MemoryForgetService
|
||||||
from app.services.memory_perceptual_service import MemoryPerceptualService
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
from app.utils.redis_lock import RedisLock
|
from app.utils.redis_lock import RedisFairLock
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -104,7 +101,12 @@ def get_sync_redis_client() -> Optional[redis.StrictRedis]:
|
|||||||
|
|
||||||
|
|
||||||
def set_asyncio_event_loop():
|
def set_asyncio_event_loop():
|
||||||
"""Set the asyncio event loop for the current thread."""
|
"""Ensure an open asyncio event loop exists for the current thread.
|
||||||
|
|
||||||
|
Reuses the existing event loop if one is available and still open.
|
||||||
|
Creates and installs a new event loop only when the current one is
|
||||||
|
closed or missing (e.g. after ``_shutdown_loop_gracefully``).
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
if loop.is_closed():
|
if loop.is_closed():
|
||||||
@@ -116,6 +118,30 @@ def set_asyncio_event_loop():
|
|||||||
return loop
|
return loop
|
||||||
|
|
||||||
|
|
||||||
|
def _shutdown_loop_gracefully(loop: asyncio.AbstractEventLoop):
|
||||||
|
"""Gracefully shutdown pending async generators and tasks on the event loop.
|
||||||
|
|
||||||
|
This prevents 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||||
|
by giving pending aclose() coroutines a chance to run before the loop is discarded.
|
||||||
|
|
||||||
|
Note: This only tears down the given loop. Callers that need a fresh event
|
||||||
|
loop afterwards should use ``set_asyncio_event_loop()`` explicitly.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Cancel and collect all remaining tasks
|
||||||
|
all_tasks = asyncio.all_tasks(loop)
|
||||||
|
if all_tasks:
|
||||||
|
for task in all_tasks:
|
||||||
|
task.cancel()
|
||||||
|
loop.run_until_complete(asyncio.gather(*all_tasks, return_exceptions=True))
|
||||||
|
# Shutdown async generators (triggers __aclose__ on httpx clients etc.)
|
||||||
|
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
@celery_app.task(name="tasks.process_item")
|
@celery_app.task(name="tasks.process_item")
|
||||||
def process_item(item: dict):
|
def process_item(item: dict):
|
||||||
"""
|
"""
|
||||||
@@ -1148,8 +1174,28 @@ def write_message_task(
|
|||||||
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
logger.info(f"[CELERY WRITE] Write completed successfully: {result}")
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
redis_client = get_sync_redis_client()
|
||||||
|
lock = None
|
||||||
|
if redis_client is not None:
|
||||||
|
lock = RedisFairLock(
|
||||||
|
key=f"memory_write:{end_user_id}",
|
||||||
|
redis_client=redis_client,
|
||||||
|
expire=600,
|
||||||
|
timeout=3600,
|
||||||
|
auto_renewal=True,
|
||||||
|
)
|
||||||
|
if not lock.acquire():
|
||||||
|
logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}")
|
||||||
|
return {
|
||||||
|
"status": "SKIPPED",
|
||||||
|
"error": "acquire lock timeout",
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"config_id": str(config_id),
|
||||||
|
"elapsed_time": time.time() - start_time,
|
||||||
|
"task_id": self.request.id,
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
|
||||||
loop = set_asyncio_event_loop()
|
loop = set_asyncio_event_loop()
|
||||||
|
|
||||||
result = loop.run_until_complete(_run())
|
result = loop.run_until_complete(_run())
|
||||||
@@ -1158,7 +1204,6 @@ def write_message_task(
|
|||||||
logger.info(f"[CELERY WRITE] Task completed successfully "
|
logger.info(f"[CELERY WRITE] Task completed successfully "
|
||||||
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}")
|
||||||
|
|
||||||
# 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用
|
|
||||||
try:
|
try:
|
||||||
_r = get_sync_redis_client()
|
_r = get_sync_redis_client()
|
||||||
if _r is not None:
|
if _r is not None:
|
||||||
@@ -1199,6 +1244,15 @@ def write_message_task(
|
|||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"task_id": self.request.id
|
"task_id": self.request.id
|
||||||
}
|
}
|
||||||
|
finally:
|
||||||
|
if lock is not None:
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[CELERY WRITE] 释放锁失败: {e}")
|
||||||
|
# Gracefully shutdown the event loop to prevent
|
||||||
|
# 'RuntimeError: Event loop is closed' from httpx.AsyncClient.__del__
|
||||||
|
_shutdown_loop_gracefully(loop)
|
||||||
|
|
||||||
|
|
||||||
# unused task
|
# unused task
|
||||||
@@ -2879,3 +2933,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
|
|||||||
"elapsed_time": time.time() - start_time,
|
"elapsed_time": time.time() - start_time,
|
||||||
"task_id": self.request.id,
|
"task_id": self.request.id,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# unused task
|
||||||
37
api/app/utils/performance_timer.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
性能监控工具模块
|
||||||
|
|
||||||
|
提供代码块执行时间统计功能,用于接口性能分析。
|
||||||
|
如需再次启用性能监控,只需在 controller 中导入 from app.utils.performance_timer import timer 并添加 with timer(...) 包裹需要监控的代码块即可
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
|
||||||
|
# 获取API专用日志器
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timer(label: str, user_count: int = 0):
|
||||||
|
"""上下文管理器:用于测量代码块执行时间
|
||||||
|
|
||||||
|
Args:
|
||||||
|
label: 统计标签,用于标识被测量的代码块
|
||||||
|
user_count: 用户数,可选参数,用于记录处理的用户数量
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
with timer("获取用户列表"):
|
||||||
|
users = get_users()
|
||||||
|
|
||||||
|
with timer("批量处理", user_count=len(user_ids)):
|
||||||
|
process_users(user_ids)
|
||||||
|
"""
|
||||||
|
start = time.perf_counter()
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
|
||||||
|
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
|
||||||
|
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import redis
|
import redis
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import time
|
||||||
|
import threading
|
||||||
|
|
||||||
UNLOCK_SCRIPT = """
|
UNLOCK_SCRIPT = """
|
||||||
if redis.call("get", KEYS[1]) == ARGV[1] then
|
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||||
@@ -10,45 +11,136 @@ else
|
|||||||
end
|
end
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
RENEW_SCRIPT = """
|
||||||
|
if redis.call("get", KEYS[1]) == ARGV[1] then
|
||||||
|
return redis.call("expire", KEYS[1], ARGV[2])
|
||||||
|
else
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
"""
|
||||||
|
|
||||||
class RedisLock:
|
CLEANUP_DEAD_HEAD_SCRIPT = """
|
||||||
|
local queue_key = KEYS[1]
|
||||||
|
local lock_key = KEYS[2]
|
||||||
|
|
||||||
|
local first = redis.call("lindex", queue_key, 0)
|
||||||
|
if not first then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
if redis.call("exists", lock_key) == 1 then
|
||||||
|
return 0
|
||||||
|
end
|
||||||
|
|
||||||
|
redis.call("lpop", queue_key)
|
||||||
|
return 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
SAFE_RELEASE_QUEUE_SCRIPT = """
|
||||||
|
local queue_key = KEYS[1]
|
||||||
|
local value = ARGV[1]
|
||||||
|
|
||||||
|
local first = redis.call("lindex", queue_key, 0)
|
||||||
|
if first == value then
|
||||||
|
redis.call("lpop", queue_key)
|
||||||
|
return 1
|
||||||
|
end
|
||||||
|
return 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_str(val):
|
||||||
|
"""统一将 Redis 返回值转为 str,兼容 decode_responses=True/False"""
|
||||||
|
if val is None:
|
||||||
|
return None
|
||||||
|
if isinstance(val, bytes):
|
||||||
|
return val.decode("utf-8")
|
||||||
|
return str(val)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisFairLock:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
redis_client: redis.StrictRedis,
|
redis_client: redis.StrictRedis,
|
||||||
expire: int = 60,
|
expire: int = 30,
|
||||||
retry_interval: float = 0.1,
|
retry_interval: float = 0.05,
|
||||||
timeout: float = 30
|
timeout: float = 600,
|
||||||
|
auto_renewal: bool = True
|
||||||
):
|
):
|
||||||
self.key = key
|
self.key = key
|
||||||
self.expire = expire
|
self.queue_key = f"{key}:queue"
|
||||||
self.value = str(uuid.uuid4())
|
self.value = str(uuid.uuid4())
|
||||||
self._locked = False
|
self.expire = expire
|
||||||
self.retry_interval = retry_interval
|
self.retry_interval = retry_interval
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.redis_client = redis_client
|
self.redis = redis_client
|
||||||
|
self._locked = False
|
||||||
|
self.auto_renewal = auto_renewal
|
||||||
|
self._renew_thread = None
|
||||||
|
self._stop_renew = threading.Event()
|
||||||
|
|
||||||
def acquire(self) -> bool:
|
def acquire(self):
|
||||||
start = time.time()
|
start = time.time()
|
||||||
|
|
||||||
|
self.redis.rpush(self.queue_key, self.value)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True)
|
first = _ensure_str(self.redis.lindex(self.queue_key, 0))
|
||||||
if ok:
|
|
||||||
self._locked = True
|
if first == self.value:
|
||||||
return True
|
ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire)
|
||||||
if time.time() - start >= self.timeout:
|
if ok:
|
||||||
|
self._locked = True
|
||||||
|
|
||||||
|
if self.auto_renewal:
|
||||||
|
self._start_renewal()
|
||||||
|
return True
|
||||||
|
|
||||||
|
if first:
|
||||||
|
self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key)
|
||||||
|
|
||||||
|
if time.time() - start > self.timeout:
|
||||||
|
self.redis.lrem(self.queue_key, 0, self.value)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
time.sleep(self.retry_interval)
|
time.sleep(self.retry_interval)
|
||||||
|
|
||||||
|
def _renewal_loop(self):
|
||||||
|
while not self._stop_renew.is_set():
|
||||||
|
time.sleep(self.expire / 3)
|
||||||
|
if self._stop_renew.is_set():
|
||||||
|
break
|
||||||
|
|
||||||
|
self.redis.eval(
|
||||||
|
RENEW_SCRIPT,
|
||||||
|
1,
|
||||||
|
self.key,
|
||||||
|
self.value,
|
||||||
|
str(self.expire)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _start_renewal(self):
|
||||||
|
self._stop_renew = threading.Event()
|
||||||
|
self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True)
|
||||||
|
self._renew_thread.start()
|
||||||
|
|
||||||
|
def _stop_renewal(self):
|
||||||
|
self._stop_renew.set()
|
||||||
|
if self._renew_thread:
|
||||||
|
self._renew_thread.join(timeout=1)
|
||||||
|
|
||||||
def release(self):
|
def release(self):
|
||||||
if not self._locked:
|
if not self._locked:
|
||||||
return
|
return
|
||||||
self.redis_client.eval(
|
|
||||||
UNLOCK_SCRIPT,
|
if self.auto_renewal:
|
||||||
1,
|
self._stop_renewal()
|
||||||
self.key,
|
|
||||||
self.value
|
self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value)
|
||||||
)
|
|
||||||
|
self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value)
|
||||||
|
|
||||||
self._locked = False
|
self._locked = False
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
@@ -59,3 +151,4 @@ class RedisLock:
|
|||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
self.release()
|
self.release()
|
||||||
|
|
||||||
|
|||||||
30
api/migrations/versions/4e89970f9e7c_202603271515.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""202603271515
|
||||||
|
|
||||||
|
Revision ID: 4e89970f9e7c
|
||||||
|
Revises: 6b8a461148ff
|
||||||
|
Create Date: 2026-03-27 15:12:27.518344
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '4e89970f9e7c'
|
||||||
|
down_revision: Union[str, None] = '6b8a461148ff'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('users', sa.Column('phone', sa.String(length=50), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('users', 'phone')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -68,8 +68,8 @@ export const getModelTypeList = async () => {
|
|||||||
return response as any[];
|
return response as any[];
|
||||||
};
|
};
|
||||||
// 获取模型列表
|
// 获取模型列表
|
||||||
export const getModelList = async (pageInfo: PageRequest) => {
|
export const getModelList = async (types: string[], pageInfo: PageRequest) => {
|
||||||
const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, is_active: true });
|
const response = await request.get(`${apiPrefix}/models`, { ...pageInfo, type: types?.join(','), is_active: true });
|
||||||
return response as any;
|
return response as any;
|
||||||
};
|
};
|
||||||
//获取模型提供者
|
//获取模型提供者
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 14:00:06
|
* @Date: 2026-02-03 14:00:06
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-24 17:48:01
|
* @Last Modified time: 2026-03-31 12:25:53
|
||||||
*/
|
*/
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { AxiosRequestConfig } from 'axios'
|
import type { AxiosRequestConfig } from 'axios'
|
||||||
@@ -63,8 +63,8 @@ export const getDashboardData = () => {
|
|||||||
|
|
||||||
/****************** User Memory APIs *******************************/
|
/****************** User Memory APIs *******************************/
|
||||||
export const userMemoryListUrl = '/dashboard/end_users'
|
export const userMemoryListUrl = '/dashboard/end_users'
|
||||||
export const getUserMemoryList = () => {
|
export const getUserMemoryList = (query?: { keyword?: string }) => {
|
||||||
return request.get(userMemoryListUrl)
|
return request.get(userMemoryListUrl, query)
|
||||||
}
|
}
|
||||||
// User Memory - Total end users
|
// User Memory - Total end users
|
||||||
export const getTotalEndUsers = () => {
|
export const getTotalEndUsers = () => {
|
||||||
@@ -154,6 +154,8 @@ export const analyticsRefresh = (end_user_id: string) => {
|
|||||||
export const getForgetStats = (end_user_id: string) => {
|
export const getForgetStats = (end_user_id: string) => {
|
||||||
return request.get(`/memory/forget-memory/stats`, { end_user_id })
|
return request.get(`/memory/forget-memory/stats`, { end_user_id })
|
||||||
}
|
}
|
||||||
|
// 获取带遗忘节点列表
|
||||||
|
export const getForgetPendingNodesUrl = '/memory/forget-memory/pending-nodes'
|
||||||
// Implicit Memory - Preferences
|
// Implicit Memory - Preferences
|
||||||
export const getImplicitPreferences = (end_user_id: string) => {
|
export const getImplicitPreferences = (end_user_id: string) => {
|
||||||
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)
|
return request.get(`/memory/implicit-memory/preferences/${end_user_id}`)
|
||||||
|
|||||||
19
web/src/assets/images/common/delete_red_big.svg
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编组 33</title>
|
||||||
|
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -187)" stroke="#FF5D34">
|
||||||
|
<g id="编组-16" transform="translate(1020, 126)">
|
||||||
|
<g id="编组-33" transform="translate(12, 61)">
|
||||||
|
<g id="编组-32" transform="translate(2.5, 3)">
|
||||||
|
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
|
||||||
|
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
|
||||||
|
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
|
||||||
|
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
|
||||||
|
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
17
web/src/assets/images/common/edit_bg.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-976, -320)">
|
||||||
|
<g id="编组-13备份-6" transform="translate(648, 122)">
|
||||||
|
<g id="编辑" transform="translate(328, 198)">
|
||||||
|
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
|
||||||
|
<g id="编组-10" transform="translate(4.3, 4.3)" stroke="#5B6167">
|
||||||
|
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
|
||||||
|
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.1 KiB |
16
web/src/assets/images/common/edit_bold.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linejoin="round">
|
||||||
|
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -135)" stroke="#171719">
|
||||||
|
<g id="编组-16" transform="translate(1020, 126)">
|
||||||
|
<g id="编辑" transform="translate(12, 9)">
|
||||||
|
<g id="编组-10" transform="translate(3, 3)">
|
||||||
|
<path d="M10,4.30130765 L10,8 C10,9.1045695 9.1045695,10 8,10 L2,10 C0.8954305,10 0,9.1045695 0,8 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.73360016,0 L5.73360016,0" id="路径"></path>
|
||||||
|
<line x1="3.98149359" y1="6.04872089" x2="10" y2="1.12483439e-14" id="路径-2"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.0 KiB |
16
web/src/assets/images/common/eye.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>link-outlined</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="API-Key-管理" transform="translate(-1034, -161)" stroke="#171719">
|
||||||
|
<g id="编组-16" transform="translate(1022, 126)">
|
||||||
|
<g id="link-outlined" transform="translate(12, 35)">
|
||||||
|
<g id="编组-14" transform="translate(2.5, 4.15)">
|
||||||
|
<path d="M5.50029186,2.425 C5.92116821,2.425 6.30372933,2.58703648 6.58052274,2.85122056 C6.84862719,3.1071115 7.01717088,3.4597026 7.01717088,3.85 C7.01717088,4.24027468 6.84857286,4.5929387 6.58042003,4.84890344 C6.3036418,5.11310155 5.92112667,5.27520782 5.50029186,5.27520782 C5.07946329,5.27520782 4.69695201,5.11310096 4.42017607,4.84890315 C4.15202505,4.59293829 3.98342734,4.24027444 3.98342734,3.85 C3.98342734,3.45970284 4.15197072,3.10711191 4.42007337,2.85122085 C4.69686448,2.58703707 5.07942175,2.425 5.50029186,2.425 Z" id="路径" fill-rule="nonzero"></path>
|
||||||
|
<path d="M5.5,7.7 C8.53756612,7.7 11,5.39612383 11,3.85 C11,2.30387617 8.53756612,0 5.5,0 C2.46243388,0 0,2.26850164 0,3.85 C0,5.43149836 2.46243388,7.7 5.5,7.7 Z" id="椭圆形"></path>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.5 KiB |
17
web/src/assets/images/common/eye_bg.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-950, -320)">
|
||||||
|
<g id="编组-13备份-6" transform="translate(648, 122)">
|
||||||
|
<g id="编辑" transform="translate(302, 198)">
|
||||||
|
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
|
||||||
|
<g id="编组-16" transform="translate(2.5, 4.7)" stroke="#5B6167">
|
||||||
|
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
|
||||||
|
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1011 B |
13
web/src/assets/images/common/link.svg
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>link-outlined</title>
|
||||||
|
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="平台管理-工具管理--MCP服务" transform="translate(-1032, -161)" fill="#171719" fill-rule="nonzero">
|
||||||
|
<g id="编组-16" transform="translate(1020, 126)">
|
||||||
|
<g id="link-outlined" transform="translate(12, 35)">
|
||||||
|
<path d="M8.8887561,10.1978746 C8.84435493,10.1534734 8.77130783,10.1534734 8.72690666,10.1978746 L7.06257876,11.8622025 C6.29200354,12.6327777 4.99147881,12.7144186 4.14069502,11.8622025 C3.28847892,11.0099864 3.37011979,9.71089398 4.14069502,8.94031876 L5.80502291,7.27599086 C5.84942409,7.23158968 5.84942409,7.15854259 5.80502291,7.11414142 L5.23496912,6.54408763 C5.19056795,6.49968645 5.11752086,6.49968645 5.07311968,6.54408763 L3.40879178,8.20841552 C2.19706941,9.4201379 2.19706941,11.3809511 3.40879178,12.5912411 C4.62051416,13.8015312 6.58132732,13.8029635 7.7916174,12.5912411 L9.4559453,10.9269132 C9.50034647,10.8825121 9.50034647,10.809465 9.4559453,10.7650638 L8.8887561,10.1978746 Z M12.5926734,3.40879178 C11.3809511,2.19706941 9.4201379,2.19706941 8.20984782,3.40879178 L6.54408763,5.07311968 C6.49968645,5.11752086 6.49968645,5.19056795 6.54408763,5.23496912 L7.11270912,5.80359062 C7.15711029,5.84799179 7.23015739,5.84799179 7.27455856,5.80359062 L8.93888646,4.13926272 C9.70946168,3.3686875 11.0099864,3.28704663 11.8607702,4.13926272 C12.7129863,4.99147881 12.6313454,6.29057124 11.8607702,7.06114647 L10.1964423,8.72547436 C10.1520411,8.76987554 10.1520411,8.84292263 10.1964423,8.88732381 L10.7664961,9.4573776 C10.8108973,9.50177877 10.8839444,9.50177877 10.9283455,9.4573776 L12.5926734,7.7930497 C13.8029635,6.58132732 13.8029635,4.62051416 12.5926734,3.40879178 L12.5926734,3.40879178 Z M9.40581494,5.99981516 C9.36141377,5.95541399 9.28836667,5.95541399 9.2439655,5.99981516 L5.99981516,9.2425332 C5.95541399,9.28693438 5.95541399,9.35998147 5.99981516,9.40438265 L6.56700436,9.97157184 C6.61140554,10.015973 6.68445263,10.015973 6.7288538,9.97157184 L9.97157184,6.7288538 C10.015973,6.68445263 10.015973,6.61140554 9.97157184,6.56700436 L9.40581494,5.99981516 Z" id="形状"></path>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 2.4 KiB |
@@ -1,12 +1,11 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
<svg width="22px" height="22px" viewBox="0 0 22 22" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
<title>更多</title>
|
<title>卡片1@3x</title>
|
||||||
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
<g id="记忆库-个人记忆-感知记忆-听觉" transform="translate(-440, -187)" fill-rule="nonzero">
|
<g id="平台管理-工具管理--MCP服务" transform="translate(-602, -128)" fill="#5B6167" fill-rule="nonzero">
|
||||||
<g id="编组-14" transform="translate(28, 168)">
|
<g id="卡片1" transform="translate(252, 112)">
|
||||||
<g id="更多" transform="translate(412, 19)">
|
<g id="更多" transform="translate(350, 16)">
|
||||||
<rect id="矩形" fill="#000000" opacity="0" x="0" y="0" width="24" height="24"></rect>
|
<path d="M5.4,12.4 C6.17319865,12.4 6.8,11.7731986 6.8,11 C6.8,10.2268014 6.17319865,9.6 5.4,9.6 C4.62680135,9.6 4,10.2268014 4,11 C4,11.7731986 4.62680135,12.4 5.4,12.4 Z M11,12.4 C11.7731986,12.4 12.4,11.7731986 12.4,11 C12.4,10.2268014 11.7731986,9.6 11,9.6 C10.2268014,9.6 9.6,10.2268014 9.6,11 C9.6,11.7731986 10.2268014,12.4 11,12.4 Z M16.6,12.4 C17.3731986,12.4 18,11.7731986 18,11 C18,10.2268014 17.3731986,9.6 16.6,9.6 C15.8268014,9.6 15.2,10.2268014 15.2,11 C15.2,11.7731986 15.8268014,12.4 16.6,12.4 Z" id="形状"></path>
|
||||||
<path d="M5.25,12 C5.25,12.8284271 5.92157288,13.5 6.75,13.5 C7.57842712,13.5 8.25,12.8284271 8.25,12 C8.25,11.1715729 7.57842712,10.5 6.75,10.5 C5.92157288,10.5 5.25,11.1715729 5.25,12 Z M10.5,12 C10.5,12.8284271 11.1715729,13.5 12,13.5 C12.8284271,13.5 13.5,12.8284271 13.5,12 C13.5,11.1715729 12.8284271,10.5 12,10.5 C11.1715729,10.5 10.5,11.1715729 10.5,12 Z M15.75,12 C15.75,12.8284271 16.4215729,13.5 17.25,13.5 C18.0784271,13.5 18.75,12.8284271 18.75,12 C18.75,11.1715729 18.0784271,10.5 17.25,10.5 C16.4215729,10.5 15.75,11.1715729 15.75,12 Z" id="形状" fill="#171719"></path>
|
|
||||||
</g>
|
</g>
|
||||||
</g>
|
</g>
|
||||||
</g>
|
</g>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.3 KiB After Width: | Height: | Size: 1.2 KiB |
@@ -1,14 +1,23 @@
|
|||||||
<?xml version="1.0" encoding="UTF-8"?>
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
<svg width="24px" height="24px" viewBox="0 0 24 24" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
<svg width="22px" height="22px" viewBox="0 0 22 22" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
<title>更多</title>
|
<title>更多@3x</title>
|
||||||
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
<defs>
|
||||||
<g id="记忆库-个人记忆-感知记忆-文本" transform="translate(-440, -187)" fill-rule="nonzero">
|
<filter x="-3.4%" y="-6.2%" width="106.8%" height="114.8%" filterUnits="objectBoundingBox" id="filter-1">
|
||||||
<g id="编组-8" transform="translate(12, 76)">
|
<feOffset dx="0" dy="2" in="SourceAlpha" result="shadowOffsetOuter1"></feOffset>
|
||||||
<g id="编组-14" transform="translate(16, 92)">
|
<feGaussianBlur stdDeviation="4" in="shadowOffsetOuter1" result="shadowBlurOuter1"></feGaussianBlur>
|
||||||
<g id="更多" transform="translate(412, 19)">
|
<feColorMatrix values="0 0 0 0 0.0901960784 0 0 0 0 0.0901960784 0 0 0 0 0.0980392157 0 0 0 0.16 0" type="matrix" in="shadowBlurOuter1" result="shadowMatrixOuter1"></feColorMatrix>
|
||||||
<rect id="矩形" fill="#E4E4E4" x="0" y="0" width="24" height="24" rx="8"></rect>
|
<feMerge>
|
||||||
<path d="M5.25,12 C5.25,12.8284271 5.92157288,13.5 6.75,13.5 C7.57842712,13.5 8.25,12.8284271 8.25,12 C8.25,11.1715729 7.57842712,10.5 6.75,10.5 C5.92157288,10.5 5.25,11.1715729 5.25,12 Z M10.5,12 C10.5,12.8284271 11.1715729,13.5 12,13.5 C12.8284271,13.5 13.5,12.8284271 13.5,12 C13.5,11.1715729 12.8284271,10.5 12,10.5 C11.1715729,10.5 10.5,11.1715729 10.5,12 Z M15.75,12 C15.75,12.8284271 16.4215729,13.5 17.25,13.5 C18.0784271,13.5 18.75,12.8284271 18.75,12 C18.75,11.1715729 18.0784271,10.5 17.25,10.5 C16.4215729,10.5 15.75,11.1715729 15.75,12 Z" id="形状" fill="#171719"></path>
|
<feMergeNode in="shadowMatrixOuter1"></feMergeNode>
|
||||||
</g>
|
<feMergeNode in="SourceGraphic"></feMergeNode>
|
||||||
|
</feMerge>
|
||||||
|
</filter>
|
||||||
|
</defs>
|
||||||
|
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="平台管理-工具管理--MCP服务" transform="translate(-998, -128)" fill-rule="nonzero">
|
||||||
|
<g id="卡片1备份" filter="url(#filter-1)" transform="translate(648, 112)">
|
||||||
|
<g id="更多" transform="translate(350, 16)">
|
||||||
|
<rect id="矩形" fill="#F6F6F6" x="0" y="0" width="22" height="22" rx="8"></rect>
|
||||||
|
<path d="M5.4,12.4 C6.17319865,12.4 6.8,11.7731986 6.8,11 C6.8,10.2268014 6.17319865,9.6 5.4,9.6 C4.62680135,9.6 4,10.2268014 4,11 C4,11.7731986 4.62680135,12.4 5.4,12.4 Z M11,12.4 C11.7731986,12.4 12.4,11.7731986 12.4,11 C12.4,10.2268014 11.7731986,9.6 11,9.6 C10.2268014,9.6 9.6,10.2268014 9.6,11 C9.6,11.7731986 10.2268014,12.4 11,12.4 Z M16.6,12.4 C17.3731986,12.4 18,11.7731986 18,11 C18,10.2268014 17.3731986,9.6 16.6,9.6 C15.8268014,9.6 15.2,10.2268014 15.2,11 C15.2,11.7731986 15.8268014,12.4 16.6,12.4 Z" id="形状" fill="#5B6167"></path>
|
||||||
</g>
|
</g>
|
||||||
</g>
|
</g>
|
||||||
</g>
|
</g>
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 1.4 KiB After Width: | Height: | Size: 2.0 KiB |
BIN
web/src/assets/images/conversation/ai.png
Normal file
|
After Width: | Height: | Size: 6.2 KiB |
|
Before Width: | Height: | Size: 189 KiB After Width: | Height: | Size: 158 KiB |
BIN
web/src/assets/images/conversation/user.png
Normal file
|
After Width: | Height: | Size: 7.8 KiB |
16
web/src/assets/images/menuNew/arrow_t_r.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编组 51</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<g id="工作台-记忆看板-3" transform="translate(-1400, -140)" stroke="#A8A9AA" stroke-width="1.2">
|
||||||
|
<g id="编组-22" transform="translate(1180, 57)">
|
||||||
|
<g id="编组-51" transform="translate(220, 83)">
|
||||||
|
<g id="编组-49" transform="translate(4.5, 4.5)">
|
||||||
|
<polyline id="路径" points="0 0 7 0 7 7"></polyline>
|
||||||
|
<line x1="7" y1="0" x2="9.71445147e-17" y2="7" id="路径-51"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 938 B |
17
web/src/assets/images/menuNew/logout_red.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>退出</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round">
|
||||||
|
<g id="工作台-记忆看板-3" transform="translate(-1196, -229)" stroke="#FF5D34" stroke-width="1.2">
|
||||||
|
<g id="编组-22" transform="translate(1180, 57)">
|
||||||
|
<g id="退出" transform="translate(16, 172)">
|
||||||
|
<g id="编组-7" transform="translate(2.5, 2)">
|
||||||
|
<path d="M4,12 L2,12 C0.8954305,12 0,11.1045695 0,10 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L4,0 L4,0" id="路径"></path>
|
||||||
|
<line x1="11" y1="6" x2="4.5" y2="6" id="路径-6"></line>
|
||||||
|
<polyline id="路径" points="8 3 11 6 8 9"></polyline>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.0 KiB |
19
web/src/assets/images/menuNew/settings.svg
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>设置-界面设置</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-记忆看板-3" transform="translate(-1196, -176)">
|
||||||
|
<g id="编组-22" transform="translate(1180, 57)">
|
||||||
|
<g id="设置-界面设置" transform="translate(16, 119)">
|
||||||
|
<g id="编组-2" transform="translate(2, 2)">
|
||||||
|
<path d="M9.43360402,6.02003217 C9.70219942,5.98544974 9.97412053,5.98544974 10.2427159,6.02003217 C10.3615753,6.03694616 10.4495746,6.13920492 10.4485827,6.25925761 L10.4485827,6.97778158 C10.8355212,7.08416506 11.1878633,7.2900028 11.4705612,7.57481981 L12.0861381,7.21535276 C12.1886462,7.15532215 12.3197841,7.18070349 12.3924911,7.27464643 C12.5566377,7.49314267 12.6927553,7.73134853 12.7976622,7.98369599 C12.8417451,8.09509234 12.7980417,8.22199461 12.6947152,8.28262525 L12.0795484,8.64127201 C12.1812591,9.03280439 12.1812591,9.44382977 12.0795484,9.83536215 L12.6951253,10.1940089 C12.7984653,10.255883 12.8408681,10.3841328 12.7947775,10.4954127 C12.6902734,10.7470629 12.5545635,10.9845849 12.3908368,11.2023979 C12.3178817,11.29706 12.1858902,11.3226521 12.0828433,11.2621154 L11.4681003,10.9034823 C11.1853554,11.1882385 10.8330304,11.3940662 10.4461218,11.5005205 L10.4461218,12.2190308 C10.4472928,12.3391587 10.3591904,12.4415237 10.2402277,12.4582563 C9.97161477,12.49242 9.69975606,12.49242 9.43114314,12.4582563 C9.31227289,12.4413547 9.22425896,12.3390925 9.22524902,12.2190308 L9.22524902,11.5005069 C8.83966141,11.3929856 8.48863308,11.187108 8.20657902,10.9030585 L7.58893771,11.263756 C7.48646614,11.3234892 7.35562246,11.2983389 7.28259844,11.2048724 C7.11845115,10.9862292 6.98233473,10.7478877 6.87742727,10.4954127 C6.83256205,10.383776 6.87616574,10.2561165 6.9799505,10.195253 L7.59757815,9.83618245 C7.54895974,9.65227947 7.52310022,9.46310373 7.52057977,9.27289936 C7.51726668,9.06031128 7.54274952,8.8482532 7.59633403,8.64250245 L6.9807708,8.28344555 C6.87737248,8.2230699 6.83375833,8.09609457 6.87823389,7.98492643 C6.98279132,7.7324117 7.11893329,7.49416785 7.28340506,7.27589055 C7.35624204,7.18108729 7.48825303,7.15531561 7.59139859,7.2157629 L8.2061552,7.57481981 C8.48870496,7.29001094 8.84091669,7.08416484 9.22773725,6.97776791 L9.22773725,6.25925761 C9.22638248,6.13907231 9.31455943,6.03660714 9.43360402,6.02003217 Z M9.85567328,8.1381029 L9.83960916,8.1381029 C9.44388649,8.13394685 9.07614467,8.3416968 8.87546122,8.68278309 C8.67477776,9.02386937 8.67178651,9.44622602 8.86761874,9.79012055 C9.06345096,10.1340151 9.42821335,10.3469528 9.82395519,10.3484022 L9.85608342,10.3484022 C10.4510916,10.3389892 10.929189,9.8552216 10.9315862,9.26014377 C10.9437669,8.6534016 10.4623788,8.15138092 9.85567328,8.1381029 Z" id="形状结合" fill="#171719" fill-rule="nonzero"></path>
|
||||||
|
<path d="M5.99499898,12 L2,12 C0.8954305,12 4.4408921e-16,11.1045695 4.4408921e-16,10 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L10,0 C11.1045695,0 12,0.8954305 12,2 L12,4.98927875 L12,4.98927875" id="路径" stroke="#171719" stroke-width="1.2" stroke-linecap="round" stroke-linejoin="round"></path>
|
||||||
|
<line x1="0" y1="3.99424593" x2="12" y2="3.99424593" id="路径-2" stroke="#171719" stroke-width="1.2"></line>
|
||||||
|
<circle id="椭圆形" fill="#171719" cx="2.2" cy="2" r="1"></circle>
|
||||||
|
<circle id="椭圆形" fill="#171719" cx="4.4" cy="2" r="1"></circle>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 3.7 KiB |
13
web/src/assets/images/menuNew/userInfo.svg
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>账户</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-记忆看板-3" transform="translate(-1196, -140)" fill="#171719" fill-rule="nonzero" stroke="#171719" stroke-width="0.4">
|
||||||
|
<g id="编组-22" transform="translate(1180, 57)">
|
||||||
|
<g id="账户" transform="translate(16, 83)">
|
||||||
|
<path d="M8,1 C4.13400675,1 1,4.13400675 1,8 C1,11.8659932 4.13400675,15 8,15 C11.8659932,15 15,11.8659932 15,8 C14.9947583,4.13618002 11.86382,1.0052417 8,1 L8,1 Z M12.4014141,12.3271621 L12.2224629,12.517543 L12.1748711,12.2662539 C12.0428879,11.5485013 11.7289104,10.8766306 11.2629844,10.3149355 C11.1145387,10.1521744 10.8641988,10.135306 10.6952598,10.2766811 C10.5263208,10.4180561 10.4988,10.6674498 10.6328477,10.8422598 C11.1442401,11.4572317 11.4244051,12.2317168 11.4248047,13.0315371 C11.4257399,13.0442165 11.4257399,13.0569476 11.4248047,13.069627 L11.4248047,13.1419512 L11.3638828,13.1819414 C9.31720678,14.5169283 6.67517799,14.5169283 4.62850195,13.1819414 L4.56758008,13.1419648 L4.56758008,13.0334512 C4.56758008,11.1409267 6.10177432,9.60673242 7.99429883,9.60673242 C9.43633406,9.61053035 10.6698828,8.57151634 10.9112248,7.14981524 C11.1525669,5.72811413 10.3309981,4.34023189 8.96849203,3.8679425 C7.60598597,3.39565311 6.10170668,3.97731994 5.4113796,5.24338809 C4.72105252,6.50945624 5.046918,8.08901446 6.18194141,8.97850977 L6.35137695,9.10986914 L6.16099609,9.20315234 C4.9414843,9.79213158 4.07443441,10.9255913 3.82512891,12.2567383 L3.77753711,12.5080273 L3.59858594,12.3176602 C2.45876944,11.1700068 1.82010269,9.61749329 1.8223981,8 C1.81896848,5.20888419 3.68748546,2.76216731 6.38105315,2.03070823 C9.07462084,1.29924914 11.9240679,2.46476648 13.3328918,4.87423832 C14.7417157,7.28371017 14.3599179,10.338544 12.4014141,12.3271621 L12.4014141,12.3271621 Z M5.87924609,6.65787305 C5.88029543,5.48791052 6.82940261,4.5402209 7.9993654,4.54091991 C9.16932819,4.5416197 10.1173016,5.49044339 10.1169523,6.66040633 C10.1166027,7.83036928 9.16806261,8.77862695 7.99809961,8.77862695 C6.82758179,8.77757806 5.87924609,7.82839134 5.87924609,6.65787305 L5.87924609,6.65787305 Z" id="形状"></path>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 2.4 KiB |
BIN
web/src/assets/images/model/bedrock.png
Normal file
|
After Width: | Height: | Size: 3.2 KiB |
@@ -1,15 +0,0 @@
|
|||||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<g clip-path="url(#clip0_16762_59518)">
|
|
||||||
<path d="M12.6667 0H3.33333C1.49238 0 0 1.49238 0 3.33333V12.6667C0 14.5076 1.49238 16 3.33333 16H12.6667C14.5076 16 16 14.5076 16 12.6667V3.33333C16 1.49238 14.5076 0 12.6667 0Z" fill="url(#paint0_linear_16762_59518)"/>
|
|
||||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M7.99984 12.093L6.3825 12.6323L5.75184 12.2116L6.4385 11.9823L6.22784 11.3503L5.04917 11.743L4.6665 11.4883V9.66631C4.6665 9.54031 4.59517 9.42497 4.4825 9.3683L3.33317 8.79364V7.20564L4.33317 6.70564L5.33317 7.20564V8.33297C5.33317 8.45964 5.4045 8.57497 5.51717 8.63164L6.8505 9.29831L7.14917 8.70164L5.99984 8.12697V7.20564L7.14917 6.63164C7.26184 6.57497 7.33317 6.45964 7.33317 6.33297V5.33297H6.6665V6.12697L5.6665 6.62697L4.6665 6.12697V4.51164L5.33317 4.06697V5.33297H5.99984V3.62297L6.3825 3.36764L7.99984 3.90697V12.093ZM11.6665 11.333C11.8498 11.333 11.9998 11.4823 11.9998 11.6663C11.9998 11.8503 11.8498 11.9996 11.6665 11.9996C11.4832 11.9996 11.3332 11.8503 11.3332 11.6663C11.3332 11.4823 11.4832 11.333 11.6665 11.333ZM10.9998 3.99964C11.1832 3.99964 11.3332 4.14897 11.3332 4.33297C11.3332 4.51697 11.1832 4.6663 10.9998 4.6663C10.8165 4.6663 10.6665 4.51697 10.6665 4.33297C10.6665 4.14897 10.8165 3.99964 10.9998 3.99964ZM12.3332 7.99964C12.5165 7.99964 12.6665 8.14897 12.6665 8.33297C12.6665 8.51697 12.5165 8.66631 12.3332 8.66631C12.1498 8.66631 11.9998 8.51697 11.9998 8.33297C11.9998 8.14897 12.1498 7.99964 12.3332 7.99964ZM11.3945 8.66631C11.5325 9.05364 11.8992 9.33297 12.3332 9.33297C12.8845 9.33297 13.3332 8.88497 13.3332 8.33297C13.3332 7.78164 12.8845 7.33297 12.3332 7.33297C11.8992 7.33297 11.5325 7.61297 11.3945 7.99964H8.6665V6.66631H10.9998C11.1838 6.66631 11.3332 6.51764 11.3332 6.33297V5.27164C11.7205 5.13364 11.9998 4.76697 11.9998 4.33297C11.9998 3.78164 11.5512 3.33297 10.9998 3.33297C10.4485 3.33297 9.99984 3.78164 9.99984 4.33297C9.99984 4.76697 10.2792 5.13364 10.6665 5.27164V5.99964H8.6665V3.6663C8.6665 3.52297 8.5745 3.39564 8.4385 3.3503L6.4385 2.68364C6.3405 2.65097 6.23384 2.66564 6.1485 2.7223L4.1485 4.05564C4.05584 4.11764 3.99984 4.22164 3.99984 4.33297V6.12697L2.8505 6.70164C2.73784 6.75831 2.6665 6.87364 2.6665 6.99964V8.99964C2.6665 9.12631 2.73784 9.24164 2.8505 9.29831L3.99984 9.87231V11.6663C3.99984 11.7776 4.05584 11.8823 4.1485 11.9436L6.1485 13.277C6.20384 13.3143 6.26784 13.333 6.33317 13.333C6.3685 13.333 6.40384 13.3276 6.4385 13.3156L8.4385 12.649C8.5745 12.6043 8.6665 12.477 8.6665 12.333V10.6663H10.1952L10.7638 11.2356L10.7725 11.227C10.7072 11.3603 10.6665 11.5083 10.6665 11.6663C10.6665 12.2176 11.1152 12.6663 11.6665 12.6663C12.2178 12.6663 12.6665 12.2176 12.6665 11.6663C12.6665 11.115 12.2178 10.6663 11.6665 10.6663C11.5078 10.6663 11.3598 10.707 11.2272 10.773L11.2358 10.7643L10.5692 10.0976C10.5065 10.035 10.4218 9.99964 10.3332 9.99964H8.6665V8.66631H11.3945Z" fill="white"/>
|
|
||||||
</g>
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="paint0_linear_16762_59518" x1="0" y1="1600" x2="1600" y2="0" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop stop-color="#055F4E"/>
|
|
||||||
<stop offset="1" stop-color="#56C0A7"/>
|
|
||||||
</linearGradient>
|
|
||||||
<clipPath id="clip0_16762_59518">
|
|
||||||
<rect width="16" height="16" fill="white"/>
|
|
||||||
</clipPath>
|
|
||||||
</defs>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 2.8 KiB After Width: | Height: | Size: 2.6 KiB |
|
Before Width: | Height: | Size: 57 KiB After Width: | Height: | Size: 3.2 KiB |
BIN
web/src/assets/images/model/ollama.png
Normal file
|
After Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 7.7 KiB |
BIN
web/src/assets/images/model/openai.png
Normal file
|
After Width: | Height: | Size: 3.3 KiB |
|
Before Width: | Height: | Size: 6.9 KiB |
|
Before Width: | Height: | Size: 7.1 KiB After Width: | Height: | Size: 2.5 KiB |
BIN
web/src/assets/images/model/xinference.png
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
@@ -1,24 +0,0 @@
|
|||||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<g id="Xorbits Square" clip-path="url(#clip0_9850_26870)">
|
|
||||||
<path id="Vector" d="M8.00391 12.3124C8.69334 13.0754 9.47526 13.7494 10.3316 14.3188C11.0667 14.8105 11.8509 15.2245 12.6716 15.5541C14.1617 14.1465 15.3959 12.4907 16.3192 10.6606L21.7051 0L12.3133 7.38353C10.5832 8.74456 9.12178 10.416 8.00391 12.3124Z" fill="url(#paint0_linear_9850_26870)"/>
|
|
||||||
<path id="Vector_2" d="M7.23504 18.9512C6.56092 18.5012 5.92386 18.0265 5.3221 17.5394L2.06445 24L7.91975 19.3959C7.69034 19.2494 7.46092 19.103 7.23504 18.9512Z" fill="url(#paint1_linear_9850_26870)"/>
|
|
||||||
<path id="Vector_3" d="M19.3161 8.57474C21.0808 10.9147 21.5961 13.5159 20.3996 15.3053C18.6526 17.9189 13.9161 17.8183 9.82024 15.0812C5.72435 12.3441 3.82024 8.0065 5.56729 5.39297C6.76377 3.60356 9.36318 3.0865 12.2008 3.81886C7.29318 1.73474 2.62376 1.94121 0.813177 4.64474C-1.45976 8.04709 1.64435 14.1177 7.74494 18.1889C13.8455 22.26 20.6361 22.8124 22.9091 19.4118C24.7179 16.703 23.1173 12.3106 19.3161 8.57474Z" fill="url(#paint2_linear_9850_26870)"/>
|
|
||||||
</g>
|
|
||||||
<defs>
|
|
||||||
<linearGradient id="paint0_linear_9850_26870" x1="2.15214" y1="24.3018" x2="21.2921" y2="0.0988218" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop stop-color="#E9A85E"/>
|
|
||||||
<stop offset="1" stop-color="#F52B76"/>
|
|
||||||
</linearGradient>
|
|
||||||
<linearGradient id="paint1_linear_9850_26870" x1="2.06269" y1="24.2294" x2="21.2027" y2="0.028252" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop stop-color="#E9A85E"/>
|
|
||||||
<stop offset="1" stop-color="#F52B76"/>
|
|
||||||
</linearGradient>
|
|
||||||
<linearGradient id="paint2_linear_9850_26870" x1="-0.613606" y1="3.843" x2="21.4449" y2="18.7258" gradientUnits="userSpaceOnUse">
|
|
||||||
<stop stop-color="#6A0CF5"/>
|
|
||||||
<stop offset="1" stop-color="#AB66F3"/>
|
|
||||||
</linearGradient>
|
|
||||||
<clipPath id="clip0_9850_26870">
|
|
||||||
<rect width="24" height="24" fill="white"/>
|
|
||||||
</clipPath>
|
|
||||||
</defs>
|
|
||||||
</svg>
|
|
||||||
|
Before Width: | Height: | Size: 1.8 KiB |
19
web/src/assets/images/prompt/delete.svg
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编组 33</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-1002, -560)" stroke="#5B6167">
|
||||||
|
<g id="编组-13备份-4" transform="translate(648, 362)">
|
||||||
|
<g id="编组-33" transform="translate(354, 198)">
|
||||||
|
<g id="编组-32" transform="translate(3.5, 4)">
|
||||||
|
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
|
||||||
|
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
|
||||||
|
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
|
||||||
|
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
|
||||||
|
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.4 KiB |
20
web/src/assets/images/prompt/delete_hover.svg
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编组 33</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-1002, -320)">
|
||||||
|
<g id="编组-13备份-6" transform="translate(648, 122)">
|
||||||
|
<g id="编组-33" transform="translate(354, 198)">
|
||||||
|
<rect id="矩形" fill-opacity="0.08" fill="#FF5D34" x="0" y="0" width="18" height="18" rx="6"></rect>
|
||||||
|
<g id="编组-32" transform="translate(3.5, 4)" stroke="#FF5D34">
|
||||||
|
<line x1="-1.80133686e-14" y1="2.22222222" x2="11" y2="2.22222222" id="路径-29"></line>
|
||||||
|
<polyline id="路径-30" stroke-linejoin="round" points="3.3 2.2221179 3.3 0 7.7 0 7.7 2.22222222"></polyline>
|
||||||
|
<path d="M1.65,2.23587458 L1.65,9 C1.65,9.55228475 2.09771525,10 2.65,10 L8.35,10 C8.90228475,10 9.35,9.55228475 9.35,9 L9.35,2.22222222 L9.35,2.22222222" id="路径-31" stroke-linejoin="round"></path>
|
||||||
|
<line x1="4.4" y1="4.45203738" x2="4.4" y2="7.78537071" id="路径-32"></line>
|
||||||
|
<line x1="6.6" y1="4.45203738" x2="6.6" y2="7.78537071" id="路径-32"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.5 KiB |
16
web/src/assets/images/prompt/edit.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-976, -560)" stroke="#5B6167">
|
||||||
|
<g id="编组-13备份-4" transform="translate(648, 362)">
|
||||||
|
<g id="编辑" transform="translate(328, 198)">
|
||||||
|
<g id="编组-10" transform="translate(4.3, 4.3)">
|
||||||
|
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
|
||||||
|
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.0 KiB |
17
web/src/assets/images/prompt/edit_bg.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-976, -320)">
|
||||||
|
<g id="编组-13备份-6" transform="translate(648, 122)">
|
||||||
|
<g id="编辑" transform="translate(328, 198)">
|
||||||
|
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
|
||||||
|
<g id="编组-10" transform="translate(4.3, 4.3)" stroke="#5B6167">
|
||||||
|
<path d="M9.4,4.04322919 L9.4,7.4 C9.4,8.5045695 8.5045695,9.4 7.4,9.4 L2,9.4 C0.8954305,9.4 0,8.5045695 0,7.4 L0,2 C0,0.8954305 0.8954305,2.22044605e-16 2,0 L5.38958415,0 L5.38958415,0" id="路径"></path>
|
||||||
|
<line x1="3.74260398" y1="5.68579764" x2="9.4" y2="1.05734433e-14" id="路径-2"></line>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1.1 KiB |
16
web/src/assets/images/prompt/eye.svg
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-950, -560)" stroke="#5B6167">
|
||||||
|
<g id="编组-13备份-4" transform="translate(648, 362)">
|
||||||
|
<g id="编辑" transform="translate(302, 198)">
|
||||||
|
<g id="编组-16" transform="translate(2.5, 4.7)">
|
||||||
|
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
|
||||||
|
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 888 B |
17
web/src/assets/images/prompt/eye_bg.svg
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="18px" height="18px" viewBox="0 0 18 18" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>编辑</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="工作台-提示词-我的历史" transform="translate(-950, -320)">
|
||||||
|
<g id="编组-13备份-6" transform="translate(648, 122)">
|
||||||
|
<g id="编辑" transform="translate(302, 198)">
|
||||||
|
<rect id="矩形" fill="#EBEBEB" fill-rule="nonzero" x="0" y="0" width="18" height="18" rx="6"></rect>
|
||||||
|
<g id="编组-16" transform="translate(2.5, 4.7)" stroke="#5B6167">
|
||||||
|
<ellipse id="椭圆形" cx="6.5" cy="4.3" rx="6.5" ry="4.3"></ellipse>
|
||||||
|
<circle id="椭圆形" cx="6.5" cy="4.3" r="1.75"></circle>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 1011 B |
13
web/src/assets/images/workflow/clear.svg
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||||
|
<title>clear-outlined</title>
|
||||||
|
<g id="空间里层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||||
|
<g id="应用管理-工作流-配置-开始" transform="translate(-1249, -24)" fill="#171719" fill-rule="nonzero">
|
||||||
|
<g id="编组-11" transform="translate(1242, 17)">
|
||||||
|
<g id="clear-outlined" transform="translate(7, 7)">
|
||||||
|
<path d="M14.4933021,14.4598985 L13.6042691,9.03045685 L13.9045274,9.03045685 C14.146076,9.03045685 14.3406568,8.82436548 14.3406568,8.56852792 L14.3406568,5.15736041 C14.3406568,4.90152284 14.146076,4.69543147 13.9045274,4.69543147 L9.77807205,4.69543147 L9.77807205,1.46192893 C9.77807205,1.20609137 9.58349122,1 9.34194262,1 L6.65806921,1 C6.4165206,1 6.22193978,1.20609137 6.22193978,1.46192893 L6.22193978,4.69543147 L2.09548441,4.69543147 C1.85393581,4.69543147 1.65935498,4.90152284 1.65935498,5.15736041 L1.65935498,8.56852792 C1.65935498,8.82436548 1.85393581,9.03045685 2.09548441,9.03045685 L2.39574275,9.03045685 L1.50670968,14.4598985 C1.50167742,14.4865482 1.5,14.513198 1.5,14.5380711 C1.5,14.7939086 1.69458082,15 1.93612943,15 L14.0638824,15 C14.0890437,15 14.114205,14.9982234 14.1376889,14.9928934 C14.3758827,14.9502538 14.5352377,14.7104061 14.4933021,14.4598985 Z M2.8335496,5.93908629 L7.3961344,5.93908629 L7.3961344,2.24365482 L8.60387743,2.24365482 L8.60387743,5.93908629 L13.1664622,5.93908629 L13.1664622,7.78680203 L2.8335496,7.78680203 L2.8335496,5.93908629 Z M10.6838793,13.7563452 L10.6838793,10.9847716 C10.6838793,10.906599 10.6234922,10.8426396 10.5496857,10.8426396 L9.74452363,10.8426396 C9.67071711,10.8426396 9.61032996,10.906599 9.61032996,10.9847716 L9.61032996,13.7563452 L6.38968187,13.7563452 L6.38968187,10.9847716 C6.38968187,10.906599 6.32929472,10.8426396 6.2554882,10.8426396 L5.45032617,10.8426396 C5.37651966,10.8426396 5.3161325,10.906599 5.3161325,10.9847716 L5.3161325,13.7563452 L2.81342055,13.7563452 L3.56993737,9.13705584 L12.428397,9.13705584 L13.1849139,13.7563452 L10.6838793,13.7563452 Z" id="形状"></path>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
After Width: | Height: | Size: 2.3 KiB |