Compare commits

..

1 Commits

Author SHA1 Message Date
Ke Sun
6056952936 Merge pull request #669 from SuanmoSuanyangTechnology/release/v0.2.8
Release/v0.2.8
2026-03-23 10:17:29 +08:00
760 changed files with 13555 additions and 25578 deletions

2
.gitignore vendored
View File

@@ -25,8 +25,6 @@ examples/
time.log time.log
celerybeat-schedule.db celerybeat-schedule.db
search_results.json search_results.json
redbear-mem-metrics/
pitch-deck/
api/migrations/versions api/migrations/versions
tmp tmp

View File

@@ -1,8 +1,6 @@
import asyncio import asyncio
import json import json
import logging import logging
import os
import threading
from typing import Dict, Any, Optional from typing import Dict, Any, Optional
import redis.asyncio as redis import redis.asyncio as redis
@@ -23,50 +21,6 @@ pool = ConnectionPool.from_url(
) )
aio_redis = redis.StrictRedis(connection_pool=pool) aio_redis = redis.StrictRedis(connection_pool=pool)
_REDIS_URL = f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}"
# Thread-local storage for connection pools.
# Each thread (and each forked process) gets its own pool to avoid
# "Future attached to a different loop" errors in Celery --pool=threads
# and stale connections after fork in --pool=prefork.
_thread_local = threading.local()
def get_thread_safe_redis() -> redis.StrictRedis:
"""Return a Redis client whose connection pool is bound to the current
thread, process **and** event loop.
The pool is recreated when:
- The PID changes (fork, Celery --pool=prefork)
- The thread has no pool yet (Celery --pool=threads)
- The previously-cached event loop has been closed (Celery tasks call
``_shutdown_loop_gracefully`` which closes the loop after each run)
"""
current_pid = os.getpid()
cached_loop = getattr(_thread_local, "loop", None)
loop_stale = cached_loop is not None and cached_loop.is_closed()
if not hasattr(_thread_local, "pool") \
or getattr(_thread_local, "pid", None) != current_pid \
or loop_stale:
_thread_local.pid = current_pid
# Python 3.10+: get_event_loop() raises RuntimeError in threads
# where no loop has been set yet (e.g. Celery --pool=threads).
try:
_thread_local.loop = asyncio.get_event_loop()
except RuntimeError:
_thread_local.loop = None
_thread_local.pool = ConnectionPool.from_url(
_REDIS_URL,
db=settings.REDIS_DB,
password=settings.REDIS_PASSWORD,
decode_responses=True,
max_connections=5,
health_check_interval=30,
)
return redis.StrictRedis(connection_pool=_thread_local.pool)
async def get_redis_connection(): async def get_redis_connection():
"""获取Redis连接""" """获取Redis连接"""
@@ -90,8 +44,10 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
val = json.dumps(val, ensure_ascii=False) val = json.dumps(val, ensure_ascii=False)
if expire is not None: if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire) await aio_redis.set(key, val, ex=expire)
else: else:
# 设置永久键值
await aio_redis.set(key, val) await aio_redis.set(key, val)
except Exception as e: except Exception as e:
logger.error(f"Redis set错误: {str(e)}") logger.error(f"Redis set错误: {str(e)}")

View File

@@ -10,7 +10,7 @@ import logging
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from datetime import datetime from datetime import datetime
from app.aioRedis import get_thread_safe_redis from app.aioRedis import aio_redis
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -68,7 +68,7 @@ class ActivityStatsCache:
"cached": True, "cached": True,
} }
value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
await get_thread_safe_redis().set(key, value, ex=expire) await aio_redis.set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}") logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True return True
except Exception as e: except Exception as e:
@@ -90,7 +90,7 @@ class ActivityStatsCache:
""" """
try: try:
key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
value = await get_thread_safe_redis().get(key) value = await aio_redis.get(key)
if value: if value:
payload = json.loads(value) payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}") logger.info(f"命中活动统计缓存: {key}")
@@ -116,7 +116,7 @@ class ActivityStatsCache:
""" """
try: try:
key = cls._get_key(workspace_id) key = cls._get_key(workspace_id)
result = await get_thread_safe_redis().delete(key) result = await aio_redis.delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}") logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0 return result > 0
except Exception as e: except Exception as e:

View File

@@ -9,7 +9,7 @@ import logging
from typing import Optional, List, Dict, Any from typing import Optional, List, Dict, Any
from datetime import datetime from datetime import datetime
from app.aioRedis import get_thread_safe_redis from app.aioRedis import aio_redis
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class InterestMemoryCache:
"cached": True, "cached": True,
} }
value = json.dumps(payload, ensure_ascii=False) value = json.dumps(payload, ensure_ascii=False)
await get_thread_safe_redis().set(key, value, ex=expire) await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}") logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True return True
except Exception as e: except Exception as e:
@@ -86,7 +86,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
value = await get_thread_safe_redis().get(key) value = await aio_redis.get(key)
if value: if value:
payload = json.loads(value) payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}") logger.info(f"命中兴趣分布缓存: {key}")
@@ -114,7 +114,7 @@ class InterestMemoryCache:
""" """
try: try:
key = cls._get_key(end_user_id, language) key = cls._get_key(end_user_id, language)
result = await get_thread_safe_redis().delete(key) result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0 return result > 0
except Exception as e: except Exception as e:

View File

@@ -1,6 +1,5 @@
import os import os
import platform import platform
import re
from datetime import timedelta from datetime import timedelta
from urllib.parse import quote from urllib.parse import quote
@@ -12,24 +11,21 @@ from app.core.logging_config import get_logger
logger = get_logger(__name__) logger = get_logger(__name__)
def _mask_url(url: str) -> str:
"""隐藏 URL 中的密码部分,适用于 redis:// 和 amqp:// 等协议"""
return re.sub(r'(://[^:]*:)[^@]+(@)', r'\1***\2', url)
# macOS fork() safety - must be set before any Celery initialization # macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin': if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 创建 Celery 应用实例 # 创建 Celery 应用实例
# broker: 优先使用环境变量 CELERY_BROKER_URL支持 amqp:// 等任意协议), # broker: 任务队列(使用 Redis DB由 CELERY_BROKER_DB 指定)
# 未配置则回退到 Redis 方案 # backend: 结果存储(使用 Redis DB由 CELERY_BACKEND_DB 指定)
# backend: 结果存储(使用 Redis
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND # NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md # 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
_broker_url = os.getenv("CELERY_BROKER_URL") or \ # Build canonical broker/backend URLs and force them into os.environ so that
f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" # Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" _backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url os.environ["CELERY_RESULT_BACKEND"] = _backend_url
@@ -49,8 +45,8 @@ celery_app = Celery(
logger.info( logger.info(
"Celery app initialized", "Celery app initialized",
extra={ extra={
"broker": _mask_url(_broker_url), "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _mask_url(_backend_url), "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
}, },
) )
# Default queue for unrouted tasks # Default queue for unrouted tasks
@@ -81,7 +77,6 @@ celery_app.conf.update(
# Worker 设置 (per-worker settings are in docker-compose command line) # Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
# 结果过期时间 # 结果过期时间
result_expires=3600, # 结果保存1小时 result_expires=3600, # 结果保存1小时
@@ -108,9 +103,6 @@ celery_app.conf.update(
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'}, 'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Clustering tasks → memory_tasks queue (使用相同的 worker避免 macOS fork 问题)
'app.tasks.run_incremental_clustering': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker) # Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'}, 'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'}, 'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},

View File

@@ -8,13 +8,11 @@ from fastapi import APIRouter
from . import ( from . import (
api_key_controller, api_key_controller,
app_controller, app_controller,
app_log_controller,
auth_controller, auth_controller,
chunk_controller, chunk_controller,
document_controller, document_controller,
emotion_config_controller, emotion_config_controller,
emotion_controller, emotion_controller,
end_user_controller,
file_controller, file_controller,
file_storage_controller, file_storage_controller,
home_page_controller, home_page_controller,
@@ -71,7 +69,6 @@ manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router) manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router) manager_router.include_router(knowledgeshare_controller.router)
manager_router.include_router(app_controller.router) manager_router.include_router(app_controller.router)
manager_router.include_router(app_log_controller.router)
manager_router.include_router(upload_controller.router) manager_router.include_router(upload_controller.router)
manager_router.include_router(memory_agent_controller.router) manager_router.include_router(memory_agent_controller.router)
manager_router.include_router(memory_dashboard_controller.router) manager_router.include_router(memory_dashboard_controller.router)
@@ -99,6 +96,5 @@ manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router) manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router) manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router) manager_router.include_router(i18n_controller.router)
manager_router.include_router(end_user_controller.router)
__all__ = ["manager_router"] __all__ = ["manager_router"]

View File

@@ -65,42 +65,16 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用 - 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用 - 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页 - 当提供 ids 参数时,按逗号分割获取指定应用,不分页
- search 参数支持应用名称模糊搜索、API Key 精确搜索
""" """
from sqlalchemy import select as sa_select
from app.models.api_key_model import ApiKey
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
service = app_service.AppService(db) service = app_service.AppService(db)
# 通过 search 参数搜索:支持应用名称模糊搜索和 API Key 精确搜索 # 当 ids 存在且不为 None 时,根据 ids 获取应用
if search:
search = search.strip()
# 尝试作为 API Key 精确匹配API Key 通常较长)
if len(search) >= 10:
matched_id = db.execute(
sa_select(ApiKey.resource_id).where(
ApiKey.workspace_id == workspace_id,
ApiKey.api_key == search,
ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
if matched_id:
# 找到 API Key直接返回关联的应用
ids = str(matched_id)
# 当 ids 存在时,根据 ids 获取应用(不分页)
if ids is not None: if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
if app_ids: items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items)
# 返回标准分页格式
meta = PageMeta(page=1, pagesize=len(items), total=len(items), hasnext=False)
return success(data=PageData(page=meta, items=items))
# ids 为空时,返回空列表
meta = PageMeta(page=1, pagesize=0, total=0, hasnext=False)
return success(data=PageData(page=meta, items=[]))
# 正常分页查询 # 正常分页查询
items_orm, total = app_service.list_apps( items_orm, total = app_service.list_apps(

View File

@@ -1,89 +0,0 @@
"""应用日志(消息记录)接口"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail
from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService
from app.services.app_log_service import AppLogService
router = APIRouter(prefix="/apps", tags=["App Logs"])
logger = get_business_logger()
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
@cur_workspace_access_guard()
def list_app_logs(
app_id: uuid.UUID,
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
is_draft: Optional[bool] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""查看应用下所有会话记录(分页)
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
- 按最新更新时间倒序排列
- 所有人(包括共享者和被共享者)都只能查看自己的会话记录
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
app_service = AppService(db)
app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询
log_service = AppLogService(db)
conversations, total = log_service.list_conversations(
app_id=app_id,
workspace_id=workspace_id,
page=page,
pagesize=pagesize,
is_draft=is_draft
)
items = [AppLogConversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
@cur_workspace_access_guard()
def get_app_log_detail(
app_id: uuid.UUID,
conversation_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""查看某会话的完整消息记录
- 返回会话基本信息 + 所有消息(按时间正序)
- 消息 meta_data 包含模型名、token 用量等信息
- 所有人(包括共享者和被共享者)都只能查看自己的会话详情
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
app_service = AppService(db)
app_service.get_app(app_id, workspace_id)
# 使用 Service 层查询
log_service = AppLogService(db)
conversation = log_service.get_conversation_detail(
app_id=app_id,
conversation_id=conversation_id,
workspace_id=workspace_id
)
detail = AppLogConversationDetail.model_validate(conversation)
return success(data=detail)

View File

@@ -1,48 +0,0 @@
"""End User 管理接口 - 无需认证"""
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
)
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
router = APIRouter(prefix="/end_users", tags=["End Users"])
logger = get_business_logger()
@router.post("")
async def create_end_user(
data: CreateEndUserRequest,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the given workspace.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
"""
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=data.workspace_id,
other_id=data.other_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -14,9 +14,6 @@ Routes:
import os import os
import uuid import uuid
from typing import Any from typing import Any
import httpx
import mimetypes
from urllib.parse import urlparse, unquote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse from fastapi.responses import FileResponse, RedirectResponse
@@ -293,101 +290,6 @@ async def upload_file_with_share_token(
) )
@router.get("/files/info-by-url", response_model=ApiResponse)
async def get_file_info_by_url(
url: str,
):
"""
Get file information by network URL (no authentication required).
Fetches file metadata from a remote URL via HTTP HEAD request.
Falls back to GET request if HEAD is not supported.
Returns file type, name, and size.
Args:
url: The network URL of the file.
Returns:
ApiResponse with file information.
"""
api_logger.info(f"File info by URL request: url={url}")
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Try HEAD request first
response = await client.head(url, follow_redirects=True)
# If HEAD fails, try GET request (some servers don't support HEAD)
if response.status_code != 200:
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
response = await client.get(url, follow_redirects=True)
if response.status_code != 200:
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access file: HTTP {response.status_code}"
)
# Get file size from Content-Length header or actual content
file_size = response.headers.get("Content-Length")
if file_size:
file_size = int(file_size)
elif hasattr(response, 'content'):
file_size = len(response.content)
else:
file_size = None
# Get content type from Content-Type header
content_type = response.headers.get("Content-Type", "application/octet-stream")
# Remove charset and other parameters from content type
content_type = content_type.split(';')[0].strip()
# Extract filename from Content-Disposition or URL
file_name = None
content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
parts = content_disposition.split("filename=")
if len(parts) > 1:
file_name = parts[1].strip('"').strip("'")
if not file_name:
parsed_url = urlparse(url)
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
# Extract file extension from filename
_, file_ext = os.path.splitext(file_name)
# If no extension found, infer from content type
if not file_ext:
ext = mimetypes.guess_extension(content_type)
if ext:
file_ext = ext
file_name = f"{file_name}{file_ext}"
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
return success(
data={
"url": url,
"file_name": file_name,
"file_ext": file_ext.lower() if file_ext else "",
"file_size": file_size,
"content_type": content_type,
},
msg="File information retrieved successfully"
)
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Unexpected error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file information: {str(e)}"
)
@router.get("/files/{file_id}", response_model=Any) @router.get("/files/{file_id}", response_model=Any)
async def download_file( async def download_file(
request: Request, request: Request,
@@ -574,12 +476,8 @@ async def get_file_url(
# For local storage, generate signed URL with expiration # For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires) url = generate_signed_url(str(file_id), expires)
else: else:
# For remote storage (OSS/S3), get presigned URL with forced download # For remote storage (OSS/S3), get presigned URL
url = await storage_service.get_file_url( url = await storage_service.get_file_url(file_key, expires=expires)
file_key,
expires=expires,
file_name=file_metadata.file_name,
)
url = _match_scheme(request, url) url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}") api_logger.info(f"Generated file URL: file_id={file_id}")
@@ -790,7 +688,7 @@ async def permanent_download_file(
# For remote storage, redirect to presigned URL with long expiration # For remote storage, redirect to presigned URL with long expiration
try: try:
# Use a very long expiration (7 days max for most cloud providers) # Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800, file_name=file_metadata.file_name) presigned_url = await storage_service.get_file_url(file_key, expires=604800)
presigned_url = _match_scheme(request, presigned_url) presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND) return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e: except Exception as e:
@@ -799,44 +697,3 @@ async def permanent_download_file(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}" detail=f"Failed to retrieve file: {str(e)}"
) )
@router.get("/files/{file_id}/status", response_model=ApiResponse)
async def get_file_status(
file_id: uuid.UUID,
db: Session = Depends(get_db),
):
"""
Get file upload/processing status (no authentication required).
This endpoint is used to check if a file (e.g., TTS audio) is ready.
Returns status: pending, completed, or failed.
Args:
file_id: The UUID of the file.
db: Database session.
Returns:
ApiResponse with file status and metadata.
"""
api_logger.info(f"File status request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
return success(
data={
"file_id": str(file_id),
"status": file_metadata.status,
"file_name": file_metadata.file_name,
"file_size": file_metadata.file_size,
"content_type": file_metadata.content_type,
},
msg="File status retrieved successfully"
)

View File

@@ -91,11 +91,9 @@ async def get_mcp_servers(
try: try:
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers=api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=headers, headers=api.builder_headers(api.headers),
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -175,7 +173,6 @@ async def get_operational_mcp_servers(
url = f'{api.mcp_base_url}/operational' url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers) headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
try: try:
cookies = api.get_cookies(access_token=token, cookies_required=True) cookies = api.get_cookies(access_token=token, cookies_required=True)
@@ -263,9 +260,7 @@ async def create_mcp_market_config(
api.login(create_data.token) api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token) cookies = api.get_cookies(create_data.token)
headers = api.builder_headers(api.headers) r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
headers['Authorization'] = f'Bearer {create_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
@@ -295,11 +290,9 @@ async def create_mcp_market_config(
'search': "" 'search': ""
} }
cookies = api.get_cookies(token) cookies = api.get_cookies(token)
headers = api.builder_headers(api.headers)
headers['Authorization'] = f'Bearer {token}'
r = api.session.put( r = api.session.put(
url=api.mcp_base_url, url=api.mcp_base_url,
headers=headers, headers=api.builder_headers(api.headers),
json=body, json=body,
cookies=cookies) cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
@@ -400,9 +393,7 @@ async def update_mcp_market_config(
api.login(update_data.token) api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None} body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token) cookies = api.get_cookies(update_data.token)
headers = api.builder_headers(api.headers) r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
headers['Authorization'] = f'Bearer {update_data.token}'
r = api.session.put(url=api.mcp_base_url, headers=headers, json=body, cookies=cookies)
raise_for_http_status(r) raise_for_http_status(r)
except Exception as e: except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}") api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")

View File

@@ -118,142 +118,142 @@ async def download_log(
return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "启动日志流式传输失败", str(e))
# @router.post("/writer_service", response_model=ApiResponse) @router.post("/writer_service", response_model=ApiResponse)
# @cur_workspace_access_guard() @cur_workspace_access_guard()
# async def write_server( async def write_server(
# user_input: Write_UserInput, user_input: Write_UserInput,
# language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
# db: Session = Depends(get_db), db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
# ): ):
# """ """
# Write service endpoint - processes write operations synchronously Write service endpoint - processes write operations synchronously
#
# Args: Args:
# user_input: Write request containing message and end_user_id user_input: Write request containing message and end_user_id
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
# Returns: Returns:
# Response with write operation status Response with write operation status
# """ """
# # 使用集中化的语言校验 # 使用集中化的语言校验
# language = get_language_from_header(language_type) language = get_language_from_header(language_type)
#
# config_id = user_input.config_id config_id = user_input.config_id
# workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# # 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
# storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
# db=db, db=db,
# workspace_id=workspace_id, workspace_id=workspace_id,
# user=current_user user=current_user
# ) )
# if storage_type is None: storage_type = 'neo4j' if storage_type is None: storage_type = 'neo4j'
# user_rag_memory_id = '' user_rag_memory_id = ''
#
# # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id # 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
# if storage_type == 'rag': if storage_type == 'rag':
# if workspace_id: if workspace_id:
# knowledge = knowledge_repository.get_knowledge_by_name( knowledge = knowledge_repository.get_knowledge_by_name(
# db=db, db=db,
# name="USER_RAG_MERORY", name="USER_RAG_MERORY",
# workspace_id=workspace_id workspace_id=workspace_id
# ) )
# if knowledge: if knowledge:
# user_rag_memory_id = str(knowledge.id) user_rag_memory_id = str(knowledge.id)
# else: else:
# api_logger.warning( api_logger.warning(
# f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储") f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
# storage_type = 'neo4j' storage_type = 'neo4j'
# else: else:
# api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
# storage_type = 'neo4j' storage_type = 'neo4j'
#
# api_logger.info( api_logger.info(
# f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
# try: try:
# messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
# result = await memory_agent_service.write_memory( result = await memory_agent_service.write_memory(
# user_input.end_user_id, user_input.end_user_id,
# messages_list, messages_list,
# config_id, config_id,
# db, db,
# storage_type, storage_type,
# user_rag_memory_id, user_rag_memory_id,
# language language
# ) )
#
# return success(data=result, msg="写入成功") return success(data=result, msg="写入成功")
# except BaseException as e: except BaseException as e:
# # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
# if hasattr(e, 'exceptions'): if hasattr(e, 'exceptions'):
# error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions] error_messages = [f"{type(sub_e).__name__}: {str(sub_e)}" for sub_e in e.exceptions]
# detailed_error = "; ".join(error_messages) detailed_error = "; ".join(error_messages)
# api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True) api_logger.error(f"Write operation error (TaskGroup): {detailed_error}", exc_info=True)
# return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error) return fail(BizCode.INTERNAL_ERROR, "写入失败", detailed_error)
# api_logger.error(f"Write operation error: {str(e)}", exc_info=True) api_logger.error(f"Write operation error: {str(e)}", exc_info=True)
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
#
#
# @router.post("/writer_service_async", response_model=ApiResponse) @router.post("/writer_service_async", response_model=ApiResponse)
# @cur_workspace_access_guard() @cur_workspace_access_guard()
# async def write_server_async( async def write_server_async(
# user_input: Write_UserInput, user_input: Write_UserInput,
# language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
# db: Session = Depends(get_db), db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
# ): ):
# """ """
# Async write service endpoint - enqueues write processing to Celery Async write service endpoint - enqueues write processing to Celery
#
# Args: Args:
# user_input: Write request containing message and end_user_id user_input: Write request containing message and end_user_id
# language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递 language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
#
# Returns: Returns:
# Task ID for tracking async operation Task ID for tracking async operation
# Use GET /memory/write_result/{task_id} to check task status and get result Use GET /memory/write_result/{task_id} to check task status and get result
# """ """
# # 使用集中化的语言校验 # 使用集中化的语言校验
# language = get_language_from_header(language_type) language = get_language_from_header(language_type)
#
# config_id = user_input.config_id config_id = user_input.config_id
# workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# api_logger.info( api_logger.info(
# f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
#
# # 获取 storage_type如果为 None 则使用默认值 # 获取 storage_type如果为 None 则使用默认值
# storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
# db=db, db=db,
# workspace_id=workspace_id, workspace_id=workspace_id,
# user=current_user user=current_user
# ) )
# if storage_type is None: storage_type = 'neo4j' if storage_type is None: storage_type = 'neo4j'
# user_rag_memory_id = '' user_rag_memory_id = ''
# if workspace_id: if workspace_id:
#
# knowledge = knowledge_repository.get_knowledge_by_name( knowledge = knowledge_repository.get_knowledge_by_name(
# db=db, db=db,
# name="USER_RAG_MERORY", name="USER_RAG_MERORY",
# workspace_id=workspace_id workspace_id=workspace_id
# ) )
# if knowledge: user_rag_memory_id = str(knowledge.id) if knowledge: user_rag_memory_id = str(knowledge.id)
# api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# try: try:
# # 获取标准化的消息列表 # 获取标准化的消息列表
# messages_list = memory_agent_service.get_messages_list(user_input) messages_list = memory_agent_service.get_messages_list(user_input)
#
# task = celery_app.send_task( task = celery_app.send_task(
# "app.core.memory.agent.write_message", "app.core.memory.agent.write_message",
# args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
# ) )
# api_logger.info(f"Write task queued: {task.id}") api_logger.info(f"Write task queued: {task.id}")
#
# return success(data={"task_id": task.id}, msg="写入任务已提交") return success(data={"task_id": task.id}, msg="写入任务已提交")
# except Exception as e: except Exception as e:
# api_logger.error(f"Async write operation failed: {str(e)}") api_logger.error(f"Async write operation failed: {str(e)}")
# return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "写入失败", str(e))
@router.post("/read_service", response_model=ApiResponse) @router.post("/read_service", response_model=ApiResponse)

View File

@@ -1,5 +1,3 @@
import time
from contextlib import contextmanager
from fastapi import APIRouter, Depends, HTTPException, status, Query from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -18,18 +16,6 @@ from app.core.logging_config import get_api_logger
# 获取API专用日志器 # 获取API专用日志器
api_logger = get_api_logger() api_logger = get_api_logger()
@contextmanager
def timer(label: str, user_count: int = 0):
"""上下文管理器:用于测量代码块执行时间"""
start = time.perf_counter()
try:
yield
finally:
elapsed = (time.perf_counter() - start) * 1000 # 转换为毫秒
extra_info = f", 用户数: {user_count}" if user_count > 0 else ""
api_logger.info(f"[性能统计] {label}: {elapsed:.2f}ms{extra_info}")
router = APIRouter( router = APIRouter(
prefix="/dashboard", prefix="/dashboard",
tags=["Dashboard"], tags=["Dashboard"],
@@ -66,7 +52,7 @@ async def get_workspace_end_users(
): ):
""" """
获取工作空间的宿主列表(高性能优化版本 v2 获取工作空间的宿主列表(高性能优化版本 v2
优化策略: 优化策略:
1. 批量查询 end_users一次查询而非循环 1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j 2. 并发查询所有用户的记忆数量Neo4j
@@ -74,7 +60,7 @@ async def get_workspace_end_users(
4. 只返回必要字段减少数据传输 4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询 5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询 6. 并发执行配置查询和记忆数量查询
返回格式: 返回格式:
{ {
"end_user": {"id": "uuid", "other_name": "名称"}, "end_user": {"id": "uuid", "other_name": "名称"},
@@ -84,149 +70,129 @@ async def get_workspace_end_users(
""" """
import asyncio import asyncio
import json import json
# from app.aioRedis import aio_redis_get, aio_redis_set from app.aioRedis import aio_redis_get, aio_redis_set
# 总耗时统计
total_start = time.perf_counter()
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# # 尝试从缓存获取30秒缓存- 暂时注释以便进行性能测试 # 尝试从缓存获取30秒缓存
# with timer("Redis缓存读取"): cache_key = f"end_users:workspace:{workspace_id}"
# cache_key = f"end_users:workspace:{workspace_id}" try:
# try: cached_data = await aio_redis_get(cache_key)
# cached_data = await aio_redis_get(cache_key) if cached_data:
# if cached_data: api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
# api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}") return success(data=json.loads(cached_data), msg="宿主列表获取成功")
# return success(data=json.loads(cached_data), msg="宿主列表获取成功") except Exception as e:
# except Exception as e: api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型 # 获取当前空间类型
with timer("获取空间类型"): current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
# 获取 end_users已优化为批量查询 # 获取 end_users已优化为批量查询
with timer("获取用户列表"): end_users = memory_dashboard_service.get_workspace_end_users(
end_users = memory_dashboard_service.get_workspace_end_users( db=db,
db=db, workspace_id=workspace_id,
workspace_id=workspace_id, current_user=current_user
current_user=current_user )
)
if not end_users: if not end_users:
api_logger.info("工作空间下没有宿主") api_logger.info("工作空间下没有宿主")
# # 缓存空结果,避免重复查询 - 暂时注释 # 缓存空结果,避免重复查询
# try: try:
# await aio_redis_set(cache_key, json.dumps([]), expire=30) await aio_redis_set(cache_key, json.dumps([]), expire=30)
# except Exception as e: except Exception as e:
# api_logger.warning(f"Redis 缓存写入失败: {str(e)}") api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功") return success(data=[], msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users] end_user_ids = [str(user.id) for user in end_users]
user_count = len(end_user_ids)
api_logger.info(f"需要处理的用户数: {user_count}")
# 并发执行两个独立的查询任务 # 并发执行两个独立的查询任务
async def get_memory_configs(): async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)""" """获取记忆配置(在线程池中执行同步查询)"""
with timer("功能模块-获取记忆配置", user_count): try:
try: return await asyncio.to_thread(
return await asyncio.to_thread( get_end_users_connected_configs_batch,
get_end_users_connected_configs_batch, end_user_ids, db
end_user_ids, db )
) except Exception as e:
except Exception as e: api_logger.error(f"批量获取记忆配置失败: {str(e)}")
api_logger.error(f"批量获取记忆配置失败: {str(e)}") return {}
return {}
async def get_memory_nums(): async def get_memory_nums():
"""获取记忆数量""" """获取记忆数量"""
with timer(f"功能模块-获取记忆数量[{current_workspace_type}]", user_count): if current_workspace_type == "rag":
if current_workspace_type == "rag": # RAG 模式:批量查询
# RAG 模式:批量查询 try:
with timer(" - RAG批量查询chunks"): chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try: try:
chunk_map = await asyncio.to_thread( return await memory_storage_service.search_all(end_user_id)
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e: except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}") api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids} return {"total": 0}
elif current_workspace_type == "neo4j": memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
# Neo4j 模式:并发查询(带并发限制) return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10 return {uid: {"total": 0} for uid in end_user_ids}
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
single_start = time.perf_counter()
try:
result = await memory_storage_service.search_all(end_user_id)
elapsed = (time.perf_counter() - single_start) * 1000
api_logger.info(f" - Neo4j单用户查询[{end_user_id}]: {elapsed:.2f}ms")
return result
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
with timer(" - Neo4j并发查询所有用户"):
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据 # 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
with timer("触发Celery初始化任务"): try:
try: from app.celery_app import celery_app as _celery_app
from app.celery_app import celery_app as _celery_app _celery_app.send_task(
_celery_app.send_task( "app.tasks.init_implicit_emotions_for_users",
"app.tasks.init_implicit_emotions_for_users", kwargs={"end_user_ids": end_user_ids},
kwargs={"end_user_ids": end_user_ids}, )
) _celery_app.send_task(
_celery_app.send_task( "app.tasks.init_interest_distribution_for_users",
"app.tasks.init_interest_distribution_for_users", kwargs={"end_user_ids": end_user_ids},
kwargs={"end_user_ids": end_user_ids}, )
) api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}") except Exception as e:
except Exception as e: api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询 # 并发执行配置查询和记忆数量查询
with timer("并发执行两个功能模块"): memory_configs_map, memory_nums_map = await asyncio.gather(
memory_configs_map, memory_nums_map = await asyncio.gather( get_memory_configs(),
get_memory_configs(), get_memory_nums()
get_memory_nums() )
)
# 构建结果(优化:使用列表推导式) # 构建结果(优化:使用列表推导式)
with timer("构建返回结果"): result = []
result = [] for end_user in end_users:
for end_user in end_users: user_id = str(end_user.id)
user_id = str(end_user.id) config_info = memory_configs_map.get(user_id, {})
config_info = memory_configs_map.get(user_id, {}) result.append({
result.append({ 'end_user': {
'end_user': { 'id': user_id,
'id': user_id, 'other_name': end_user.other_name
'other_name': end_user.other_name },
}, 'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_num': memory_nums_map.get(user_id, {"total": 0}), 'memory_config': {
'memory_config': { "memory_config_id": config_info.get("memory_config_id"),
"memory_config_id": config_info.get("memory_config_id"), "memory_config_name": config_info.get("memory_config_name")
"memory_config_name": config_info.get("memory_config_name") }
} })
})
# 写入缓存30秒过期
# # 写入缓存30秒过期- 暂时注释以便进行性能测试 try:
# with timer("Redis缓存写入"): await aio_redis_set(cache_key, json.dumps(result), expire=30)
# try: except Exception as e:
# await aio_redis_set(cache_key, json.dumps(result), expire=30) api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# except Exception as e:
# api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应) # 触发社区聚类补全任务(异步,不阻塞接口响应)
try: try:
@@ -236,8 +202,6 @@ async def get_workspace_end_users(
except Exception as e: except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}") api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
total_elapsed = (time.perf_counter() - total_start) * 1000
api_logger.info(f"[性能统计] 接口总耗时: {total_elapsed:.2f}ms, 用户数: {user_count}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功") return success(data=result, msg="宿主列表获取成功")
@@ -699,12 +663,9 @@ async def dashboard_data(
rag_data["total_memory"] = total_chunk rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量 # total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app from app.repositories import app_repository
from app.services import app_service as _app_svc apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
_, total_app = _app_svc.AppService(db).list_apps( rag_data["total_app"] = len(apps_orm)
workspace_id=workspace_id, include_shared=True, pagesize=1
)
rag_data["total_app"] = total_app
# total_knowledge: 使用 total_kb总知识库数 # total_knowledge: 使用 total_kb总知识库数
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user) total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
@@ -726,7 +687,7 @@ async def dashboard_data(
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}") api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0 rag_data["total_api_call"] = 0
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={total_app}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}") api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
except Exception as e: except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}") api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -31,7 +31,6 @@ from app.schemas.memory_storage_schema import (
ForgettingCurveRequest, ForgettingCurveRequest,
ForgettingCurveResponse, ForgettingCurveResponse,
ForgettingCurvePoint, ForgettingCurvePoint,
PendingNodesResponse,
) )
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService from app.services.memory_forget_service import MemoryForgetService
@@ -309,100 +308,6 @@ async def get_forgetting_stats(
return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "获取遗忘引擎统计失败", str(e))
@router.get("/pending-nodes", response_model=ApiResponse)
async def get_pending_nodes(
end_user_id: str,
page: int = 1,
pagesize: int = 10,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
获取待遗忘节点列表(独立分页接口)
查询满足遗忘条件的节点(激活值低于阈值且最后访问时间超过最小天数)。
此接口独立分页,与 /stats 接口分离。
Args:
end_user_id: 组ID即 end_user_id必填
page: 页码从1开始默认1
pagesize: 每页数量默认10
current_user: 当前用户
db: 数据库会话
Returns:
ApiResponse: 包含待遗忘节点列表和分页信息的响应
Examples:
- 第1页每页10条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=1&pagesize=10
- 第2页每页20条GET /memory/forget-memory/pending-nodes?end_user_id=xxx&page=2&pagesize=20
Notes:
- page 从1开始pagesize 必须大于0
- 返回格式:{"items": [...], "page": {"page": 1, "pagesize": 10, "total": 100, "hasnext": true}}
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 验证 end_user_id 必填
if not end_user_id:
api_logger.warning(f"用户 {current_user.username} 尝试获取待遗忘节点但未提供 end_user_id")
return fail(BizCode.INVALID_PARAMETER, "end_user_id 不能为空", "end_user_id is required")
# 通过 end_user_id 获取关联的 config_id
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
except Exception as e:
api_logger.error(f"获取终端用户配置时发生错误: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取终端用户配置失败", str(e))
# 验证分页参数
if page < 1:
return fail(BizCode.INVALID_PARAMETER, "page 必须大于等于1", "page < 1")
if pagesize < 1:
return fail(BizCode.INVALID_PARAMETER, "pagesize 必须大于等于1", "pagesize < 1")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取待遗忘节点: "
f"end_user_id={end_user_id}, page={page}, pagesize={pagesize}"
)
try:
# 调用服务层获取待遗忘节点列表
result = await forget_service.get_pending_nodes(
db=db,
end_user_id=end_user_id,
config_id=config_id,
page=page,
pagesize=pagesize
)
# 构建响应
response_data = PendingNodesResponse(**result)
return success(data=response_data.model_dump(), msg="查询成功")
except Exception as e:
api_logger.error(f"获取待遗忘节点列表失败: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取待遗忘节点列表失败", str(e))
@router.post("/forgetting_curve", response_model=ApiResponse) @router.post("/forgetting_curve", response_model=ApiResponse)
async def get_forgetting_curve( async def get_forgetting_curve(
request: ForgettingCurveRequest, request: ForgettingCurveRequest,

View File

@@ -54,8 +54,8 @@ router = APIRouter(
@router.get("/info", response_model=ApiResponse) @router.get("/info", response_model=ApiResponse)
async def get_storage_info( async def get_storage_info(
storage_id: str, storage_id: str,
current_user: User = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
""" """
Example wrapper endpoint - retrieves storage information Example wrapper endpoint - retrieves storage information
@@ -75,19 +75,24 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
def create_config( def create_config(
payload: ConfigParamsCreate, payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求创建配置: {payload.config_name}")
try: try:
# 将 workspace_id 注入到 payload 中(保持为 UUID 类型) # 将 workspace_id 注入到 payload 中(保持为 UUID 类型)
@@ -102,11 +107,9 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}") api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
@@ -116,11 +119,9 @@ def create_config(
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type) lang = get_language_from_header(x_language_type)
if lang == "en": if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else: else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg) return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}") api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@@ -128,10 +129,10 @@ def create_config(
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称) @router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config( def delete_config(
config_id: UUID | int, config_id: UUID|int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"), force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""删除记忆配置(带终端用户保护) """删除记忆配置(带终端用户保护)
@@ -144,24 +145,24 @@ def delete_config(
force: 设置为 true 可强制删除(即使有终端用户正在使用) force: 设置为 true 可强制删除(即使有终端用户正在使用)
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db) config_id=resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info( api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: " f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}" f"config_id={config_id}, force={force}"
) )
try: try:
# 使用带保护的删除服务 # 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService from app.services.memory_config_service import MemoryConfigService
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force) result = config_service.delete_config(config_id=config_id, force=force)
if result["status"] == "error": if result["status"] == "error":
api_logger.warning( api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}" f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
@@ -171,7 +172,7 @@ def delete_config(
msg=result["message"], msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)} data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
) )
if result["status"] == "warning": if result["status"] == "warning":
api_logger.warning( api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, " f"记忆配置正在使用,无法删除: config_id={config_id}, "
@@ -185,7 +186,7 @@ def delete_config(
"force_required": result["force_required"] "force_required": result["force_required"]
} }
) )
api_logger.info( api_logger.info(
f"记忆配置删除成功: config_id={config_id}, " f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}" f"affected_users={result['affected_users']}"
@@ -194,7 +195,7 @@ def delete_config(
msg=result["message"], msg=result["message"],
data={"affected_users": result["affected_users"]} data={"affected_users": result["affected_users"]}
) )
except Exception as e: except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True) api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@@ -202,9 +203,9 @@ def delete_config(
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc @router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config( def update_config(
payload: ConfigUpdate, payload: ConfigUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -212,13 +213,12 @@ def update_config(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新 # 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None: if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段") api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
"config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -231,9 +231,9 @@ def update_config(
@router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选 @router.post("/update_config_extracted", response_model=ApiResponse) # 更新数据库中的部分内容 所有业务字段均可选
def update_config_extracted( def update_config_extracted(
payload: ConfigUpdateExtracted, payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db) payload.config_id = resolve_config_id(payload.config_id, db)
@@ -241,7 +241,7 @@ def update_config_extracted(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新提取配置: {payload.config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -256,11 +256,11 @@ def update_config_extracted(
# 遗忘引擎配置接口已迁移到 memory_forget_controller.py # 遗忘引擎配置接口已迁移到 memory_forget_controller.py
# 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config # 使用新接口: /api/memory/forget/read_config 和 /api/memory/forget/update_config
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除 @router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted( def read_config_extracted(
config_id: UUID | int, config_id: UUID | int,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db) config_id = resolve_config_id(config_id, db)
@@ -268,7 +268,7 @@ def read_config_extracted(
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取提取配置: {config_id}")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -278,19 +278,18 @@ def read_config_extracted(
api_logger.error(f"Read config extracted failed: {str(e)}") api_logger.error(f"Read config extracted failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "查询配置失败", str(e))
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
@router.get("/read_all_config", response_model=ApiResponse) # 读取所有配置文件列表
def read_all_config( def read_all_config(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置") api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求读取所有配置")
try: try:
svc = DataConfigService(db) svc = DataConfigService(db)
@@ -304,14 +303,14 @@ def read_all_config(
@router.post("/pilot_run", response_model=None) @router.post("/pilot_run", response_model=None)
async def pilot_run( async def pilot_run(
payload: ConfigPilotRun, payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> StreamingResponse: ) -> StreamingResponse:
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
api_logger.info( api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, " f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, " f"dialogue_text_length={len(payload.dialogue_text)}, "
@@ -334,9 +333,9 @@ async def pilot_run(
@router.get("/search/kb_type_distribution", response_model=ApiResponse) @router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution( async def get_kb_type_distribution(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}") api_logger.info(f"KB type distribution requested for end_user_id: {end_user_id}")
try: try:
result = await kb_type_distribution(end_user_id) result = await kb_type_distribution(end_user_id)
@@ -345,12 +344,12 @@ async def get_kb_type_distribution(
api_logger.error(f"KB type distribution failed: {str(e)}") api_logger.error(f"KB type distribution failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "知识库类型分布查询失败", str(e))
@router.get("/search/dialogue", response_model=ApiResponse) @router.get("/search/dialogue", response_model=ApiResponse)
async def search_dialogues_num( async def search_dialogues_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}") api_logger.info(f"Search dialogue requested for end_user_id: {end_user_id}")
try: try:
result = await search_dialogue(end_user_id) result = await search_dialogue(end_user_id)
@@ -362,9 +361,9 @@ async def search_dialogues_num(
@router.get("/search/chunk", response_model=ApiResponse) @router.get("/search/chunk", response_model=ApiResponse)
async def search_chunks_num( async def search_chunks_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}") api_logger.info(f"Search chunk requested for end_user_id: {end_user_id}")
try: try:
result = await search_chunk(end_user_id) result = await search_chunk(end_user_id)
@@ -376,9 +375,9 @@ async def search_chunks_num(
@router.get("/search/statement", response_model=ApiResponse) @router.get("/search/statement", response_model=ApiResponse)
async def search_statements_num( async def search_statements_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search statement requested for end_user_id: {end_user_id}") api_logger.info(f"Search statement requested for end_user_id: {end_user_id}")
try: try:
result = await search_statement(end_user_id) result = await search_statement(end_user_id)
@@ -390,9 +389,9 @@ async def search_statements_num(
@router.get("/search/entity", response_model=ApiResponse) @router.get("/search/entity", response_model=ApiResponse)
async def search_entities_num( async def search_entities_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search entity requested for end_user_id: {end_user_id}") api_logger.info(f"Search entity requested for end_user_id: {end_user_id}")
try: try:
result = await search_entity(end_user_id) result = await search_entity(end_user_id)
@@ -404,9 +403,9 @@ async def search_entities_num(
@router.get("/search", response_model=ApiResponse) @router.get("/search", response_model=ApiResponse)
async def search_all_num( async def search_all_num(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search all requested for end_user_id: {end_user_id}") api_logger.info(f"Search all requested for end_user_id: {end_user_id}")
try: try:
result = await search_all(end_user_id) result = await search_all(end_user_id)
@@ -418,9 +417,9 @@ async def search_all_num(
@router.get("/search/detials", response_model=ApiResponse) @router.get("/search/detials", response_model=ApiResponse)
async def search_entities_detials( async def search_entities_detials(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search details requested for end_user_id: {end_user_id}") api_logger.info(f"Search details requested for end_user_id: {end_user_id}")
try: try:
result = await search_detials(end_user_id) result = await search_detials(end_user_id)
@@ -432,9 +431,9 @@ async def search_entities_detials(
@router.get("/search/edges", response_model=ApiResponse) @router.get("/search/edges", response_model=ApiResponse)
async def search_entity_edges( async def search_entity_edges(
end_user_id: Optional[str] = None, end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Search edges requested for end_user_id: {end_user_id}") api_logger.info(f"Search edges requested for end_user_id: {end_user_id}")
try: try:
result = await search_edges(end_user_id) result = await search_edges(end_user_id)
@@ -444,12 +443,14 @@ async def search_entity_edges(
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse) @router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
async def get_hot_memory_tags_api( async def get_hot_memory_tags_api(
limit: int = 10, limit: int = 10,
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
获取热门记忆标签带Redis缓存 获取热门记忆标签带Redis缓存
@@ -460,18 +461,18 @@ async def get_hot_memory_tags_api(
- 缓存未命中:~600-800ms取决于LLM速度 - 缓存未命中:~600-800ms取决于LLM速度
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 构建缓存键 # 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}" cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}") api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
try: try:
# 尝试从Redis缓存获取 # 尝试从Redis缓存获取
import json import json
from app.aioRedis import aio_redis_get, aio_redis_set from app.aioRedis import aio_redis_get, aio_redis_set
cached_result = await aio_redis_get(cache_key) cached_result = await aio_redis_get(cache_key)
if cached_result: if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}") api_logger.info(f"Cache hit for key: {cache_key}")
@@ -480,11 +481,11 @@ async def get_hot_memory_tags_api(
return success(data=data, msg="查询成功(缓存)") return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError: except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh") api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询 # 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query") api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit) result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟 # 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串 # 注意result是列表需要转换为JSON字符串
try: try:
@@ -494,9 +495,9 @@ async def get_hot_memory_tags_api(
except Exception as cache_error: except Exception as cache_error:
# 缓存写入失败不影响主流程 # 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}") api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}") api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@@ -504,8 +505,8 @@ async def get_hot_memory_tags_api(
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse) @router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache( async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
""" """
清除热门标签缓存 清除热门标签缓存
@@ -515,12 +516,12 @@ async def clear_hot_memory_tags_cache(
- 数据更新后立即生效 - 数据更新后立即生效
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}") api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try: try:
from app.aioRedis import aio_redis_delete from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值 # 清除所有limit的缓存常见的limit值
cleared_count = 0 cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]: for limit in [5, 10, 15, 20, 30, 50]:
@@ -529,12 +530,12 @@ async def clear_hot_memory_tags_cache(
if result: if result:
cleared_count += 1 cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}") api_logger.info(f"Cleared cache for key: {cache_key}")
return success( return success(
data={"cleared_count": cleared_count}, data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存" msg=f"成功清除 {cleared_count} 个缓存"
) )
except Exception as e: except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}") api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@@ -542,7 +543,7 @@ async def clear_hot_memory_tags_cache(
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
@@ -552,3 +553,4 @@ async def get_recent_activity_stats_api(
except Exception as e: except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}") api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))

View File

@@ -42,7 +42,6 @@ def get_model_strategies():
@router.get("", response_model=ApiResponse) @router.get("", response_model=ApiResponse)
def get_model_list( def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"), type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
capability: Optional[list[str]] = Query(None, description="能力筛选(支持多个,如 ?capability=chat 或 ?capability=chat, embedding"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"), provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"), is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"), is_public: Optional[bool] = Query(None, description="公开状态筛选"),
@@ -75,21 +74,10 @@ def get_model_list(
unique_flat_type = list(dict.fromkeys(flat_type)) unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type] type_list = [ModelType(t.lower()) for t in unique_flat_type]
capability_list = []
if capability is not None:
flat_capability = []
for item in capability:
split_items = [c.strip() for c in item.split(', ') if c.strip()]
flat_capability.extend(split_items)
unique_flat_capability = list(dict.fromkeys(flat_capability))
capability_list = unique_flat_capability
api_logger.error(f"获取模型type_list: {type_list}") api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery( query = model_schema.ModelConfigQuery(
type=type_list, type=type_list,
provider=provider, provider=provider,
capability=capability_list,
is_active=is_active, is_active=is_active,
is_public=is_public, is_public=is_public,
search=search, search=search,

View File

@@ -27,7 +27,6 @@ from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService from app.services.shared_chat_service import SharedChatService
from app.services.workflow_service import WorkflowService from app.services.workflow_service import WorkflowService
from app.models.file_metadata_model import FileMetadata
from app.utils.app_config_utils import workflow_config_4_app_release, \ from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release agent_config_4_app_release, multi_agent_config_4_app_release
@@ -260,41 +259,8 @@ def get_conversation(
conv_service = ConversationService(db) conv_service = ConversationService(db)
messages = conv_service.get_messages(conversation_id) messages = conv_service.get_messages(conversation_id)
file_ids = [] # 构建响应
message_file_id_map = {} conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump()
# 第一次遍历:解析 audio_url收集所有有效的 file_id
for idx, m in enumerate(messages):
if m.role == "assistant" and m.meta_data:
audio_url = m.meta_data.get("audio_url")
if not audio_url:
continue
try:
file_id = uuid.UUID(audio_url.rstrip("/").split("/")[-1])
except (ValueError, IndexError):
# audio_url 无法解析为 UUID标记为 unknown
m.meta_data["audio_status"] = "unknown"
continue
file_ids.append(file_id)
message_file_id_map[idx] = file_id
# 批量查询所有相关的 FileMetadata
file_status_map = {}
if file_ids:
file_metas = (
db.query(FileMetadata)
.filter(FileMetadata.id.in_(set(file_ids)))
.all()
)
file_status_map = {fm.id: fm.status for fm in file_metas}
# 第二次遍历:将查询结果映射回消息
for idx, file_id in message_file_id_map.items():
m = messages[idx]
m.meta_data["audio_status"] = file_status_map.get(file_id, "unknown")
conv_dict = conversation_schema.Conversation.model_validate(conversation).model_dump(mode="json")
conv_dict["messages"] = [ conv_dict["messages"] = [
conversation_schema.Message.model_validate(m) for m in messages conversation_schema.Message.model_validate(m) for m in messages
] ]
@@ -354,16 +320,6 @@ async def chat(
other_id=other_id, other_id=other_id,
original_user_id=user_id original_user_id=user_id
) )
# Only extract and set memory_config_id when the end user doesn't have one yet
if not new_end_user.memory_config_id:
from app.services.memory_config_service import MemoryConfigService
memory_config_service = MemoryConfigService(db)
memory_config_id, _ = memory_config_service.extract_memory_config_id(release.type, release.config or {})
if memory_config_id:
new_end_user.memory_config_id = memory_config_id
db.commit()
db.refresh(new_end_user)
end_user_id = str(new_end_user.id) end_user_id = str(new_end_user.id)
# appid = share.app_id # appid = share.app_id
@@ -713,7 +669,6 @@ async def config_query(
content = { content = {
"app_type": release.app.type, "app_type": release.app.type,
"variables": release.config.get("variables"), "variables": release.config.get("variables"),
"memory": release.config.get("memory", {}).get("enabled"),
"features": release.config.get("features") "features": release.config.get("features")
} }
elif release.app.type == AppType.MULTI_AGENT: elif release.app.type == AppType.MULTI_AGENT:

View File

@@ -91,7 +91,7 @@ async def chat(
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id other_id = payload.user_id
workspace_id = api_key_auth.workspace_id workspace_id = app.workspace_id
end_user_repo = EndUserRepository(db) end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user( new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id, app_id=app.id,

View File

@@ -6,7 +6,6 @@ from app.core.response_utils import success
from app.db import get_db from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import ( from app.schemas.memory_api_schema import (
ListConfigsResponse,
MemoryReadRequest, MemoryReadRequest,
MemoryReadResponse, MemoryReadResponse,
MemoryWriteRequest, MemoryWriteRequest,
@@ -32,15 +31,14 @@ async def write_memory_api_service(
request: Request, request: Request,
api_key_auth: ApiKeyAuth = None, api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
message: str = Body(..., description="Message content"), payload: MemoryWriteRequest = Body(..., embed=False),
): ):
""" """
Write memory to storage. Write memory to storage.
Stores memory content for the specified end user using the Memory API Service. Stores memory content for the specified end user using the Memory API Service.
""" """
body = await request.json()
payload = MemoryWriteRequest(**body)
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db) memory_api_service = MemoryAPIService(db)
@@ -64,15 +62,13 @@ async def read_memory_api_service(
request: Request, request: Request,
api_key_auth: ApiKeyAuth = None, api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
message: str = Body(..., description="Query message"), payload: MemoryReadRequest = Body(..., embed=False),
): ):
""" """
Read memory from storage. Read memory from storage.
Queries and retrieves memories for the specified end user with context-aware responses. Queries and retrieves memories for the specified end user with context-aware responses.
""" """
body = await request.json()
payload = MemoryReadRequest(**body)
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}") logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
memory_api_service = MemoryAPIService(db) memory_api_service = MemoryAPIService(db)
@@ -89,27 +85,3 @@ async def read_memory_api_service(
logger.info(f"Memory read successful for end_user: {payload.end_user_id}") logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully") return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
@router.get("/configs")
@require_api_key(scopes=["memory"])
async def list_memory_configs(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
List all memory configs for the workspace.
Returns all available memory configurations associated with the authorized workspace.
"""
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.list_memory_configs(
workspace_id=api_key_auth.workspace_id,
)
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")

View File

@@ -111,18 +111,6 @@ def get_current_user_info(
break break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}") api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
# 设置权限:如果用户来自 SSO Source则使用该 Source 的 permissions否则返回 "all" 表示拥有所有权限
if current_user.external_source:
from premium.sso.models import SSOSource
source = db.query(SSOSource).filter(SSOSource.source_code == current_user.external_source).first()
if source and source.permissions:
result_schema.permissions = source.permissions
else:
result_schema.permissions = []
else:
result_schema.permissions = ["all"]
return success(data=result_schema, msg=t("users.info.get_success")) return success(data=result_schema, msg=t("users.info.get_success"))
@@ -147,6 +135,7 @@ def get_tenant_superusers(
return success(data=superusers_schema, msg=t("users.list.superusers_success")) return success(data=superusers_schema, msg=t("users.list.superusers_success"))
@router.get("/{user_id}", response_model=ApiResponse) @router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id( def get_user_info_by_id(
user_id: uuid.UUID, user_id: uuid.UUID,

View File

@@ -5,7 +5,7 @@
from typing import Optional from typing import Optional
import datetime import datetime
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends, Header from fastapi import APIRouter, Depends,Header
from app.db import get_db from app.db import get_db
from app.core.language_utils import get_language_from_header from app.core.language_utils import get_language_from_header
@@ -19,15 +19,13 @@ from app.services.user_memory_service import (
analytics_graph_data, analytics_graph_data,
analytics_community_graph_data, analytics_community_graph_data,
) )
from app.services.memory_entity_relationship_service import MemoryEntityService, MemoryEmotion, MemoryInteraction from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository from app.repositories.workspace_repository import WorkspaceRepository
from app.repositories.end_user_repository import EndUserRepository from app.schemas.end_user_schema import (
from app.schemas.end_user_info_schema import ( EndUserProfileResponse,
EndUserInfoResponse, EndUserProfileUpdate,
EndUserInfoCreate,
EndUserInfoUpdate,
) )
from app.models.end_user_model import EndUser from app.models.end_user_model import EndUser
from app.dependencies import get_current_user from app.dependencies import get_current_user
@@ -47,9 +45,9 @@ router = APIRouter(
@router.get("/analytics/memory_insight/report", response_model=ApiResponse) @router.get("/analytics/memory_insight/report", response_model=ApiResponse)
async def get_memory_insight_report_api( async def get_memory_insight_report_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的记忆洞察报告 获取缓存的记忆洞察报告
@@ -75,10 +73,10 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse) @router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api( async def get_user_summary_api(
end_user_id: str, end_user_id: str,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
获取缓存的用户摘要 获取缓存的用户摘要
@@ -92,7 +90,7 @@ async def get_user_summary_api(
""" """
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -104,7 +102,7 @@ async def get_user_summary_api(
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}") api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try: try:
# 调用服务层获取缓存数据 # 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id, model_id, language) result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
if result["is_cached"]: if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}") api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -119,10 +117,10 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse) @router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api( async def generate_cache_api(
request: GenerateCacheRequest, request: GenerateCacheRequest,
language_type: str = Header(default=None, alias="X-Language-Type"), language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
手动触发缓存生成 手动触发缓存生成
@@ -136,7 +134,7 @@ async def generate_cache_api(
""" """
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间 # 检查用户是否已选择工作空间
@@ -157,12 +155,10 @@ async def generate_cache_api(
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}") api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
# 生成记忆洞察 # 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
language=language)
# 生成用户摘要 # 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
language=language)
# 构建响应 # 构建响应
result = { result = {
@@ -213,9 +209,9 @@ async def generate_cache_api(
@router.get("/analytics/node_statistics", response_model=ApiResponse) @router.get("/analytics/node_statistics", response_model=ApiResponse)
async def get_node_statistics_api( async def get_node_statistics_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -224,8 +220,7 @@ async def get_node_statistics_api(
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info( api_logger.info(f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
f"记忆类型统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
try: try:
# 调用新的记忆类型统计函数 # 调用新的记忆类型统计函数
@@ -233,23 +228,21 @@ async def get_node_statistics_api(
# 计算总数用于日志 # 计算总数用于日志
total_count = sum(item["count"] for item in result) total_count = sum(item["count"] for item in result)
api_logger.info( api_logger.info(f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
f"成功获取记忆类型统计: end_user_id={end_user_id}, 总记忆数={total_count}, 类型数={len(result)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"记忆类型查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "记忆类型查询失败", str(e))
@router.get("/analytics/graph_data", response_model=ApiResponse) @router.get("/analytics/graph_data", response_model=ApiResponse)
async def get_graph_data_api( async def get_graph_data_api(
end_user_id: str, end_user_id: str,
node_types: Optional[str] = None, node_types: Optional[str] = None,
limit: int = 100, limit: int = 100,
depth: int = 1, depth: int = 1,
center_node_id: Optional[str] = None, center_node_id: Optional[str] = None,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -305,9 +298,9 @@ async def get_graph_data_api(
@router.get("/analytics/community_graph", response_model=ApiResponse) @router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api( async def get_community_graph_data_api(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
@@ -338,130 +331,111 @@ async def get_community_graph_data_api(
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}") api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
#=======================终端用户信息接口=======================
@router.get("/end_user_info", response_model=ApiResponse) @router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_info( async def get_end_user_profile(
end_user_id: str, end_user_id: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
"""
查询终端用户信息记录
根据 end_user_id 查询单条终端用户信息记录。
"""
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
# 检查用户是否已选择工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询终端用户信息但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info( api_logger.info(
f"查询终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, " f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}" f"workspace={workspace_id}"
) )
# 校验 end_user 是否属于当前工作空间 try:
end_user_repo = EndUserRepository(db) # 查询终端用户
end_user = end_user_repo.get_end_user_by_id(end_user_id) end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if end_user is None:
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found") if not end_user:
if str(end_user.workspace_id) != str(workspace_id): api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
api_logger.warning( return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
f"用户 {current_user.username} 尝试查询不属于工作空间 {workspace_id} 的终端用户 {end_user_id}" # 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
) )
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
result = user_memory_service.get_end_user_info(db, end_user_id) api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="查询成功")
if result["success"]: except Exception as e:
api_logger.info(f"成功查询终端用户信息: end_user_id={end_user_id}") api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
return success(data=result["data"], msg="查询成功") return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
else:
error_msg = result["error"]
api_logger.error(f"查询终端用户信息失败: end_user_id={end_user_id}, error={error_msg}")
if error_msg == "终端用户信息记录不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg)
elif error_msg == "无效的终端用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg)
else:
return fail(BizCode.INTERNAL_ERROR, "查询终端用户信息失败", error_msg)
@router.post("/end_user_info/updated", response_model=ApiResponse) @router.post("/updated_end_user/profile", response_model=ApiResponse)
async def update_end_user_info( async def update_end_user_profile(
info_update: EndUserInfoUpdate, profile_update: EndUserProfileUpdate,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
) -> dict: ) -> dict:
""" """
更新终端用户信息记录 更新终端用户的基本信息
根据 end_user_id 更新终端用户信息记录,支持批量更新多个别名 该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息
所有字段都是可选的,只更新提供的字段。
示例请求体:
{
"end_user_id": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
"other_name": "张三1",
"aliases": ["小张", "张工"],
"meta_data": {"position": "工程师", "department": "技术部"}
}
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
end_user_id = info_update.end_user_id end_user_id = profile_update.end_user_id
# 验证工作空间
if workspace_id is None: if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新终端用户信息但未选择工作空间") api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None") return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info( api_logger.info(
f"更新终端用户信息请求: end_user_id={end_user_id}, user={current_user.username}, " f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}" f"workspace={workspace_id}"
) )
# 校验 end_user 是否属于当前工作空间 # 调用 Service 层处理业务逻辑
end_user_repo = EndUserRepository(db) result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
end_user = end_user_repo.get_end_user_by_id(end_user_id)
if end_user is None:
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", "end_user not found")
if str(end_user.workspace_id) != str(workspace_id):
api_logger.warning(
f"用户 {current_user.username} 尝试更新不属于工作空间 {workspace_id} 的终端用户 {end_user_id}"
)
return fail(BizCode.PERMISSION_DENIED, "该终端用户不属于当前工作空间", "end_user workspace mismatch")
# 获取更新数据(排除 end_user_id
update_data = info_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
result = user_memory_service.update_end_user_info(db, end_user_id, update_data)
if result["success"]: if result["success"]:
api_logger.info(f"成功更新终端用户信息: end_user_id={end_user_id}") api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return success(data=result["data"], msg="更新成功") return success(data=result["data"], msg="更新成功")
else: else:
error_msg = result["error"] error_msg = result["error"]
api_logger.error(f"终端用户信息更新失败: end_user_id={end_user_id}, error={error_msg}") api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
if error_msg == "终端用户信息记录不存在": # 根据错误类型映射到合适的业务错误码
return fail(BizCode.USER_NOT_FOUND, "终端用户信息记录不存在", error_msg) if error_msg == "终端用户不存在":
elif error_msg == "无效的终端用户ID格式": return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
return fail(BizCode.INVALID_USER_ID, "无效的终端用户ID格式", error_msg) elif error_msg == "无效的用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
else: else:
return fail(BizCode.INTERNAL_ERROR, "终端用户信息更新失败", error_msg) # 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
@router.get("/memory_space/timeline_memories", response_model=ApiResponse) @router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories( async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
id: str, label: str, current_user: User = Depends(get_current_user),
language_type: str = Header(default=None, alias="X-Language-Type"), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), ):
db: Session = Depends(get_db),
):
# 使用集中化的语言校验 # 使用集中化的语言校验
language = get_language_from_header(language_type) language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db) workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id) workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
@@ -473,13 +447,11 @@ async def memory_space_timeline_of_shared_memories(
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language) timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
return success(data=timeline_memories_result, msg="共同记忆时间线") return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse) @router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str, async def memory_space_relationship_evolution(id: str, label: str,
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
try: try:
api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}") api_logger.info(f"关系演变查询请求: id={id}, table={label}, user={current_user.username}")

View File

@@ -329,6 +329,7 @@ class LangChainAgent:
db.close() db.close()
except Exception as e: except Exception as e:
logger.warning(f"Failed to get db session: {e}") logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
try: try:
@@ -597,10 +598,8 @@ class LangChainAgent:
for msg in reversed(output_messages): for msg in reversed(output_messages):
if isinstance(msg, AIMessage): if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get( total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
"total_tokens", 0) if response_meta else 0
0
) if response_meta else 0
yield total_tokens yield total_tokens
break break
if memory_flag: if memory_flag:

View File

@@ -231,8 +231,8 @@ class Settings:
# Celery configuration (internal) # Celery configuration (internal)
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md # 详见 docs/celery-env-bug-report.md
# 默认使用 Redis 作为 broker 和 backend与业务缓存隔离 # 默认使用 Redis DB 3 (broker)DB 4 (backend),与业务缓存 (DB 1/2) 隔离
# 如需使用 RabbitMQ在 .env 中设置 CELERY_BROKER_URL=amqp://user:pass@host:5672/vhost # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))

View File

@@ -529,9 +529,8 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
# Fallback to console only if file write fails # Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}") print(f"Warning: Could not write to timing log: {e}")
# Always log at INFO level (avoids Celery treating stdout as WARNING) # Always print to console (backward compatible behavior)
_timing_logger = logging.getLogger(__name__) print(f"{step_name}: {duration:.2f}s")
_timing_logger.info(f"{step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service", def get_agent_logger(name: str = "agent_service",

View File

@@ -155,7 +155,7 @@ async def clean_databases(data) -> str:
# Process reranked results # Process reranked results
reranked = results.get('reranked_results', {}) reranked = results.get('reranked_results', {})
if reranked: if reranked:
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']: for category in ['summaries', 'statements', 'chunks', 'entities']:
items = reranked.get(category, []) items = reranked.get(category, [])
if isinstance(items, list): if isinstance(items, list):
content_list.extend(items) content_list.extend(items)
@@ -169,18 +169,11 @@ async def clean_databases(data) -> str:
elif isinstance(time_search, list): elif isinstance(time_search, list):
content_list.extend(time_search) content_list.extend(time_search)
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复) # Extract text content
text_parts = [] text_parts = []
seen_community_names = set()
for item in content_list: for item in content_list:
if isinstance(item, dict): if isinstance(item, dict):
# community 节点用 name 去重 text = item.get('statement') or item.get('content', '')
if 'member_count' in item or 'core_entities' in item:
community_name = item.get('name') or item.get('id', '')
if community_name in seen_community_names:
continue
seen_community_names.add(community_name)
text = item.get('statement') or item.get('content') or item.get('summary', '')
if text: if text:
text_parts.append(text) text_parts.append(text)
elif isinstance(item, str): elif isinstance(item, str):
@@ -361,11 +354,7 @@ async def retrieve(state: ReadState) -> ReadState:
) )
time_retrieval_tool = create_time_retrieval_tool(end_user_id) time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = { search_params = {"end_user_id": end_user_id, "return_raw_results": True}
"end_user_id": end_user_id,
"return_raw_results": True,
"include": ["summaries", "statements", "chunks", "entities", "communities"],
}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent( agent = create_agent(
llm, llm,
@@ -401,32 +390,8 @@ async def retrieve(state: ReadState) -> ReadState:
raw_results = tool_results['content'] raw_results = tool_results['content']
clean_content = await clean_databases(raw_results) clean_content = await clean_databases(raw_results)
# 社区展开:从 tool 返回结果中提取命中的 community
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
_expanded_stmts_to_write = []
try:
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
reranked = results_dict.get('reranked_results', {})
community_hits = reranked.get('communities', [])
if not community_hits:
community_hits = results_dict.get('communities', [])
if community_hits:
from app.core.memory.agent.services.search_service import expand_communities_to_statements
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
community_results=community_hits,
end_user_id=end_user_id,
existing_content=clean_content,
)
if new_texts:
clean_content = clean_content + '\n' + '\n'.join(new_texts)
except Exception as parse_err:
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
try: try:
raw_results = raw_results['results'] raw_results = raw_results['results']
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
if _expanded_stmts_to_write and isinstance(raw_results, dict):
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
except Exception: except Exception:
raw_results = [] raw_results = []

View File

@@ -334,22 +334,13 @@ async def Input_Summary(state: ReadState) -> ReadState:
"end_user_id": end_user_id, "end_user_id": end_user_id,
"question": data, "question": data,
"return_raw_results": True, "return_raw_results": True,
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点 "include": ["summaries"] # Only search summary nodes for faster performance
} }
try: try:
if storage_type != "rag": if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search( retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
**search_params, memory_config=memory_config)
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
)
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
community_hits = reranked.get('communities', [])
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
else: else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e: except Exception as e:

View File

@@ -178,7 +178,7 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope)
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope): elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J') logger.info('写入长期记忆NEO4J')
formatted_messages = redis_messages formatted_messages = (redis_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'): if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id config_id = memory_config.config_id

View File

@@ -252,10 +252,9 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development # TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = { fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids', 'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id', 'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary" 'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
} }
# 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements
if isinstance(data, dict): if isinstance(data, dict):
# Clean dictionary # Clean dictionary
@@ -311,7 +310,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"search_type": search_type, "search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"), "end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10), "limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]), "include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # Don't save to file "output_path": None, # Don't save to file
"memory_config": memory_config, "memory_config": memory_config,
"rerank_alpha": rerank_alpha, "rerank_alpha": rerank_alpha,

View File

@@ -13,72 +13,6 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
# 需要从展开结果中过滤的字段(含 Neo4j DateTime不可 JSON 序列化)
_EXPAND_FIELDS_TO_REMOVE = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
}
def _clean_expand_fields(obj):
"""递归过滤展开结果中不可序列化的字段DateTime 等)。"""
if isinstance(obj, dict):
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
if isinstance(obj, list):
return [_clean_expand_fields(i) for i in obj]
return obj
async def expand_communities_to_statements(
community_results: List[dict],
end_user_id: str,
existing_content: str = "",
limit: int = 10,
) -> Tuple[List[dict], List[str]]:
"""
社区展开 helper给定命中的 community 列表,拉取关联 Statement。
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
- 过滤不可序列化字段
- 返回 (cleaned_expanded_stmts, new_texts)
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
"""
community_ids = [r.get("id") for r in community_results if r.get("id")]
if not community_ids or not end_user_id:
return [], []
from app.repositories.neo4j.graph_search import search_graph_community_expand
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
connector = Neo4jConnector()
try:
result = await search_graph_community_expand(
connector=connector,
community_ids=community_ids,
end_user_id=end_user_id,
limit=limit,
)
except Exception as e:
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
return [], []
finally:
await connector.close()
expanded_stmts = result.get("expanded_statements", [])
if not expanded_stmts:
return [], []
existing_lines = set(existing_content.splitlines())
new_texts = [
s["statement"] for s in expanded_stmts
if s.get("statement") and s["statement"] not in existing_lines
]
cleaned = _clean_expand_fields(expanded_stmts)
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements新增 {len(new_texts)}community_ids={community_ids}")
return cleaned, new_texts
class SearchService: class SearchService:
"""Service for executing hybrid search and processing results.""" """Service for executing hybrid search and processing results."""
@@ -87,7 +21,7 @@ class SearchService:
"""Initialize the search service.""" """Initialize the search service."""
logger.info("SearchService initialized") logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict, node_type: str = "") -> str: def extract_content_from_result(self, result: dict) -> str:
""" """
Extract only meaningful content from search results, dropping all metadata. Extract only meaningful content from search results, dropping all metadata.
@@ -96,11 +30,9 @@ class SearchService:
- Entities: extract 'name' and 'fact_summary' fields - Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field - Summaries: extract 'content' field
- Chunks: extract 'content' field - Chunks: extract 'content' field
- Communities: extract 'content' field (c.summary), prefixed with community name
Args: Args:
result: Search result dictionary result: Search result dictionary
node_type: Hint for node type ("community", "summary", etc.)
Returns: Returns:
Clean content string without metadata Clean content string without metadata
@@ -114,21 +46,8 @@ class SearchService:
if 'statement' in result and result['statement']: if 'statement' in result and result['statement']:
content_parts.append(result['statement']) content_parts.append(result['statement'])
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定 # Summaries/Chunks: extract content field
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要 if 'content' in result and result['content']:
is_community = (
node_type == "community"
or 'member_count' in result
or 'core_entities' in result
)
if is_community:
name = result.get('name', '')
content = result.get('content', '')
if content:
prefix = f"[主题:{name}] " if name else ""
content_parts.append(f"{prefix}{content}")
elif 'content' in result and result['content']:
# Summaries / Chunks
content_parts.append(result['content']) content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original) # Entities: extract name and fact_summary (commented out in original)
@@ -180,8 +99,7 @@ class SearchService:
rerank_alpha: float = 0.4, rerank_alpha: float = 0.4,
output_path: str = "search_results.json", output_path: str = "search_results.json",
return_raw_results: bool = False, return_raw_results: bool = False,
memory_config = None, memory_config = None
expand_communities: bool = True,
) -> Tuple[str, str, Optional[dict]]: ) -> Tuple[str, str, Optional[dict]]:
""" """
Execute hybrid search and return clean content. Execute hybrid search and return clean content.
@@ -196,15 +114,13 @@ class SearchService:
output_path: Path to save search results (default: "search_results.json") output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False) return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: Memory configuration object (required) memory_config: Memory configuration object (required)
expand_communities: If True, expand community hits to member statements (default: True).
Set to False for quick-summary paths that only need community-level text.
Returns: Returns:
Tuple of (clean_content, cleaned_query, raw_results) Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False raw_results is None if return_raw_results=False
""" """
if include is None: if include is None:
include = ["statements", "chunks", "entities", "summaries", "communities"] include = ["statements", "chunks", "entities", "summaries"]
# Clean query # Clean query
cleaned_query = self.clean_query(question) cleaned_query = self.clean_query(question)
@@ -230,8 +146,8 @@ class SearchService:
if search_type == "hybrid": if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {}) reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities # Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in reranked_results: if category in include and category in reranked_results:
@@ -241,33 +157,19 @@ class SearchService:
else: else:
# For keyword or embedding search, results are directly in answer dict # For keyword or embedding search, results are directly in answer dict
# Apply same priority order # Apply same priority order
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities'] priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order: for category in priority_order:
if category in include and category in answer: if category in include and category in answer:
category_results = answer[category] category_results = answer[category]
if isinstance(category_results, list): if isinstance(category_results, list):
answer_list.extend(category_results) answer_list.extend(category_results)
# 对命中的 community 节点展开其成员 statements路径 "0"/"1" 需要,路径 "2" 不需要)
if expand_communities and "communities" in include:
community_results = (
answer.get('reranked_results', {}).get('communities', [])
if search_type == "hybrid"
else answer.get('communities', [])
)
cleaned_stmts, new_texts = await expand_communities_to_statements(
community_results=community_results,
end_user_id=end_user_id,
)
answer_list.extend(cleaned_stmts)
# Extract clean content from all results,按类型传入 node_type 区分 community # Extract clean content from all results
content_list = [] content_list = [
for ans in answer_list: self.extract_content_from_result(ans)
# community 节点有 member_count 或 core_entities 字段 for ans in answer_list
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else "" ]
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
# Filter out empty strings and join with newlines # Filter out empty strings and join with newlines

View File

@@ -11,7 +11,7 @@ async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker", chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1", end_user_id: str = "group_1",
messages: list = None, messages: list = None,
ref_id: str = "", ref_id: str = "wyl_20251027",
config_id: str = None config_id: str = None
) -> List[DialogData]: ) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy. """Generate chunks from structured messages using the specified chunker strategy.
@@ -40,13 +40,12 @@ async def get_chunked_dialogs(
role = msg['role'] role = msg['role']
content = msg['content'] content = msg['content']
files = msg.get("file_content", [])
if role not in ['user', 'assistant']: if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}") raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip(): if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip(), files=files)) conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages: if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering") raise ValueError("Message list cannot be empty after filtering")
@@ -85,7 +84,7 @@ async def get_chunked_dialogs(
pruning_scene=memory_config.pruning_scene or "education", pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold, pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None, scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_class_infos=memory_config.ontology_class_infos, ontology_classes=memory_config.ontology_classes,
) )
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")

View File

@@ -39,30 +39,6 @@
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}] 比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么 拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指代消歧规则Coreference Resolution
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
1. **"用户"的消歧**
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物,则"用户"指的就是这个人
- 示例:历史中有"老李的原名叫李建国",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
2. **"我"的消歧**
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么"
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
3. **"他/她/它"的消歧**
- 从上下文或历史中找出最近提到的同类实体
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
4. **"那个人/这个人"的消歧**
- 从历史中找出最近提到的人物
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
5. **优先级**
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
- 如果无法从历史中确定指代对象保留原问题但在reason中说明"无法确定指代对象"
输出要求: 输出要求:
@@ -95,34 +71,6 @@
"reason": "输出原问题的关键要素" "reason": "输出原问题的关键要素"
} }
] ]
## 指代消歧示例(重要):
示例1 - "用户"的消歧:
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
输入问题:"用户是谁?"
输出:
[
{
"original_question": "用户是谁?",
"extended_question": "李建国是谁?",
"type": "单跳",
"reason": "历史中反复提到'老李/李建国/建国哥''用户'指的就是对话发起者李建国"
}
]
示例2 - "我"的消歧:
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
输入问题:"我推荐的书是什么?"
输出:
[
{
"original_question": "我推荐的书是什么?",
"extended_question": "张曼玉推荐的书是什么?",
"type": "单跳",
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
}
]
**Output format** **Output format**
**CRITICAL JSON FORMATTING REQUIREMENTS:** **CRITICAL JSON FORMATTING REQUIREMENTS:**
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes 1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes

View File

@@ -27,30 +27,6 @@
比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}] 比如:输入历史信息内容:[{'Query': '4月27日我和你推荐过一本书书名是什么', 'ANswer': '张曼玉推荐了《小王子》'}]
拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么 拆分问题4月27日我和你推荐过一本书书名是什么可以拆分为4月27日张曼玉推荐过一本书书名是什么
## 指代消歧规则Coreference Resolution
在拆分问题时,必须解析并替换所有指代词和抽象称呼,使问题具体化:
1. **"用户"的消歧**
- "用户是谁?" → 分析历史记录,找出对话发起者的姓名
- 如果历史中有"我叫X"、"我的名字是X"、或多次提到某个人物(如"老李"、"李建国"),则"用户"指的就是这个人
- 示例:历史中反复出现"老李/李建国/建国哥",则"用户是谁?"应拆分为"李建国是谁?"或"老李(李建国)是谁?"
2. **"我"的消歧**
- "我喜欢什么?" → 从历史中找出对话发起者的姓名,替换为"X喜欢什么"
- 示例:历史中有"张曼玉推荐了《小王子》",则"我推荐的书是什么?"应拆分为"张曼玉推荐的书是什么?"
3. **"他/她/它"的消歧**
- 从上下文或历史中找出最近提到的同类实体
- 示例:历史中有"老李的同事叫他建国哥",则"他的同事怎么称呼他?"应拆分为"老李的同事怎么称呼他?"
4. **"那个人/这个人"的消歧**
- 从历史中找出最近提到的人物
- 示例:历史中有"李建国",则"那个人的原名是什么?"应拆分为"李建国的原名是什么?"
5. **优先级**
- 如果历史记录中反复出现某个人物(如"老李"、"李建国"、"建国哥"),则"用户"很可能指的就是这个人
- 如果无法从历史中确定指代对象保留原问题但在reason中说明"无法确定指代对象"
## 指令: ## 指令:
你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型: 你是一个智能数据拆分助手,请根据数据特性判断输入属于哪种类型:
单跳Single-hop 单跳Single-hop
@@ -175,34 +151,6 @@
] ]
- 必须通过json.loads()的格式支持的形式输出 - 必须通过json.loads()的格式支持的形式输出
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。 - 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
## 指代消歧示例(重要):
示例1 - "用户"的消歧:
输入历史:[{'Query': '老李的原名叫什么?', 'Answer': '李建国'}, {'Query': '老李的同事叫他什么?', 'Answer': '建国哥'}]
输入问题:"用户是谁?"
输出:
[
{
"id": "Q1",
"question": "李建国是谁?",
"type": "单跳",
"reason": "历史中反复提到'老李/李建国/建国哥''用户'指的就是对话发起者李建国"
}
]
示例2 - "我"的消歧:
输入历史:[{'Query': '张曼玉推荐了什么书?', 'Answer': '《小王子》'}]
输入问题:"我推荐的书是什么?"
输出:
[
{
"id": "Q1",
"question": "张曼玉推荐的书是什么?",
"type": "单跳",
"reason": "历史中提到张曼玉推荐了书,'我'指的就是张曼玉"
}
]
- 关键的JSON格式要求 - 关键的JSON格式要求
1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号 1.JSON结构仅使用标准ASCII双引号-切勿使用中文引号“”或其他Unicode引号
2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们 2.如果提取的语句文本包含引号,请使用反斜杠(\“)正确转义它们

View File

@@ -6,17 +6,14 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
""" """
import asyncio import asyncio
import time import time
import uuid
from datetime import datetime from datetime import datetime
from typing import List, Optional
from dotenv import load_dotenv from dotenv import load_dotenv
from app.core.logging_config import get_agent_logger from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import \ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
memory_summary_generation
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.utils.log.logging_utils import log_time from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context from app.db import get_db_context
@@ -26,17 +23,18 @@ from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo
from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig from app.schemas.memory_config_schema import MemoryConfig
load_dotenv() load_dotenv()
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
async def write( async def write(
end_user_id: str, end_user_id: str,
memory_config: MemoryConfig, memory_config: MemoryConfig,
messages: list, messages: list,
ref_id: str = "", ref_id: str = "wyl20251027",
language: str = "zh", language: str = "zh",
) -> None: ) -> None:
""" """
Execute the complete knowledge extraction pipeline. Execute the complete knowledge extraction pipeline.
@@ -45,11 +43,9 @@ async def write(
end_user_id: Group identifier end_user_id: Group identifier
memory_config: MemoryConfig object containing all configuration memory_config: MemoryConfig object containing all configuration
messages: Structured message list [{"role": "user", "content": "..."}, ...] messages: Structured message list [{"role": "user", "content": "..."}, ...]
ref_id: Reference ID, defaults to "" ref_id: Reference ID, defaults to "wyl20251027"
language: 语言类型 ("zh" 中文, "en" 英文),默认中文 language: 语言类型 ("zh" 中文, "en" 英文),默认中文
""" """
if not ref_id:
ref_id = uuid.uuid4().hex
# Extract config values # Extract config values
embedding_model_id = str(memory_config.embedding_model_id) embedding_model_id = str(memory_config.embedding_model_id)
chunker_strategy = memory_config.chunker_strategy chunker_strategy = memory_config.chunker_strategy
@@ -103,14 +99,14 @@ async def write(
if memory_config.scene_id: if memory_config.scene_id:
try: try:
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
with get_db_context() as db: with get_db_context() as db:
ontology_types = load_ontology_types_for_scene( ontology_types = load_ontology_types_for_scene(
scene_id=memory_config.scene_id, scene_id=memory_config.scene_id,
workspace_id=memory_config.workspace_id, workspace_id=memory_config.workspace_id,
db=db db=db
) )
if ontology_types: if ontology_types:
logger.info( logger.info(
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}" f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
@@ -139,11 +135,9 @@ async def write(
all_chunk_nodes, all_chunk_nodes,
all_statement_nodes, all_statement_nodes,
all_entity_nodes, all_entity_nodes,
all_perceptual_nodes,
all_statement_chunk_edges, all_statement_chunk_edges,
all_statement_entity_edges, all_statement_entity_edges,
all_entity_entity_edges, all_entity_entity_edges,
all_perceptual_edges,
all_dedup_details, all_dedup_details,
) = await orchestrator.run(chunked_dialogs, is_pilot_run=False) ) = await orchestrator.run(chunked_dialogs, is_pilot_run=False)
@@ -151,6 +145,11 @@ async def write(
# Step 3: Save all data to Neo4j database # Step 3: Save all data to Neo4j database
step_start = time.time() step_start = time.time()
from app.repositories.neo4j.create_indexes import create_fulltext_indexes
try:
await create_fulltext_indexes()
except Exception as e:
logger.error(f"Error creating indexes: {e}", exc_info=True)
# 添加死锁重试机制 # 添加死锁重试机制
max_retries = 3 max_retries = 3
@@ -163,43 +162,15 @@ async def write(
chunk_nodes=all_chunk_nodes, chunk_nodes=all_chunk_nodes,
statement_nodes=all_statement_nodes, statement_nodes=all_statement_nodes,
entity_nodes=all_entity_nodes, entity_nodes=all_entity_nodes,
perceptual_nodes=all_perceptual_nodes,
statement_chunk_edges=all_statement_chunk_edges, statement_chunk_edges=all_statement_chunk_edges,
statement_entity_edges=all_statement_entity_edges, statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges, entity_edges=all_entity_entity_edges,
perceptual_edges=all_perceptual_edges,
connector=neo4j_connector, connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
) )
if success: if success:
logger.info("Successfully saved all data to Neo4j") logger.info("Successfully saved all data to Neo4j")
# 使用 Celery 异步任务触发聚类(不阻塞主流程)
if all_entity_nodes:
try:
from app.tasks import run_incremental_clustering
end_user_id = all_entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in all_entity_nodes]
# 异步提交 Celery 任务
task = run_incremental_clustering.apply_async(
kwargs={
"end_user_id": end_user_id,
"new_entity_ids": new_entity_ids,
"llm_model_id": str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
"embedding_model_id": str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
},
# 设置任务优先级(低优先级,不影响主业务)
priority=3,
)
logger.info(
f"[Clustering] 增量聚类任务已提交到 Celery - "
f"task_id={task.id}, end_user_id={end_user_id}, entity_count={len(new_entity_ids)}"
)
except Exception as e:
# 聚类任务提交失败不影响主流程
logger.error(f"[Clustering] 提交聚类任务失败(不影响主流程): {e}", exc_info=True)
break break
else: else:
logger.warning("Failed to save some data to Neo4j") logger.warning("Failed to save some data to Neo4j")
@@ -233,8 +204,9 @@ async def write(
summaries = await memory_summary_generation( summaries = await memory_summary_generation(
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
) )
ms_connector = Neo4jConnector()
try: try:
ms_connector = Neo4jConnector()
await add_memory_summary_nodes(summaries, ms_connector) await add_memory_summary_nodes(summaries, ms_connector)
await add_memory_summary_statement_edges(summaries, ms_connector) await add_memory_summary_statement_edges(summaries, ms_connector)
finally: finally:
@@ -274,21 +246,5 @@ async def write(
except Exception as cache_err: except Exception as cache_err:
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
# Close LLM/Embedder underlying httpx clients to prevent
# 'RuntimeError: Event loop is closed' during garbage collection
for client_obj in (llm_client, embedder_client):
try:
underlying = getattr(client_obj, 'client', None) or getattr(client_obj, 'model', None)
if underlying is None:
continue
# Unwrap RedBearLLM / RedBearEmbeddings to get the LangChain model
inner = getattr(underlying, '_model', underlying)
# LangChain OpenAI models expose async_client (httpx.AsyncClient)
http_client = getattr(inner, 'async_client', None)
if http_client is not None and hasattr(http_client, 'aclose'):
await http_client.aclose()
except Exception:
pass
logger.info("=== Pipeline Complete ===") logger.info("=== Pipeline Complete ===")
logger.info(f"Total execution time: {total_time:.2f} seconds") logger.info(f"Total execution time: {total_time:.2f} seconds")

View File

@@ -1,10 +1,10 @@
from typing import Any, List
import re
import os
import asyncio import asyncio
import json import json
import logging
import os
from typing import Any, List
import numpy as np import numpy as np
import logging
# Fix tokenizer parallelism warning # Fix tokenizer parallelism warning
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -246,7 +246,6 @@ class ChunkerClient:
"total_sub_chunks": len(sub_chunks), "total_sub_chunks": len(sub_chunks),
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)
else: else:
@@ -259,7 +258,6 @@ class ChunkerClient:
"message_role": msg.role, "message_role": msg.role,
"chunker_strategy": self.chunker_config.chunker_strategy, "chunker_strategy": self.chunker_config.chunker_strategy,
}, },
files=msg.files
) )
dialogue.chunks.append(chunk) dialogue.chunks.append(chunk)

View File

@@ -65,7 +65,7 @@ class OpenAIClient(LLMClient):
type=type_ type=type_
) )
logger.debug(f"OpenAI 客户端初始化完成: type={type_}") logger.info(f"OpenAI 客户端初始化完成: type={type_}")
async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any:
""" """

View File

@@ -2,7 +2,6 @@
OpenAI Embedder 客户端实现 OpenAI Embedder 客户端实现
基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。 基于 LangChain 和 RedBearEmbeddings 的 OpenAI 嵌入模型客户端实现。
自动支持火山引擎的多模态 Embedding。
""" """
from typing import List from typing import List
@@ -14,7 +13,6 @@ from app.core.memory.llm_tools.embedder_client import (
) )
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
from app.core.models.embedding import RedBearEmbeddings from app.core.models.embedding import RedBearEmbeddings
from app.models.models_model import ModelProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -27,7 +25,6 @@ class OpenAIEmbedderClient(EmbedderClient):
- 批量文本嵌入 - 批量文本嵌入
- 自动重试机制 - 自动重试机制
- 错误处理 - 错误处理
- 火山引擎多模态 Embedding自动识别
""" """
def __init__(self, model_config: RedBearModelConfig): def __init__(self, model_config: RedBearModelConfig):
@@ -39,7 +36,7 @@ class OpenAIEmbedderClient(EmbedderClient):
""" """
super().__init__(model_config) super().__init__(model_config)
# 初始化 RedBearEmbeddings(自动支持火山引擎多模态) # 初始化 RedBearEmbeddings 模型
self.model = RedBearEmbeddings( self.model = RedBearEmbeddings(
RedBearModelConfig( RedBearModelConfig(
model_name=self.model_name, model_name=self.model_name,
@@ -50,9 +47,8 @@ class OpenAIEmbedderClient(EmbedderClient):
timeout=self.timeout, timeout=self.timeout,
) )
) )
self.is_multimodal = self.model.is_multimodal_supported()
logger.info(f"OpenAI Embedder 客户端初始化完成 (provider={self.provider}, multimodal={self.is_multimodal})") logger.info("OpenAI Embedder 客户端初始化完成")
async def response( async def response(
self, self,
@@ -81,14 +77,7 @@ class OpenAIEmbedderClient(EmbedderClient):
return [] return []
# 生成嵌入向量 # 生成嵌入向量
if self.is_multimodal: embeddings = await self.model.aembed_documents(texts)
# 火山引擎多模态 Embedding
embeddings = await self.model.aembed_multimodal(
[{"type": "text", "text": text} for text in texts]
)
else:
# 普通 Embedding
embeddings = await self.model.aembed_documents(texts)
logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量") logger.debug(f"成功生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings

View File

@@ -6,7 +6,6 @@ of the memory system including LLM, chunking, pruning, and search.
Classes: Classes:
LLMConfig: Configuration for LLM client LLMConfig: Configuration for LLM client
ChunkerConfig: Configuration for dialogue chunking ChunkerConfig: Configuration for dialogue chunking
OntologyClassInfo: Single ontology class with name and description
PruningConfig: Configuration for semantic pruning PruningConfig: Configuration for semantic pruning
TemporalSearchParams: Parameters for temporal search queries TemporalSearchParams: Parameters for temporal search queries
""" """
@@ -51,41 +50,30 @@ class ChunkerConfig(BaseModel):
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.") min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
class OntologyClassInfo(BaseModel):
"""本体类型的名称与语义描述,用于剪枝提示词注入。
Attributes:
class_name: 本体类型名称(如"患者""课程"
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
"""
class_name: str = Field(..., description="本体类型名称")
class_description: str = Field(default="", description="本体类型语义描述")
class PruningConfig(BaseModel): class PruningConfig(BaseModel):
"""Configuration for semantic pruning of dialogue content. """Configuration for semantic pruning of dialogue content.
Attributes: Attributes:
pruning_switch: Enable or disable semantic pruning pruning_switch: Enable or disable semantic pruning
pruning_scene: Scene name for pruning from ontology_scene table pruning_scene: Scene name for pruning, either a built-in key
('education', 'online_service', 'outbound') or a custom scene_name
from ontology_scene table
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
scene_id: Optional ontology scene UUID scene_id: Optional ontology scene UUID, used to load custom ontology classes
ontology_class_infos: Full ontology class info (name + description) from ontology_classes: List of class_name strings from ontology_class table,
ontology_class table, injected into the pruning prompt to drive injected into the prompt when pruning_scene is not a built-in scene
scene-aware preservation decisions
""" """
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
pruning_scene: str = Field( pruning_scene: str = Field(
"education", "education",
description="Scene name from ontology_scene table.", description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
) )
pruning_threshold: float = Field( pruning_threshold: float = Field(
0.5, ge=0.0, le=0.9, 0.5, ge=0.0, le=0.9,
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).") scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
ontology_class_infos: List[OntologyClassInfo] = Field( ontology_classes: Optional[List[str]] = Field(
default_factory=list, None, description="Class names from ontology_class table for custom scenes."
description="Full ontology class info (name + description) injected into pruning prompt."
) )

View File

@@ -44,21 +44,21 @@ def parse_historical_datetime(v):
""" """
if v is None: if v is None:
return v return v
# 处理 Neo4j DateTime 对象 # 处理 Neo4j DateTime 对象
if hasattr(v, 'to_native'): if hasattr(v, 'to_native'):
return v.to_native() return v.to_native()
# 处理 Python datetime 对象 # 处理 Python datetime 对象
if isinstance(v, datetime): if isinstance(v, datetime):
return v return v
if isinstance(v, str): if isinstance(v, str):
# 匹配 ISO 8601 格式YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM] # 匹配 ISO 8601 格式YYYY-MM-DD 或 YYYY-MM-DDTHH:MM:SS[.ffffff][Z|±HH:MM]
# 支持1-4位年份 # 支持1-4位年份
pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?' pattern = r'^(\d{1,4})-(\d{2})-(\d{2})(?:T(\d{2}):(\d{2}):(\d{2})(?:\.(\d+))?(?:Z|([+-]\d{2}:\d{2}))?)?'
match = re.match(pattern, v) match = re.match(pattern, v)
if match: if match:
try: try:
year = int(match.group(1)) year = int(match.group(1))
@@ -68,31 +68,31 @@ def parse_historical_datetime(v):
minute = int(match.group(5)) if match.group(5) else 0 minute = int(match.group(5)) if match.group(5) else 0
second = int(match.group(6)) if match.group(6) else 0 second = int(match.group(6)) if match.group(6) else 0
microsecond = 0 microsecond = 0
# 处理微秒 # 处理微秒
if match.group(7): if match.group(7):
# 补齐或截断到6位 # 补齐或截断到6位
us_str = match.group(7).ljust(6, '0')[:6] us_str = match.group(7).ljust(6, '0')[:6]
microsecond = int(us_str) microsecond = int(us_str)
# 处理时区 # 处理时区
tzinfo = None tzinfo = None
if 'Z' in v or match.group(8): if 'Z' in v or match.group(8):
tzinfo = timezone.utc tzinfo = timezone.utc
# 创建 datetime 对象 # 创建 datetime 对象
return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo) return datetime(year, month, day, hour, minute, second, microsecond, tzinfo=tzinfo)
except (ValueError, OverflowError): except (ValueError, OverflowError):
# 日期值无效如月份13、日期32等 # 日期值无效如月份13、日期32等
return None return None
# 如果不匹配模式,尝试使用 fromisoformat用于标准格式 # 如果不匹配模式,尝试使用 fromisoformat用于标准格式
try: try:
return datetime.fromisoformat(v.replace('Z', '+00:00')) return datetime.fromisoformat(v.replace('Z', '+00:00'))
except Exception: except Exception:
return None return None
return v return v
@@ -114,7 +114,7 @@ class Edge(BaseModel):
end_user_id: str = Field(..., description="The end user ID of the edge.") end_user_id: str = Field(..., description="The end user ID of the edge.")
run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.") run_id: str = Field(default_factory=lambda: uuid4().hex, description="Unique identifier for this pipeline run.")
created_at: datetime = Field(..., description="The valid time of the edge from system perspective.") created_at: datetime = Field(..., description="The valid time of the edge from system perspective.")
expired_at: Optional[datetime] = Field(default=None, description="The expired time of the edge from system perspective.") expired_at: Optional[datetime] = Field(None, description="The expired time of the edge from system perspective.")
class ChunkEdge(Edge): class ChunkEdge(Edge):
@@ -167,7 +167,7 @@ class EntityEntityEdge(Edge):
source_statement_id: str = Field(..., description="Statement where this relationship was extracted") source_statement_id: str = Field(..., description="Statement where this relationship was extracted")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start") valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
@@ -175,12 +175,6 @@ class EntityEntityEdge(Edge):
return parse_historical_datetime(v) return parse_historical_datetime(v)
class PerceptualEdge(Edge):
"""Edge connecting perceptual nodes to their source chunks
"""
pass
class Node(BaseModel): class Node(BaseModel):
"""Base class for all graph nodes in the knowledge graph. """Base class for all graph nodes in the knowledge graph.
@@ -212,8 +206,7 @@ class DialogueNode(Node):
ref_id: str = Field(..., description="Reference identifier of the dialog") ref_id: str = Field(..., description="Reference identifier of the dialog")
content: str = Field(..., description="Dialogue content") content: str = Field(..., description="Dialogue content")
dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector") dialog_embedding: Optional[List[float]] = Field(None, description="Dialog embedding vector")
config_id: Optional[int | str] = Field(None, config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this dialogue (integer or string)")
description="Configuration ID used to process this dialogue (integer or string)")
class StatementNode(Node): class StatementNode(Node):
@@ -248,17 +241,17 @@ class StatementNode(Node):
chunk_id: str = Field(..., description="ID of the parent chunk") chunk_id: str = Field(..., description="ID of the parent chunk")
stmt_type: str = Field(..., description="Type of the statement") stmt_type: str = Field(..., description="Type of the statement")
statement: str = Field(..., description="The statement text content") statement: str = Field(..., description="The statement text content")
# Speaker identification # Speaker identification
speaker: Optional[str] = Field( speaker: Optional[str] = Field(
None, None,
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses" description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
) )
# Emotion fields (ordered as requested, emotion_intensity first for display) # Emotion fields (ordered as requested, emotion_intensity first for display)
emotion_intensity: Optional[float] = Field( emotion_intensity: Optional[float] = Field(
None, None,
ge=0.0, ge=0.0,
le=1.0, le=1.0,
description="Emotion intensity: 0.0-1.0 (displayed on node)" description="Emotion intensity: 0.0-1.0 (displayed on node)"
) )
@@ -271,26 +264,25 @@ class StatementNode(Node):
description="Emotion subject: self/other/object" description="Emotion subject: self/other/object"
) )
emotion_type: Optional[str] = Field( emotion_type: Optional[str] = Field(
None, None,
description="Emotion type: joy/sadness/anger/fear/surprise/neutral" description="Emotion type: joy/sadness/anger/fear/surprise/neutral"
) )
emotion_keywords: Optional[List[str]] = Field( emotion_keywords: Optional[List[str]] = Field(
default_factory=list, default_factory=list,
description="Emotion keywords list, max 3 items" description="Emotion keywords list, max 3 items"
) )
# Temporal fields # Temporal fields
temporal_info: TemporalInfo = Field(..., description="Temporal information") temporal_info: TemporalInfo = Field(..., description="Temporal information")
valid_at: Optional[datetime] = Field(None, description="Temporal validity start") valid_at: Optional[datetime] = Field(None, description="Temporal validity start")
invalid_at: Optional[datetime] = Field(None, description="Temporal validity end") invalid_at: Optional[datetime] = Field(None, description="Temporal validity end")
# Embedding and other fields # Embedding and other fields
statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector") statement_embedding: Optional[List[float]] = Field(None, description="Statement embedding vector")
chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector") chunk_embedding: Optional[List[float]] = Field(None, description="Chunk embedding vector")
connect_strength: str = Field(..., description="Strong VS Weak classification of this statement") connect_strength: str = Field(..., description="Strong VS Weak classification of this statement")
config_id: Optional[int | str] = Field(None, config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this statement (integer or string)")
description="Configuration ID used to process this statement (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,
@@ -317,13 +309,13 @@ class StatementNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed" description="Total number of times this node has been accessed"
) )
@field_validator('valid_at', 'invalid_at', mode='before') @field_validator('valid_at', 'invalid_at', mode='before')
@classmethod @classmethod
def validate_datetime(cls, v): def validate_datetime(cls, v):
"""使用通用的历史日期解析函数""" """使用通用的历史日期解析函数"""
return parse_historical_datetime(v) return parse_historical_datetime(v)
@field_validator('emotion_type', mode='before') @field_validator('emotion_type', mode='before')
@classmethod @classmethod
def validate_emotion_type(cls, v): def validate_emotion_type(cls, v):
@@ -334,7 +326,7 @@ class StatementNode(Node):
if v not in valid_types: if v not in valid_types:
raise ValueError(f"emotion_type must be one of {valid_types}, got {v}") raise ValueError(f"emotion_type must be one of {valid_types}, got {v}")
return v return v
@field_validator('emotion_subject', mode='before') @field_validator('emotion_subject', mode='before')
@classmethod @classmethod
def validate_emotion_subject(cls, v): def validate_emotion_subject(cls, v):
@@ -345,7 +337,7 @@ class StatementNode(Node):
if v not in valid_subjects: if v not in valid_subjects:
raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}") raise ValueError(f"emotion_subject must be one of {valid_subjects}, got {v}")
return v return v
@field_validator('emotion_keywords', mode='before') @field_validator('emotion_keywords', mode='before')
@classmethod @classmethod
def validate_emotion_keywords(cls, v): def validate_emotion_keywords(cls, v):
@@ -413,20 +405,19 @@ class ExtractedEntityNode(Node):
entity_type: str = Field(..., description="Type of the entity") entity_type: str = Field(..., description="Type of the entity")
description: str = Field(..., description="Entity description") description: str = Field(..., description="Entity description")
example: str = Field( example: str = Field(
default="", default="",
description="A concise example (around 20 characters) to help understand the entity" description="A concise example (around 20 characters) to help understand the entity"
) )
aliases: List[str] = Field( aliases: List[str] = Field(
default_factory=list, default_factory=list,
description="Entity aliases - alternative names for this entity" description="Entity aliases - alternative names for this entity"
) )
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector") name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 # TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
# fact_summary: str = Field(default="", description="Summary of the fact about this entity") # fact_summary: str = Field(default="", description="Summary of the fact about this entity")
connect_strength: str = Field(..., description="Strong VS Weak about this entity") connect_strength: str = Field(..., description="Strong VS Weak about this entity")
config_id: Optional[int | str] = Field(None, config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
description="Configuration ID used to process this entity (integer or string)")
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,
@@ -453,16 +444,16 @@ class ExtractedEntityNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed" description="Total number of times this node has been accessed"
) )
# Explicit Memory Classification # Explicit Memory Classification
is_explicit_memory: bool = Field( is_explicit_memory: bool = Field(
default=False, default=False,
description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)" description="Whether this entity represents explicit/semantic memory (knowledge, concepts, definitions, theories, principles)"
) )
@field_validator('aliases', mode='before') @field_validator('aliases', mode='before')
@classmethod @classmethod
def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段 def validate_aliases_field(cls, v): # 字段验证器 自动清理和验证 aliases 字段
"""Validate and clean aliases field using utility function. """Validate and clean aliases field using utility function.
This validator ensures that the aliases field is always a valid list of strings. This validator ensures that the aliases field is always a valid list of strings.
@@ -516,9 +507,8 @@ class MemorySummaryNode(Node):
memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory") memory_type: Optional[str] = Field(None, description="Type/category of the episodic memory")
summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary") summary_embedding: Optional[List[float]] = Field(None, description="Embedding vector for the summary")
metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary") metadata: dict = Field(default_factory=dict, description="Additional metadata for the summary")
config_id: Optional[int | str] = Field(None, config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this summary (integer or string)")
description="Configuration ID used to process this summary (integer or string)")
# ACT-R Forgetting Engine Properties # ACT-R Forgetting Engine Properties
original_statement_id: Optional[str] = Field( original_statement_id: Optional[str] = Field(
None, None,
@@ -532,7 +522,7 @@ class MemorySummaryNode(Node):
None, None,
description="Timestamp when the nodes were merged" description="Timestamp when the nodes were merged"
) )
# ACT-R Memory Activation Properties # ACT-R Memory Activation Properties
importance_score: float = Field( importance_score: float = Field(
default=0.5, default=0.5,
@@ -559,18 +549,3 @@ class MemorySummaryNode(Node):
ge=0, ge=0,
description="Total number of times this node has been accessed (reset to 1 on creation)" description="Total number of times this node has been accessed (reset to 1 on creation)"
) )
class PerceptualNode(Node):
"""Node representing a multimodal message in the knowledge graph.
"""
perceptual_type: int
file_path: str
file_name: str
file_ext: str
summary: str
keywords: list[str]
topic: str
domain: str
file_type: str
summary_embedding: list[float] | None

View File

@@ -30,7 +30,6 @@ class ConversationMessage(BaseModel):
""" """
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').") role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
msg: str = Field(..., description="The text content of the message.") msg: str = Field(..., description="The text content of the message.")
files: list[tuple] = Field(default_factory=list, description="The file content of the message", exclude=True)
class TemporalValidityRange(BaseModel): class TemporalValidityRange(BaseModel):
@@ -131,8 +130,7 @@ class Chunk(BaseModel):
content: str = Field(..., description="The content of the chunk as a string.") content: str = Field(..., description="The content of the chunk as a string.")
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).") speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.") statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
files: list[tuple] = Field(default_factory=list, description="List of files in the chunk.") chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
chunk_embedding: Optional[List[float]] = Field(default=None, description="The embedding vector of the chunk.")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
@classmethod @classmethod

View File

@@ -238,7 +238,7 @@ def rerank_with_activation(
reranked: Dict[str, List[Dict[str, Any]]] = {} reranked: Dict[str, List[Dict[str, Any]]] = {}
for category in ["statements", "chunks", "entities", "summaries", "communities"]: for category in ["statements", "chunks", "entities", "summaries"]:
keyword_items = keyword_results.get(category, []) keyword_items = keyword_results.get(category, [])
embedding_items = embedding_results.get(category, []) embedding_items = embedding_results.get(category, [])
@@ -281,23 +281,21 @@ def rerank_with_activation(
for item in items_list: for item in items_list:
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id") item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
if item_id and item_id in combined_items: if item_id and item_id in combined_items:
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value") combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
# 步骤 4: 计算基础分数和最终分数 # 步骤 4: 计算基础分数和最终分数
for item_id, item in combined_items.items(): for item_id, item in combined_items.items():
bm25_norm = float(item.get("bm25_score", 0) or 0) bm25_norm = float(item.get("bm25_score", 0) or 0)
emb_norm = float(item.get("embedding_score", 0) or 0) emb_norm = float(item.get("embedding_score", 0) or 0)
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义 act_norm = float(item.get("normalized_activation_value", 0) or 0)
raw_act_norm = item.get("normalized_activation_value")
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
# 第一阶段只考虑内容相关性BM25 + Embedding # 第一阶段只考虑内容相关性BM25 + Embedding
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重 # alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
base_score = content_score # 第一阶段用内容分数 base_score = content_score # 第一阶段用内容分数
# 存储激活度分数供第二阶段使用None 表示无激活值,不参与激活值排序) # 存储激活度分数供第二阶段使用
item["activation_score"] = act_norm # 可能为 None item["activation_score"] = act_norm
item["content_score"] = content_score item["content_score"] = content_score
item["base_score"] = base_score item["base_score"] = base_score
@@ -726,8 +724,6 @@ async def run_hybrid_search(
try: try:
keyword_task = None keyword_task = None
embedding_task = None embedding_task = None
keyword_results: Dict[str, List] = {}
embedding_results: Dict[str, List] = {}
if search_type in ["keyword", "hybrid"]: if search_type in ["keyword", "hybrid"]:
# Keyword-based search # Keyword-based search
@@ -750,42 +746,35 @@ async def run_hybrid_search(
# 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig # 从数据库读取嵌入器配置(按 ID并构建 RedBearModelConfig
config_load_start = time.time() config_load_start = time.time()
try: with get_db_context() as db:
with get_db_context() as db: config_service = MemoryConfigService(db)
config_service = MemoryConfigService(db) embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id)) rb_config = RedBearModelConfig(
rb_config = RedBearModelConfig( model_name=embedder_config_dict["model_name"],
model_name=embedder_config_dict["model_name"], provider=embedder_config_dict["provider"],
provider=embedder_config_dict["provider"], api_key=embedder_config_dict["api_key"],
api_key=embedder_config_dict["api_key"], base_url=embedder_config_dict["base_url"],
base_url=embedder_config_dict["base_url"], type="llm"
type="llm" )
) config_load_time = time.time() - config_load_start
config_load_time = time.time() - config_load_start logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
# Init embedder # Init embedder
embedder_init_start = time.time() embedder_init_start = time.time()
embedder = OpenAIEmbedderClient(model_config=rb_config) embedder = OpenAIEmbedderClient(model_config=rb_config)
embedder_init_time = time.time() - embedder_init_start embedder_init_time = time.time() - embedder_init_start
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s") logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
embedding_task = asyncio.create_task( embedding_task = asyncio.create_task(
search_graph_by_embedding( search_graph_by_embedding(
connector=connector, connector=connector,
embedder_client=embedder, embedder_client=embedder,
query_text=query_text, query_text=query_text,
end_user_id=end_user_id, end_user_id=end_user_id,
limit=limit, limit=limit,
include=include, include=include,
)
) )
except Exception as emb_init_err: )
logger.warning(
f"[PERF] Embedding search skipped due to init error "
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
)
embedding_task = None
if keyword_task: if keyword_task:
keyword_results = await keyword_task keyword_results = await keyword_task

View File

@@ -7,7 +7,6 @@
- 增量更新incremental_update新实体到达时只处理新实体及其邻居 - 增量更新incremental_update新实体到达时只处理新实体及其邻居
""" """
import asyncio
import logging import logging
import uuid import uuid
from math import sqrt from math import sqrt
@@ -20,9 +19,8 @@ logger = logging.getLogger(__name__)
# 全量迭代最大轮数,防止不收敛 # 全量迭代最大轮数,防止不收敛
MAX_ITERATIONS = 10 MAX_ITERATIONS = 10
# 社区摘要核心实体数量
# 社区核心实体取 top-N 数量 CORE_ENTITY_LIMIT = 5
CORE_ENTITY_LIMIT = 10
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float: def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
@@ -69,16 +67,15 @@ class LabelPropagationEngine:
def __init__( def __init__(
self, self,
connector: Neo4jConnector, connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None, llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None, embedding_model_id: Optional[str] = None,
): ):
self.connector = connector self.connector = connector
self.repo = CommunityRepository(connector) self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id self.embedding_model_id = embedding_model_id
# 缓存客户端实例,避免重复初始化
self._llm_client = None
self._embedder_client = None
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 公开接口 # 公开接口
@@ -108,81 +105,58 @@ class LabelPropagationEngine:
async def full_clustering(self, end_user_id: str) -> None: async def full_clustering(self, end_user_id: str) -> None:
""" """
全量标签传播初始化(分批处理,控制内存峰值) 全量标签传播初始化。
策略: 1. 拉取所有实体,初始化每个实体为独立社区
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存 2. 迭代:每轮对所有实体做邻居投票,更新社区标签
- labels 字典跨批次共享(只存 id→community_id内存极小 3. 直到标签不再变化或达到 MAX_ITERATIONS
- 每批独立跑 MAX_ITERATIONS 轮 LPA批次间通过 labels 传递社区信息 4. 将最终标签写入 Neo4j
- 所有批次完成后统一 flush 和 merge
""" """
BATCH_SIZE = 888 # 每批实体数,可按需调整 entities = await self.repo.get_all_entities(end_user_id)
if not entities:
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
total_count = await self.repo.get_entity_count(end_user_id)
if not total_count:
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类") logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
return return
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id) # 初始化:每个实体持有自己 id 作为社区标签
logger.info(f"[Clustering] 用户 {end_user_id}{total_count} 个实体," labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE}") embeddings: Dict[str, Optional[List[float]]] = {
e["id"]: e.get("name_embedding") for e in entities
}
# labels 跨批次共享:只存 id→community_id内存极小 # 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids} logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据 neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for batch_start in range(0, total_count, BATCH_SIZE): for iteration in range(MAX_ITERATIONS):
batch_entities = await self.repo.get_entities_page( changed = 0
end_user_id, skip=batch_start, limit=BATCH_SIZE # 随机顺序Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
) for entity in entities:
if not batch_entities: eid = entity["id"]
break # 直接从缓存取邻居,不再发起 Neo4j 查询
neighbors = neighbors_cache.get(eid, [])
batch_ids = [e["id"] for e in batch_entities] # 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
batch_embeddings: Dict[str, Optional[List[float]]] = { enriched = []
e["id"]: e.get("name_embedding") for e in batch_entities for nb in neighbors:
} nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info( logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}" f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS}"
f"加载 {len(batch_entities)} 个实体的邻居图..." f"标签变化数: {changed}"
) )
neighbors_cache = await self.repo.get_entity_neighbors_for_ids( if changed == 0:
batch_ids, end_user_id logger.info("[Clustering] 标签已收敛,提前结束迭代")
) break
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
for iteration in range(MAX_ITERATIONS): # 将最终标签写入 Neo4j
changed = 0
for entity in batch_entities:
eid = entity["id"]
neighbors = neighbors_cache.get(eid, [])
# 注入跨批次的最新标签邻居可能在其他批次labels 里有其最新值)
enriched = []
for nb in neighbors:
nb_copy = dict(nb)
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
enriched.append(nb_copy)
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
if new_label and new_label != labels[eid]:
labels[eid] = new_label
changed += 1
logger.info(
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
)
if changed == 0:
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
break
# 释放本批次的大对象
del neighbors_cache, batch_embeddings, batch_entities
# 所有批次完成,统一写入 Neo4j
await self._flush_labels(labels, end_user_id) await self._flush_labels(labels, end_user_id)
pre_merge_count = len(set(labels.values())) pre_merge_count = len(set(labels.values()))
logger.info( logger.info(
@@ -190,6 +164,7 @@ class LabelPropagationEngine:
f"{len(labels)} 个实体,开始后处理合并" f"{len(labels)} 个实体,开始后处理合并"
) )
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
all_community_ids = list(set(labels.values())) all_community_ids = list(set(labels.values()))
await self._evaluate_merge(all_community_ids, end_user_id) await self._evaluate_merge(all_community_ids, end_user_id)
@@ -197,15 +172,17 @@ class LabelPropagationEngine:
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区," f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
f"{len(labels)} 个实体" f"{len(labels)} 个实体"
) )
# 为所有社区生成元数据
# 查询存活社区并生成元数据 # 注意_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活社区
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
surviving_communities = await self.repo.get_all_entities(end_user_id) surviving_communities = await self.repo.get_all_entities(end_user_id)
surviving_community_ids = list({ surviving_community_ids = list({
e.get("community_id") for e in surviving_communities e.get("community_id") for e in surviving_communities
if e.get("community_id") if e.get("community_id")
}) })
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}") logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
await self._generate_community_metadata(surviving_community_ids, end_user_id) for cid in surviving_community_ids:
await self._generate_community_metadata(cid, end_user_id)
async def incremental_update( async def incremental_update(
self, new_entity_ids: List[str], end_user_id: str self, new_entity_ids: List[str], end_user_id: str
@@ -218,17 +195,8 @@ class LabelPropagationEngine:
3. 若邻居无社区 → 创建新社区 3. 若邻居无社区 → 创建新社区
4. 若邻居分属多个社区 → 评估是否合并 4. 若邻居分属多个社区 → 评估是否合并
""" """
# 收集所有需要生成元数据的社区ID
communities_to_update = set()
for entity_id in new_entity_ids: for entity_id in new_entity_ids:
cid = await self._process_single_entity(entity_id, end_user_id) await self._process_single_entity(entity_id, end_user_id)
if cid:
communities_to_update.add(cid)
# 批量生成所有社区的元数据
if communities_to_update:
await self._generate_community_metadata(list(communities_to_update), end_user_id, force=True)
# ────────────────────────────────────────────────────────────────────────── # ──────────────────────────────────────────────────────────────────────────
# 内部方法 # 内部方法
@@ -236,21 +204,8 @@ class LabelPropagationEngine:
async def _process_single_entity( async def _process_single_entity(
self, entity_id: str, end_user_id: str self, entity_id: str, end_user_id: str
) -> Optional[str]: ) -> None:
""" """处理单个新实体的社区分配。"""
处理单个新实体的社区分配。
该函数会为新实体分配社区,可能的情况包括:
1. 孤立实体(无邻居):创建新的单成员社区
2. 邻居都没有社区:创建新社区并将实体和邻居都加入
3. 邻居有社区:通过加权投票选择最合适的社区加入
Returns:
Optional[str]: 分配到的社区ID。当前实现总是返回一个有效的社区ID
但返回类型保留为Optional以支持未来可能的扩展场景
(例如:实体无法分配到任何社区的情况)。
调用方应检查返回值的真假性truthiness
"""
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id) neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
# 查询自身 embedding从邻居查询结果中无法获取需单独查 # 查询自身 embedding从邻居查询结果中无法获取需单独查
@@ -262,7 +217,7 @@ class LabelPropagationEngine:
await self.repo.upsert_community(new_cid, end_user_id, member_count=1) await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id) await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}") logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
return new_cid return
# 统计邻居社区分布 # 统计邻居社区分布
community_ids_in_neighbors = set( community_ids_in_neighbors = set(
@@ -284,7 +239,7 @@ class LabelPropagationEngine:
logger.debug( logger.debug(
f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}" f"[Clustering] 新实体 {entity_id}{len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
) )
return new_cid await self._generate_community_metadata(new_cid, end_user_id)
else: else:
# 加入得票最多的社区 # 加入得票最多的社区
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id) await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
@@ -296,8 +251,7 @@ class LabelPropagationEngine:
await self._evaluate_merge( await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id list(community_ids_in_neighbors), end_user_id
) )
# 返回目标社区ID稍后批量生成元数据 await self._generate_community_metadata(target_cid, end_user_id)
return target_cid
async def _evaluate_merge( async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str self, community_ids: List[str], end_user_id: str
@@ -461,223 +415,94 @@ class LabelPropagationEngine:
except Exception: except Exception:
return None return None
@staticmethod
def _build_entity_lines(members: List[Dict]) -> List[str]:
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
lines = []
for m in members:
m_name = m.get("name", "")
aliases = m.get("aliases") or []
description = m.get("description") or ""
example = m.get("example") or ""
aliases_str = f"(别名:{''.join(aliases)}" if aliases else ""
desc_str = f"{description}" if description else ""
example_str = f"(示例:{example}" if example else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
return lines
async def _generate_community_metadata( async def _generate_community_metadata(
self, community_ids: List[str], end_user_id: str, force: bool = False self, community_id: str, end_user_id: str
) -> None: ) -> None:
""" """
一个或多个社区生成并写入元数据(优化版:批量 LLM 调用) 为社区生成并写入元数据:名称、摘要、核心实体
流程: - core_entities按 activation_value 排序取 top-N 实体名称列表(无需 LLM
1. 批量准备所有社区的 prompt - name / summary若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
2. 并发调用 LLM 生成所有社区的 name / summary
3. 批量 embed 所有 summary
4. 批量写入数据库
Args:
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
""" """
async def _prepare_one(cid: str) -> Optional[Dict]: try:
"""准备单个社区的数据和 prompt""" # 先检查属性是否已完整,完整则跳过,避免重复生成
try: check_embedding = bool(self.embedding_model_id)
if not force: if await self.repo.is_community_complete(community_id, end_user_id, check_embedding=check_embedding):
check_embedding = bool(self.embedding_model_id) logger.debug(f"[Clustering] 社区 {community_id} 属性已完整,跳过生成")
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding): return
return None
members = await self.repo.get_community_members(cid, end_user_id) members = await self.repo.get_community_members(community_id, end_user_id)
if not members: if not members:
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成") return
return None
sorted_members = sorted( # 核心实体:按 activation_value 降序取 top-N
members, sorted_members = sorted(
key=lambda m: m.get("activation_value") or 0, members,
reverse=True, key=lambda m: m.get("activation_value") or 0,
) reverse=True,
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")] )
all_names = [m["name"] for m in members if m.get("name")] core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
# 默认值 name = "".join(core_entities[:3]) if core_entities else community_id[:8]
name = "".join(core_entities[:3]) if core_entities else cid[:8] summary = f"包含实体:{', '.join(all_names)}"
summary = f"包含实体:{', '.join(all_names)}"
# 准备 LLM prompt如果配置了 LLM # 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
prompt = None if self.llm_model_id:
if self.llm_model_id: try:
entity_list_str = "\n".join(self._build_entity_lines(members)) from app.db import get_db_context
relationships = await self.repo.get_community_relationships(cid, end_user_id) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}" entity_list_str = "".join(all_names)
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = ( prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n" f"以下是一组语义相关的实体:{entity_list_str}\n\n"
f"请为这组实体所代表的主题:\n" f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n" f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要(不超过80个字\n\n" f"2. 写一句话摘要(不超过50个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n" f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>" f"名称:<名称>\n摘要:<摘要>"
) )
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
return { for line in text.strip().splitlines():
"community_id": cid, if line.startswith("名称:"):
"end_user_id": end_user_id, name = line[3:].strip()
"name": name, elif line.startswith("摘要:"):
"summary": summary, summary = line[3:].strip()
"core_entities": core_entities,
"prompt": prompt,
"summary_embedding": None,
}
except Exception as e:
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
return None
# --- 阶段1并发准备所有社区数据 ---
results = await asyncio.gather(
*[_prepare_one(cid) for cid in community_ids],
return_exceptions=True,
)
metadata_list = []
for cid, res in zip(community_ids, results):
if isinstance(res, Exception):
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
elif res is not None:
metadata_list.append(res)
if not metadata_list:
logger.warning(f"[Clustering] 无有效元数据可写入community_ids={community_ids}")
return
# --- 阶段2批量调用 LLM 生成 name 和 summary ---
if self.llm_model_id:
llm_client = self._get_llm_client()
if not llm_client:
logger.warning(
f"[Clustering] LLM 已配置model_id={self.llm_model_id})但客户端初始化失败,"
f"将跳过社区元数据的 LLM 富化。请检查 model_id 是否正确或数据库连接是否正常。"
)
if llm_client:
prompts_to_process = [(i, m) for i, m in enumerate(metadata_list) if m.get("prompt")]
if prompts_to_process:
logger.info(f"[Clustering] 批量调用 LLM 生成 {len(prompts_to_process)} 个社区元数据")
async def _call_llm(idx: int, meta: Dict) -> tuple:
"""单个 LLM 调用"""
try:
response = await llm_client.chat([{"role": "user", "content": meta["prompt"]}])
text = response.content if hasattr(response, "content") else str(response)
return (idx, text, None)
except Exception as e:
logger.warning(f"[Clustering] 社区 {meta['community_id']} LLM 生成失败: {e}")
return (idx, None, e)
# 并发调用所有 LLM 请求
llm_results = await asyncio.gather(
*[_call_llm(idx, meta) for idx, meta in prompts_to_process],
return_exceptions=True
)
# 解析 LLM 响应
for result in llm_results:
if isinstance(result, Exception):
continue
idx, text, error = result
if error or not text:
continue
meta = metadata_list[idx]
for line in text.strip().splitlines():
if line.startswith("名称:"):
meta["name"] = line[3:].strip()
elif line.startswith("摘要:"):
meta["summary"] = line[3:].strip()
logger.info(f"[Clustering] LLM 批量生成完成")
# --- 阶段3批量生成 summary_embedding ---
if self.embedding_model_id:
embedder = self._get_embedder_client()
if not embedder:
logger.warning(
f"[Clustering] Embedding 已配置model_id={self.embedding_model_id})但客户端初始化失败,"
f"将跳过社区摘要的向量化。请检查 model_id 是否正确或数据库连接是否正常。"
)
if embedder:
try:
summaries = [m["summary"] for m in metadata_list]
logger.info(f"[Clustering] 批量生成 {len(summaries)} 个 summary embedding")
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
logger.info(f"[Clustering] Embedding 批量生成完成")
except Exception as e: except Exception as e:
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True) logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
# --- 阶段4批量写入数据库 --- # 生成 summary_embedding
# 移除 prompt 字段(不需要存储) summary_embedding: Optional[List[float]] = None
for m in metadata_list: if self.embedding_model_id and summary:
m.pop("prompt", None) try:
from app.db import get_db_context
if len(metadata_list) == 1: from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
m = metadata_list[0]
result = await self.repo.update_community_metadata( with get_db_context() as db:
community_id=m["community_id"], embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
end_user_id=m["end_user_id"], vectors = await embedder.response([summary])
name=m["name"], if vectors:
summary=m["summary"], summary_embedding = vectors[0]
core_entities=m["core_entities"], except Exception as e:
summary_embedding=m["summary_embedding"], logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
await self.repo.update_community_metadata(
community_id=community_id,
end_user_id=end_user_id,
name=name,
summary=summary,
core_entities=core_entities,
summary_embedding=summary_embedding,
) )
if not result: logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败") except Exception as e:
else: logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
ok = await self.repo.batch_update_community_metadata(metadata_list)
if not ok:
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
else:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
def _get_llm_client(self):
"""获取或创建 LLM 客户端(单例模式)"""
if self._llm_client is None and self.llm_model_id:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
self._llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
logger.info(f"[Clustering] LLM 客户端初始化完成(单例): model_id={self.llm_model_id}")
return self._llm_client
def _get_embedder_client(self):
"""获取或创建 Embedder 客户端(单例模式)"""
if self._embedder_client is None and self.embedding_model_id:
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
with get_db_context() as db:
self._embedder_client = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
logger.info(f"[Clustering] Embedder 客户端初始化完成(单例): model_id={self.embedding_model_id}")
return self._embedder_client
@staticmethod @staticmethod
def _new_community_id() -> str: def _new_community_id() -> str:
return str(uuid.uuid4()) return str(uuid.uuid4())

View File

@@ -9,7 +9,6 @@
""" """
import asyncio import asyncio
import logging
import os import os
import hashlib import hashlib
import json import json
@@ -21,26 +20,13 @@ from pydantic import BaseModel, Field
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
from app.core.memory.models.config_models import PruningConfig from app.core.memory.models.config_models import PruningConfig
from app.core.memory.utils.config.config_utils import get_pruning_config
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
SceneConfigRegistry, SceneConfigRegistry,
ScenePatterns ScenePatterns
) )
logger = logging.getLogger(__name__)
def message_has_files(message: "ConversationMessage") -> bool:
"""检查消息是否包含文件。
Args:
message: 待检查的消息对象
Returns:
bool: 如果消息包含文件则返回 True否则返回 False
"""
return message.files and len(message.files) > 0
class DialogExtractionResponse(BaseModel): class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。 """对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -48,8 +34,6 @@ class DialogExtractionResponse(BaseModel):
- is_related对话与场景的相关性判定。 - is_related对话与场景的相关性判定。
- times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。 - times / ids / amounts / contacts / addresses / keywords重要信息片段用来在不相关对话中保留关键消息。
- preserve_keywords情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。 - preserve_keywords情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
- scene_unrelated_snippets与当前场景无关且无语义关联的消息片段原文截取
用于高阈值阶段精准删除跨场景内容。
""" """
is_related: bool = Field(...) is_related: bool = Field(...)
times: List[str] = Field(default_factory=list) times: List[str] = Field(default_factory=list)
@@ -59,7 +43,6 @@ class DialogExtractionResponse(BaseModel):
addresses: List[str] = Field(default_factory=list) addresses: List[str] = Field(default_factory=list)
keywords: List[str] = Field(default_factory=list) keywords: List[str] = Field(default_factory=list)
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留") preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容")
class MessageImportanceResponse(BaseModel): class MessageImportanceResponse(BaseModel):
@@ -108,14 +91,12 @@ class SemanticPruner:
# 加载统一填充词库 # 加载统一填充词库
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene) self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
# 本体类型列表:直接使用 ontology_class_infosname + description # 本体类型列表(用于注入提示词,所有场景均支持
self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or [] self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
# _ontology_classes 仅用于日志统计
self._ontology_classes = [info.class_name for info in self._ontology_class_infos]
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}") self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
if self._ontology_class_infos: if self._ontology_classes:
self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}") self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
else: else:
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词") self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
@@ -140,8 +121,7 @@ class SemanticPruner:
1. 空消息 1. 空消息
2. 场景特定填充词库精确匹配 2. 场景特定填充词库精确匹配
3. 常见寒暄精确匹配 3. 常见寒暄精确匹配
4. 组合寒暄模式(前缀 + 后缀组合,如"好的谢谢""同学你好""明白了" 4. 纯表情/标点
5. 纯表情/标点
""" """
t = message.msg.strip() t = message.msg.strip()
if not t: if not t:
@@ -163,55 +143,6 @@ class SemanticPruner:
if t in common_greetings: if t in common_greetings:
return True return True
# 组合寒暄模式短消息≤15字且完全由寒暄成分构成
# 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充
if len(t) <= 15:
# 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢"
_confirm_prefixes = {"好的", "", "", "嗯嗯", "", "明白", "明白了", "知道了", "了解", "收到", "没问题"}
_thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"}
_greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"}
_greeting_prefixes = {"同学", "老师", "您好", "你好"}
_close_patterns = {
"没有了", "没事了", "没问题了", "好了", "行了", "可以了",
"不用了", "不需要了", "就这样", "就这样吧", "那就这样",
}
_polite_responses = {
"不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的",
}
# 规则1确认词 + 感谢词(如"好的谢谢"、"嗯谢谢"
for cp in _confirm_prefixes:
for ts in _thanks_suffixes:
if t == cp + ts or t == cp + "" + ts or t == cp + "," + ts:
return True
# 规则2称呼前缀 + 问候(如"同学你好"、"老师好"
for gp in _greeting_prefixes:
for gs in _greeting_suffixes:
if t == gp + gs or t.startswith(gp) and t.endswith(""):
return True
# 规则3结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢"
for cp in _close_patterns:
if t.startswith(cp):
remainder = t[len(cp):].lstrip(",、 ")
if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes):
return True
# 规则4礼貌回应如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话)
for pr in _polite_responses:
if t.startswith(pr):
remainder = t[len(pr):].lstrip(",、 ")
# 后半是祝福/套话(不含实质信息)
if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder):
return True
# 规则5纯确认词加"了"后缀(如"明白了"、"知道了"、"好了"
_confirm_base = {"明白", "知道", "了解", "收到", "", "", "可以", "没问题"}
for cb in _confirm_base:
if t == cb + "" or t == cb + "了。" or t == cb + "了!":
return True
# 检查是否为纯表情符号(方括号包裹) # 检查是否为纯表情符号(方括号包裹)
if re.fullmatch(r"(\[[^\]]+\])+", t): if re.fullmatch(r"(\[[^\]]+\])+", t):
return True return True
@@ -400,13 +331,13 @@ class SemanticPruner:
rendered = self.template.render( rendered = self.template.render(
pruning_scene=self.config.pruning_scene, pruning_scene=self.config.pruning_scene,
ontology_class_infos=self._ontology_class_infos, ontology_classes=self._ontology_classes,
dialog_text=dialog_text, dialog_text=dialog_text,
language=self.language language=self.language
) )
log_template_rendering("extracat_Pruning.jinja2", { log_template_rendering("extracat_Pruning.jinja2", {
"pruning_scene": self.config.pruning_scene, "pruning_scene": self.config.pruning_scene,
"ontology_class_infos_count": len(self._ontology_class_infos), "ontology_classes_count": len(self._ontology_classes),
"language": self.language "language": self.language
}) })
log_prompt_rendering("pruning-extract", rendered) log_prompt_rendering("pruning-extract", rendered)
@@ -446,193 +377,6 @@ class SemanticPruner:
) )
return fallback_response return fallback_response
def _get_pruning_mode(self) -> str:
"""根据 pruning_threshold 返回当前剪枝阶段。
- 低阈值 [0.0, 0.3)conservative 只删填充,保留所有实质内容
- 中阈值 [0.3, 0.6)semantic 保留场景相关 + 有语义关联的内容,删除无关联内容
- 高阈值 [0.6, 0.9]strict 只保留场景相关内容,跨场景内容可被删除
"""
t = float(self.config.pruning_threshold)
if t < 0.3:
return "conservative"
elif t < 0.6:
return "semantic"
else:
return "strict"
def _apply_related_dialog_pruning(
self,
msgs: List[ConversationMessage],
extraction: "DialogExtractionResponse",
dialog_label: str,
pruning_mode: str,
) -> List[ConversationMessage]:
"""相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。
- conservative只删填充
- semantic / strict场景感知剪枝
"""
if pruning_mode == "conservative":
preserve_tokens = self._build_preserve_tokens(extraction)
return self._prune_fillers_only(msgs, preserve_tokens, dialog_label)
else:
return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode)
def _prune_fillers_only(
self,
msgs: List[ConversationMessage],
preserve_tokens: List[str],
dialog_label: str,
) -> List[ConversationMessage]:
"""相关对话专用只删填充消息LLM 保护消息和实质内容一律保留。
不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。
至少保留 1 条消息。
注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值,
即使 LLM 误将其关键词放入 preserve_tokens 也应删除。
"""
to_delete_ids: set = set()
for m in msgs:
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{m.msg[:40]}',文件数={len(m.files)}")
continue
# 填充检测优先:先判断是否为填充,再看 LLM 保护
if self._is_filler_message(m):
to_delete_ids.add(id(m))
self._log(f" [填充] '{m.msg[:40]}' → 删除")
continue
if self._msg_matches_tokens(m, preserve_tokens):
self._log(f" [保护] '{m.msg[:40]}' → LLM保护跳过")
kept = [m for m in msgs if id(m) not in to_delete_ids]
if not kept and msgs:
kept = [msgs[0]]
deleted = len(msgs) - len(kept)
self._log(
f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} "
f"填充删除={deleted} 保留={len(kept)}"
)
return kept
def _prune_with_scene_filter(
self,
msgs: List[ConversationMessage],
extraction: "DialogExtractionResponse",
dialog_label: str,
mode: str,
) -> List[ConversationMessage]:
"""场景感知剪枝,供 semantic / strict 两个阈值档位调用。
本函数体现剪枝系统的三层递进逻辑:
第一层conservative阈值 < 0.3
不进入本函数,由 _prune_fillers_only 处理。
保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。
第二层semantic阈值 [0.3, 0.6)
保留标准:内容价值优先,场景相关性是参考而非唯一标准。
- 填充消息 → 删除(最高优先级)
- 场景相关消息 → 保留
- 场景无关消息 → 有两次豁免机会:
1. 命中 scene_preserve_tokensLLM 标记的关键词/时间/金额等)→ 保留
2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值)
3. 两次豁免均未命中 → 删除
第三层strict阈值 [0.6, 0.9]
保留标准:场景相关性优先,无任何豁免。
- 填充消息 → 删除(最高优先级)
- 场景相关消息 → 保留
- 场景无关消息 → 直接删除preserve_keywords 和情感词在此模式下均不生效
至少保留 1 条消息(兜底取第一条)。
"""
# strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址),
# 不保护 keywords / preserve_keywords让场景过滤能删掉更多内容。
# semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords
if mode == "strict":
scene_preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses
)
else:
scene_preserve_tokens = self._build_preserve_tokens(extraction)
unrelated_snippets = extraction.scene_unrelated_snippets or []
to_delete_ids: set = set()
for m in msgs:
msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与任何剪枝判断
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与剪枝):'{msg_text[:40]}',文件数={len(m.files)}")
continue
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
if self._is_filler_message(m):
to_delete_ids.add(id(m))
self._log(f" [填充] '{msg_text[:40]}' → 删除")
continue
# 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况
is_scene_unrelated = any(
snip and (snip in msg_text or msg_text in snip)
for snip in unrelated_snippets
)
if is_scene_unrelated:
if mode == "strict":
# strict场景无关直接删除不做任何豁免
# 场景相关性是唯一裁决标准preserve_keywords 在此模式下不生效
to_delete_ids.add(id(m))
self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除")
elif mode == "semantic":
# semantic场景无关但有内容价值 → 保留
# 豁免第一层:命中 scene_preserve_tokens关键词/结构化信息保护)
if self._msg_matches_tokens(m, scene_preserve_tokens):
self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留")
else:
# 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留
has_contextual_emotion = any(
word in msg_text
for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧",
"喜欢", "讨厌", "", "", "担心", "害怕", "兴奋",
"压力", "", "疲惫", "", "焦虑", "委屈", "感动"]
)
if not has_contextual_emotion:
to_delete_ids.add(id(m))
self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)")
else:
self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留")
else:
# 不在 scene_unrelated_snippets 中 → 场景相关,直接保留
if self._msg_matches_tokens(m, scene_preserve_tokens):
self._log(f" [保护] '{msg_text[:40]}' → LLM保护跳过")
# else: 普通场景相关消息,保留,不输出日志
kept = [m for m in msgs if id(m) not in to_delete_ids]
if not kept and msgs:
kept = [msgs[0]]
deleted = len(msgs) - len(kept)
self._log(
f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} "
f"删除={deleted} 保留={len(kept)}"
)
return kept
def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]:
"""统一构建 preserve_tokens合并 LLM 抽取的所有重要片段。"""
return (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords
)
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
"""判断消息是否包含任意抽取到的重要片段。""" """判断消息是否包含任意抽取到的重要片段。"""
if not tokens: if not tokens:
@@ -653,18 +397,16 @@ class SemanticPruner:
proportion = float(self.config.pruning_threshold) proportion = float(self.config.pruning_threshold)
extraction = await self._extract_dialog_important(dialog.content) extraction = await self._extract_dialog_important(dialog.content)
pruning_mode = self._get_pruning_mode()
self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}")
if extraction.is_related: if extraction.is_related:
kept = self._apply_related_dialog_pruning( # 相关对话不剪枝
dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode
)
dialog.context = ConversationContext(msgs=kept)
return dialog return dialog
# 在不相关对话中LLM 已通过 preserve_tokens 标记需要保护的内容 # 在不相关对话中LLM 已通过 preserve_tokens 标记需要保护的内容
preserve_tokens = self._build_preserve_tokens(extraction) preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords
)
msgs = dialog.context.msgs msgs = dialog.context.msgs
# 分类:填充 / 其他可删LLM保护消息通过不加入任何桶来隐式保护 # 分类:填充 / 其他可删LLM保护消息通过不加入任何桶来隐式保护
@@ -731,7 +473,7 @@ class SemanticPruner:
# 阈值保护最高0.9 # 阈值保护最高0.9
proportion = float(self.config.pruning_threshold) proportion = float(self.config.pruning_threshold)
if proportion > 0.9: if proportion > 0.9:
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9") print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9")
proportion = 0.9 proportion = 0.9
if proportion < 0.0: if proportion < 0.0:
proportion = 0.0 proportion = 0.0
@@ -739,30 +481,11 @@ class SemanticPruner:
self._log( self._log(
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
) )
pruning_mode = self._get_pruning_mode()
self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}")
result: List[DialogData] = [] result: List[DialogData] = []
total_original_msgs = 0 total_original_msgs = 0
total_deleted_msgs = 0 total_deleted_msgs = 0
# 统计对象:直接收集结构化数据,无需事后正则解析
stats = {
"scene": self.config.pruning_scene,
"dialog_total": len(dialogs),
"deletion_ratio": proportion,
"enabled": self.config.pruning_switch,
"pruning_mode": pruning_mode,
"related_count": 0,
"unrelated_count": 0,
"related_indices": [],
"unrelated_indices": [],
"total_deleted_messages": 0,
"remaining_dialogs": 0,
"dialogs": [],
}
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息) # 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
semaphore = asyncio.Semaphore(self.max_concurrent) semaphore = asyncio.Semaphore(self.max_concurrent)
@@ -782,31 +505,12 @@ class SemanticPruner:
original_count = len(msgs) original_count = len(msgs)
total_original_msgs += original_count total_original_msgs += original_count
# 相关对话:根据阶段决定处理力度
if extraction.is_related:
stats["related_count"] += 1
stats["related_indices"].append(d_idx + 1)
kept = self._apply_related_dialog_pruning(
msgs, extraction, f"对话 {d_idx+1}", pruning_mode
)
deleted_count = original_count - len(kept)
total_deleted_msgs += deleted_count
dd.context.msgs = kept
result.append(dd)
stats["dialogs"].append({
"index": d_idx + 1,
"is_related": True,
"total_messages": original_count,
"deleted": deleted_count,
"kept": len(kept),
})
continue
stats["unrelated_count"] += 1
stats["unrelated_indices"].append(d_idx + 1)
# 从 LLM 抽取结果中获取所有需要保留的 token # 从 LLM 抽取结果中获取所有需要保留的 token
preserve_tokens = self._build_preserve_tokens(extraction) preserve_tokens = (
extraction.times + extraction.ids + extraction.amounts +
extraction.contacts + extraction.addresses + extraction.keywords +
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
)
# 判断是否需要详细日志 # 判断是否需要详细日志
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
@@ -823,12 +527,6 @@ class SemanticPruner:
for idx, m in enumerate(msgs): for idx, m in enumerate(msgs):
msg_text = m.msg.strip() msg_text = m.msg.strip()
# 最高优先级保护:带有文件的消息一律保留,不参与分类
if message_has_files(m):
self._log(f" [保护] 带文件的消息(不参与分类,直接保留):索引{idx}, '{msg_text[:40]}', 文件数={len(m.files)}")
llm_protected_msgs.append((idx, m)) # 放入保护列表
continue
if self._msg_matches_tokens(m, preserve_tokens): if self._msg_matches_tokens(m, preserve_tokens):
llm_protected_msgs.append((idx, m)) llm_protected_msgs.append((idx, m))
@@ -845,16 +543,16 @@ class SemanticPruner:
# important_msgs 仅用于日志统计 # important_msgs 仅用于日志统计
important_msgs = llm_protected_msgs important_msgs = llm_protected_msgs
# 计算删除配额 # 计算删除配额
delete_target = int(original_count * proportion) delete_target = int(original_count * proportion)
if proportion > 0 and original_count > 0 and delete_target == 0: if proportion > 0 and original_count > 0 and delete_target == 0:
delete_target = 1 delete_target = 1
# 确保至少保留1条消息 # 确保至少保留1条消息
max_deletable = max(0, original_count - 1) max_deletable = max(0, original_count - 1)
delete_target = min(delete_target, max_deletable) delete_target = min(delete_target, max_deletable)
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息 # 删除策略:优先删填充消息,再按出现顺序删其余可删消息
to_delete_indices = set() to_delete_indices = set()
deleted_details = [] deleted_details = []
@@ -872,73 +570,58 @@ class SemanticPruner:
break break
to_delete_indices.add(idx) to_delete_indices.add(idx)
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'") deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
# 执行删除 # 执行删除
kept_msgs = [] kept_msgs = []
for idx, m in enumerate(msgs): for idx, m in enumerate(msgs):
if idx not in to_delete_indices: if idx not in to_delete_indices:
kept_msgs.append(m) kept_msgs.append(m)
# 确保至少保留1条 # 确保至少保留1条
if not kept_msgs and msgs: if not kept_msgs and msgs:
kept_msgs = [msgs[0]] kept_msgs = [msgs[0]]
dd.context.msgs = kept_msgs dd.context.msgs = kept_msgs
deleted_count = original_count - len(kept_msgs) deleted_count = original_count - len(kept_msgs)
total_deleted_msgs += deleted_count total_deleted_msgs += deleted_count
# 输出删除详情 # 输出删除详情
if deleted_details: if deleted_details:
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:") self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
for detail in deleted_details: for detail in deleted_details:
self._log(f" {detail}") self._log(f" {detail}")
# ========== 问答对统计(已注释) ========== # ========== 问答对统计(已注释) ==========
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else "" # qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
# ======================================== # ========================================
self._log( self._log(
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) " f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
f"删除={deleted_count} 保留={len(kept_msgs)}" f"删除={deleted_count} 保留={len(kept_msgs)}"
) )
stats["dialogs"].append({
"index": d_idx + 1,
"is_related": False,
"total_messages": original_count,
"protected": len(important_msgs),
"fillers": len(filler_msgs),
"deletable": len(deletable_msgs),
"deleted": deleted_count,
"kept": len(kept_msgs),
})
result.append(dd) result.append(dd)
# 补全统计对象
stats["total_deleted_messages"] = total_deleted_msgs
stats["remaining_dialogs"] = len(result)
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}")
self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs}")
# 直接序列化统计对象,无需正则解析 # 保存日志
try: try:
from app.core.config import settings from app.core.config import settings
settings.ensure_memory_output_dir() settings.ensure_memory_output_dir()
log_output_path = settings.get_memory_output_path("pruned_terminal.json") log_output_path = settings.get_memory_output_path("pruned_terminal.json")
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
payload = self._parse_logs_to_structured(sanitized_logs)
with open(log_output_path, "w", encoding="utf-8") as f: with open(log_output_path, "w", encoding="utf-8") as f:
json.dump(stats, f, ensure_ascii=False, indent=2) json.dump(payload, f, ensure_ascii=False, indent=2)
except Exception as e: except Exception as e:
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}") self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
# Safety: avoid empty dataset # Safety: avoid empty dataset
if not result: if not result:
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs return dialogs
return result return result
def _log(self, msg: str) -> None: def _log(self, msg: str) -> None:
@@ -946,7 +629,118 @@ class SemanticPruner:
try: try:
self.run_logs.append(msg) self.run_logs.append(msg)
except Exception: except Exception:
# 任何异常都不影响打印
pass pass
logger.debug(msg) print(msg)
def _sanitize_log_line(self, line: str) -> str:
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
try:
return re.sub(r"^\[[^\]]+\]\s*", "", line)
except Exception:
return line
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
"""将已去前缀的日志列表解析为结构化 JSON便于数据对接。"""
summary = {
"scene": self.config.pruning_scene,
"dialog_total": None,
"deletion_ratio": None,
"enabled": None,
"related_count": None,
"unrelated_count": None,
"related_indices": [],
"unrelated_indices": [],
"total_deleted_messages": None,
"remaining_dialogs": None,
}
dialogs = []
# 解析函数
def parse_int(value: str) -> Optional[int]:
try:
return int(value)
except Exception:
return None
def parse_float(value: str) -> Optional[float]:
try:
return float(value)
except Exception:
return None
def parse_indices(s: str) -> List[int]:
s = s.strip()
if not s:
return []
parts = [p.strip() for p in s.split(",") if p.strip()]
out: List[int] = []
for p in parts:
try:
out.append(int(p))
except Exception:
pass
return out
# 正则
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
re_remaining = re.compile(r"剩余对话数=(\d+)")
for line in logs:
# 第一行:总览
m = re_header.search(line)
if m:
summary["dialog_total"] = parse_int(m.group(1))
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
summary["deletion_ratio"] = parse_float(m.group(3))
summary["enabled"] = True if m.group(4) == "True" else False
continue
# 第二行:相关/不相关数量
m = re_counts.search(line)
if m:
summary["related_count"] = parse_int(m.group(1))
summary["unrelated_count"] = parse_int(m.group(2))
continue
# 第三行:相关/不相关索引
m = re_indices.search(line)
if m:
summary["related_indices"] = parse_indices(m.group(1))
summary["unrelated_indices"] = parse_indices(m.group(2))
continue
# 对话级统计
m = re_dialog.search(line)
if m:
dialogs.append({
"index": parse_int(m.group(1)),
"total_messages": parse_int(m.group(2)),
"quota_delete": parse_int(m.group(3)),
"actual_deleted": parse_int(m.group(4)),
"kept": parse_int(m.group(5)),
})
continue
# 全局删除总数
m = re_total_del.search(line)
if m:
summary["total_deleted_messages"] = parse_int(m.group(1))
continue
# 剩余对话数
m = re_remaining.search(line)
if m:
summary["remaining_dialogs"] = parse_int(m.group(1))
continue
return {
"scene": summary["scene"],
"timestamp": datetime.now().isoformat(),
"summary": {k: v for k, v in summary.items() if k != "scene"},
"dialogs": dialogs,
}

View File

@@ -203,7 +203,6 @@ def accurate_match(
) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]: ) -> Tuple[List[ExtractedEntityNode], Dict[str, str], Dict[str, Dict]]:
""" """
精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。 精确匹配:按 (end_user_id, name, entity_type) 合并实体并建立重定向与合并记录。
同时检测某实体的 name 是否命中另一实体的 aliases若命中则直接合并。
返回: (deduped_entities, id_redirect, exact_merge_map) 返回: (deduped_entities, id_redirect, exact_merge_map)
""" """
exact_merge_map: Dict[str, Dict] = {} exact_merge_map: Dict[str, Dict] = {}
@@ -241,48 +240,6 @@ def accurate_match(
pass pass
deduped_entities = list(canonical_map.values()) deduped_entities = list(canonical_map.values())
# 2) 第二轮:检测某实体的 name 是否命中另一实体的 aliasesalias-to-name 精确合并)
# 场景LLM 把 aliases 中的词(如"齐齐")又单独抽取为独立实体,需在此阶段合并掉
# 优化:先构建 (end_user_id, alias_lower) -> canonical 的反向索引,查找 O(1)
alias_index: Dict[tuple, ExtractedEntityNode] = {}
for canonical in deduped_entities:
uid = getattr(canonical, "end_user_id", None)
for alias in (getattr(canonical, "aliases", []) or []):
alias_lower = alias.strip().lower()
if alias_lower:
alias_index[(uid, alias_lower)] = canonical
i = 0
while i < len(deduped_entities):
ent = deduped_entities[i]
ent_name = (getattr(ent, "name", "") or "").strip().lower()
ent_uid = getattr(ent, "end_user_id", None)
canonical = alias_index.get((ent_uid, ent_name))
# 确保不是自身
if canonical is not None and canonical.id != ent.id:
_merge_attribute(canonical, ent)
id_redirect[ent.id] = canonical.id
for k, v in list(id_redirect.items()):
if v == ent.id:
id_redirect[k] = canonical.id
try:
k = f"{canonical.end_user_id}|{(canonical.name or '').strip()}|{(canonical.entity_type or '').strip()}"
if k not in exact_merge_map:
exact_merge_map[k] = {
"canonical_id": canonical.id,
"end_user_id": canonical.end_user_id,
"name": canonical.name,
"entity_type": canonical.entity_type,
"merged_ids": set(),
}
exact_merge_map[k]["merged_ids"].add(ent.id)
except Exception:
pass
deduped_entities.pop(i)
else:
i += 1
return deduped_entities, id_redirect, exact_merge_map return deduped_entities, id_redirect, exact_merge_map
def fuzzy_match( def fuzzy_match(

View File

@@ -25,17 +25,17 @@ from app.repositories.neo4j.neo4j_connector import Neo4jConnector
async def dedup_layers_and_merge_and_return( async def dedup_layers_and_merge_and_return(
dialogue_nodes: List[DialogueNode], dialogue_nodes: List[DialogueNode],
chunk_nodes: List[ChunkNode], chunk_nodes: List[ChunkNode],
statement_nodes: List[StatementNode], statement_nodes: List[StatementNode],
entity_nodes: List[ExtractedEntityNode], entity_nodes: List[ExtractedEntityNode],
statement_chunk_edges: List[StatementChunkEdge], statement_chunk_edges: List[StatementChunkEdge],
statement_entity_edges: List[StatementEntityEdge], statement_entity_edges: List[StatementEntityEdge],
entity_entity_edges: List[EntityEntityEdge], entity_entity_edges: List[EntityEntityEdge],
dialog_data_list: List[DialogData], dialog_data_list: List[DialogData],
pipeline_config: ExtractionPipelineConfig, pipeline_config: ExtractionPipelineConfig,
connector: Optional[Neo4jConnector] = None, connector: Optional[Neo4jConnector] = None,
llm_client=None, llm_client = None,
) -> Tuple[ ) -> Tuple[
List[DialogueNode], List[DialogueNode],
List[ChunkNode], List[ChunkNode],
@@ -44,7 +44,7 @@ async def dedup_layers_and_merge_and_return(
List[StatementChunkEdge], List[StatementChunkEdge],
List[StatementEntityEdge], List[StatementEntityEdge],
List[EntityEntityEdge], List[EntityEntityEdge],
dict dict, # 新增:返回去重详情
]: ]:
""" """
执行两层实体去重与融合: 执行两层实体去重与融合:

View File

@@ -5,11 +5,8 @@
""" """
import asyncio import asyncio
import logging
from typing import Any, Dict, List, Tuple from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import DialogData from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig from app.core.models.base import RedBearModelConfig
@@ -51,9 +48,9 @@ class EmbeddingGenerator:
return await self.embedder_client.response(texts) return await self.embedder_client.response(texts)
# 分批并行处理 # 分批并行处理
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理") print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)] batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本") print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
# 并行发送所有批次 # 并行发送所有批次
batch_results = await asyncio.gather(*[ batch_results = await asyncio.gather(*[
@@ -65,7 +62,7 @@ class EmbeddingGenerator:
for batch_result in batch_results: for batch_result in batch_results:
embeddings.extend(batch_result) embeddings.extend(batch_result)
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量") print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
return embeddings return embeddings
async def generate_statement_embeddings( async def generate_statement_embeddings(
@@ -80,7 +77,7 @@ class EmbeddingGenerator:
Returns: Returns:
每个对话的陈述句嵌入向量映射列表 每个对话的陈述句嵌入向量映射列表
""" """
logger.debug("=== 生成陈述句嵌入向量 ===") print("\n=== 生成陈述句嵌入向量 ===")
# 收集所有陈述句 # 收集所有陈述句
all_statements = [] all_statements = []
@@ -105,7 +102,7 @@ class EmbeddingGenerator:
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
stmt_embedding_maps[d_idx][stmt_id] = embedding stmt_embedding_maps[d_idx][stmt_id] = embedding
logger.info(f"{len(all_statements)} 个陈述句生成了嵌入向量") print(f"{len(all_statements)} 个陈述句生成了嵌入向量")
return stmt_embedding_maps return stmt_embedding_maps
async def generate_chunk_embeddings( async def generate_chunk_embeddings(
@@ -120,7 +117,7 @@ class EmbeddingGenerator:
Returns: Returns:
每个对话的分块嵌入向量映射列表 每个对话的分块嵌入向量映射列表
""" """
logger.debug("=== 生成分块嵌入向量 ===") print("\n=== 生成分块嵌入向量 ===")
# 收集所有分块 # 收集所有分块
all_chunks = [] all_chunks = []
@@ -141,7 +138,7 @@ class EmbeddingGenerator:
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
chunk_embedding_maps[d_idx][chunk_id] = embedding chunk_embedding_maps[d_idx][chunk_id] = embedding
logger.info(f"{len(all_chunks)} 个分块生成了嵌入向量") print(f"{len(all_chunks)} 个分块生成了嵌入向量")
return chunk_embedding_maps return chunk_embedding_maps
async def generate_dialog_embeddings( async def generate_dialog_embeddings(
@@ -175,7 +172,7 @@ class EmbeddingGenerator:
Returns: Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表) (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
""" """
logger.debug("=== 生成所有嵌入向量 ===") print("\n=== 生成所有嵌入向量 ===")
# 并发生成陈述句和分块嵌入向量 # 并发生成陈述句和分块嵌入向量
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather( stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
@@ -186,7 +183,9 @@ class EmbeddingGenerator:
# 对话嵌入向量(当前跳过) # 对话嵌入向量(当前跳过)
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs) dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量") print(
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
)
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
@@ -202,7 +201,7 @@ class EmbeddingGenerator:
Returns: Returns:
更新后的三元组映射列表(实体包含嵌入向量) 更新后的三元组映射列表(实体包含嵌入向量)
""" """
logger.debug("=== 生成实体嵌入向量 ===") print("\n=== 生成实体嵌入向量 ===")
entity_texts: List[str] = [] entity_texts: List[str] = []
entity_refs: List[Any] = [] entity_refs: List[Any] = []
@@ -220,7 +219,7 @@ class EmbeddingGenerator:
entity_refs.append(ent) entity_refs.append(ent)
if not entity_texts: if not entity_texts:
logger.debug("没有找到需要生成嵌入向量的实体") print("没有找到需要生成嵌入向量的实体")
return triplet_maps return triplet_maps
# 批量生成嵌入向量 # 批量生成嵌入向量
@@ -228,13 +227,13 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度 # 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))): for i in range(min(5, len(embeddings))):
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}") print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体 # 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings): for ent, emb in zip(entity_refs, embeddings):
setattr(ent, "name_embedding", emb) setattr(ent, "name_embedding", emb)
logger.info(f"{len(entity_refs)} 个实体生成了嵌入向量") print(f"{len(entity_refs)} 个实体生成了嵌入向量")
return triplet_maps return triplet_maps
@@ -297,7 +296,7 @@ async def embedding_generation_all(
Returns: Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表) (陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
""" """
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===") print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
generator = EmbeddingGenerator(embedding_id) generator = EmbeddingGenerator(embedding_id)

View File

@@ -188,6 +188,7 @@ async def _process_chunk_summary(
response_model=MemorySummaryResponse, response_model=MemorySummaryResponse,
) )
summary_text = structured.summary.strip() summary_text = structured.summary.strip()
# Generate title and type for the summary # Generate title and type for the summary
title = None title = None
episodic_type = None episodic_type = None

View File

@@ -1,7 +1,6 @@
{# {#
对话级抽取与相关性判定模板(用于剪枝加速) 对话级抽取与相关性判定模板(用于剪枝加速)
输入pruning_scene, ontology_class_infos, dialog_text, language 输入pruning_scene, ontology_classes, dialog_text, language
- ontology_class_infos: List[{class_name: str, class_description: str}]
输出:严格 JSON不要包含任何多余文本字段 输出:严格 JSON不要包含任何多余文本字段
- is_related: bool是否与所选场景相关 - is_related: bool是否与所选场景相关
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
@@ -19,16 +18,20 @@
#} #}
{# ── 确定场景说明 ── #} {# ── 确定场景说明 ── #}
{% if ontology_class_infos and ontology_class_infos | length > 0 %} {% if ontology_classes and ontology_classes | length > 0 %}
{% if language == 'en' %} {% if language == 'en' %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %} {% set custom_types_str = ontology_classes | join(', ') %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
{% else %} {% else %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %} {% set custom_types_str = ontology_classes | join('、') %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
{% endif %} {% endif %}
{% else %} {% else %}
{% if language == 'en' %} {% if language == 'en' %}
{% set custom_types_str = '' %}
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %} {% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
{% else %} {% else %}
{% set custom_types_str = '' %}
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %} {% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
{% endif %} {% endif %}
{% endif %} {% endif %}
@@ -39,17 +42,8 @@
2. 从对话中抽取所有需要保留的重要信息片段。 2. 从对话中抽取所有需要保留的重要信息片段。
场景说明:{{ instruction }} 场景说明:{{ instruction }}
{% if custom_types_str %}
{% if ontology_class_infos and ontology_class_infos | length > 0 %} 重要提示:只要对话中出现与上述实体类型({{ custom_types_str }}相关的内容即判定为相关is_related=true
【本场景实体类型定义】
以下实体类型定义了本场景中哪些内容是重要的。
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
{% for info in ontology_class_infos %}
- {{ info.class_name }}{{ info.class_description }}
{% endfor %}
重要提示只要对话中出现与上述任意实体类型相关的内容即判定为相关is_related=true
{% endif %} {% endif %}
--- ---
@@ -57,40 +51,13 @@
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段: 以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
- 时间信息:日期、时间点、时间段、有效期 → times 字段 - 时间信息:日期、时间点、时间段、有效期 → times 字段
- 编号信息学号、工号、订单号、申请号、账号、ID → ids 字段 - 编号信息学号、工号、订单号、申请号、账号、ID → ids 字段
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段) - 金额信息:价格、费用、金额(含货币符号或单位 → amounts 字段
- 联系方式电话、手机号、邮箱、微信、QQ → contacts 字段 - 联系方式电话、手机号、邮箱、微信、QQ → contacts 字段
- 地址信息:地点、地址、位置 → addresses 字段 - 地址信息:地点、地址、位置 → addresses 字段
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段) - 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段 - **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段 - **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了"→ preserve_keywords 字段 - **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
【场景无关内容标记】
请从对话中识别出与当前场景({{ pruning_scene }}**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
判断标准:
- 与场景实体类型完全无关
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
**重要scene_unrelated_snippets 必须认真填写,不能为空数组。**
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
示例(场景=在线教育):
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
示例(场景=情感陪伴):
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
- "期末考试12月25日3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
- "我上次高数作业87分这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
【可以删除的内容】 【可以删除的内容】
以下类型的内容属于低价值信息,可以在剪枝时删除: 以下类型的内容属于低价值信息,可以在剪枝时删除:
@@ -121,8 +88,7 @@
"contacts": [<string>...], "contacts": [<string>...],
"addresses": [<string>...], "addresses": [<string>...],
"keywords": [<string>...], "keywords": [<string>...],
"preserve_keywords": [<string>...], "preserve_keywords": [<string>...]
"scene_unrelated_snippets": [<string>...]
} }
{% else %} {% else %}
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks: You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
@@ -130,17 +96,8 @@ You are a dialogue content analysis assistant. Please analyze the full dialogue
2. Extract all important information fragments that must be preserved. 2. Extract all important information fragments that must be preserved.
Scenario Description: {{ instruction }} Scenario Description: {{ instruction }}
{% if custom_types_str %}
{% if ontology_class_infos and ontology_class_infos | length > 0 %} Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
[Scene Entity Type Definitions]
The following entity types define what content is important in this scene.
Content related to ANY of these types must be preserved and extracted into the keywords field:
{% for info in ontology_class_infos %}
- {{ info.class_name }}: {{ info.class_description }}
{% endfor %}
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
{% endif %} {% endif %}
--- ---
@@ -148,22 +105,13 @@ Important: If the dialogue contains content related to any of the entity types a
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields: The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
- Time information: dates, time points, durations, expiry dates → times field - Time information: dates, time points, durations, expiry dates → times field
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field - ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here) - Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
- Contact information: phone numbers, emails, WeChat, QQ → contacts field - Contact information: phone numbers, emails, WeChat, QQ → contacts field
- Address information: locations, addresses, places → addresses field - Address information: locations, addresses, places → addresses field
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here) - Scene keywords: professional terms and event names strongly related to the scene → keywords field
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field - **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field - **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field - **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
[Scene-Unrelated Content Marking]
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
Criteria:
- Completely unrelated to the scene's entity types
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
[CAN BE DELETED] [CAN BE DELETED]
The following types of content are low-value and can be removed during pruning: The following types of content are low-value and can be removed during pruning:
@@ -193,7 +141,6 @@ Output strict JSON only (fixed keys, order doesn't matter):
"contacts": [<string>...], "contacts": [<string>...],
"addresses": [<string>...], "addresses": [<string>...],
"keywords": [<string>...], "keywords": [<string>...],
"preserve_keywords": [<string>...], "preserve_keywords": [<string>...]
"scene_unrelated_snippets": [<string>...]
} }
{% endif %} {% endif %}

View File

@@ -5,15 +5,6 @@
===Task=== ===Task===
Extract entities and knowledge triplets from the given statement. Extract entities and knowledge triplets from the given statement.
**⚠️ CRITICAL REQUIREMENTS:**
1. **ALIASES ORDER IS CRITICAL**: The FIRST alias in the array will be used as the user's primary display name (other_name). You MUST put the most important/frequently used name FIRST.
2. **ALWAYS include aliases field**: Even if empty, you MUST include "aliases": [] in EVERY entity.
<!-- TODO: v0.2.10 - denied_aliases 功能暂时禁用,将通过 Cypher 查询实现
2. **DENIED_ALIASES**: When user explicitly denies a name (e.g., "我不叫X", "I'm not called X"), you MUST put X in denied_aliases field, NOT in aliases.
3. **ALWAYS include both fields**: Even if empty, you MUST include "aliases": [] and "denied_aliases": [] in EVERY entity.
-->
{% if language == "zh" %} {% if language == "zh" %}
**重要请使用中文生成实体名称name、描述description和示例example。** **重要请使用中文生成实体名称name、描述description和示例example。**
{% else %} {% else %}
@@ -27,29 +18,34 @@ Extract entities and knowledge triplets from the given statement.
{% if ontology_types %} {% if ontology_types %}
===Ontology Type Guidance=== ===Ontology Type Guidance===
**CRITICAL: Use ONLY predefined type names below. If no exact match, use CLOSEST type. NEVER invent new types.** **CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.**
**Type Priority:** **If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.**
1. [场景类型] Scene Types (domain-specific, prefer first)
2. [通用类型] General Types (standard ontologies)
3. [通用父类] Parent Types (hierarchy context)
**Rules:** **Type Priority (from highest to lowest):**
- Type MUST exactly match predefined names 1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first
- Do NOT modify, translate, or abbreviate type names 2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
- Prefer scene types over general types 3. **[通用父类] Parent Types** - Provide type hierarchy context
**Predefined Types:** **Type Matching Rules:**
- Entity type MUST exactly match one of the predefined type names below
- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list
- Do NOT modify, translate, abbreviate, or create variations of type names
- Prefer scene types (marked [场景类型]) over general types when both could apply
- If uncertain, check the type description to find the best match
**Predefined Ontology Types:**
{{ ontology_types }} {{ ontology_types }}
{% if type_hierarchy_hints %} {% if type_hierarchy_hints %}
**Hierarchy:** **Type Hierarchy Reference:**
The following shows type inheritance relationships (Child → Parent → Grandparent):
{% for hint in type_hierarchy_hints %} {% for hint in type_hierarchy_hints %}
- {{ hint }} - {{ hint }}
{% endfor %} {% endfor %}
{% endif %} {% endif %}
**ALLOWED Names:** **ALLOWED Type Names (use EXACTLY one of these, no exceptions):**
{{ ontology_type_names | join(', ') }} {{ ontology_type_names | join(', ') }}
{% endif %} {% endif %}
@@ -66,94 +62,66 @@ Extract entities and knowledge triplets from the given statement.
- **Entity descriptions must be in English** - **Entity descriptions must be in English**
- **Examples must be in English** - **Examples must be in English**
{% endif %} {% endif %}
- **Semantic Memory (is_explicit_memory):** - **Semantic Memory Classification (is_explicit_memory):**
* `true` for: Concepts, Knowledge, Definitions, Theories, Methods (e.g., "Machine Learning", "REST API") * Set to `true` if the entity represents **explicit/semantic memory**:
* `false` for: People, Organizations, Locations, Events, Specific objects - **Concepts:** "Machine Learning", "Photosynthesis", "Democracy"
* For `is_explicit_memory=true`, provide concise example (~20 chars{% if language == "zh" %},使用中文{% endif %}) - **Knowledge:** "Python Programming Language", "Theory of Relativity"
- **Definitions:** "API (Application Programming Interface)", "REST API"
**🚨🚨🚨 ALIASES & DENIED_ALIASES - MANDATORY FIELDS 🚨🚨🚨** - **Principles:** "SOLID Principles", "First Law of Thermodynamics"
- **Theories:** "Evolution Theory", "Quantum Mechanics"
**CRITICAL RULES (违反将导致提取失败):** - **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm"
- **Technical Terms:** "Neural Network", "Database"
1. **EVERY entity MUST have aliases field:** * Set to `false` for:
- `"aliases": [...]` - REQUIRED, even if empty `[]` - **People:** "John Smith", "Dr. Wang"
- **Organizations:** "Microsoft", "Harvard University"
2. **ALIASES - 别名提取规则:** - **Locations:** "Beijing", "Central Park"
- **Events:** "2024 Conference", "Project Meeting"
- **Specific objects:** "iPhone 15", "Building A"
- **Example Generation (IMPORTANT for semantic memory entities):**
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
* The example should be:
- **Specific and concrete**: Use real-world scenarios or applications
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
{% if language == "zh" %} {% if language == "zh" %}
- 包含:昵称、全名、简称、别称、网名等 - **使用中文**
- 顺序:**第一个别名将作为用户的主显示名称other_name必须把最重要/最常用的名字放在第一位**
- 提取顺序:严格按照对话中首次出现的顺序
- 示例:
* "我叫张三,大家叫我小张" → aliases=["张三", "小张"](张三是第一个,将成为 other_name
* "大家叫我小李,我全名叫李明" → aliases=["小李", "李明"](小李先出现,将成为 other_name
- 空值:如果没有别名,使用 `[]`
- 重要:只提取本次对话中明确提到的别名,不要推测或添加未提及的名字
{% else %} {% else %}
- Include: nicknames, full names, abbreviations, alternative names - **In English**
- Order: **The FIRST alias will be used as the user's primary display name (other_name). Put the most important/frequently used name FIRST**
- Extraction order: Strictly follow the order of first appearance in conversation
- Examples:
* "I'm John, people call me Johnny" → aliases=["John", "Johnny"] (John is first, will become other_name)
* "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
- Empty: If no aliases, use `[]`
- Important: Only extract aliases explicitly mentioned in current conversation, do not infer or add unmentioned names
{% endif %} {% endif %}
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
- **Aliases Extraction:**
3. **USER ENTITY SPECIAL HANDLING:**
{% if language == "zh" %} {% if language == "zh" %}
- 用户实体的 name 字段:使用 "用户" 或 "我" * 别名使用中文
- 用户的真实姓名:放入 aliases
- **🚨 禁止将 "用户"、"我" 放入 aliases 中aliases 只能包含用户的真实姓名、昵称等**
- 示例:
* "我叫李明" → name="用户", aliases=["李明"]
* ❌ 错误aliases=["用户", "李明"]"用户"不是真实姓名,禁止放入 aliases
* ❌ 错误aliases=["我", "李明"]"我"不是真实姓名,禁止放入 aliases
{% else %} {% else %}
- User entity name field: use "User" or "I" * Aliases should be in English
- User's real name: put in aliases
- **🚨 NEVER put "User" or "I" in aliases. Aliases must only contain real names, nicknames, etc.**
- Examples:
* "I'm John" → name="User", aliases=["John"]
* ❌ Wrong: aliases=["User", "John"] ("User" is not a real name, FORBIDDEN in aliases)
* ❌ Wrong: aliases=["I", "John"] ("I" is not a real name, FORBIDDEN in aliases)
{% endif %} {% endif %}
* Include common alternative names, abbreviations and full names
* If no aliases exist, use empty array: []
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
4. **ALIASES ORDER:** - For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
{% if language == "zh" %} Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
- 顺序优先级:按出现顺序,先出现的在前
{% else %}
- Order priority: by appearance order, first mentioned comes first
{% endif %}
**EXAMPLES OF CORRECT EXTRACTION:**
{% if language == "zh" %}
- "我叫张三" → aliases=["张三"] (张三将成为 other_name
- "大家叫我小明,我全名叫李明" → aliases=["小明", "李明"] (小明先出现,将成为 other_name
- "我是李华,网名叫华仔" → aliases=["李华", "华仔"] (李华先出现,将成为 other_name
{% else %}
- "I'm John" → aliases=["John"] (John will become other_name)
- "People call me Mike, my full name is Michael" → aliases=["Mike", "Michael"] (Mike appears first, will become other_name)
- "I'm John Smith, username JSmith" → aliases=["John Smith", "JSmith"] (John Smith appears first, will become other_name)
{% endif %}
- Exclude lengthy quotes, dates, temporal expressions
- Numeric values: extract as entities (instance_of: 'Numeric', name: units, numeric_value: value)
**Triplet Extraction:** **Triplet Extraction:**
- Extract (subject, predicate, object) where subject/object are entities, predicate is relationship - Extract (subject, predicate, object) triplets where:
- Subject: main entity performing the action or being described
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
- Object: entity, value, or concept affected by the predicate
{% if language == "zh" %} {% if language == "zh" %}
- subject_name 和 object_name 使用中文 - subject_name 和 object_name 必须使用中文
{% else %} {% else %}
- subject_name and object_name in English - subject_name and object_name must be in English (translate if original is in another language)
{% endif %} {% endif %}
- Use ONLY predicates from "Predicate Instructions" (uppercase tokens) - Exclude all temporal expressions from every field
- Exclude temporal expressions, do NOT include `statement_id` - Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
- **When NOT to extract:** emotions, fillers, no clear predicate, standalone nouns - Do NOT translate predicate tokens
- **If no valid triplet:** Return triplets: [] - Do NOT include `statement_id` field (assigned automatically)
**When NOT to extract triplets:**
- Non-propositional utterances (emotions, fillers, onomatopoeia)
- No clear predicate from the given definitions applies
- Standalone noun phrases or checklist items → extract as entities only
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
{%- if predicate_instructions -%} {%- if predicate_instructions -%}
**Predicate Instructions:** **Predicate Instructions:**
@@ -239,44 +207,26 @@ Output:
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false} {"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
] ]
} }
**Example 4 (别名 - Chinese):** "我的名字是乐力齐,我的小名是齐齐,同事们都叫我小乐"
Output:
{
"triplets": [],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["乐力齐", "齐齐", "小乐"], "is_explicit_memory": false}
]
}
**Example 5 (别名顺序 - Chinese):** "我叫陈思远。对了,我的网名叫「远山」"
Output:
{
"triplets": [],
"entities": [
{"entity_idx": 0, "name": "用户", "type": "Person", "description": "用户本人", "example": "", "aliases": ["陈思远", "远山"], "is_explicit_memory": false}
]
}
{% endif %} {% endif %}
===End of Examples=== ===End of Examples===
{% if ontology_types %} {% if ontology_types %}
**⚠️ REMINDER: Examples use generic types for illustration. You MUST use predefined types from "ALLOWED Names" above.** **⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.**
{% endif %} {% endif %}
===Output Format=== ===Output Format===
**JSON Requirements:** **JSON Requirements:**
- Use ASCII double quotes ("), escape with \" - Use only ASCII double quotes (") for JSON structure
- No Chinese quotes (""), no line breaks in strings - Never use Chinese quotation marks ("") or Unicode quotes
- Escape quotation marks in text with backslashes (\")
- Ensure proper string closure and comma separation
- No line breaks within JSON string values
{% if language == "zh" %} {% if language == "zh" %}
- **语言name、descriptionexample、subject_name、object_name 使用中文** - **语言要求实体名称name、描述description)、示例(example、subject_name、object_name 必须使用中文**
{% else %} {% else %}
- **Language: names, descriptions, examples in English (translate if needed)** - **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English**
- **If the original text is in Chinese, translate all names to English**
{% endif %} {% endif %}
- **⚠️ ALIASES ORDER: preserve temporal order of appearance**
- **🚨 MANDATORY FIELD: EVERY entity MUST include "aliases" field, even if empty array []**
{{ json_schema }} {{ json_schema }}

View File

@@ -2,7 +2,6 @@ from .base import RedBearModelConfig, get_provider_llm_class, RedBearModelFacto
from .llm import RedBearLLM from .llm import RedBearLLM
from .embedding import RedBearEmbeddings from .embedding import RedBearEmbeddings
from .rerank import RedBearRerank from .rerank import RedBearRerank
from .generation import RedBearImageGenerator, RedBearVideoGenerator
__all__ = [ __all__ = [
"RedBearModelConfig", "RedBearModelConfig",
@@ -10,7 +9,5 @@ __all__ = [
"RedBearEmbeddings", "RedBearEmbeddings",
"RedBearRerank", "RedBearRerank",
"RedBearModelFactory", "RedBearModelFactory",
"get_provider_llm_class", "get_provider_llm_class"
"RedBearImageGenerator",
"RedBearVideoGenerator"
] ]

View File

@@ -67,7 +67,7 @@ class RedBearModelFactory:
**config.extra_params **config.extra_params
} }
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]:
# 使用 httpx.Timeout 对象来设置详细的超时配置 # 使用 httpx.Timeout 对象来设置详细的超时配置
# 这样可以分别控制连接超时和读取超时 # 这样可以分别控制连接超时和读取超时
import httpx import httpx
@@ -160,13 +160,11 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
# dashscope 的 omni 模型使用 OpenAI 兼容模式 # dashscope 的 omni 模型使用 OpenAI 兼容模式
if provider == ModelProvider.DASHSCOPE and config.is_omni: if provider == ModelProvider.DASHSCOPE and config.is_omni:
return ChatOpenAI return ChatOpenAI
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.VOLCANO]: if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
if type == ModelType.LLM: if type == ModelType.LLM:
return OpenAI return OpenAI
elif type == ModelType.CHAT: elif type == ModelType.CHAT:
return ChatOpenAI return ChatOpenAI
else:
raise BusinessException(f"不支持的模型提供商及类型: {provider}-{type}", code=BizCode.PROVIDER_NOT_SUPPORTED)
elif provider == ModelProvider.DASHSCOPE: elif provider == ModelProvider.DASHSCOPE:
return ChatTongyi return ChatTongyi
elif provider == ModelProvider.OLLAMA: elif provider == ModelProvider.OLLAMA:

View File

@@ -1,190 +1,23 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, TypeVar, Callable
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from app.core.models.base import RedBearModelConfig, get_provider_embedding_class, RedBearModelFactory from app.core.models.base import RedBearModelConfig,get_provider_embedding_class,RedBearModelFactory
from app.models.models_model import ModelProvider
class RedBearEmbeddings(Embeddings): class RedBearEmbeddings(Embeddings):
"""统一的 Embedding 类,自动支持多模态(根据 provider 判断)""" """Embedding → 完全符合 LangChain Embeddings"""
def __init__(self, config: RedBearModelConfig): def __init__(self, config: RedBearModelConfig):
self._model = self._create_model(config)
self._config = config self._config = config
self._is_volcano = config.provider.lower() == ModelProvider.VOLCANO
if self._is_volcano:
# 火山引擎使用 Ark SDK
self._client = self._create_volcano_client(config)
self._model = None
else:
# 其他 provider 使用 LangChain
self._model = self._create_model(config)
self._client = None
def _create_model(self, config: RedBearModelConfig) -> Embeddings: def _create_model(self, config: RedBearModelConfig) -> Embeddings:
"""根据配置创建 LangChain 模型""" """根据配置创建模型"""
embedding_class = get_provider_embedding_class(config.provider) embedding_class = get_provider_embedding_class(config.provider)
model_params = RedBearModelFactory.get_model_params(config) model_params = RedBearModelFactory.get_model_params(config)
return embedding_class(**model_params) return embedding_class(**model_params)
def _create_volcano_client(self, config: RedBearModelConfig):
"""创建火山引擎客户端"""
from volcenginesdkarkruntime import Ark
return Ark(api_key=config.api_key, base_url=config.base_url)
# ==================== LangChain 标准接口 ====================
def embed_documents(self, texts: list[str]) -> list[list[float]]: def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""批量文本向量化LangChain 标准接口)""" return self._model.embed_documents(texts)
if self._is_volcano:
# 火山引擎多模态 Embedding
contents = [{"type": "text", "text": text} for text in texts]
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
encoding_format="float"
)
return [response.data.embedding]
else:
# 其他 provider
return self._model.embed_documents(texts)
def embed_query(self, text: str) -> List[float]: def embed_query(self, text: str) -> List[float]:
"""单个文本向量化LangChain 标准接口)""" return self._model.embed_query(text)
if self._is_volcano:
# 火山引擎多模态 Embedding
result = self.embed_documents([text])
return result[0] if result else []
else:
# 其他 provider
return self._model.embed_query(text)
# ==================== 多模态扩展方法 ====================
def embed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""
多模态向量化(仅火山引擎支持)
Args:
contents: 内容列表,格式:
- 文本: {"type": "text", "text": "..."}
- 图片: {"type": "image_url", "image_url": {"url": "..."}}
- 视频: {"type": "video_url", "video_url": {"url": "..."}}
**kwargs: 其他参数
Returns:
向量列表
"""
if not self._is_volcano:
raise NotImplementedError(
f"多模态 Embedding 仅支持火山引擎,当前 provider: {self._config.provider}"
)
response = self._client.multimodal_embeddings.create(
model=self._config.model_name,
input=contents,
**kwargs
)
return [response.data.embedding]
async def aembed_multimodal(
self,
contents: List[Dict[str, Any]],
**kwargs
) -> List[List[float]]:
"""异步多模态向量化"""
# 火山引擎 SDK 暂不支持异步,使用同步方法
return self.embed_multimodal(contents, **kwargs)
def embed_text(self, text: str, **kwargs) -> List[float]:
"""文本向量化(便捷方法)"""
if self._is_volcano:
result = self.embed_multimodal(
[{"type": "text", "text": text}],
**kwargs
)
return result[0] if result else []
else:
return self.embed_query(text)
def embed_image(self, image_url: str, **kwargs) -> List[float]:
"""图片向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"图片向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "image_url", "image_url": {"url": image_url}}],
**kwargs
)
return result[0] if result else []
def embed_video(self, video_url: str, **kwargs) -> List[float]:
"""视频向量化(仅火山引擎支持)"""
if not self._is_volcano:
raise NotImplementedError(
f"视频向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
result = self.embed_multimodal(
[{"type": "video_url", "video_url": {"url": video_url}}],
**kwargs
)
return result[0] if result else []
def embed_batch(
self,
items: List[Union[str, Dict[str, Any]]],
**kwargs
) -> List[List[float]]:
"""
批量向量化(支持混合类型)
Args:
items: 可以是字符串列表或内容字典列表
**kwargs: 其他参数
Returns:
向量列表
"""
# 如果全是字符串,使用标准方法
if all(isinstance(item, str) for item in items):
return self.embed_documents(items)
# 如果包含字典,需要多模态支持
if not self._is_volcano:
raise NotImplementedError(
f"混合类型批量向量化仅支持火山引擎,当前 provider: {self._config.provider}"
)
# 标准化输入格式
contents = []
for item in items:
if isinstance(item, str):
contents.append({"type": "text", "text": item})
elif isinstance(item, dict):
contents.append(item)
else:
raise ValueError(f"不支持的输入类型: {type(item)}")
return self.embed_multimodal(contents, **kwargs)
# ==================== 工具方法 ====================
def is_multimodal_supported(self) -> bool:
"""检查是否支持多模态"""
return self._is_volcano
def get_provider(self) -> str:
"""获取 provider"""
return self._config.provider
# 保留 RedBearMultimodalEmbeddings 作为别名,向后兼容
RedBearMultimodalEmbeddings = RedBearEmbeddings

View File

@@ -1,344 +0,0 @@
"""
图片和视频生成模型封装
支持的 Provider:
- Volcano (火山引擎): 使用 volcenginesdkarkruntime
- OpenAI: 使用 openai SDK
"""
from typing import Any, Dict, Optional
from volcenginesdkarkruntime import Ark
from volcenginesdkarkruntime.types.images.images import (
SequentialImageGenerationOptions,
ContentGenerationTool,
OptimizePromptOptions
)
from app.core.models.base import RedBearModelConfig
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.models.models_model import ModelProvider
class RedBearImageGenerator:
"""图片生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
# elif provider == ModelProvider.OPENAI:
# from openai import OpenAI
# return OpenAI(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的图片生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
sequential_image_generation: Optional[str] = None,
sequential_image_generation_options: Optional[Dict] = None,
tools: Optional[list] = None,
optimize_prompt_options: Optional[Dict] = None,
stream: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成图片
Args:
prompt: 提示词
image: 参考图片URL或URL列表图文生图/多图融合)
size: 图片尺寸,支持 "2K", "2048x2048", "1920x1080"至少3686400像素
output_format: 输出格式,如 "png", "jpg"
response_format: 返回格式,"url""b64_json"
watermark: 是否添加水印
sequential_image_generation: 组图生成模式,"auto""disabled"
sequential_image_generation_options: 组图生成选项,如 {"max_images": 4}
tools: 工具列表,如 [{"type": "web_search"}] 用于联网搜索生图
optimize_prompt_options: 提示词优化选项,如 {"mode": "fast"}
stream: 是否使用流式生成
**kwargs: 其他参数
Returns:
生成结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {
"model": self._config.model_name,
"prompt": prompt,
"size": size,
"output_format": output_format,
"response_format": response_format,
"watermark": watermark,
}
if image is not None:
params["image"] = image
if sequential_image_generation:
params["sequential_image_generation"] = sequential_image_generation
if sequential_image_generation_options:
params["sequential_image_generation_options"] = SequentialImageGenerationOptions(
**sequential_image_generation_options
)
if tools:
params["tools"] = [ContentGenerationTool(**tool) if isinstance(tool, dict) else tool for tool in tools]
if optimize_prompt_options:
params["optimize_prompt_options"] = OptimizePromptOptions(**optimize_prompt_options)
if stream:
params["stream"] = True
params.update(kwargs)
response = self._client.images.generate(**params)
# elif provider == ModelProvider.OPENAI:
# response = self._client.images.generate(
# model=self._config.model_name,
# prompt=prompt,
# size=size,
# n=n,
# **kwargs
# )
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image: Optional[Any] = None,
size: Optional[str] = "2K",
output_format: str = "png",
response_format: str = "url",
watermark: bool = False,
**kwargs
) -> Dict[str, Any]:
"""异步生成图片"""
return self.generate(prompt, image, size, output_format, response_format, watermark, **kwargs)
class RedBearVideoGenerator:
"""视频生成模型封装"""
def __init__(self, config: RedBearModelConfig):
self._config = config
self._client = self._create_client(config)
def _create_client(self, config: RedBearModelConfig):
"""根据 provider 创建客户端"""
provider = config.provider.lower()
if provider == ModelProvider.VOLCANO:
return Ark(api_key=config.api_key, base_url=config.base_url)
else:
raise BusinessException(
f"不支持的视频生成提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def generate(
self,
prompt: str,
image_url: Optional[str] = None,
first_frame_url: Optional[str] = None,
last_frame_url: Optional[str] = None,
reference_images: Optional[list] = None,
draft_task_id: Optional[str] = None,
duration: Optional[int] = None,
frames: Optional[int] = None,
ratio: Optional[str] = None,
resolution: Optional[str] = None,
generate_audio: bool = False,
watermark: bool = False,
camera_fixed: bool = False,
seed: Optional[int] = None,
return_last_frame: bool = False,
service_tier: str = "default",
execution_expires_after: Optional[int] = None,
draft: bool = False,
**kwargs
) -> Dict[str, Any]:
"""
生成视频
Args:
prompt: 提示词
image_url: 首帧图片URL图生视频-基于首帧)
first_frame_url: 首帧图片URL图生视频-基于首尾帧)
last_frame_url: 尾帧图片URL图生视频-基于首尾帧)
reference_images: 参考图片URL列表图生视频-基于参考图)
draft_task_id: Draft任务ID基于Draft生成正式视频
duration: 视频时长与frames二选一
frames: 视频帧数与duration二选一
ratio: 视频比例,如 "16:9", "9:16", "adaptive"
resolution: 视频分辨率,如 "720p", "1080p"
generate_audio: 是否生成音频
watermark: 是否添加水印
camera_fixed: 是否固定镜头
seed: 随机种子
return_last_frame: 是否返回最后一帧
service_tier: 服务层级,"default""flex"(离线推理)
execution_expires_after: 任务过期时间(秒)
draft: 是否生成样片
**kwargs: 其他参数
Returns:
生成结果包含任务ID需要轮询获取结果
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
content = [{"type": "text", "text": prompt}]
if draft_task_id:
content = [{"type": "draft_task", "draft_task": {"id": draft_task_id}}]
else:
if image_url:
content.append({"type": "image_url", "image_url": {"url": image_url}})
if first_frame_url:
content.append({"type": "image_url", "image_url": {"url": first_frame_url}, "role": "first_frame"})
if last_frame_url:
content.append({"type": "image_url", "image_url": {"url": last_frame_url}, "role": "last_frame"})
if reference_images:
for ref_url in reference_images:
content.append({"type": "image_url", "image_url": {"url": ref_url}, "role": "reference_image"})
params = {"model": self._config.model_name, "content": content, "watermark": watermark}
if duration:
params["duration"] = duration
if frames:
params["frames"] = frames
if ratio:
params["ratio"] = ratio
if resolution:
params["resolution"] = resolution
if generate_audio:
params["generate_audio"] = generate_audio
if camera_fixed:
params["camera_fixed"] = camera_fixed
if seed is not None:
params["seed"] = seed
if return_last_frame:
params["return_last_frame"] = return_last_frame
if service_tier != "default":
params["service_tier"] = service_tier
if execution_expires_after:
params["execution_expires_after"] = execution_expires_after
if draft:
params["draft"] = draft
params.update(kwargs)
response = self._client.content_generation.tasks.create(**params)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
return response.model_dump() if hasattr(response, 'model_dump') else response
async def agenerate(
self,
prompt: str,
image_url: Optional[str] = None,
duration: Optional[int] = None,
**kwargs
) -> Dict[str, Any]:
"""异步生成视频"""
return self.generate(prompt, image_url=image_url, duration=duration, **kwargs)
def get_task_status(self, task_id: str) -> Dict[str, Any]:
"""
查询视频生成任务状态
Args:
task_id: 任务ID
Returns:
任务状态信息
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
response = self._client.content_generation.tasks.get(task_id=task_id)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
async def aget_task_status(self, task_id: str) -> Dict[str, Any]:
"""异步查询任务状态"""
return self.get_task_status(task_id)
def list_tasks(self, page_size: int = 10, status: Optional[str] = None, **kwargs) -> Dict[str, Any]:
"""
查询视频生成任务列表
Args:
page_size: 每页数量
status: 任务状态筛选,如 "succeeded", "failed", "pending"
**kwargs: 其他参数
Returns:
任务列表
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
params = {"page_size": page_size}
if status:
params["status"] = status
params.update(kwargs)
response = self._client.content_generation.tasks.list(**params)
return response.model_dump() if hasattr(response, 'model_dump') else response
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)
def delete_task(self, task_id: str) -> None:
"""
删除或取消视频生成任务
Args:
task_id: 任务ID
"""
provider = self._config.provider.lower()
if provider == ModelProvider.VOLCANO:
self._client.content_generation.tasks.delete(task_id=task_id)
else:
raise BusinessException(
f"不支持的提供商: {provider}",
code=BizCode.PROVIDER_NOT_SUPPORTED
)

View File

@@ -1,334 +0,0 @@
provider: volcano
models:
# Doubao-Seed 2.0 系列
- name: doubao-seed-2-0-pro-260215
type: chat
provider: volcano
description: 旗舰级全能通用模型,面向 Agent 时代的复杂推理与长链路任务执行场景。强调多模态理解、长上下文推理、结构化生成与工具增强执行。复杂指令与多约束执行能力突出,可稳定应对多步复杂规划、复杂图文推理、视频内容理解与高难度分析等场景。侧重长链路推理能力与复杂任务稳定性,适配真实业务中的复杂场景。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-lite-260215
type: chat
provider: volcano
description: 面向高频企业场景兼顾性能与成本的均衡型模型综合能力超越上一代Doubao-Seed-1.8。胜任非结构化信息处理、内容创作、搜索推荐、数据分析等生产型工作,支持长上下文、多源信息融合、多步指令执行与高保真结构化输出。在保障稳定效果的同时显著优化成本。兼顾生成质量与响应速度,适合作为通用生产级模型。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-mini-260215
type: chat
provider: volcano
description: 面向低时延、高并发与成本敏感场景提供极致的模型推理速度。模型效果与Doubao-Seed-1.6相当。支持256k上下文、4档思考长度和多模态理解适合成本和速度优先的轻量级任务。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-2-0-code-preview-260215
type: chat
provider: volcano
description: 面向真实编程环境优化的 Coding 模型,能稳定调用 Claude Code 等常见 IDE 中的工具。模型特别优化了前端能力,在使用常见的前端框架时能有良好表现。模型支持使用 Skills可以配合多种自定义技能使用。Seed 2.0 的编程加强版,更适合 Agentic Coding。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
# Doubao-Seed 1.x 系列
- name: doubao-seed-1-8-251228
type: chat
provider: volcano
description: Doubao-Seed-1.8 面向多模态 Agent 场景定向优化。Agent 能力上Tool Use、复杂指令遵循等能力均大幅增强。多模态理解方面视觉基础能力显著提升可低帧率理解超长视频视频运动理解、复杂空间理解及文档结构化解析能力也有所优化还原生支持智能上下文管理用户可配置上下文策略。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-251015
type: chat
provider: volcano
description: Doubao-Seed-1.6全新多模态深度思考模型同时支持minimal/low/medium/high 四种reasoning effort。 更强模型效果,服务复杂任务和有挑战场景。支持 256k 上下文窗口,输出长度支持最大 32k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-lite-251015
type: chat
provider: volcano
description: 更高性价比常见任务的最佳选择支持minimal、low、medium、high 四种reasoning_effort思考深度
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-1-6-flash-250828
type: chat
provider: volcano
description: Doubao-Seed-1.6-flash推理速度极致的多模态深度思考模型TPOT低至10ms 同时支持文本和视觉理解文本理解能力超过上一代lite视觉理解比肩友商pro系列模型。支持 256k 上下文窗口,输出长度支持最大 16k tokens。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-seed-code-preview-251028
type: chat
provider: volcano
description: 面向Agentic编程任务进行了深度优化。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 代码模型
logo: volcano
- name: doubao-seed-1-6-vision-250815
type: chat
provider: volcano
description: 全新Doubao-Seed-1.6系列视觉深度思考模型视觉理解能力显著增强并支持image_process视觉工具
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
# Doubao 1.5 系列
- name: doubao-1-5-vision-pro-32k-250115
type: chat
provider: volcano
description: 全新升级的多模态大模型,支持任意分辨率和极端长宽比图像识别,增强视觉推理、文档识别、细节信息理解和指令遵循能力。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 大语言模型
- 多模态模型
logo: volcano
- name: doubao-1-5-pro-32k-250115
type: chat
provider: volcano
description: 全新一代主力模型,性能全面升级,在知识、代码、推理等方面表现卓越。最大支持 128k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
- name: doubao-1-5-lite-32k-250115
type: chat
provider: volcano
description: 全新一代轻量版模型,极致响应速度,效果与时延均达到全球一流水平。支持 32k 上下文窗口,输出长度支持最大 12k tokens。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 大语言模型
logo: volcano
# Doubao-Seedance 视频生成系列
- name: doubao-seedance-1-5-pro-251215
type: video
provider: volcano
description: 豆包视频生成模型Seedance 1.5 pro 作为全球领先的视频生成模型,可生成音画高精同步的视频内容。支持多人多语言对白,全面覆盖环境音、动作音、合成音、乐器音、背景音及人声,支持首尾帧,实现影视级叙事效果,满足影视、漫剧、电商及广告领域的高阶创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-250528
type: video
provider: volcano
description: 一款支持多镜头叙事的视频生成基础模型,在各维度表现出色。它在语义理解与指令遵循能力上取得突破,能生成运动流畅、细节丰富、风格多样且具备影视级美感的 1080P 高清视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-pro-fast-251015
type: video
provider: volcano
description: 一款价格触底、效能封顶的全面模型在视频生成质量、速度、价格之间取得了卓越平衡。它继承了Seedance 1.0 pro 核心优势,同时生成速度提升、价格更具竞争力,为创作者带来效率与成本双重优化的体验。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
logo: volcano
- name: doubao-seedance-1-0-lite-i2v-250428
type: video
provider: volcano
description: 基于首帧图片、尾帧图片(可选)、参考图片(可选)和文本提示词(可选)相结合的方式生成视频
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 视频生成
- 图生视频
logo: volcano
- name: doubao-seedance-1-0-lite-t2v-250428
type: video
provider: volcano
description: 基于文本提示词生成视频
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 视频生成
- 文生视频
logo: volcano
# Doubao-Seedream 图像生成系列
- name: doubao-seedream-5-0-260128
type: image
provider: volcano
description: 字节跳动发布的最新图像创作模型。该模型首次搭载联网检索功能,能融合实时网络信息,提升生图时效性。同时,模型的聪明度进一步升级,能够精准解析复杂指令和视觉内容。此外,模型在世界知识广度、参考一致性及专业场景生成质量上均有增强,可更好地满足企业级视觉创作需求。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-5-251128
type: image
provider: volcano
description: 字节跳动最新推出的图像多模态模型整合了文生图、图生图、组图输出等能力融合常识和推理能力。相比前代4.0模型生成效果大幅提升,具备更好的编辑一致性和多图融合效果,能更精准的控制画面细节,小字、小人脸生成更自然,图片排版、色彩更和谐,美感提升。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-4-0-250828
type: image
provider: volcano
description: 基于领先架构的SOTA级多模态图像创作模型其生成美感、指令遵循、结构完整度、主体保持一致性处于世界头部水平。模型采用同一套架构实现文生图与编辑能力的统一原生支持文本 、单图和多图输入,并能通过对提示词的深度推理,自动适配最优的图像比例尺寸与生成数量,可一次性连续输出最多 15 张内容关联的图像,支持 4K 超高清输出。
is_deprecated: false
is_official: true
capability:
- vision
is_omni: false
tags:
- 图像生成
logo: volcano
- name: doubao-seedream-3-0-t2i-250415
type: image
provider: volcano
description: 一款支持原生高分辨率的中英双语图像生成基础模型综合能力媲美GPT-4o处于世界第一梯队。支持原生 2K 分辨率输出;响应速度更快;小字生成更准确,文本排版效果增强;指令遵循能力强,美感&结构提升,保真度和细节表现较好。
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 图像生成
- 文生图
logo: volcano
# Doubao 翻译系列
- name: doubao-seed-translation-250915
type: chat
provider: volcano
description: 通用多语言翻译模型支持30余种语言互译支持 4K 上下文窗口,输出长度支持最大 3K tokens
is_deprecated: false
is_official: true
capability: []
is_omni: false
tags:
- 翻译模型
logo: volcano
# Doubao Embedding 系列
- name: doubao-embedding-vision-251215
type: embedding
provider: volcano
description: 主要面向图文多模向量检索的使用场景,支持图片输入及中、英双语文本输入,最长 128K 上下文长度。
is_deprecated: false
is_official: true
capability:
- vision
- video
is_omni: false
tags:
- 向量模型
- 多模态模型
logo: volcano

View File

@@ -61,16 +61,24 @@ class ElasticSearchConfig(BaseModel):
class ElasticSearchVector(BaseVector): class ElasticSearchVector(BaseVector):
def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey): def __init__(self, index_name: str, config: ElasticSearchConfig, embedding_config: ModelApiKey, reranker_config: ModelApiKey):
super().__init__(index_name.lower()) super().__init__(index_name.lower())
# self.embeddings = XinferenceEmbeddings(
# 初始化 Embedding 模型(自动支持火山引擎多模态) # server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"), # Default Xinference port
# model_uid="bge-m3" # replace model_uid with the model UID return from launching the model
# )
# Remove debug printing to avoid leaking sensitive information
# print("embedding:" + embedding_config.model_name + "|" + embedding_config.provider + "|" + embedding_config.api_key + "|" + embedding_config.api_base)
self.embeddings = RedBearEmbeddings(RedBearModelConfig( self.embeddings = RedBearEmbeddings(RedBearModelConfig(
model_name=embedding_config.model_name, model_name=embedding_config.model_name,
provider=embedding_config.provider, provider=embedding_config.provider,
api_key=embedding_config.api_key, api_key=embedding_config.api_key,
base_url=embedding_config.api_base base_url=embedding_config.api_base
)) ))
self.is_multimodal_embedding = self.embeddings.is_multimodal_supported() # self.reranker = XinferenceRerank(
# server_url=os.getenv("XINFERENCE_URL", "http://127.0.0.1"),
# model_uid="bge-reranker-large"
# )
# Remove debug printing to avoid leaking sensitive information
# print("reranker:"+ reranker_config.model_name + "|" + reranker_config.provider + "|" + reranker_config.api_key + "|" + reranker_config.api_base)
self.reranker = RedBearRerank(RedBearModelConfig( self.reranker = RedBearRerank(RedBearModelConfig(
model_name=reranker_config.model_name, model_name=reranker_config.model_name,
provider=reranker_config.provider, provider=reranker_config.provider,
@@ -136,11 +144,7 @@ class ElasticSearchVector(BaseVector):
def add_chunks(self, chunks: list[DocumentChunk], **kwargs): def add_chunks(self, chunks: list[DocumentChunk], **kwargs):
# 实现 Elasticsearch 保存向量 # 实现 Elasticsearch 保存向量
texts = [chunk.page_content for chunk in chunks] texts = [chunk.page_content for chunk in chunks]
if self.is_multimodal_embedding: embeddings = self.embeddings.embed_documents(list(texts))
# 火山引擎多模态 Embedding
embeddings = self.embeddings.embed_batch(texts)
else:
embeddings = self.embeddings.embed_documents(list(texts))
self.create(chunks, embeddings, **kwargs) self.create(chunks, embeddings, **kwargs)
def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs): def create(self, chunks: list[DocumentChunk], embeddings: list[list[float]], **kwargs):
@@ -390,11 +394,7 @@ class ElasticSearchVector(BaseVector):
updated count. updated count.
""" """
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"
if self.is_multimodal_embedding: chunk.vector = self.embeddings.embed_query(chunk.page_content)
# 火山引擎多模态 Embedding
chunk.vector = self.embeddings.embed_text(chunk.page_content)
else:
chunk.vector = self.embeddings.embed_query(chunk.page_content)
body = { body = {
"script": { "script": {
@@ -454,11 +454,7 @@ class ElasticSearchVector(BaseVector):
def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]: def search_by_vector(self, query: str, **kwargs: Any) -> list[DocumentChunk]:
"""Search the nearest neighbors to a vector.""" """Search the nearest neighbors to a vector."""
if self.is_multimodal_embedding: query_vector = self.embeddings.embed_query(query)
# 火山引擎多模态 Embedding
query_vector = self.embeddings.embed_text(query)
else:
query_vector = self.embeddings.embed_query(query)
top_k = kwargs.get("top_k", 1024) top_k = kwargs.get("top_k", 1024)
score_threshold = float(kwargs.get("score_threshold") or 0.3) score_threshold = float(kwargs.get("score_threshold") or 0.3)
indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3" indices = kwargs.get("indices", self._collection_name) # Default single index, multi-index availableetc "index1,index2,index3"

View File

@@ -109,13 +109,17 @@ class StorageBackend(ABC):
pass pass
@abstractmethod @abstractmethod
async def get_url( async def get_url(self, file_key: str, expires: int = 3600) -> str:
self, """
file_key: str, Get an access URL for the file.
expires: int = 3600,
file_name: Optional[str] = None Args:
) -> str: file_key: Unique identifier for the file in the storage system.
"""Get an access URL for the file.""" expires: URL validity period in seconds (default: 1 hour).
Returns:
URL for accessing the file.
"""
pass pass
async def get_permanent_url(self, file_key: str) -> Optional[str]: async def get_permanent_url(self, file_key: str) -> Optional[str]:

View File

@@ -210,12 +210,7 @@ class LocalStorage(StorageBackend):
cause=e, cause=e,
) )
async def get_url( async def get_url(self, file_key: str, expires: int = 3600) -> str:
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None
) -> str:
""" """
Get an access URL for the file. Get an access URL for the file.
@@ -225,7 +220,6 @@ class LocalStorage(StorageBackend):
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (not used for local storage). expires: URL validity period in seconds (not used for local storage).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A relative URL path for accessing the file. A relative URL path for accessing the file.

View File

@@ -7,7 +7,6 @@ Storage Service (OSS) using the oss2 SDK.
import io import io
import logging import logging
import urllib.parse
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
import oss2 import oss2
@@ -44,8 +43,6 @@ class OSSStorage(StorageBackend):
access_key_id: str, access_key_id: str,
access_key_secret: str, access_key_secret: str,
bucket_name: str, bucket_name: str,
connect_timeout: int = 30,
multipart_threshold: int = 10 * 1024 * 1024, # 10MB
): ):
""" """
Initialize the OSSStorage backend. Initialize the OSSStorage backend.
@@ -55,8 +52,6 @@ class OSSStorage(StorageBackend):
access_key_id: The Aliyun access key ID. access_key_id: The Aliyun access key ID.
access_key_secret: The Aliyun access key secret. access_key_secret: The Aliyun access key secret.
bucket_name: The name of the OSS bucket. bucket_name: The name of the OSS bucket.
connect_timeout: Connection timeout in seconds (default: 30).
multipart_threshold: File size threshold for multipart upload (default: 10MB).
Raises: Raises:
StorageConfigError: If any required configuration is missing. StorageConfigError: If any required configuration is missing.
@@ -73,17 +68,10 @@ class OSSStorage(StorageBackend):
self.endpoint = endpoint self.endpoint = endpoint
self.bucket_name = bucket_name self.bucket_name = bucket_name
self.multipart_threshold = multipart_threshold
try: try:
auth = oss2.Auth(access_key_id, access_key_secret) auth = oss2.Auth(access_key_id, access_key_secret)
# 设置超时和重试 self.bucket = oss2.Bucket(auth, endpoint, bucket_name)
self.bucket = oss2.Bucket(
auth,
endpoint,
bucket_name,
connect_timeout=connect_timeout
)
logger.info( logger.info(
f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}" f"OSSStorage initialized with endpoint: {endpoint}, bucket: {bucket_name}"
) )
@@ -119,38 +107,21 @@ class OSSStorage(StorageBackend):
if content_type: if content_type:
headers["Content-Type"] = content_type headers["Content-Type"] = content_type
# 大文件使用分片上传 self.bucket.put_object(file_key, content, headers=headers if headers else None)
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for large file: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers if headers else None).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB per part
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers if headers else None)
logger.info(f"File uploaded to OSS successfully: {file_key}") logger.info(f"File uploaded to OSS successfully: {file_key}")
return file_key return file_key
except OssError as e: except OssError as e:
logger.error(f"OSS error uploading file {file_key}: {e}") logger.error(f"OSS error uploading file {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to upload file to OSS: {str(e)}", message=f"Failed to upload file to OSS: {e.message}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to upload file to OSS {file_key}: {e}") logger.error(f"Failed to upload file to OSS {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to upload file to OSS: {str(e)}", message=f"Failed to upload file to OSS: {e}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
@@ -163,73 +134,28 @@ class OSSStorage(StorageBackend):
) -> int: ) -> int:
"""Upload from async stream to OSS. Returns total bytes written.""" """Upload from async stream to OSS. Returns total bytes written."""
buf = io.BytesIO() buf = io.BytesIO()
headers = {"Content-Type": content_type} if content_type else None
upload_id = None
try: try:
# 收集流数据
total_size = 0
async for chunk in stream: async for chunk in stream:
if not chunk:
continue
buf.write(chunk) buf.write(chunk)
total_size += len(chunk)
content = buf.getvalue() content = buf.getvalue()
headers = {"Content-Type": content_type} if content_type else None
if not content: self.bucket.put_object(file_key, content, headers=headers)
raise StorageUploadError( logger.info(f"File stream uploaded to OSS successfully: {file_key}")
message="Empty stream content", return len(content)
file_key=file_key,
)
# 大文件使用分片上传
if len(content) > self.multipart_threshold:
logger.info(f"Using multipart upload for stream: {file_key} ({len(content)} bytes)")
upload_id = self.bucket.init_multipart_upload(file_key, headers=headers).upload_id
parts = []
part_size = 5 * 1024 * 1024 # 5MB
part_num = 1
for offset in range(0, len(content), part_size):
chunk = content[offset:offset + part_size]
result = self.bucket.upload_part(file_key, upload_id, part_num, chunk)
parts.append(oss2.models.PartInfo(part_num, result.etag))
part_num += 1
self.bucket.complete_multipart_upload(file_key, upload_id, parts)
else:
self.bucket.put_object(file_key, content, headers=headers)
logger.info(f"File stream uploaded to OSS successfully: {file_key} ({total_size} bytes)")
return total_size
except OssError as e: except OssError as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"OSS error stream uploading file {file_key}: {e}") logger.error(f"OSS error stream uploading file {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {str(e)}", message=f"Failed to stream upload file to OSS: {e.message}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
if upload_id:
try:
self.bucket.abort_multipart_upload(file_key, upload_id)
except:
pass
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}") logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
raise StorageUploadError( raise StorageUploadError(
message=f"Failed to stream upload file to OSS: {str(e)}", message=f"Failed to stream upload file to OSS: {e}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
finally:
buf.close()
async def download(self, file_key: str) -> bytes: async def download(self, file_key: str) -> bytes:
""" """
@@ -255,14 +181,14 @@ class OSSStorage(StorageBackend):
except OssError as e: except OssError as e:
logger.error(f"OSS error downloading file {file_key}: {e}") logger.error(f"OSS error downloading file {file_key}: {e}")
raise StorageDownloadError( raise StorageDownloadError(
message=f"Failed to download file from OSS: {str(e)}", message=f"Failed to download file from OSS: {e.message}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to download file from OSS {file_key}: {e}") logger.error(f"Failed to download file from OSS {file_key}: {e}")
raise StorageDownloadError( raise StorageDownloadError(
message=f"Failed to download file from OSS: {str(e)}", message=f"Failed to download file from OSS: {e}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
@@ -288,14 +214,14 @@ class OSSStorage(StorageBackend):
except OssError as e: except OssError as e:
logger.error(f"OSS error deleting file {file_key}: {e}") logger.error(f"OSS error deleting file {file_key}: {e}")
raise StorageDeleteError( raise StorageDeleteError(
message=f"Failed to delete file from OSS: {str(e)}", message=f"Failed to delete file from OSS: {e.message}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
except Exception as e: except Exception as e:
logger.error(f"Failed to delete file from OSS {file_key}: {e}") logger.error(f"Failed to delete file from OSS {file_key}: {e}")
raise StorageDeleteError( raise StorageDeleteError(
message=f"Failed to delete file from OSS: {str(e)}", message=f"Failed to delete file from OSS: {e}",
file_key=file_key, file_key=file_key,
cause=e, cause=e,
) )
@@ -316,33 +242,24 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to check file existence in OSS {file_key}: {e}") logger.error(f"Failed to check file existence in OSS {file_key}: {e}")
return False return False
async def get_url( async def get_url(self, file_key: str, expires: int = 3600) -> str:
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
params = {} url = self.bucket.sign_url("GET", file_key, expires)
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["response-content-disposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.bucket.sign_url("GET", file_key, expires, params=params if params else None)
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}" return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -6,7 +6,6 @@ using the boto3 SDK.
""" """
import io import io
import urllib.parse
import logging import logging
from typing import AsyncIterator, Optional from typing import AsyncIterator, Optional
@@ -353,37 +352,31 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to check file existence in S3 {file_key}: {e}") logger.error(f"Failed to check file existence in S3 {file_key}: {e}")
return False return False
async def get_url( async def get_url(self, file_key: str, expires: int = 3600) -> str:
self,
file_key: str,
expires: int = 3600,
file_name: Optional[str] = None,
) -> str:
""" """
Get a presigned URL for accessing the file. Get a presigned URL for accessing the file.
Args: Args:
file_key: Unique identifier for the file in the storage system. file_key: Unique identifier for the file in the storage system.
expires: URL validity period in seconds (default: 1 hour). expires: URL validity period in seconds (default: 1 hour).
file_name: If set, adds Content-Disposition: attachment to force download.
Returns: Returns:
A presigned URL for accessing the file. A presigned URL for accessing the file.
""" """
try: try:
params = {"Bucket": self.bucket_name, "Key": file_key}
if file_name:
filename_encoded = urllib.parse.quote(file_name.encode("utf-8"))
params["ResponseContentDisposition"] = f"attachment; filename*=UTF-8''{filename_encoded}"
url = self.client.generate_presigned_url( url = self.client.generate_presigned_url(
"get_object", "get_object",
Params=params, Params={
"Bucket": self.bucket_name,
"Key": file_key,
},
ExpiresIn=expires, ExpiresIn=expires,
) )
logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s") logger.debug(f"Generated presigned URL for {file_key}, expires in {expires}s")
return url return url
except Exception as e: except Exception as e:
logger.error(f"Failed to generate presigned URL for {file_key}: {e}") logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}" return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str: async def get_permanent_url(self, file_key: str) -> str:

View File

@@ -99,7 +99,7 @@ class SimpleMCPClient:
# 建立 SSE 连接 # 建立 SSE 连接
response = await self._session.get(self.server_url) response = await self._session.get(self.server_url)
if response.status not in (200, 202): if response.status != 200:
error_text = await response.text() error_text = await response.text()
raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}")
@@ -190,9 +190,7 @@ class SimpleMCPClient:
try: try:
async with self._session.post(self._endpoint_url, json=request) as response: async with self._session.post(self._endpoint_url, json=request) as response:
# MCP SSE 协议POST 请求返回 200 或 202 均为正常 if response.status != 200:
# 202 Accepted 表示请求已接受,结果通过 SSE 流异步返回
if response.status not in (200, 202):
error_text = await response.text() error_text = await response.text()
raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") raise MCPConnectionError(f"请求失败 {response.status}: {error_text}")
@@ -207,7 +205,7 @@ class SimpleMCPClient:
raise MCPConnectionError("endpoint URL 未初始化") raise MCPConnectionError("endpoint URL 未初始化")
async with self._session.post(self._endpoint_url, json=notification) as response: async with self._session.post(self._endpoint_url, json=notification) as response:
if response.status not in (200, 202): if response.status != 200:
logger.warning(f"通知发送失败: {response.status}") logger.warning(f"通知发送失败: {response.status}")
async def _initialize_modelscope_session(self): async def _initialize_modelscope_session(self):

View File

@@ -9,7 +9,7 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from app.core.workflow.adapters.errors import ExceptionDefinition from app.core.workflow.adapters.errors import ExceptionDefineition
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
EdgeDefinition, EdgeDefinition,
NodeDefinition, NodeDefinition,
@@ -40,8 +40,8 @@ class WorkflowParserResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefinition] = Field(default_factory=list) warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefinition] = Field(default_factory=list) errors: list[ExceptionDefineition] = Field(default_factory=list)
class WorkflowImportResult(BaseModel): class WorkflowImportResult(BaseModel):
@@ -51,8 +51,8 @@ class WorkflowImportResult(BaseModel):
edges: list[EdgeDefinition] = Field(default_factory=list) edges: list[EdgeDefinition] = Field(default_factory=list)
nodes: list[NodeDefinition] = Field(default_factory=list) nodes: list[NodeDefinition] = Field(default_factory=list)
variables: list[VariableDefinition] = Field(default_factory=list) variables: list[VariableDefinition] = Field(default_factory=list)
warnings: list[ExceptionDefinition] = Field(default_factory=list) warnings: list[ExceptionDefineition] = Field(default_factory=list)
errors: list[ExceptionDefinition] = Field(default_factory=list) errors: list[ExceptionDefineition] = Field(default_factory=list)
class BasePlatformAdapter(ABC): class BasePlatformAdapter(ABC):

View File

@@ -9,9 +9,9 @@ from urllib.parse import quote
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ( from app.core.workflow.adapters.errors import (
UnsupportedVariableType, UnsupportVariableType,
UnknownModelWarning, UnknowModelWarning,
ExceptionDefinition, ExceptionDefineition,
ExceptionType ExceptionType
) )
from app.core.workflow.nodes.assigner.config import AssignmentItem from app.core.workflow.nodes.assigner.config import AssignmentItem
@@ -54,7 +54,7 @@ from app.core.workflow.nodes.http_request.config import (
HttpFormData, HttpFormData,
HttpTimeOutConfig, HttpTimeOutConfig,
HttpRetryConfig, HttpRetryConfig,
HttpErrorDefaultTemplate, HttpErrorDefaultTamplete,
HttpErrorHandleConfig HttpErrorHandleConfig
) )
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
@@ -108,7 +108,7 @@ class DifyConverter(BaseConverter):
try: try:
return config.model_validate(value) return config.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -138,7 +138,7 @@ class DifyConverter(BaseConverter):
var_selector = mapping.get(var_selector, var_selector) var_selector = mapping.get(var_selector, var_selector)
return var_selector return var_selector
def _process_list_variable_literal(self, variable_selector: list) -> str | None: def _process_list_variable_litearl(self, variable_selector: list) -> str | None:
if not self.process_var_selector(".".join(variable_selector)): if not self.process_var_selector(".".join(variable_selector)):
return None return None
return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}"
@@ -269,7 +269,7 @@ class DifyConverter(BaseConverter):
var_type = self.variable_type_map(var["type"]) var_type = self.variable_type_map(var["type"])
if not var_type: if not var_type:
self.errors.append( self.errors.append(
UnsupportedVariableType( UnsupportVariableType(
scope=node["id"], scope=node["id"],
name=var["variable"], name=var["variable"],
var_type=var["type"], var_type=var["type"],
@@ -281,7 +281,7 @@ class DifyConverter(BaseConverter):
if var_type in ["file", "array[file]"]: if var_type in ["file", "array[file]"]:
self.errors.append( self.errors.append(
ExceptionDefinition( ExceptionDefineition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -311,7 +311,7 @@ class DifyConverter(BaseConverter):
def convert_question_classifier_node_config(self, node: dict) -> dict: def convert_question_classifier_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknownModelWarning( UnknowModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -327,7 +327,7 @@ class DifyConverter(BaseConverter):
) )
result = QuestionClassifierNodeConfig.model_construct( result = QuestionClassifierNodeConfig.model_construct(
input_variable=self._process_list_variable_literal(node_data.get("query_variable_selector")), input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")),
user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")),
categories=categories, categories=categories,
).model_dump() ).model_dump()
@@ -337,13 +337,13 @@ class DifyConverter(BaseConverter):
def convert_llm_node_config(self, node: dict) -> dict: def convert_llm_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknownModelWarning( UnknowModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
) )
) )
context = self._process_list_variable_literal(node_data["context"]["variable_selector"]) context = self._process_list_variable_litearl(node_data["context"]["variable_selector"])
memory = MemoryWindowSetting( memory = MemoryWindowSetting(
enable=bool(node_data.get("memory")), enable=bool(node_data.get("memory")),
enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)),
@@ -367,7 +367,7 @@ class DifyConverter(BaseConverter):
) )
) )
vision = node_data["vision"]["enabled"] vision = node_data["vision"]["enabled"]
vision_input = self._process_list_variable_literal( vision_input = self._process_list_variable_litearl(
node_data["vision"]["configs"]["variable_selector"] node_data["vision"]["configs"]["variable_selector"]
) if vision else None ) if vision else None
result = LLMNodeConfig.model_construct( result = LLMNodeConfig.model_construct(
@@ -433,7 +433,7 @@ class DifyConverter(BaseConverter):
conditions.append( conditions.append(
LoopConditionDetail.model_construct( LoopConditionDetail.model_construct(
operator=self.convert_compare_operator(condition["comparison_operator"]), operator=self.convert_compare_operator(condition["comparison_operator"]),
left=self._process_list_variable_literal(condition["variable_selector"]), left=self._process_list_variable_litearl(condition["variable_selector"]),
right=self.trans_variable_format( right=self.trans_variable_format(
right_value right_value
) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type(
@@ -453,7 +453,7 @@ class DifyConverter(BaseConverter):
right_input_type = variable["value_type"] right_input_type = variable["value_type"]
right_value_type = self.variable_type_map(variable["var_type"]) right_value_type = self.variable_type_map(variable["var_type"])
if right_input_type == ValueInputType.VARIABLE: if right_input_type == ValueInputType.VARIABLE:
right_value = self._process_list_variable_literal(variable.get("value", "")) right_value = self._process_list_variable_litearl(variable.get("value", ""))
else: else:
right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) right_value = self.convert_variable_type(right_value_type, variable.get("value", ""))
loop_variables.append( loop_variables.append(
@@ -475,10 +475,10 @@ class DifyConverter(BaseConverter):
def convert_iteration_node_config(self, node: dict) -> dict: def convert_iteration_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
result = IterationNodeConfig.model_construct( result = IterationNodeConfig.model_construct(
input=self._process_list_variable_literal(node_data["iterator_selector"]), input=self._process_list_variable_litearl(node_data["iterator_selector"]),
parallel=node_data["is_parallel"], parallel=node_data["is_parallel"],
parallel_count=node_data["parallel_nums"], parallel_count=node_data["parallel_nums"],
output=self._process_list_variable_literal(node_data["output_selector"]), output=self._process_list_variable_litearl(node_data["output_selector"]),
output_type=self.variable_type_map(node_data.get("output_type")), output_type=self.variable_type_map(node_data.get("output_type")),
flatten=node_data["flatten_output"], flatten=node_data["flatten_output"],
).model_dump() ).model_dump()
@@ -494,8 +494,8 @@ class DifyConverter(BaseConverter):
continue continue
assignments.append( assignments.append(
AssignmentItem( AssignmentItem(
variable_selector=self._process_list_variable_literal(assignment["variable_selector"]), variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]),
value=self._process_list_variable_literal( value=self._process_list_variable_litearl(
assignment["value"] assignment["value"]
) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"],
operation=self.convert_assignment_operator(assignment["operation"]) operation=self.convert_assignment_operator(assignment["operation"])
@@ -514,7 +514,7 @@ class DifyConverter(BaseConverter):
input_variables.append( input_variables.append(
InputVariable.model_construct( InputVariable.model_construct(
name=input_variable["variable"], name=input_variable["variable"],
variable=self._process_list_variable_literal(input_variable["value_selector"]), variable=self._process_list_variable_litearl(input_variable["value_selector"]),
) )
) )
@@ -570,7 +570,7 @@ class DifyConverter(BaseConverter):
else: else:
if node_data["body"]["data"]: if node_data["body"]["data"]:
body_content = (node_data["body"]["data"][0].get("value") or body_content = (node_data["body"]["data"][0].get("value") or
self._process_list_variable_literal(node_data["body"]["data"][0].get("file"))) self._process_list_variable_litearl(node_data["body"]["data"][0].get("file")))
else: else:
body_content = "" body_content = ""
@@ -585,7 +585,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -603,7 +603,7 @@ class DifyConverter(BaseConverter):
self.trans_variable_format(key_value[0]) self.trans_variable_format(key_value[0])
] = self.trans_variable_format(key_value[1]) ] = self.trans_variable_format(key_value[1])
else: else:
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
@@ -625,7 +625,7 @@ class DifyConverter(BaseConverter):
default_header = var["value"] default_header = var["value"]
elif var["key"] == "status_code": elif var["key"] == "status_code":
default_status_code = var["value"] default_status_code = var["value"]
default_value = HttpErrorDefaultTemplate( default_value = HttpErrorDefaultTamplete(
body=default_body, body=default_body,
headers=default_header, headers=default_header,
status_code=default_status_code, status_code=default_status_code,
@@ -668,7 +668,7 @@ class DifyConverter(BaseConverter):
for variable in node_data["variables"]: for variable in node_data["variables"]:
mapping.append(VariablesMappingConfig.model_construct( mapping.append(VariablesMappingConfig.model_construct(
name=variable["variable"], name=variable["variable"],
value=self._process_list_variable_literal(variable["value_selector"]) value=self._process_list_variable_litearl(variable["value_selector"])
)) ))
result = JinjaRenderNodeConfig.model_construct( result = JinjaRenderNodeConfig.model_construct(
template=node_data["template"], template=node_data["template"],
@@ -679,14 +679,14 @@ class DifyConverter(BaseConverter):
def convert_knowledge_node_config(self, node: dict) -> dict: def convert_knowledge_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
detail=f"Please reconfigure the Knowledge Retrieval node.", detail=f"Please reconfigure the Knowledge Retrieval node.",
)) ))
result = KnowledgeRetrievalNodeConfig.model_construct( result = KnowledgeRetrievalNodeConfig.model_construct(
query=self._process_list_variable_literal(node_data["query_variable_selector"]), query=self._process_list_variable_litearl(node_data["query_variable_selector"]),
).model_dump() ).model_dump()
self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result)
@@ -695,7 +695,7 @@ class DifyConverter(BaseConverter):
def convert_parameter_extractor_node_config(self, node: dict) -> dict: def convert_parameter_extractor_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append( self.warnings.append(
UnknownModelWarning( UnknowModelWarning(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
model_name=node_data["model"].get("name") model_name=node_data["model"].get("name")
@@ -712,7 +712,7 @@ class DifyConverter(BaseConverter):
) )
) )
result = ParameterExtractorNodeConfig.model_construct( result = ParameterExtractorNodeConfig.model_construct(
text=self._process_list_variable_literal(node_data["query"]), text=self._process_list_variable_litearl(node_data["query"]),
params=params, params=params,
prompt=node_data.get("instruction") prompt=node_data.get("instruction")
).model_dump() ).model_dump()
@@ -727,14 +727,14 @@ class DifyConverter(BaseConverter):
group_type = {} group_type = {}
if not advanced_settings or not advanced_settings["group_enabled"]: if not advanced_settings or not advanced_settings["group_enabled"]:
group_variables = [ group_variables = [
self._process_list_variable_literal(variable) self._process_list_variable_litearl(variable)
for variable in node_data["variables"] for variable in node_data["variables"]
] ]
group_type["output"] = node_data["output_type"] group_type["output"] = node_data["output_type"]
else: else:
for group in advanced_settings["groups"]: for group in advanced_settings["groups"]:
group_variables[group["group_name"]] = [ group_variables[group["group_name"]] = [
self._process_list_variable_literal(variable) self._process_list_variable_litearl(variable)
for variable in group["variables"] for variable in group["variables"]
] ]
group_type[group["group_name"]] = group["output_type"] group_type[group["group_name"]] = group["output_type"]
@@ -751,7 +751,7 @@ class DifyConverter(BaseConverter):
def convert_tool_node_config(self, node: dict) -> dict: def convert_tool_node_config(self, node: dict) -> dict:
node_data = node["data"] node_data = node["data"]
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
node_id=node["id"], node_id=node["id"],
node_name=node_data["title"], node_name=node_data["title"],
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,

View File

@@ -12,7 +12,7 @@ from app.core.workflow.adapters.base_adapter import (
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.dify.converter import DifyConverter from app.core.workflow.adapters.dify.converter import DifyConverter
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ( from app.schemas.workflow_schema import (
NodeDefinition, NodeDefinition,
@@ -85,7 +85,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
if not all(field in self.config for field in require_fields): if not all(field in self.config for field in require_fields):
return False return False
if self.config.get("app", {}).get("mode") == "workflow": if self.config.get("app", {}).get("mode") == "workflow":
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.PLATFORM, type=ExceptionType.PLATFORM,
detail="workflow mode is not supported" detail="workflow mode is not supported"
)) ))
@@ -111,12 +111,12 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
edge = self._convert_edge(edge) edge = self._convert_edge(edge)
if edge: if edge:
self.edges.append(edge) self.edges.append(edge)
#
for variable in self.config.get("workflow").get("conversation_variables"): for variable in self.config.get("workflow").get("conversation_variables"):
con_var = self._convert_variable(variable) con_var = self._convert_variable(variable)
if variable: if variable:
self.conv_variables.append(con_var) self.conv_variables.append(con_var)
#
# for variables in config.get("workflow").get("environment_variables"): # for variables in config.get("workflow").get("environment_variables"):
# variable = self._convert_variable(variables) # variable = self._convert_variable(variables)
# conv_variables.append(variable) # conv_variables.append(variable)
@@ -152,7 +152,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
"y": node["position"]["y"] + position["y"] "y": node["position"]["y"] + position["y"]
} }
self.errors.append( self.errors.append(
ExceptionDefinition( ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
detail="parent cycle node not found" detail="parent cycle node not found"
@@ -189,7 +189,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
node_data = node["data"] node_data = node["data"]
converter = self.get_node_convert(node_type) converter = self.get_node_convert(node_type)
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -197,7 +197,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
)) ))
return converter(node) return converter(node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node["id"], node_id=node["id"],
node_name=node["data"]["title"], node_name=node["data"]["title"],
@@ -207,6 +207,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None:
try: try:
source = edge["source"] source = edge["source"]
target = edge["target"] target = edge["target"]
label = None label = None
@@ -229,7 +230,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
label=label, label=label,
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}", detail=f"convert edge error - {e}",
)) ))
@@ -245,7 +246,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
description=variable.get("description") description=variable.get("description")
) )
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}", detail=f"convert variable error - {e}",

View File

@@ -18,7 +18,7 @@ class ExceptionType(StrEnum):
UNKNOWN = "unknown" UNKNOWN = "unknown"
class ExceptionDefinition(BaseModel): class ExceptionDefineition(BaseModel):
type: ExceptionType type: ExceptionType
detail: str detail: str
@@ -29,7 +29,7 @@ class ExceptionDefinition(BaseModel):
name: str | None = None name: str | None = None
class UnknownModelWarning(ExceptionDefinition): class UnknowModelWarning(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id, node_name, model_name): def __init__(self, node_id, node_name, model_name):
@@ -40,36 +40,36 @@ class UnknownModelWarning(ExceptionDefinition):
) )
class UnknownError(ExceptionDefinition): class UnknowError(ExceptionDefineition):
type: ExceptionType = ExceptionType.UNKNOWN type: ExceptionType = ExceptionType.UNKNOWN
def __init__(self, detail: str, **kwargs): def __init__(self, detail: str, **kwargs):
super().__init__(detail=detail, **kwargs) super().__init__(detail=detail, **kwargs)
class UnsupportedPlatform(ExceptionDefinition): class UnsupportPlatform(ExceptionDefineition):
type: ExceptionType = ExceptionType.PLATFORM type: ExceptionType = ExceptionType.PLATFORM
def __init__(self, platform: str): def __init__(self, platform: str):
super().__init__(detail=f"Unsupported platform {platform}") super().__init__(detail=f"Unsupport platform {platform}")
class UnsupportedVariableType(ExceptionDefinition): class UnsupportVariableType(ExceptionDefineition):
type: ExceptionType = ExceptionType.VARIABLE type: ExceptionType = ExceptionType.VARIABLE
def __init__(self, scope, name, var_type: str, **kwargs): def __init__(self, scope, name, var_type: str, **kwargs):
super().__init__(scope=scope, name=name, detail=f"Unsupported variable type: [{var_type}]", **kwargs) super().__init__(scope=scope, name=name, detail=f"Unsupport variable type[{var_type}]", **kwargs)
class InvalidConfiguration(ExceptionDefinition): class InvalidConfiguration(ExceptionDefineition):
type: ExceptionType = ExceptionType.CONFIG type: ExceptionType = ExceptionType.CONFIG
def __init__(self): def __init__(self):
super().__init__(detail="Invalid workflow configuration format") super().__init__(detail="Invalid workflow configuration format")
class UnsupportedNodeType(ExceptionDefinition): class UnsupportNodeType(ExceptionDefineition):
type: ExceptionType = ExceptionType.NODE type: ExceptionType = ExceptionType.NODE
def __init__(self, node_id: str, node_type: str): def __init__(self, node_id: str, node_type: str):
super().__init__(node_id=node_id, detail=f"Unsupported node type {node_type}") super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}")

View File

@@ -11,7 +11,7 @@ from app.core.workflow.adapters.base_adapter import (
BasePlatformAdapter, BasePlatformAdapter,
WorkflowParserResult WorkflowParserResult
) )
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType, UnsupportedNodeType from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
from app.core.workflow.nodes.enums import NodeType from app.core.workflow.nodes.enums import NodeType
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
@@ -73,7 +73,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
node_type = self.map_node_type(node["type"]) node_type = self.map_node_type(node["type"])
if node_type == NodeType.UNKNOWN: if node_type == NodeType.UNKNOWN:
self.errors.append(UnsupportedNodeType( self.errors.append(UnsupportNodeType(
node_id=node_id, node_id=node_id,
node_type=node["type"] node_type=node["type"]
)) ))
@@ -85,7 +85,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
return NodeDefinition(**node) return NodeDefinition(**node)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.NODE, type=ExceptionType.NODE,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,
@@ -97,14 +97,14 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None: def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
try: try:
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids: if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"edge {edge.get('id')} skipped: source or target node not found" detail=f"edge {edge.get('id')} skipped: source or target node not found"
)) ))
return None return None
return EdgeDefinition(**edge) return EdgeDefinition(**edge)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.EDGE, type=ExceptionType.EDGE,
detail=f"convert edge error - {e}" detail=f"convert edge error - {e}"
)) ))
@@ -115,7 +115,7 @@ class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
try: try:
return VariableDefinition(**variable) return VariableDefinition(**variable)
except Exception as e: except Exception as e:
self.warnings.append(ExceptionDefinition( self.warnings.append(ExceptionDefineition(
type=ExceptionType.VARIABLE, type=ExceptionType.VARIABLE,
name=variable.get("name"), name=variable.get("name"),
detail=f"convert variable error - {e}" detail=f"convert variable error - {e}"

View File

@@ -1,6 +1,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
from app.core.workflow.adapters.base_converter import BaseConverter from app.core.workflow.adapters.base_converter import BaseConverter
from app.core.workflow.adapters.errors import ExceptionDefinition, ExceptionType from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig
from app.core.workflow.nodes.configs import ( from app.core.workflow.nodes.configs import (
StartNodeConfig, StartNodeConfig,
@@ -65,7 +65,7 @@ class MemoryBearConverter(BaseConverter):
try: try:
return config_cls.model_validate(value) return config_cls.model_validate(value)
except Exception as e: except Exception as e:
self.errors.append(ExceptionDefinition( self.errors.append(ExceptionDefineition(
type=ExceptionType.CONFIG, type=ExceptionType.CONFIG,
node_id=node_id, node_id=node_id,
node_name=node_name, node_name=node_name,

View File

@@ -7,7 +7,7 @@ import re
import uuid import uuid
from collections import defaultdict from collections import defaultdict
from functools import lru_cache from functools import lru_cache
from typing import Any, Iterable, Callable from typing import Any, Iterable
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
from langgraph.graph import START, END from langgraph.graph import START, END
@@ -20,52 +20,48 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import NodeFactory from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.utils.expression_evaluator import evaluate_condition from app.core.workflow.utils.expression_evaluator import evaluate_condition
from app.core.workflow.validator import WorkflowValidator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
class GraphBuilder: class GraphBuilder:
def __init__( def __init__(
self, self,
workflow_config: dict[str, Any], workflow_config: dict[str, Any],
stream: bool = False, stream: bool = False,
cycle: str = '', subgraph: bool = False,
variable_pool: VariablePool | None = None variable_pool: VariablePool | None = None
): ):
self.workflow_config = workflow_config self.workflow_config = workflow_config
self.stream = stream self.stream = stream
self.cycle = cycle self.subgraph = subgraph
self.start_node_id: str | None = None self.start_node_id = None
self.end_node_ids = []
self.node_map: dict[str, dict] = {} self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {} self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_activation_dep: Callable = self._find_upstream_activation_dep self._find_upstream_branch_node = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
if variable_pool: if variable_pool:
self.variable_pool = variable_pool self.variable_pool = variable_pool
else: else:
self.variable_pool = VariablePool() self.variable_pool = VariablePool()
self.graph: StateGraph | None = None self.graph = StateGraph(WorkflowState)
self.nodes: list = [] self.add_nodes()
self.edges: list = [] self.add_edges()
self.reachable_nodes: set[str] | None = None self._analyze_end_node_output()
self.end_nodes: list[dict] = [] # EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._adj: dict[str, list[str]] = defaultdict(list) @property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@property
def edges(self) -> list[dict[str, Any]]:
return self.workflow_config.get("edges", [])
def get_node_type(self, node_id: str) -> str: def get_node_type(self, node_id: str) -> str:
"""Retrieve the type of node given its ID. """Retrieve the type of node given its ID.
@@ -91,51 +87,60 @@ class GraphBuilder:
result[node[0]].append(node[1]) result[node[0]].append(node[1])
return result return result
def _build_adj(self): def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
for edge in self.edges: """
if edge["source"] not in self.reachable_nodes: Recursively find all upstream branch (control) nodes that influence the execution
continue of the given target node.
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
self._adj[edge.get("source")].append(edge["target"])
def _find_upstream_activation_dep( This method walks upstream along the workflow graph starting from `target_node`.
self, It distinguishes between:
target_node: str - branch nodes (node types listed in `BRANCH_NODES`)
) -> tuple[tuple[tuple[str, str]], tuple[str]]: - non-branch nodes (ordinary processing nodes)
"""Find upstream dependencies that affect the activation of a target node.
Walks upstream along the workflow graph from the target node, collecting Traversal rules:
two types of dependencies: 1. For each immediate upstream node:
- Branch control nodes: upstream branch nodes (e.g. if-else) whose - If it is a branch node, it is recorded as an affecting control node.
routing outcome determines whether the target node executes. - If it is a non-branch node, the traversal continues recursively upstream.
- Output nodes: upstream END nodes that must complete their output 2. If ANY upstream path reaches a START / CYCLE_START node without encountering
before the target node can activate. a branch node, the traversal is considered invalid:
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
The traversal terminates early and returns empty tuples if any upstream Special case:
path reaches START/CYCLE_START without encountering a branch or output - If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
node, indicating the target node is directly reachable and should be it is considered directly reachable from the workflow entry, and therefore
activated immediately. has no controlling branch nodes.
Args: Args:
target_node: The ID of the node whose upstream activation target_node (str):
dependencies are to be resolved. The identifier of the node whose upstream control branches
are to be resolved.
Returns: Returns:
A tuple of two elements: tuple[bool, tuple[tuple[str, str]]]:
- A deduplicated tuple of (branch_node_id, branch_label) pairs - has_branch (bool):
representing upstream branch control dependencies. Empty if True if every upstream path from `target_node` encounters
any clean path to START exists. at least one branch node.
- A deduplicated tuple of upstream output node IDs that must False if any path reaches a start node without a branch.
complete before this node activates. - branch_nodes (tuple[tuple[str, str]]):
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
""" """
source_nodes = self._reverse_adj[target_node] source_nodes = [
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]: if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return tuple(), tuple() return False, tuple()
branch_nodes = [] branch_nodes = []
output_nodes = []
non_branch_nodes = [] non_branch_nodes = []
for node_info in source_nodes: for node_info in source_nodes:
@@ -144,23 +149,19 @@ class GraphBuilder:
(node_info["id"], node_info["branch"]) (node_info["id"], node_info["branch"])
) )
else: else:
if self.get_node_type(node_info["id"]) == NodeType.END:
output_nodes.append(node_info["id"])
non_branch_nodes.append(node_info["id"]) non_branch_nodes.append(node_info["id"])
has_branch = True has_branch = True
for node_id in non_branch_nodes: for node_id in non_branch_nodes:
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id) node_has_branch, nodes = self._find_upstream_branch_node(node_id)
if not upstream_control_nodes: has_branch = has_branch and node_has_branch
if not upstream_output_nodes and node_id not in output_nodes: if not has_branch:
return tuple(), tuple() break
branch_nodes = [] branch_nodes.extend(nodes)
has_branch = False if not has_branch:
if has_branch: branch_nodes = []
branch_nodes.extend(upstream_control_nodes)
output_nodes.extend(upstream_output_nodes)
return tuple(set(branch_nodes)), tuple(set(output_nodes)) return has_branch, tuple(set(branch_nodes))
def _analyze_end_node_output(self): def _analyze_end_node_output(self):
""" """
@@ -181,10 +182,11 @@ class GraphBuilder:
""" """
# Collect all End nodes in the workflow # Collect all End nodes in the workflow
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes") end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
# Iterate through each End node to analyze its output # Iterate through each End node to analyze its output
for end_node in self.end_nodes: for end_node in end_nodes:
end_node_id = end_node.get("id") end_node_id = end_node.get("id")
config = end_node.get("config", {}) config = end_node.get("config", {})
output = config.get("output") output = config.get("output")
@@ -193,33 +195,42 @@ class GraphBuilder:
if not output: if not output:
continue continue
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Split output into ordered segments # Split output into ordered segments
output_template = list(_OUTPUT_PATTERN.findall(output)) output_template = list(re.findall(pattern, output))
# Determine whether each segment is literal text # Determine whether each segment is literal text
# True -> literal (can be directly output) # True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value) # False -> variable placeholder (needs runtime value)
output_flag = [ output_flag = [
not bool(_VARIABLE_PATTERN.match(item)) not bool(variable_pattern.match(item))
for item in output_template for item in output_template
] ]
# Stream mode: output activation depends on upstream branch nodes # Stream mode: output activation depends on upstream branch nodes
if self.stream: if self.stream:
# Find upstream branch nodes that can control this End node # Find upstream branch nodes that can control this End node
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id) has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
# Build StreamOutputConfig for this End node # Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig( self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
# If there is no upstream branch, output is active immediately # If there is no upstream branch, output is active immediately
activate=activate, activate=not has_branch,
# Branch nodes that control activation of this End node # Branch nodes that control activation of this End node
control_nodes=self._merge_control_nodes(upstream_control_nodes), control_nodes=self._merge_control_nodes(control_nodes),
upstream_output_nodes=list(upstream_output_nodes),
control_resolved=not bool(upstream_control_nodes),
output_resolved=not bool(upstream_output_nodes),
# Convert output segments into OutputContent objects # Convert output segments into OutputContent objects
outputs=list( outputs=list(
@@ -238,16 +249,14 @@ class GraphBuilder:
cursor=0 cursor=0
) )
logger.info(f"[Stream Analysis] end_id: {end_node_id}, " logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {activate}, " f"activate: {not has_branch}, "
f"control_nodes: {upstream_control_nodes}," f"control_nodes: {control_nodes},"
f"ref_outputs: {upstream_output_nodes},"
f"output: {output_template}," f"output: {output_template},"
f"output_activate: {output_flag}") f"output_activate: {output_flag}")
# Non-stream mode: all outputs are activated by default # Non-stream mode: all outputs are activated by default
else: else:
self.end_node_map[end_node_id] = StreamOutputConfig( self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
activate=True, activate=True,
control_nodes={}, control_nodes={},
outputs=list( outputs=list(
@@ -260,10 +269,7 @@ class GraphBuilder:
for output_string, activate in zip(output_template, output_flag) for output_string, activate in zip(output_template, output_flag)
] ]
), ),
cursor=0, cursor=0
upstream_output_nodes=[],
control_resolved=True,
output_resolved=True,
) )
def add_nodes(self): def add_nodes(self):
@@ -286,13 +292,24 @@ class GraphBuilder:
""" """
for node in self.nodes: for node in self.nodes:
node_type = node.get("type") node_type = node.get("type")
node_id = node.get("id") if node_type == NodeType.NOTES:
if node_id not in self.reachable_nodes:
continue continue
node_id = node.get("id")
cycle_node = node.get("cycle")
if cycle_node:
# Nodes within a loop subgraph are constructed by CycleGraphNode
if not self.subgraph:
continue
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# Create node instance (start and end nodes are also created) # Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph # NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
node_instance = NodeFactory.create_node(node, self.workflow_config, self._adj[node_id]) node_instance = NodeFactory.create_node(node, self.workflow_config)
if node_type in BRANCH_NODES: if node_type in BRANCH_NODES:
@@ -365,8 +382,6 @@ class GraphBuilder:
for edge in self.edges: for edge in self.edges:
source = edge.get("source") source = edge.get("source")
target = edge.get("target") target = edge.get("target")
if source not in self.reachable_nodes or target not in self.reachable_nodes:
continue
condition = edge.get("condition") condition = edge.get("condition")
edge_type = edge.get("type") edge_type = edge.get("type")
@@ -388,12 +403,11 @@ class GraphBuilder:
# Add conditional edges # Add conditional edges
for source_node, branches in conditional_edges.items(): for source_node, branches in conditional_edges.items():
def make_router(src, branch_list): def make_router(src, branch_list):
"""Create a router function for each source node that routes to a NOP node for later merging.""" """reate a router function for each source node that routes to a NOP node for later merging."""
def make_branch_node(node_name, targets): def make_branch_node(node_name, targets):
def node(s): def node(s):
# NOTE: NOP NODE USED FOR ROUTING ONLY. # NOTE: NOP NODE MUST NOT MODIFY STATE
# MUST NOT MUTATE STATE DIRECTLY; ONLY EMIT ACTIVATE SIGNALS.
return { return {
"activate": { "activate": {
node_id: s["activate"][node_name] node_id: s["activate"][node_name]
@@ -434,7 +448,7 @@ class GraphBuilder:
branch_activate = [] branch_activate = []
new_state = state.copy() new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
node_output = variable_pool.get_node_output(src, default=dict(), strict=False) node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
for label, branch in unique_branch.items(): for label, branch in unique_branch.items():
if node_output and evaluate_condition( if node_output and evaluate_condition(
branch["condition"], branch["condition"],
@@ -480,52 +494,12 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}") logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node # Connect End nodes to the global END node
for node in self.reachable_nodes: for end_node_id in self.end_node_ids:
if not self._adj[node]: self.graph.add_edge(end_node_id, END)
self.graph.add_edge(node, END) logger.debug(f"Added edge: {end_node_id} -> END")
return return
def build(self) -> CompiledStateGraph: def build(self) -> CompiledStateGraph:
nodes = self.workflow_config.get("nodes", [])
edges = self.workflow_config.get("edges", [])
for node in nodes:
if (node.get("cycle") or '') == self.cycle:
node_type = node.get("type")
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node.get("id")
elif node_type == NodeType.NOTES:
continue
self.nodes.append(node)
self.node_map[node.get("id")] = node
for edge in edges:
source_in = edge.get("source") in self.node_map
target_in = edge.get("target") in self.node_map
if source_in ^ target_in:
raise ValueError(
f"Cycle node is connected to external node, "
f"source: {edge.get('source')}, target: {edge.get('target')}"
)
if source_in and target_in:
self.edges.append(edge)
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self._build_adj()
self._find_upstream_activation_dep: Callable = lru_cache(
maxsize=len(self.nodes)*2
)(self._find_upstream_activation_dep)
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.add_edges()
self._analyze_end_node_output()
checkpointer = InMemorySaver() checkpointer = InMemorySaver()
return self.graph.compile(checkpointer=checkpointer) self.graph = self.graph.compile(checkpointer=checkpointer)
return self.graph

View File

@@ -2,7 +2,6 @@
# Author: Eternity # Author: Eternity
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/10 13:33 # @Time : 2026/2/10 13:33
from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.engine.variable_pool import VariablePool
@@ -10,11 +9,9 @@ class WorkflowResultBuilder:
def build_final_output( def build_final_output(
self, self,
result: dict, result: dict,
execution_context: ExecutionContext,
variable_pool: VariablePool, variable_pool: VariablePool,
elapsed_time: float, elapsed_time: float,
final_output: str, final_output: str,
success: bool
): ):
"""Construct the final standardized output of the workflow execution. """Construct the final standardized output of the workflow execution.
@@ -28,13 +25,10 @@ class WorkflowResultBuilder:
- "node_outputs" (dict): Outputs of executed nodes. - "node_outputs" (dict): Outputs of executed nodes.
- "messages" (list): Conversation messages exchanged during execution. - "messages" (list): Conversation messages exchanged during execution.
- "error" (str, optional): Error message if any node failed. - "error" (str, optional): Error message if any node failed.
execution_context (ExecutionContext): The execution context containing metadata like
execution ID, workspace ID, and user ID.)
variable_pool (VariablePool): Variable Pool variable_pool (VariablePool): Variable Pool
elapsed_time (float): Total execution time in seconds. elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow final_output (Any): The aggregated or final output content of the workflow
(e.g., combined messages from all End nodes). (e.g., combined messages from all End nodes).
success (bool): Whether the execution was successful.
Returns: Returns:
dict: A dictionary containing the final workflow execution result with keys: dict: A dictionary containing the final workflow execution result with keys:
@@ -52,23 +46,18 @@ class WorkflowResultBuilder:
""" """
node_outputs = result.get("node_outputs", {}) node_outputs = result.get("node_outputs", {})
token_usage = self.aggregate_token_usage(node_outputs) token_usage = self.aggregate_token_usage(node_outputs)
conversation_vars = {} conversation_id = variable_pool.get_value("sys.conversation_id")
sys_vars = {}
if variable_pool:
conversation_vars = variable_pool.get_all_conversation_vars()
sys_vars = variable_pool.get_all_system_vars()
return { return {
"status": "completed" if success else "failed", "status": "completed",
"output": final_output, "output": final_output,
"variables": { "variables": {
"conv": conversation_vars, "conv": variable_pool.get_all_conversation_vars(),
"sys": sys_vars "sys": variable_pool.get_all_system_vars()
}, },
"node_outputs": node_outputs, "node_outputs": node_outputs,
"messages": result.get("messages", []), "messages": result.get("messages", []),
"conversation_id": execution_context.conversation_id, "conversation_id": conversation_id,
"elapsed_time": elapsed_time, "elapsed_time": elapsed_time,
"token_usage": token_usage, "token_usage": token_usage,
"error": result.get("error"), "error": result.get("error"),

View File

@@ -12,29 +12,14 @@ class ExecutionContext(BaseModel):
execution_id: str execution_id: str
workspace_id: str workspace_id: str
user_id: str user_id: str
conversation_id: str
memory_storage_type: str
user_rag_memory_id: str
checkpoint_config: RunnableConfig checkpoint_config: RunnableConfig
@classmethod @classmethod
def create( def create(cls, execution_id: str, workspace_id: str, user_id: str):
cls,
execution_id: str,
workspace_id: str,
user_id: str,
conversation_id: str,
memory_storage_type: str,
user_rag_memory_id: str
):
return cls( return cls(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id,
conversation_id=conversation_id,
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id,
checkpoint_config=RunnableConfig( checkpoint_config=RunnableConfig(
configurable={ configurable={
"thread_id": uuid.uuid4(), "thread_id": uuid.uuid4(),

View File

@@ -33,8 +33,6 @@ class WorkflowState(dict):
"workspace_id", "workspace_id",
"user_id", "user_id",
"activate", "activate",
"memory_storage_type",
"user_rag_memory_id"
}) })
__optional_keys__ = frozenset({ __optional_keys__ = frozenset({
"error", "error",
@@ -64,9 +62,6 @@ class WorkflowState(dict):
# node activate status # node activate status
activate: Annotated[dict[str, bool], merge_activate_state] activate: Annotated[dict[str, bool], merge_activate_state]
memory_storage_type: str
user_rag_memory_id: str
class WorkflowStateManager: class WorkflowStateManager:
def create_initial_state( def create_initial_state(
@@ -90,9 +85,7 @@ class WorkflowStateManager:
looping=0, looping=0,
activate={ activate={
start_node_id: True start_node_id: True
}, }
memory_storage_type=execution_context.memory_storage_type,
user_rag_memory_id=execution_context.user_rag_memory_id
) )
@staticmethod @staticmethod

View File

@@ -3,7 +3,6 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11 # @Time : 2026/2/9 15:11
import re import re
from collections import deque
from typing import AsyncGenerator from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@@ -38,8 +37,8 @@ class OutputContent(BaseModel):
activate: bool = Field( activate: bool = Field(
..., ...,
description=( description=(
"Whether this output segment is currently active." "Whether this output segment is currently active.\n"
"- True: allowed to be emitted/output" "- True: allowed to be emitted/output\n"
"- False: blocked until activated by branch control" "- False: blocked until activated by branch control"
) )
) )
@@ -47,8 +46,8 @@ class OutputContent(BaseModel):
is_variable: bool = Field( is_variable: bool = Field(
..., ...,
description=( description=(
"Whether this segment represents a variable placeholder." "Whether this segment represents a variable placeholder.\n"
"True -> variable (e.g. {{ node.field }})" "True -> variable (e.g. {{ node.field }})\n"
"False -> literal text" "False -> literal text"
) )
) )
@@ -87,16 +86,12 @@ class StreamOutputConfig(BaseModel):
- which upstream branch/control nodes gate the activation - which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated - how each parsed output segment is streamed and activated
""" """
id: str = Field(
...,
description="ID of the End node this configuration belongs to."
)
activate: bool = Field( activate: bool = Field(
..., ...,
description=( description=(
"Global activation flag for the End node output." "Global activation flag for the End node output.\n"
"When False, output segments should not be emitted even if available." "When False, output segments should not be emitted even if available.\n"
"This flag typically becomes True once required control branch conditions " "This flag typically becomes True once required control branch conditions "
"are satisfied." "are satisfied."
) )
@@ -105,46 +100,17 @@ class StreamOutputConfig(BaseModel):
control_nodes: dict[str, list[str]] = Field( control_nodes: dict[str, list[str]] = Field(
..., ...,
description=( description=(
"Control branch conditions for this End node output." "Control branch conditions for this End node output.\n"
"Mapping of `branch_node_id -> expected_branch_label`." "Mapping of `branch_node_id -> expected_branch_label`.\n"
"The End node output becomes globally active when a controlling branch node " "The End node output becomes globally active when a controlling branch node "
"reports a matching completion status." "reports a matching completion status."
) )
) )
upstream_output_nodes: list[str] = Field(
...,
description=(
"Upstream output node dependencies (data flow)."
"Represents END/output nodes that this output depends on."
"These nodes provide data sources required before this output can be activated "
"or streamed."
"Used to ensure correct ordering and dependency resolution in streaming mode."
)
)
control_resolved: bool = Field(
...,
description=(
"Whether all upstream branch control dependencies have been satisfied."
"True if no upstream branch nodes exist or the required branch "
"conditions have been met."
)
)
output_resolved: bool = Field(
...,
description=(
"Whether all upstream output node dependencies have been completed."
"True if no upstream output nodes exist or all upstream output "
"nodes have finished their output."
)
)
outputs: list[OutputContent] = Field( outputs: list[OutputContent] = Field(
..., ...,
description=( description=(
"Ordered list of output segments parsed from the output template." "Ordered list of output segments parsed from the output template.\n"
"Each segment represents either a literal text block or a variable placeholder " "Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently." "that may be activated independently."
) )
@@ -153,97 +119,49 @@ class StreamOutputConfig(BaseModel):
cursor: int = Field( cursor: int = Field(
..., ...,
description=( description=(
"Streaming cursor index." "Streaming cursor index.\n"
"Indicates the next output segment index to be emitted." "Indicates the next output segment index to be emitted.\n"
"Segments with index < cursor are considered already streamed." "Segments with index < cursor are considered already streamed."
) )
) )
force: bool = Field(
default=False,
description=(
"Force flag for output emission."
"When True, all output segments are emitted regardless of activation state."
"Triggered when this output node has finished execution."
)
)
def update_activate(self, scope: str, status=None): def update_activate(self, scope: str, status=None):
""" """
Update streaming activation state based on upstream events. Update streaming activation state based on an upstream node or special variable.
Args: Args:
scope (str): scope (str):
Identifier of the completed upstream entity. Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`. - If a control branch node, it should match a key in `control_nodes`.
- If an upstream output node, it should match an entry in `upstream_output_nodes`. - If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
it may appear in output segments.
status (optional): status (optional):
Completion status of the control branch node. Completion status of the control branch node.
Required when `scope` refers to a control node. Required when `scope` refers to a control node.
Behavior: Behavior:
1. Force activation: 1. Control branch nodes:
- If `self.force` is True, the method returns immediately. - If `scope` matches a key in `control_nodes` and `status` matches the expected
- If `scope == self.id`, the node marks itself as completed: branch label, the End node output becomes globally active (`activate = True`).
- `activate = True`
- `force = True`
This is typically used for final flushing when the node finishes execution.
2. Control dependency resolution: 2. Variable output segments:
- If `scope` matches a key in `control_nodes`: - For each segment that is a variable (`is_variable=True`):
- `status` must be provided. - If the segment literal references `scope`, mark the segment as active.
- If `status` matches expected branch labels, mark control as resolved - This applies both to regular node variables (e.g., "node_id.field")
(`control_resolved = True`). and special system variables (e.g., "sys.xxx").
3. Upstream output dependency resolution:
- If `scope` is in `upstream_output_nodes`,
mark data dependency as resolved (`output_resolved = True`).
4. Global activation condition:
- The node becomes active when BOTH conditions are satisfied:
- control_resolved == True
- output_resolved == True
- Once activated, `activate` remains True.
5. Variable segment activation:
- For each output segment that is a variable (`is_variable=True`):
- If the segment depends on the given `scope`,
mark the segment as active.
- This applies to both node variables (e.g., "node_id.field")
and system variables (e.g., "sys.xxx").
Notes: Notes:
- This method does NOT emit output or advance the streaming cursor. - This method does not emit output or advance the streaming cursor.
- It only updates activation and dependency resolution states. - It only updates activation flags based on upstream events or special variables.
- Activation is driven by both control flow (branch nodes) and
data flow (upstream output nodes).
""" """
if self.force:
return
if scope == self.id: # Case 1: resolve control branch dependency
self.activate = True
self.force = True
return
# resolve control branch dependency
if scope in self.control_nodes: if scope in self.control_nodes:
if status is None: if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided") raise RuntimeError("[Stream Output] Control node activation status not provided")
if status in self.control_nodes[scope]: if status in self.control_nodes[scope]:
self.control_resolved = True self.activate = True
if scope in self.upstream_output_nodes: # Case 2: activate variable segments related to this node
self.upstream_output_nodes.remove(scope)
if not self.upstream_output_nodes:
self.output_resolved = True
self.activate = self.activate or (self.control_resolved and self.output_resolved)
# activate variable segments related to this node
for i in range(len(self.outputs)): for i in range(len(self.outputs)):
if ( if (
self.outputs[i].is_variable self.outputs[i].is_variable
@@ -256,17 +174,12 @@ class StreamOutputCoordinator:
def __init__(self): def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {} self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None self.activate_end: str | None = None
self.output_queue: deque[str] = deque()
self.processed_outputs = []
def initialize_end_outputs( def initialize_end_outputs(
self, self,
end_node_map: dict[str, StreamOutputConfig] end_node_map: dict[str, StreamOutputConfig]
): ):
self.end_outputs = end_node_map self.end_outputs = end_node_map
self.processed_outputs = []
self.activate_end = None
self.output_queue = deque()
@property @property
def current_activate_end_info(self): def current_activate_end_info(self):
@@ -296,13 +209,10 @@ class StreamOutputCoordinator:
scope (str): The node ID or scope that has completed execution. scope (str): The node ID or scope that has completed execution.
status (str | None): Optional status of the node (used for branch/control nodes). status (str | None): Optional status of the node (used for branch/control nodes).
""" """
for node in self.end_outputs: for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(scope, status) self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and node not in self.processed_outputs: if self.end_outputs[node].activate and self.activate_end is None:
self.output_queue.append(node) self.activate_end = node
self.processed_outputs.append(node)
if self.activate_end is None and self.output_queue:
self.activate_end = self.output_queue.popleft()
async def emit_activate_chunk( async def emit_activate_chunk(
self, self,
@@ -346,7 +256,7 @@ class StreamOutputCoordinator:
final_chunk = '' final_chunk = ''
current_segment = end_info.outputs[end_info.cursor] current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force and not end_info.force: if not current_segment.activate and not force:
# Stop processing until this segment becomes active # Stop processing until this segment becomes active
break break
@@ -363,7 +273,7 @@ class StreamOutputCoordinator:
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}") logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
if final_chunk: if final_chunk:
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}") logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
yield { yield {
"event": "message", "event": "message",
"data": { "data": {
@@ -375,7 +285,8 @@ class StreamOutputCoordinator:
end_info.cursor += 1 end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs): if end_info.cursor >= len(end_info.outputs):
self.pop_current_activate_end() self.end_outputs.pop(self.activate_end)
self.activate_end = None
async def flush_remaining_chunk( async def flush_remaining_chunk(
self, self,
@@ -414,8 +325,6 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True): async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event yield msg_event
if self.output_queue:
self.activate_end = self.output_queue.popleft()
# Move to next active End node if current one is done # Move to next active End node if current one is done
if not self.activate_end and self.end_outputs: if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0] self.activate_end = list(self.end_outputs.keys())[0]

View File

@@ -13,7 +13,7 @@ from pydantic import BaseModel
from app.core.workflow.engine.runtime_schema import ExecutionContext from app.core.workflow.engine.runtime_schema import ExecutionContext
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
from app.core.workflow.variable.variable_objects import T, create_variable_instance, ArrayVariable, FileVariable from app.core.workflow.variable.variable_objects import T, create_variable_instance
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -351,12 +351,12 @@ class VariablePool:
} }
return runtime_vars return runtime_vars
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None: def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量) """获取指定节点的输出(运行时变量)
Args: Args:
node_id: 节点 ID node_id: 节点 ID
default: 默认值 defalut: 默认值
strict: 是否严格模式 strict: 是否严格模式
Returns: Returns:
@@ -368,21 +368,11 @@ class VariablePool:
if strict: if strict:
raise KeyError(f"node {node_id} output not exist") raise KeyError(f"node {node_id} output not exist")
else: else:
return default return defalut
def copy(self, pool: 'VariablePool'): def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables) self.variables = deepcopy(pool.variables)
def is_file_variable(self, selector):
variable_struct = self.get_instance(selector, default=None, strict=False)
if variable_struct is None:
return False
if isinstance(variable_struct, FileVariable):
return True
elif isinstance(variable_struct, ArrayVariable) and variable_struct.child_type == FileVariable:
return True
return False
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""导出为字典 """导出为字典

View File

@@ -3,7 +3,6 @@
# @Email: 1533512157@qq.com # @Email: 1533512157@qq.com
# @Time : 2026/2/9 13:51 # @Time : 2026/2/9 13:51
import datetime import datetime
import time
import logging import logging
from typing import Any from typing import Any
@@ -83,15 +82,13 @@ class WorkflowExecutor:
CompiledStateGraph: The compiled and ready-to-run state graph. CompiledStateGraph: The compiled and ready-to-run state graph.
""" """
logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}") logger.info(f"Starting workflow graph build: execution_id={self.execution_context.execution_id}")
start_time = time.time()
builder = GraphBuilder( builder = GraphBuilder(
self.workflow_config, self.workflow_config,
stream=stream, stream=stream,
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.variable_pool = builder.variable_pool self.variable_pool = builder.variable_pool
self.graph = builder.build()
self.stream_coordinator.initialize_end_outputs(builder.end_node_map) self.stream_coordinator.initialize_end_outputs(builder.end_node_map)
self.event_handler = EventStreamHandler( self.event_handler = EventStreamHandler(
@@ -99,8 +96,7 @@ class WorkflowExecutor:
variable_pool=self.variable_pool, variable_pool=self.variable_pool,
execution_id=self.execution_context.execution_id execution_id=self.execution_context.execution_id
) )
logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}, " logger.info(f"Workflow graph build completed: execution_id={self.execution_context.execution_id}")
f"cost: {time.time() - start_time:.4f}s")
return self.graph return self.graph
@@ -132,18 +128,89 @@ class WorkflowExecutor:
- token_usage: aggregated token usage if available - token_usage: aggregated token usage if available
- error: error message if any - error: error message if any
""" """
start = datetime.datetime.now() logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
async for event in self.execute_stream(input_data):
if event.get("event") == "workflow_end": start_time = datetime.datetime.now()
return event.get("data")
return self.result_builder.build_final_output( # Execute the workflow
{"error": "Workflow execution did not end as expected"}, try:
self.execution_context, # Build the workflow graph
self.variable_pool, graph = self.build_graph()
(datetime.datetime.now() - start).total_seconds(),
"", # Initialize the variable pool with input data
success=False await self.variable_initializer.initialize(
) variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
full_content = ''
for end_id in self.stream_coordinator.end_outputs.keys():
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
# Append messages for user and assistant
if input_data.get("files"):
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "user",
"content": input_data.get("files")
},
{
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# Calculate elapsed time
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
except Exception as e:
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
async def execute_stream( async def execute_stream(
self, self,
@@ -177,12 +244,11 @@ class WorkflowExecutor:
"data": { "data": {
"execution_id": self.execution_context.execution_id, "execution_id": self.execution_context.execution_id,
"workspace_id": self.execution_context.workspace_id, "workspace_id": self.execution_context.workspace_id,
"conversation_id": self.execution_context.conversation_id, "conversation_id": input_data.get("conversation_id"),
"timestamp": int(start_time.timestamp() * 1000) "timestamp": int(start_time.timestamp() * 1000)
} }
} }
result = None
full_content = ''
try: try:
# Build the workflow graph in streaming mode # Build the workflow graph in streaming mode
graph = self.build_graph(stream=True) graph = self.build_graph(stream=True)
@@ -200,6 +266,7 @@ class WorkflowExecutor:
start_node_id=self.start_node_id start_node_id=self.start_node_id
) )
full_content = ''
self.stream_coordinator.update_scope_activation("sys") self.stream_coordinator.update_scope_activation("sys")
# Execute the workflow with streaming # Execute the workflow with streaming
@@ -296,13 +363,7 @@ class WorkflowExecutor:
yield { yield {
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
result,
self.execution_context,
self.variable_pool,
elapsed_time,
full_content,
success=True)
} }
except Exception as e: except Exception as e:
@@ -311,20 +372,16 @@ class WorkflowExecutor:
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}", logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True) exc_info=True)
if result is None:
result = {"error": str(e)}
else:
result["error"] = str(e)
yield { yield {
"event": "workflow_end", "event": "workflow_end",
"data": self.result_builder.build_final_output( "data": {
result, "execution_id": self.execution_context.execution_id,
self.execution_context, "status": "failed",
self.variable_pool, "error": str(e),
elapsed_time, "elapsed_time": elapsed_time,
full_content, "timestamp": end_time.isoformat()
success=False }
)
} }
@@ -333,9 +390,7 @@ async def execute_workflow(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str, user_id: str
memory_storage_type: str,
user_rag_memory_id: str
) -> dict[str, Any]: ) -> dict[str, Any]:
""" """
Execute a workflow (convenience function, non-streaming). Execute a workflow (convenience function, non-streaming).
@@ -346,8 +401,6 @@ async def execute_workflow(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Returns: Returns:
dict: Workflow execution result. dict: Workflow execution result.
@@ -355,10 +408,7 @@ async def execute_workflow(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id
conversation_id=input_data.get("conversation_id"),
memory_storage_type=memory_storage_type,
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,
@@ -372,9 +422,7 @@ async def execute_workflow_stream(
input_data: dict[str, Any], input_data: dict[str, Any],
execution_id: str, execution_id: str,
workspace_id: str, workspace_id: str,
user_id: str, user_id: str
memory_storage_type: str,
user_rag_memory_id: str
): ):
""" """
Execute a workflow in streaming mode (convenience function). Execute a workflow in streaming mode (convenience function).
@@ -385,8 +433,6 @@ async def execute_workflow_stream(
execution_id (str): Execution ID. execution_id (str): Execution ID.
workspace_id (str): Workspace ID. workspace_id (str): Workspace ID.
user_id (str): User ID. user_id (str): User ID.
user_rag_memory_id: rag knowledge db id
memory_storage_type: neo4j / rag
Yields: Yields:
dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end. dict: Streaming workflow events, e.g. node start, node end, chunk messages, workflow end.
@@ -394,10 +440,7 @@ async def execute_workflow_stream(
execution_context = ExecutionContext.create( execution_context = ExecutionContext.create(
execution_id=execution_id, execution_id=execution_id,
workspace_id=workspace_id, workspace_id=workspace_id,
user_id=user_id, user_id=user_id
memory_storage_type=memory_storage_type,
conversation_id=input_data.get("conversation_id"),
user_rag_memory_id=user_rag_memory_id
) )
executor = WorkflowExecutor( executor = WorkflowExecutor(
workflow_config=workflow_config, workflow_config=workflow_config,

View File

@@ -64,7 +64,9 @@ class AgentNode(BaseNode):
if not release: if not release:
raise ValueError(f"Agent 不存在: {agent_id}") raise ValueError(f"Agent 不存在: {agent_id}")
return release, message return release, message
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class AssignerNode(BaseNode): class AssignerNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.variable_updater = True self.variable_updater = True
self.typed_config: AssignerNodeConfig | None = None self.typed_config: AssignerNodeConfig | None = None

View File

@@ -28,7 +28,7 @@ class BaseNode(ABC):
All node types should inherit from this class and implement the `execute` method. All node types should inherit from this class and implement the `execute` method.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
"""Initialize the node. """Initialize the node.
Args: Args:
@@ -41,7 +41,6 @@ class BaseNode(ABC):
self.node_type = node_config["type"] self.node_type = node_config["type"]
self.cycle = node_config.get("cycle") self.cycle = node_config.get("cycle")
self.node_name = node_config.get("name", self.node_id) self.node_name = node_config.get("name", self.node_id)
self.down_stream_nodes = down_stream_nodes
# 使用 or 运算符处理 None 值 # 使用 or 运算符处理 None 值
self.config = node_config.get("config") or {} self.config = node_config.get("config") or {}
self.error_handling = node_config.get("error_handling") or {} self.error_handling = node_config.get("error_handling") or {}
@@ -94,16 +93,18 @@ class BaseNode(ABC):
dict: A dict with a single key 'activate', mapping node IDs to dict: A dict with a single key 'activate', mapping node IDs to
their activation status (True/False). their activation status (True/False).
""" """
activate_flag = self.check_activate(state) edges = self.workflow_config.get("edges")
under_stream_nodes = [
if self.node_type not in BRANCH_NODES: edge.get("target")
activate = {node_id: activate_flag for node_id in self.down_stream_nodes} for edge in edges
else: if edge.get("source") == self.node_id and self.node_type not in BRANCH_NODES
activate = {} ]
return {
activate[self.node_id] = activate_flag "activate": {
node_id: self.check_activate(state)
return {"activate": activate} for node_id in under_stream_nodes
} | {self.node_id: self.check_activate(state)}
}
@abstractmethod @abstractmethod
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -314,8 +315,8 @@ class BaseNode(ABC):
elapsed_time = (time.time() - start_time) * 1000 elapsed_time = (time.time() - start_time) * 1000
logger.debug(f"Node {self.node_id} streaming execution finished, " logger.info(f"Node {self.node_id} streaming execution finished, "
f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}")
# Extract processed output (call subclass's _extract_output) # Extract processed output (call subclass's _extract_output)
extracted_output = self._extract_output(final_result) extracted_output = self._extract_output(final_result)
@@ -427,8 +428,8 @@ class BaseNode(ABC):
when an error edge exists. If no error edge exists, this method when an error edge exists. If no error edge exists, this method
raises an exception to stop the workflow. raises an exception to stop the workflow.
""" """
# # Check if the node has an error edge defined # Check if the node has an error edge defined
# error_edge = self._find_error_edge() error_edge = self._find_error_edge()
# Extract input data (for logging or audit purposes) # Extract input data (for logging or audit purposes)
input_data = self._extract_input(state, variable_pool) input_data = self._extract_input(state, variable_pool)
@@ -446,26 +447,27 @@ class BaseNode(ABC):
"error": error_message "error": error_message
} }
# if error_edge: if error_edge:
# # If an error edge exists, log a warning and continue to error node # If an error edge exists, log a warning and continue to error node
# logger.warning( logger.warning(
# f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}" f"Node {self.node_id} execution failed, redirecting to error node: {error_edge['target']}"
# ) )
# return { return {
# "node_outputs": { "node_outputs": {
# self.node_id: node_output self.node_id: node_output
# }, },
# "error": error_message, "error": error_message,
# "error_node": self.node_id "error_node": self.node_id
# } }
# else: else:
writer = get_stream_writer() # If no error edge, send the error via stream writer and stop the workflow
writer({ writer = get_stream_writer()
"type": "node_error", writer({
**node_output "type": "node_error",
}) **node_output
logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}") })
raise Exception(f"Node {self.node_id} execution failed: {error_message}") logger.error(f"Node {self.node_id} execution failed, stopping workflow: {error_message}")
raise Exception(f"Node {self.node_id} execution failed: {error_message}")
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
"""Extracts the input data for this node (used for logging or audit). """Extracts the input data for this node (used for logging or audit).
@@ -621,6 +623,7 @@ class BaseNode(ABC):
async def process_message( async def process_message(
api_config: ModelInfo, api_config: ModelInfo,
content: str | dict | FileObject, content: str | dict | FileObject,
end_user_id: str,
enable_file=False enable_file=False
) -> list | str | None: ) -> list | str | None:
provider = api_config.provider provider = api_config.provider
@@ -639,10 +642,10 @@ class BaseNode(ABC):
return content return content
elif isinstance(content, FileObject): elif isinstance(content, FileObject):
if content.content_cache.get(f"{provider}_{api_config.is_omni}"): if content.content_cache.get(provider):
return content.content_cache[f"{provider}_{api_config.is_omni}"] return content.content_cache[provider]
with get_db_read() as db: with get_db_read() as db:
multimodal_service = MultimodalService(db, api_config=api_config) multimodel_service = MultimodalService(db, api_config=api_config)
file_obj = FileInput( file_obj = FileInput(
type=content.type, type=content.type,
url=content.url, url=content.url,
@@ -651,15 +654,16 @@ class BaseNode(ABC):
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None, upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
) )
file_obj.set_content(content.get_content()) file_obj.set_content(content.get_content())
message = await multimodal_service.process_files( message = await multimodel_service.process_files(
end_user_id,
[file_obj], [file_obj],
) )
content.set_content(file_obj.get_content()) content.set_content(file_obj.get_content())
if message: if message:
content.content_cache[f"{provider}_{api_config.is_omni}"] = message content.content_cache[provider] = message
return message return message
return None return None
raise TypeError(f'Unexpected input value type - {type(content)}') raise TypeError(f'Unexpect input value type - {type(content)}')
@staticmethod @staticmethod
def process_model_output(content) -> str: def process_model_output(content) -> str:

View File

@@ -51,8 +51,8 @@ console.log(result)
class CodeNode(BaseNode): class CodeNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: CodeNodeConfig | None = None self.typed_config: CodeNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
else: else:
raise ValueError(f"Unsupported language: {self.typed_config.language}") raise ValueError(f"Unsupported language: {self.typed_config.language}")
async with httpx.AsyncClient(timeout=60) as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
"http://sandbox:8194/v1/sandbox/run", "http://sandbox:8194/v1/sandbox/run",
headers={ headers={

View File

@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
) )
right: Any = Field( right: Any = Field(
default=None, ...,
description="Right-hand operand of the comparison expression" description="Right-hand operand of the comparison expression"
) )

View File

@@ -158,7 +158,7 @@ class LoopRuntime:
self.variable_pool.variables["conv"].update( self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"] self.child_variable_pool.variables["conv"]
) )
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool: def evaluate_conditional(self) -> bool:
@@ -261,4 +261,4 @@ class LoopRuntime:
idx += 1 idx += 1
logger.info(f"loop node {self.node_id}: execution completed") logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state} return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}

View File

@@ -30,13 +30,17 @@ class CycleGraphNode(BaseNode):
It acts as a container and execution controller for a subgraph. It acts as a container and execution controller for a subgraph.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.cycle_nodes = list() # Nodes belonging to this cycle
self.cycle_edges = list() # Edges connecting nodes within the cycle
self.start_node_id = None # ID of the start node within the cycle self.start_node_id = None # ID of the start node within the cycle
self.graph: StateGraph | CompiledStateGraph | None = None self.graph: StateGraph | CompiledStateGraph | None = None
self.child_variable_pool: VariablePool | None = None self.child_variable_pool: VariablePool | None = None
self.build_graph()
self.iteration_flag = True
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
outputs = {"__child_state": VariableType.ARRAY_OBJECT} outputs = {"__child_state": VariableType.ARRAY_OBJECT}
@@ -115,11 +119,11 @@ class CycleGraphNode(BaseNode):
else: else:
remain_edges.append(edge) remain_edges.append(edge)
# # Update workflow_config by removing cycle nodes and internal edges # Update workflow_config by removing cycle nodes and internal edges
# self.workflow_config["nodes"] = [ self.workflow_config["nodes"] = [
# node for node in nodes if node.get("cycle") != self.node_id node for node in nodes if node.get("cycle") != self.node_id
# ] ]
# self.workflow_config["edges"] = remain_edges self.workflow_config["edges"] = remain_edges
return cycle_nodes, cycle_edges return cycle_nodes, cycle_edges
@@ -133,18 +137,18 @@ class CycleGraphNode(BaseNode):
3. Compile the graph for runtime execution 3. Compile the graph for runtime execution
""" """
from app.core.workflow.engine.graph_builder import GraphBuilder from app.core.workflow.engine.graph_builder import GraphBuilder
self.cycle_nodes, self.cycle_edges = self.pure_cycle_graph()
self.child_variable_pool = VariablePool() self.child_variable_pool = VariablePool()
builder = GraphBuilder( builder = GraphBuilder(
{ {
"nodes": self.cycle_nodes, "nodes": self.cycle_nodes,
"edges": self.cycle_edges, "edges": self.cycle_edges,
}, },
variable_pool=self.child_variable_pool, subgraph=True,
cycle=self.node_id variable_pool=self.child_variable_pool
) )
self.graph = builder.build()
self.start_node_id = builder.start_node_id self.start_node_id = builder.start_node_id
self.graph = builder.build()
self.child_variable_pool = builder.variable_pool self.child_variable_pool = builder.variable_pool
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
@@ -165,7 +169,6 @@ class CycleGraphNode(BaseNode):
Raises: Raises:
RuntimeError: If the node type is unsupported. RuntimeError: If the node type is unsupported.
""" """
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
return await LoopRuntime( return await LoopRuntime(
start_id=self.start_node_id, start_id=self.start_node_id,
@@ -191,7 +194,6 @@ class CycleGraphNode(BaseNode):
raise RuntimeError("Unknown cycle node type") raise RuntimeError("Unknown cycle node type")
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool): async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
self.build_graph()
if self.node_type == NodeType.LOOP: if self.node_type == NodeType.LOOP:
yield { yield {
"__final__": True, "__final__": True,

View File

@@ -1,4 +0,0 @@
from .config import DocExtractorNodeConfig
from .node import DocExtractorNode
__all__ = ["DocExtractorNode", "DocExtractorNodeConfig"]

View File

@@ -1,18 +0,0 @@
from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig
class DocExtractorNodeConfig(BaseNodeConfig):
file_selector: str = Field(
...,
description="File variable selector, e.g. {{ sys.files }} or {{ node_id.file }}"
)
class Config:
json_schema_extra = {
"examples": [
{
"file_selector": "{{ sys.files }}"
}
]
}

View File

@@ -1,103 +0,0 @@
import logging
from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.document_extractor.config import DocExtractorNodeConfig
from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.db import get_db_read
from app.schemas.app_schema import FileInput, FileType, TransferMethod
logger = logging.getLogger(__name__)
def _file_object_to_file_input(f: FileObject) -> FileInput:
"""Convert workflow FileObject to multimodal FileInput."""
return FileInput(
type=FileType.DOCUMENT,
transfer_method=TransferMethod(f.transfer_method),
url=f.url or None,
upload_file_id=f.file_id or None,
file_type=f.origin_file_type or "",
)
def _normalise_files(val: Any) -> list[FileObject]:
if isinstance(val, FileObject):
return [val]
if isinstance(val, dict) and val.get("is_file"):
return [FileObject(**val)]
if isinstance(val, list):
result: list[FileObject] = []
for item in val:
if isinstance(item, FileObject):
result.append(item)
elif isinstance(item, dict) and item.get("is_file"):
result.append(FileObject(**item))
else:
logger.warning("Ignoring non-file entry in file list for document extractor: %r", item)
return result
return []
class DocExtractorNode(BaseNode):
"""Document Extractor Node.
Reads one or more file variables and extracts their text content
by delegating to MultimodalService._extract_document_text.
Outputs:
text (string) full concatenated text of all input files
chunks (array[string]) per-file extracted text
"""
def _output_types(self) -> dict[str, VariableType]:
return {
"text": VariableType.STRING,
"chunks": VariableType.ARRAY_STRING,
}
def _extract_output(self, business_result: Any) -> Any:
return business_result
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
return {"file_selector": self.config.get("file_selector")}
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
config = DocExtractorNodeConfig(**self.config)
raw_val = self.get_variable(config.file_selector, variable_pool, strict=False)
if raw_val is None:
logger.warning(f"Node {self.node_id}: file variable '{config.file_selector}' is empty")
return {"text": "", "chunks": []}
files = _normalise_files(raw_val)
if not files:
return {"text": "", "chunks": []}
chunks: list[str] = []
with get_db_read() as db:
from app.services.multimodal_service import MultimodalService
svc = MultimodalService(db)
for f in files:
try:
file_input = _file_object_to_file_input(f)
# Ensure URL is populated for local files
if not file_input.url:
file_input.url = await svc.get_file_url(file_input)
# Reuse cached bytes if already fetched
if f.get_content():
file_input.set_content(f.get_content())
text = await svc._extract_document_text(file_input)
chunks.append(text)
except Exception as e:
logger.error(
f"Node {self.node_id}: failed to extract file url={f.url} file_id={f.file_id}: {e}",
exc_info=True,
)
chunks.append("")
full_text = "\n\n".join(c for c in chunks if c)
logger.info(f"Node {self.node_id}: extracted {len(files)} file(s), total chars={len(full_text)}")
return {"text": full_text, "chunks": chunks}

View File

@@ -1,7 +1,9 @@
"""End 节点配置""" """End 节点配置"""
from pydantic import Field from pydantic import Field
from app.core.workflow.nodes.base_config import BaseNodeConfig from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
from app.core.workflow.variable.base_variable import VariableType
class EndNodeConfig(BaseNodeConfig): class EndNodeConfig(BaseNodeConfig):

View File

@@ -36,6 +36,8 @@ class EndNode(BaseNode):
Returns: Returns:
最终输出字符串 最终输出字符串
""" """
logger.info(f"节点 {self.node_id} (End) 开始执行")
# 获取配置的输出模板 # 获取配置的输出模板
output_template = self.config.get("output") output_template = self.config.get("output")
@@ -44,4 +46,11 @@ class EndNode(BaseNode):
output = self._render_template(output_template, variable_pool, strict=False) output = self._render_template(output_template, variable_pool, strict=False)
else: else:
output = "" output = ""
# 统计信息(用于日志)
node_outputs = state.get("node_outputs", {})
total_nodes = len(node_outputs)
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
return output return output

View File

@@ -23,13 +23,12 @@ class NodeType(StrEnum):
BREAK = "break" BREAK = "break"
MEMORY_READ = "memory-read" MEMORY_READ = "memory-read"
MEMORY_WRITE = "memory-write" MEMORY_WRITE = "memory-write"
DOCUMENT_EXTRACTOR = "document-extractor"
UNKNOWN = "unknown" UNKNOWN = "unknown"
NOTES = "notes" NOTES = "notes"
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER}) BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER]
class ComparisonOperator(StrEnum): class ComparisonOperator(StrEnum):

View File

@@ -115,7 +115,7 @@ class HttpRetryConfig(BaseModel):
) )
class HttpErrorDefaultTemplate(BaseModel): class HttpErrorDefaultTamplete(BaseModel):
body: str = Field( body: str = Field(
default="", default="",
description="Default body returned on HTTP error", description="Default body returned on HTTP error",
@@ -143,7 +143,7 @@ class HttpErrorHandleConfig(BaseModel):
description="Error handling strategy: 'none', 'default', or 'branch'", description="Error handling strategy: 'none', 'default', or 'branch'",
) )
default: HttpErrorDefaultTemplate | None = Field( default: HttpErrorDefaultTamplete | None = Field(
default=None, default=None,
description="Default response template for error handling", description="Default response template for error handling",
) )

View File

@@ -16,7 +16,7 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
from app.core.workflow.utils.file_processor import mime_to_file_type from app.core.workflow.utils.file_processer import mime_to_file_type
from app.core.workflow.variable.base_variable import VariableType, FileObject from app.core.workflow.variable.base_variable import VariableType, FileObject
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.schemas import FileType, TransferMethod from app.schemas import FileType, TransferMethod
@@ -157,8 +157,8 @@ class HttpRequestNode(BaseNode):
or a branch identifier string when error branching is enabled. or a branch identifier string when error branching is enabled.
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: HttpRequestNodeConfig | None = None self.typed_config: HttpRequestNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
) )
right: Any = Field( right: Any = Field(
default=None, ...,
description="Value to compare with" description="Value to compare with"
) )

View File

@@ -14,8 +14,8 @@ logger = logging.getLogger(__name__)
class IfElseNode(BaseNode): class IfElseNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: IfElseNodeConfig | None = None self.typed_config: IfElseNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -31,13 +31,13 @@ class IfElseNode(BaseNode):
expressions.append({ expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False), "left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right "right": expression.right
if expression.input_type == ValueInputType.CONSTANT or expression.right is None if expression.input_type == ValueInputType.CONSTANT
else self.get_variable(expression.right, variable_pool, strict=False), else self.get_variable(expression.right, variable_pool, strict=False),
"operator": str(expression.operator), "operator": expression.operator,
}) })
result.append({ result.append({
"expressions": expressions, "expressions": expressions,
"logical_operator": str(case.logical_operator), "logical_operator": case.logical_operator,
}) })
return { return {
"cases": result "cases": result

View File

@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
class JinjaRenderNode(BaseNode): class JinjaRenderNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: JinjaRenderNodeConfig | None = None self.typed_config: JinjaRenderNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class KnowledgeRetrievalNode(BaseNode): class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None self.vector_service: ElasticSearchVector | None = None

View File

@@ -70,8 +70,8 @@ class LLMNode(BaseNode):
- ai/assistant: AI 消息AIMessage - ai/assistant: AI 消息AIMessage
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: LLMNodeConfig | None = None self.typed_config: LLMNodeConfig | None = None
self.messages = [] self.messages = []
@@ -144,6 +144,7 @@ class LLMNode(BaseNode):
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}") f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
messages_config = self.typed_config.messages messages_config = self.typed_config.messages
if messages_config: if messages_config:
# 使用 LangChain 消息格式 # 使用 LangChain 消息格式
messages = [] messages = []
@@ -152,6 +153,7 @@ class LLMNode(BaseNode):
content_template = msg_config.content content_template = msg_config.content
content_template = self._render_context(content_template, variable_pool) content_template = self._render_context(content_template, variable_pool)
content = self._render_template(content_template, variable_pool) content = self._render_template(content_template, variable_pool)
user_id = self.get_variable("sys.user_id", variable_pool)
# 根据角色创建对应的消息对象 # 根据角色创建对应的消息对象
if role == "system": if role == "system":
messages.append({ messages.append({
@@ -159,31 +161,32 @@ class LLMNode(BaseNode):
"content": await self.process_message( "content": await self.process_message(
model_info, model_info,
content, content,
user_id,
self.typed_config.vision, self.typed_config.vision,
) )
}) })
elif role in ["user", "human"]: elif role in ["user", "human"]:
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
elif role in ["ai", "assistant"]: elif role in ["ai", "assistant"]:
messages.append({ messages.append({
"role": "assistant", "role": "assistant",
"content": await self.process_message(model_info, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
else: else:
logger.warning(f"未知的消息角色: {role},默认使用 user") logger.warning(f"未知的消息角色: {role},默认使用 user")
messages.append({ messages.append({
"role": "user", "role": "user",
"content": await self.process_message(model_info, content, self.typed_config.vision) "content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
}) })
if self.typed_config.vision_input and self.typed_config.vision: if self.typed_config.vision_input and self.typed_config.vision:
file_content = [] file_content = []
files = variable_pool.get_instance(self.typed_config.vision_input) files = variable_pool.get_instance(self.typed_config.vision_input)
for file in files.value: for file in files.value:
content = await self.process_message(model_info, file.value, self.typed_config.vision) content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
if messages and messages[-1]["role"] == 'user': if messages and messages[-1]["role"] == 'user':
@@ -197,7 +200,7 @@ class LLMNode(BaseNode):
if isinstance(message["content"], list): if isinstance(message["content"], list):
file_content = [] file_content = []
for file in message["content"]: for file in message["content"]:
content = await self.process_message(model_info, file, self.typed_config.vision) content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
if content: if content:
file_content.extend(content) file_content.extend(content)
history_message.append( history_message.append(
@@ -207,6 +210,7 @@ class LLMNode(BaseNode):
message["content"] = await self.process_message( message["content"] = await self.process_message(
model_info, model_info,
message["content"], message["content"],
user_id,
self.typed_config.vision self.typed_config.vision
) )
history_message.append(message) history_message.append(message)

View File

@@ -1,4 +1,3 @@
import re
from typing import Any from typing import Any
from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.state_manager import WorkflowState
@@ -6,16 +5,14 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.base_node import BaseNode
from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig from app.core.workflow.nodes.memory.config import MemoryReadNodeConfig, MemoryWriteNodeConfig
from app.core.workflow.variable.base_variable import VariableType from app.core.workflow.variable.base_variable import VariableType
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
from app.db import get_db_read from app.db import get_db_read
from app.schemas import FileInput
from app.services.memory_agent_service import MemoryAgentService from app.services.memory_agent_service import MemoryAgentService
from app.tasks import write_message_task from app.tasks import write_message_task
class MemoryReadNode(BaseNode): class MemoryReadNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: MemoryReadNodeConfig | None = None self.typed_config: MemoryReadNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
@@ -39,32 +36,19 @@ class MemoryReadNode(BaseNode):
search_switch=self.typed_config.search_switch, search_switch=self.typed_config.search_switch,
history=[], history=[],
db=db, db=db,
storage_type=state["memory_storage_type"], storage_type="neo4j",
user_rag_memory_id=state["user_rag_memory_id"] user_rag_memory_id=""
) )
class MemoryWriteNode(BaseNode): class MemoryWriteNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: MemoryWriteNodeConfig | None = None self.typed_config: MemoryWriteNodeConfig | None = None
def _output_types(self) -> dict[str, VariableType]: def _output_types(self) -> dict[str, VariableType]:
return {"output": VariableType.STRING} return {"output": VariableType.STRING}
@staticmethod
def _extract_multimodal_memory_variables(content: str, variable_pool: VariablePool) -> tuple[list[str], str]:
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
variables = variable_pattern.findall(content)
file_variables = []
for variable in variables:
if variable_pool.is_file_variable(variable):
file_variables.append(variable)
for var in file_variables:
content = content.replace(var, "")
return file_variables, content
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
self.typed_config = MemoryWriteNodeConfig(**self.config) self.typed_config = MemoryWriteNodeConfig(**self.config)
end_user_id = self.get_variable("sys.user_id", variable_pool) end_user_id = self.get_variable("sys.user_id", variable_pool)
@@ -79,42 +63,17 @@ class MemoryWriteNode(BaseNode):
}) })
for message in self.typed_config.messages: for message in self.typed_config.messages:
file_variables, content = self._extract_multimodal_memory_variables(
message.content,
variable_pool
)
file_info = []
for var in file_variables:
instence: FileVariable | ArrayVariable[FileVariable] = variable_pool.get_instance(var)
if isinstance(instence, FileVariable):
file_info.append(FileInput(
type=instence.value.type,
transfer_method=instence.value.transfer_method,
upload_file_id=instence.value.file_id,
url=instence.value.url,
file_type=instence.value.origin_file_type
).model_dump())
elif isinstance(instence, ArrayVariable) and instence.child_type == FileVariable:
for file_instence in instence.value:
file_info.append(FileInput(
type=file_instence.value.type,
transfer_method=file_instence.value.transfer_method,
upload_file_id=file_instence.value.file_id,
url=file_instence.value.url,
file_type=file_instence.value.origin_file_type
).model_dump())
messages.append({ messages.append({
"role": message.role, "role": message.role,
"content": self._render_template(content, variable_pool), "content": self._render_template(message.content, variable_pool)
"files": file_info
}) })
write_message_task.delay( write_message_task.delay(
end_user_id=end_user_id, end_user_id,
message=messages, messages,
config_id=str(self.typed_config.config_id), str(self.typed_config.config_id),
storage_type=state["memory_storage_type"], "neo4j",
user_rag_memory_id=state["user_rag_memory_id"] ""
) )
return "success" return "success"

View File

@@ -26,7 +26,6 @@ from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNode
from app.core.workflow.nodes.question_classifier import QuestionClassifierNode from app.core.workflow.nodes.question_classifier import QuestionClassifierNode
from app.core.workflow.nodes.breaker import BreakNode from app.core.workflow.nodes.breaker import BreakNode
from app.core.workflow.nodes.tool import ToolNode from app.core.workflow.nodes.tool import ToolNode
from app.core.workflow.nodes.document_extractor import DocExtractorNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,8 +49,7 @@ WorkflowNode = Union[
ToolNode, ToolNode,
MemoryReadNode, MemoryReadNode,
MemoryWriteNode, MemoryWriteNode,
CodeNode, CodeNode
DocExtractorNode
] ]
@@ -83,7 +81,6 @@ class NodeFactory:
NodeType.MEMORY_READ: MemoryReadNode, NodeType.MEMORY_READ: MemoryReadNode,
NodeType.MEMORY_WRITE: MemoryWriteNode, NodeType.MEMORY_WRITE: MemoryWriteNode,
NodeType.CODE: CodeNode, NodeType.CODE: CodeNode,
NodeType.DOCUMENT_EXTRACTOR: DocExtractorNode
} }
@classmethod @classmethod
@@ -107,15 +104,13 @@ class NodeFactory:
def create_node( def create_node(
cls, cls,
node_config: dict[str, Any], node_config: dict[str, Any],
workflow_config: dict[str, Any], workflow_config: dict[str, Any]
down_stream_nodes: list[str]
) -> WorkflowNode | None: ) -> WorkflowNode | None:
"""创建节点实例 """创建节点实例
Args: Args:
node_config: 节点配置 node_config: 节点配置
workflow_config: 工作流配置 workflow_config: 工作流配置
down_stream_nodes: 下游节点
Returns: Returns:
节点实例或 None对于不支持的节点类型 节点实例或 None对于不支持的节点类型
@@ -132,7 +127,7 @@ class NodeFactory:
# 创建节点实例 # 创建节点实例
logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})")
return node_class(node_config, workflow_config, down_stream_nodes) return node_class(node_config, workflow_config)
@classmethod @classmethod
def get_supported_types(cls) -> list[str]: def get_supported_types(cls) -> list[str]:

View File

@@ -250,8 +250,6 @@ class ConditionBase(ABC):
self.type_limit = getattr(self, "type_limit", None) self.type_limit = getattr(self, "type_limit", None)
def resolve_right_literal_value(self): def resolve_right_literal_value(self):
if self.right_selector is None:
return None
if self.input_type == ValueInputType.VARIABLE: if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}" pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip() right_expression = re.sub(pattern, r"\1", self.right_selector).strip()

View File

@@ -21,8 +21,8 @@ logger = logging.getLogger(__name__)
class ParameterExtractorNode(BaseNode): class ParameterExtractorNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: ParameterExtractorNodeConfig | None = None self.typed_config: ParameterExtractorNodeConfig | None = None
self.response_metadata = {} self.response_metadata = {}

View File

@@ -22,8 +22,8 @@ DEFAULT_EMPTY_QUESTION_CASE = f"{DEFAULT_CASE_PREFIX}1"
class QuestionClassifierNode(BaseNode): class QuestionClassifierNode(BaseNode):
"""问题分类器节点""" """问题分类器节点"""
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) super().__init__(node_config, workflow_config)
self.typed_config: QuestionClassifierNodeConfig | None = None self.typed_config: QuestionClassifierNodeConfig | None = None
self.category_to_case_map = {} self.category_to_case_map = {}
self.response_metadata = {} self.response_metadata = {}

View File

@@ -27,8 +27,14 @@ class StartNode(BaseNode):
注意:变量的验证和默认值处理由 Executor 在初始化时完成。 注意:变量的验证和默认值处理由 Executor 在初始化时完成。
""" """
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config, down_stream_nodes) """初始化 Start 节点
Args:
node_config: 节点配置
workflow_config: 工作流配置
"""
super().__init__(node_config, workflow_config)
# 解析并验证配置 # 解析并验证配置
self.typed_config: StartNodeConfig | None = None self.typed_config: StartNodeConfig | None = None
@@ -56,6 +62,7 @@ class StartNode(BaseNode):
包含系统参数、会话变量和自定义变量的字典 包含系统参数、会话变量和自定义变量的字典
""" """
self.typed_config = StartNodeConfig(**self.config) self.typed_config = StartNodeConfig(**self.config)
logger.info(f"节点 {self.node_id} (Start) 开始执行")
# 处理自定义变量(传入 pool 避免重复创建) # 处理自定义变量(传入 pool 避免重复创建)
custom_vars = self._process_custom_variables(variable_pool) custom_vars = self._process_custom_variables(variable_pool)
@@ -70,9 +77,9 @@ class StartNode(BaseNode):
**custom_vars # 自定义变量作为节点输出的一部分 **custom_vars # 自定义变量作为节点输出的一部分
} }
logger.debug( logger.info(
f"Node {self.node_id} (Start) execution completed, " f"节点 {self.node_id} (Start) 执行完成,"
f"outputting {len(custom_vars)} custom variables" f"输出了 {len(custom_vars)} 个自定义变量"
) )
return result return result

Some files were not shown because too many files have changed in this diff Show More