Merge remote-tracking branch 'origin/develop' into refactor/memory-config-management
This commit is contained in:
@@ -25,17 +25,13 @@ ENV DEBIAN_FRONTEND=noninteractive
|
||||
# 4. Setup apt
|
||||
# Python package and implicit dependencies:
|
||||
# opencv-python: libglib2.0-0 libglx-mesa0 libgl1
|
||||
# aspose-slides: pkg-config libicu-dev libgdiplus libssl1.1_1.1.1f-1ubuntu2_amd64.deb
|
||||
# python-pptx: default-jdk tika-server-standard-3.0.0.jar
|
||||
# libreoffice: libreoffice libreoffice-writer libreoffice-impress fonts-wqy-zenhei fonts-noto-cjk
|
||||
# python-docx: default-jdk tika-server-standard-3.0.0.jar
|
||||
# Building C extensions: libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev
|
||||
RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libicu-dev && \
|
||||
if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
rm -f /etc/apt/sources.list.d/debian.sources && \
|
||||
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm main contrib non-free non-free-firmware" > /etc/apt/sources.list && \
|
||||
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-updates main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
|
||||
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bookworm-backports main contrib non-free non-free-firmware" >> /etc/apt/sources.list && \
|
||||
echo "deb https://mirrors.tuna.tsinghua.edu.cn/debian-security bookworm-security main contrib non-free non-free-firmware" >> /etc/apt/sources.list; \
|
||||
sed -i 's|http://ports.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
||||
sed -i 's|http://archive.ubuntu.com|http://mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list; \
|
||||
fi; \
|
||||
rm -f /etc/apt/apt.conf.d/docker-clean && \
|
||||
echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache && \
|
||||
@@ -44,7 +40,7 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt --no-install-recommends install -y ca-certificates && \
|
||||
apt update && \
|
||||
apt install -y libglib2.0-0 libglx-mesa0 libgl1 && \
|
||||
apt install -y pkg-config libgdiplus && \
|
||||
apt install -y libreoffice libreoffice-writer libreoffice-impress fonts-wqy-zenhei fonts-noto-cjk && \
|
||||
apt install -y default-jdk && \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
@@ -64,21 +60,13 @@ RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
# https://forum.aspose.com/t/aspose-slides-for-net-no-usable-version-of-libssl-found-with-linux-server/271344/13
|
||||
# 5. aspose-slides on linux/arm64 is unavailable
|
||||
COPY libssl1.1_1.1.1f-1ubuntu2_amd64.deb libssl1.1_1.1.1f-1ubuntu2_arm64.deb /tmp/
|
||||
RUN if [ "$(uname -m)" = "x86_64" ]; then \
|
||||
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_amd64.deb; \
|
||||
elif [ "$(uname -m)" = "aarch64" ]; then \
|
||||
dpkg -i /tmp/libssl1.1_1.1.1f-1ubuntu2_arm64.deb; \
|
||||
fi && \
|
||||
rm -f /tmp/libssl1.1_*.deb
|
||||
|
||||
|
||||
# 6. install dependencies from uv.lock file
|
||||
# 5. install dependencies from uv.lock file
|
||||
COPY ./pyproject.toml /code/pyproject.toml
|
||||
COPY ./uv.lock /code/uv.lock
|
||||
COPY ./app /code/app
|
||||
COPY ./alembic.ini /code/alembic.ini
|
||||
COPY ./migrations /code/migrations
|
||||
|
||||
# https://github.com/astral-sh/uv/issues/10462
|
||||
# uv records index url into uv.lock but doesn't failover among multiple indexes
|
||||
|
||||
@@ -83,6 +83,7 @@ celery_app.autodiscover_tasks(['app'])
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
@@ -97,6 +98,11 @@ beat_schedule_config = {
|
||||
"schedule": workspace_reflection_schedule,
|
||||
"args": (),
|
||||
},
|
||||
"regenerate-memory-cache": {
|
||||
"task": "app.tasks.regenerate_memory_cache",
|
||||
"schedule": memory_cache_regeneration_schedule,
|
||||
"args": (),
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
|
||||
@@ -35,6 +35,7 @@ from . import (
|
||||
tool_controller,
|
||||
tool_execution_controller,
|
||||
)
|
||||
from . import user_memory_controllers
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
manager_router = APIRouter()
|
||||
@@ -58,6 +59,7 @@ manager_router.include_router(upload_controller.router)
|
||||
manager_router.include_router(memory_agent_controller.router)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(memory_storage_controller.router)
|
||||
manager_router.include_router(user_memory_controllers.router)
|
||||
manager_router.include_router(api_key_controller.router)
|
||||
manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
|
||||
@@ -287,7 +287,7 @@ async def get_workspace_total_memory_count(
|
||||
"total_memory_count": int,
|
||||
"host_count": int,
|
||||
"details": [
|
||||
{"host_id": "uuid", "count": 100},
|
||||
{"end_user_id": "uuid", "count": 100, "name": "用户名称"},
|
||||
...
|
||||
]
|
||||
}
|
||||
|
||||
@@ -218,6 +218,7 @@ async def start_reflection_configs(
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
language_type: str = "zh",
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -257,7 +258,8 @@ async def reflection_run(
|
||||
memory_verify=result.memory_verify,
|
||||
quality_assessment=result.quality_assessment,
|
||||
violation_handling_strategy="block",
|
||||
model_id=model_id
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import datetime
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
@@ -8,7 +9,12 @@ from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
@@ -21,11 +27,10 @@ from app.schemas.memory_storage_schema import (
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.memory_storage_service import (
|
||||
DataConfigService,
|
||||
GenerateCacheRequest,
|
||||
MemoryStorageService,
|
||||
analytics_hot_memory_tags,
|
||||
analytics_memory_insight_report,
|
||||
analytics_recent_activity_stats,
|
||||
analytics_user_summary,
|
||||
kb_type_distribution,
|
||||
search_all,
|
||||
search_chunk,
|
||||
@@ -491,20 +496,6 @@ async def get_hot_memory_tags_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
api_logger.info(f"Memory insight report requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await analytics_memory_insight_report(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Memory insight report failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告生成失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
|
||||
async def get_recent_activity_stats_api(
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -543,3 +534,4 @@ async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
自我反思结果。
|
||||
"""
|
||||
return await self_reflexion(host_id)
|
||||
|
||||
|
||||
382
api/app/controllers/user_memory_controllers.py
Normal file
382
api/app/controllers/user_memory_controllers.py
Normal file
@@ -0,0 +1,382 @@
|
||||
"""
|
||||
用户记忆相关的控制器
|
||||
包含用户摘要、记忆洞察、节点统计、图数据和用户档案等接口
|
||||
"""
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_node_statistics,
|
||||
analytics_graph_data,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
)
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
|
||||
# Get API logger
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Initialize service
|
||||
user_memory_service = UserMemoryService()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/memory-storage",
|
||||
tags=["User Memory"],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的记忆洞察报告"""
|
||||
api_logger.info(f"记忆洞察报告请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"记忆洞察报告缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"记忆洞察报告查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "记忆洞察报告查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str, # 使用 end_user_id
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""获取缓存的用户摘要"""
|
||||
api_logger.info(f"用户摘要请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
# 缓存存在,返回缓存数据
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
else:
|
||||
# 缓存不存在,返回提示消息
|
||||
api_logger.info(f"用户摘要缓存不存在: end_user_id={end_user_id}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
手动触发缓存生成
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
group_id = request.end_user_id
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={group_id if group_id else '全部用户'}"
|
||||
)
|
||||
|
||||
try:
|
||||
if group_id:
|
||||
# 为单个用户生成
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"end_user_id": group_id,
|
||||
"insight_success": insight_result["success"],
|
||||
"summary_success": summary_result["success"],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# 收集错误信息
|
||||
if not insight_result["success"]:
|
||||
result["errors"].append({
|
||||
"type": "insight",
|
||||
"error": insight_result.get("error")
|
||||
})
|
||||
if not summary_result["success"]:
|
||||
result["errors"].append({
|
||||
"type": "summary",
|
||||
"error": summary_result.get("error")
|
||||
})
|
||||
|
||||
# 记录结果
|
||||
if result["insight_success"] and result["summary_success"]:
|
||||
api_logger.info(f"成功为用户 {group_id} 生成缓存")
|
||||
else:
|
||||
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
|
||||
|
||||
return success(data=result, msg="生成完成")
|
||||
|
||||
else:
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 批量生成完成: "
|
||||
f"总数={result['total_users']}, 成功={result['successful']}, 失败={result['failed']}"
|
||||
)
|
||||
|
||||
return success(data=result, msg="批量生成完成")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"缓存生成失败: user={current_user.username}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "缓存生成失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/node_statistics", response_model=ApiResponse)
|
||||
async def get_node_statistics_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询节点统计但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"节点统计请求: end_user_id={end_user_id}, user={current_user.username}, workspace={workspace_id}")
|
||||
|
||||
try:
|
||||
result = await analytics_node_statistics(db, end_user_id)
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["total"] == 0:
|
||||
api_logger.warning(f"节点统计查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(f"成功获取节点统计: end_user_id={end_user_id}, total={result['total']}")
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户摘要查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户摘要查询失败", str(e))
|
||||
|
||||
@router.get("/analytics/graph_data", response_model=ApiResponse)
|
||||
async def get_graph_data_api(
|
||||
end_user_id: str,
|
||||
node_types: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询图数据但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 参数验证
|
||||
if limit > 1000:
|
||||
limit = 1000
|
||||
api_logger.warning("limit 参数超过最大值,已调整为 1000")
|
||||
|
||||
if depth > 3:
|
||||
depth = 3
|
||||
api_logger.warning("depth 参数超过最大值,已调整为 3")
|
||||
|
||||
# 解析 node_types 参数
|
||||
node_types_list = None
|
||||
if node_types:
|
||||
node_types_list = [t.strip() for t in node_types.split(",") if t.strip()]
|
||||
|
||||
api_logger.info(
|
||||
f"图数据查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}, node_types={node_types_list}, limit={limit}, depth={depth}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await analytics_graph_data(
|
||||
db=db,
|
||||
end_user_id=end_user_id,
|
||||
node_types=node_types_list,
|
||||
limit=limit,
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"图数据查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
name=end_user.name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取用户信息: end_user_id={end_user_id}")
|
||||
return success(data=profile_data.model_dump(), msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户信息查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息查询失败", str(e))
|
||||
|
||||
|
||||
@router.post("/updated_end_user/profile", response_model=ApiResponse)
|
||||
async def update_end_user_profile(
|
||||
profile_update: EndUserProfileUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""
|
||||
更新终端用户的基本信息
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"用户信息更新请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 更新字段(只更新提供的非 None 字段,排除 end_user_id)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
for field, value in update_data.items():
|
||||
if value is not None:
|
||||
setattr(end_user, field, value)
|
||||
|
||||
# 更新 updated_at 时间戳
|
||||
end_user.updated_at = datetime.datetime.now()
|
||||
|
||||
# 更新 updatetime_profile 为当前时间戳(毫秒)
|
||||
current_timestamp = int(datetime.datetime.now().timestamp() * 1000)
|
||||
end_user.updatetime_profile = current_timestamp
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
name=end_user.name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}, updatetime_profile={current_timestamp}")
|
||||
return success(data=profile_data.model_dump(), msg="更新成功")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
@@ -151,6 +151,9 @@ class Settings:
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
|
||||
REFLECTION_INTERVAL_TIME:Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
|
||||
@@ -3,53 +3,43 @@
|
||||
"source_data": [
|
||||
{
|
||||
"statement_name": "用户是2023年春天去北京工作的。",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "62beac695b1346f4871740a45db88782"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户后来基本一直都在北京上班。",
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "4cba5ac08b674d7fb1e2ae634d2b8f0b"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从2023年开始就一直在北京生活。",
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "e612a44da4db483993c350df7c97a1a1"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户从来没有长期离开过北京。",
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "b3c787a2e33c49f7981accabbbb4538a"
|
||||
},
|
||||
{
|
||||
"statement_name": "由于公司调整,用户在2024年上半年被调到上海待了差不多半年。",
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "64cde4230cb24a4da726e7db9e7aa616"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在被调到上海期间每天都是在上海办公室打卡。",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户在入职时使用的身份信息是之前的,身份证号为11010119950308123X。",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的银行卡号是6222023847595898。",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户的身份信息和银行卡信息一直没变。",
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "b3ca618e1e204b83bebd70e75cf2073f"
|
||||
},
|
||||
{
|
||||
"statement_name": "用户认为在上海的那段时间更多算是远程配合。",
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962",
|
||||
"statement_created_at": "2025-12-19T10:31:15.239252"
|
||||
"statement_id": "150af89d2c154e6eb41ff1a91e37f962"
|
||||
}
|
||||
],
|
||||
"databasets": [
|
||||
@@ -57,24 +47,11 @@
|
||||
"entity1_name": "Person",
|
||||
"description": "表示人类个体的通用类型",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "用户",
|
||||
"entity2": {
|
||||
"entity_idx": 0,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Person",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "用户",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3d3896797b334572a80d57590026063d"
|
||||
}
|
||||
},
|
||||
@@ -82,24 +59,11 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份信息",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用于个人身份识别的数据",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Information",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份信息",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "aa766a517e82490599a9b3af54cfd933"
|
||||
}
|
||||
},
|
||||
@@ -107,24 +71,11 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "6222023847595898",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "Strong",
|
||||
"description": "用户的银行卡号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "6c7567cd1f3c478bb42d1b65383e6f2f",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Numeric",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "6222023847595898",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "610ba361918f4e68a65ce6ad06e5c7a0"
|
||||
}
|
||||
},
|
||||
@@ -132,25 +83,13 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "上海办公室",
|
||||
"entity2": {
|
||||
"entity_idx": 1,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["上海办"],
|
||||
"connect_strength": "Strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "位于上海的工作办公场所",
|
||||
"statement_id": "8b1b12e23b844b8088dfeb67da6ad669",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "上海办公室",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "fb702ef695c14e14af3e56786bc8815b"
|
||||
}
|
||||
},
|
||||
@@ -158,25 +97,12 @@
|
||||
"entity1_name": "用户",
|
||||
"description": "叙述者,讲述个人工作与生活经历的个体",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "北京",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"aliases": ["京", "京城", "北平"],
|
||||
"connect_strength": "strong",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"description": "中国的首都城市,用户主要工作和生活所在地",
|
||||
"statement_id": "62beac695b1346f4871740a45db88782",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Location",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "北京",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "81b2d1a571bb46a08a2d7a1e87efb945"
|
||||
}
|
||||
},
|
||||
@@ -184,24 +110,11 @@
|
||||
"entity1_name": "11010119950308123X",
|
||||
"description": "具体的身份证号码值",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"relationship_type": "EXTRACTED_RELATIONSHIP",
|
||||
"relationship": {},
|
||||
"entity2_name": "身份证号",
|
||||
"entity2": {
|
||||
"entity_idx": 2,
|
||||
"run_id": "62b59cfebeea43dd94d91763056f069a",
|
||||
"connect_strength": "strong",
|
||||
"description": "中华人民共和国公民的身份号码",
|
||||
"created_at": "2025-12-19T10:31:15.239252000",
|
||||
"statement_id": "030afd362e9b4110b139e68e5d3e7143",
|
||||
"expired_at": "9999-12-31T00:00:00.000000000",
|
||||
"entity_type": "Identifier",
|
||||
"group_id": "88a459f5_text08",
|
||||
"user_id": "88a459f5_text08",
|
||||
"name": "身份证号",
|
||||
"apply_id": "88a459f5_text08",
|
||||
"id": "3e5f920645b2404fadb0e9ff60d1306e"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,8 +17,23 @@ import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config import get_model_config
|
||||
from app.core.memory.utils.config.get_data import (
|
||||
extract_and_process_changes,
|
||||
get_data,
|
||||
get_data_statement,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
@@ -26,6 +41,10 @@ from app.repositories.neo4j.cypher_queries import (
|
||||
)
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.repositories.neo4j.neo4j_update import neo4j_data
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConflictResultSchema,
|
||||
ReflexionResultSchema,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
@@ -38,7 +57,9 @@ if not _root_logger.handlers:
|
||||
else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
data: str
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
@@ -66,6 +87,7 @@ class ReflectionConfig(BaseModel):
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
use_enum_values = True
|
||||
@@ -126,6 +148,7 @@ class ReflectionEngine:
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
self._lazy_init_done = False
|
||||
|
||||
@@ -135,7 +158,6 @@ class ReflectionEngine:
|
||||
return
|
||||
|
||||
if self.neo4j_connector is None:
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
@@ -147,20 +169,35 @@ class ReflectionEngine:
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
# from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
model_id = self.llm_client
|
||||
# model_id = self.llm_client
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(model_id)
|
||||
# self.llm_client = factory.get_llm_client(model_id)
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
}
|
||||
|
||||
model_config = get_model_config(self.llm_client)
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
model_name=model_config.get("model_name"),
|
||||
provider=model_config.get("provider"),
|
||||
api_key=model_config.get("api_key"),
|
||||
base_url=model_config.get("base_url"),
|
||||
timeout=model_config.get("timeout", 30),
|
||||
max_retries=model_config.get("max_retries", 2),
|
||||
extra_params=extra_params
|
||||
), type_=model_config.get("type"))
|
||||
|
||||
if self.get_data_func is None:
|
||||
from app.core.memory.utils.config.get_data import get_data
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
from app.core.memory.utils.config.get_data import get_data_statement
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
if self.render_evaluate_prompt_func is None:
|
||||
@@ -176,11 +213,9 @@ class ReflectionEngine:
|
||||
self.render_reflexion_prompt_func = render_reflexion_prompt
|
||||
|
||||
if self.conflict_schema is None:
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
self.conflict_schema = ConflictResultSchema
|
||||
|
||||
if self.reflexion_schema is None:
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
self.reflexion_schema = ReflexionResultSchema
|
||||
|
||||
if self.update_query is None:
|
||||
@@ -227,15 +262,12 @@ class ReflectionEngine:
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
conflicts_found = len(conflict_data[0]['data']) if has_conflict else 0
|
||||
logging.info(f"冲突状态: {has_conflict}, 发现 {conflicts_found} 个冲突")
|
||||
# # 检查是否真的有冲突
|
||||
conflicts_found=''
|
||||
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
conflicts_found=''
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
if not solved_data:
|
||||
@@ -256,7 +288,7 @@ class ReflectionEngine:
|
||||
await self._log_data("solved_data", solved_data)
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
memories_updated = await self._apply_reflection_results(solved_data)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
|
||||
@@ -280,9 +312,60 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{text}\n\n中文翻译为英文,输出格式为{{\"data\":\"翻译后的内容\"}}"
|
||||
}
|
||||
]
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=translation_messages,
|
||||
response_model=TranslationResponse
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
reflexion_data=[]
|
||||
if data['memory_verifies']!=[]:
|
||||
for i in data['memory_verifies']:
|
||||
end_data={}
|
||||
end_data['has_privacy'] = i['has_privacy']
|
||||
privacy=i['privacy_types']
|
||||
privacy_types_=[]
|
||||
for pri in privacy:
|
||||
privacy_types_.append(await self.Translate(pri))
|
||||
end_data['privacy_types']=privacy_types_
|
||||
end_data['summary']=await self.Translate(i['summary'])
|
||||
memory_verifies.append(end_data)
|
||||
end_datas['memory_verifies']=memory_verifies
|
||||
|
||||
if data['quality_assessments']!=[]:
|
||||
for i in data['quality_assessments']:
|
||||
end_data = {}
|
||||
end_data['score']=i['score']
|
||||
end_data['summary'] = await self.Translate(i['summary'])
|
||||
quality_assessments.append(end_data)
|
||||
end_datas['quality_assessments'] = quality_assessments
|
||||
for i in data['reflexion_data']:
|
||||
end_data = {}
|
||||
end_data['reason'] = await self.Translate(i['reason'])
|
||||
end_data['solution'] = await self.Translate(i['solution'])
|
||||
reflexion_data.append(end_data)
|
||||
end_datas['reflexion_data'] = reflexion_data
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
quality_assessment=self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
asyncio.get_event_loop().time()
|
||||
logging.info("====== 自我反思流程开始 ======")
|
||||
@@ -291,20 +374,18 @@ class ReflectionEngine:
|
||||
|
||||
source_data, databasets = await self.extract_fields_from_json()
|
||||
result_data['baseline'] = self.config.baseline
|
||||
result_data[
|
||||
'source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
print(item)
|
||||
quality_assessments.append(item['quality_assessment'])
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
|
||||
# 检查是否真的有冲突
|
||||
has_conflict = conflict_data[0].get('conflict', False)
|
||||
@@ -314,8 +395,16 @@ class ReflectionEngine:
|
||||
# 记录冲突数据
|
||||
await self._log_data("conflict", conflict_data)
|
||||
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
'data': item['data'],
|
||||
'conflict': item['conflict']
|
||||
}
|
||||
cleaned_conflict_data.append(cleaned_item)
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, source_data)
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -331,6 +420,14 @@ class ReflectionEngine:
|
||||
for result in item['results']:
|
||||
reflexion_data.append(result['reflexion'])
|
||||
result_data['reflexion_data'] = reflexion_data
|
||||
if memory_verifies_flag==False:
|
||||
result_data['memory_verifies']=[]
|
||||
if quality_assessment==False:
|
||||
result_data['quality_assessments']=[]
|
||||
|
||||
if language_type=='en':
|
||||
result_data=await self.extract_translation(result_data)
|
||||
print(time.time()-start_time,'----------')
|
||||
return result_data
|
||||
|
||||
|
||||
@@ -407,12 +504,13 @@ class ReflectionEngine:
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
quality_assessment = self.config.quality_assessment
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
@@ -422,7 +520,8 @@ class ReflectionEngine:
|
||||
self.config.baseline,
|
||||
memory_verify,
|
||||
quality_assessment,
|
||||
statement_databasets
|
||||
statement_databasets,
|
||||
language_type
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
@@ -485,6 +584,7 @@ class ReflectionEngine:
|
||||
memory_verify,
|
||||
statement_databasets
|
||||
)
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
@@ -537,7 +637,8 @@ class ReflectionEngine:
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
"""
|
||||
success_count = await neo4j_data(solved_data)
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
return success_count
|
||||
|
||||
async def _log_data(self, label: str, data: Any) -> None:
|
||||
@@ -644,5 +745,8 @@ class ReflectionEngine:
|
||||
execution_time=time_result.execution_time + fact_result.execution_time
|
||||
)
|
||||
else:
|
||||
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,20 @@ import uuid
|
||||
import logging
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from openai import BaseModel
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from pydantic import model_validator, Field
|
||||
|
||||
from app.schemas.memory_storage_schema import SingleReflexionResultSchema
|
||||
from app.schemas.memory_storage_schema import ReflexionResultSchema
|
||||
from app.repositories.neo4j.neo4j_update import map_field_names
|
||||
# 添加项目根目录到 Python 路径
|
||||
sys.path.append(str(Path(__file__).parent))
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def _load_(data: List[Any]) -> List[Dict]:
|
||||
@@ -59,6 +73,14 @@ async def get_data(result):
|
||||
"""
|
||||
从数据库中获取数据
|
||||
"""
|
||||
EXCLUDE_FIELDS = {
|
||||
"user_id",
|
||||
"group_id",
|
||||
"entity_type",
|
||||
"connect_strength",
|
||||
"relationship_type",
|
||||
"apply_id"
|
||||
}
|
||||
neo4j_databasets=[]
|
||||
for item in result:
|
||||
filtered_item = {}
|
||||
@@ -73,14 +95,17 @@ async def get_data(result):
|
||||
rel_filtered['statement_id'] = value.get('statement_id')
|
||||
rel_filtered['expired_at'] = value.get('expired_at')
|
||||
rel_filtered['created_at'] = value.get('created_at')
|
||||
filtered_item[key] = rel_filtered
|
||||
filtered_item[key] = value
|
||||
elif key == 'entity2' and value is not None:
|
||||
# 过滤entity2的name_embedding字段
|
||||
entity2_filtered = {}
|
||||
if hasattr(value, 'items'):
|
||||
for e_key, e_value in value.items():
|
||||
if 'name_embedding' not in e_key.lower():
|
||||
entity2_filtered[e_key] = e_value
|
||||
if e_key in EXCLUDE_FIELDS:
|
||||
continue
|
||||
if 'name_embedding' in e_key.lower():
|
||||
continue
|
||||
entity2_filtered[e_key] = e_value
|
||||
filtered_item[key] = entity2_filtered
|
||||
else:
|
||||
filtered_item[key] = value
|
||||
@@ -94,8 +119,57 @@ async def get_data_statement( result):
|
||||
neo4j_databasets.append(i)
|
||||
return neo4j_databasets
|
||||
|
||||
class ReflexionResultSchema(BaseModel):
|
||||
"""Schema for the complete reflexion result data - a list of individual conflict resolutions."""
|
||||
results: List[SingleReflexionResultSchema] = Field(..., description="List of individual conflict resolution results, grouped by conflict type.")
|
||||
|
||||
@model_validator(mode="before")
|
||||
def _normalize_resolved(cls, v):
|
||||
if isinstance(v, dict):
|
||||
conflict = v.get("conflict")
|
||||
if isinstance(conflict, dict) and conflict.get("conflict") is False:
|
||||
v["resolved"] = None
|
||||
else:
|
||||
resolved = v.get("resolved")
|
||||
if isinstance(resolved, dict):
|
||||
orig = resolved.get("original_memory_id")
|
||||
mem = resolved.get("resolved_memory")
|
||||
if orig is None and (mem is None or mem == {}):
|
||||
v["resolved"] = None
|
||||
return v
|
||||
def extract_and_process_changes(DATA):
|
||||
"""提取并处理 change 字段"""
|
||||
all_changes = []
|
||||
for i, item in enumerate(DATA):
|
||||
try:
|
||||
result = ReflexionResultSchema(**item)
|
||||
for j, res in enumerate(result.results):
|
||||
if res.resolved and res.resolved.change:
|
||||
for k, change in enumerate(res.resolved.change):
|
||||
change_data = {}
|
||||
for field_item in change.field:
|
||||
for key, value in field_item.items():
|
||||
change_data[key] = value
|
||||
if isinstance(value, list):
|
||||
print(f" - {key}: {value[0]} -> {value[1]}")
|
||||
else:
|
||||
print(f" - {key}: {value}")
|
||||
|
||||
all_changes.append({
|
||||
'data': change_data
|
||||
})
|
||||
|
||||
# 测试字段映射
|
||||
try:
|
||||
mapped = map_field_names(change_data)
|
||||
print(f" 映射结果: {mapped}")
|
||||
except Exception as e:
|
||||
print(f" 映射失败: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"处理结果 {i + 1} 失败: {e}")
|
||||
|
||||
return all_changes
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
@@ -1,222 +1,88 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j),以及相关配置参数:
|
||||
原本的输入句子:{{statement_databasets}}
|
||||
需要检测冲突对象:{{ evaluate_data }}
|
||||
冲突判定类型:{{ baseline }}(取值为 TIME / FACT / HYBRID)
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
记忆质量评估开关开关:{{ quality_assessment }}(取值为 true / false)
|
||||
|
||||
你的任务是:
|
||||
对用户历史记忆数据进行冲突检测和记忆审核,并输出严格结构化的 JSON 分析结果
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接evaluate_data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估)
|
||||
## 冲突定义
|
||||
# 记忆数据分析任务
|
||||
|
||||
## 输入数据
|
||||
- **原始句子**: {{statement_databasets}}
|
||||
- **检测对象**: {{ evaluate_data }}
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **质量评估**: {{ quality_assessment }} (true/false)
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
## 任务目标
|
||||
对用户记忆数据进行冲突检测、隐私审核和质量评估,输出结构化JSON结果。
|
||||
**数据关系**: statement_databasets中的statement_id对应evaluate_data中的记录,代表句子拆分后的实体关系。
|
||||
## 1. 冲突检测
|
||||
### 时间冲突
|
||||
时间冲突是指同一用户的相关事件在时间维度上存在逻辑矛盾:
|
||||
|
||||
1. **同一活动的时间冲突**:
|
||||
- 同一用户的同一活动在不同时间点被记录(如"周五打球"和"周六打球")
|
||||
- 同一用户在同一时间段内被记录进行不同的互斥活动
|
||||
|
||||
2. **时间逻辑错误**:
|
||||
- expired_at 早于 created_at
|
||||
- 同一事实的 created_at 时间差异超过合理误差范围(>5分钟)
|
||||
|
||||
3. **日期属性冲突**:
|
||||
- 同一人的生日记录为不同日期(如"2月10号"和"2月16号")
|
||||
4.存在明确先后约束 A -> B,但 t(A) > t(B)
|
||||
-例:入学时间晚于毕业时间。
|
||||
-处理:标记异常、降权、触发逻辑反思或人工审查。
|
||||
5.时间属性冲突
|
||||
-单值日期属性出现多值(生日、入职日期)
|
||||
-注意:本质属于事实冲突的日期特例,归入事实冲突仲裁框架。
|
||||
6.互斥重叠冲突
|
||||
-例:同一主体的两个事件区间重叠且互斥(如同一时间出现在两地)
|
||||
-处理:证据仲裁、保留多版本(active + candidate)。
|
||||
|
||||
|
||||
|
||||
- **同一活动时间矛盾**: 同一用户同一活动的不同时间记录
|
||||
- **时间逻辑错误**: expired_at < created_at,created_at时间差>5分钟
|
||||
- **日期属性冲突**: 同一人的生日等单值属性出现多值
|
||||
- **先后约束违反**: 存在A→B约束但t(A)>t(B)(如入学>毕业)
|
||||
- **互斥重叠**: 同一时间出现在不同地点等互斥事件
|
||||
### 事实冲突
|
||||
事实冲突是指同一实体的属性或关系存在相互矛盾的陈述:
|
||||
|
||||
1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
|
||||
### 混合冲突检测
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:
|
||||
检测任何逻辑上不一致或相互矛盾的记录
|
||||
## 记忆审核定义
|
||||
|
||||
### 隐私信息检测(隐私冲突)
|
||||
当memory_verify为true时,需要额外检测包含个人隐私信息的记录:
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
|
||||
### 隐私检测原则
|
||||
- 检测description、entity1_name、entity2_name等字段中的隐私信息
|
||||
- 识别数字模式(如手机号11位数字、身份证18位等)
|
||||
- 识别关键词(如"身份证"、"银行卡"、"密码"等)
|
||||
- 检测敏感实体类型和关系
|
||||
|
||||
## 冲突检测原则
|
||||
|
||||
**全面检测**:不区分冲突类型,检测所有可能的冲突
|
||||
**完整输出**:如果发现任何冲突或隐私信息,必须将所有相关记录都放入data字段
|
||||
**实体关联**:重点检查涉及相同实体(entity1_name, entity2_name)的记录
|
||||
**语义分析**:分析description字段的语义相似性和冲突性
|
||||
**时间逻辑**:检查时间字段的逻辑一致性
|
||||
**隐私检测**:当memory_verify为true时,检测所有包含隐私信息的记录
|
||||
|
||||
## 不符合冲突检测
|
||||
-称呼
|
||||
## 重要检测示例
|
||||
|
||||
### 冲突检测示例
|
||||
- 用户与不同时间点的关系(周五 vs 周六,2月10号 vs 2月16号)
|
||||
- 同一实体的重复定义但描述不同
|
||||
- 同一关系的不同表述但含义冲突
|
||||
- 任何逻辑上不可能同时为真的记录
|
||||
|
||||
### 隐私信息检测示例
|
||||
- 包含手机号的记录:"用户的手机号是13812345678"
|
||||
- 包含身份证的记录:"身份证号码为110101199001011234"
|
||||
- 包含银行卡的记录:"银行卡号6222021234567890"
|
||||
- 包含社交账号的记录:"微信号是user123456"
|
||||
- 包含敏感信息的实体名称或描述
|
||||
|
||||
## 输出要求
|
||||
|
||||
**关键原则**:
|
||||
1. 当存在冲突或检测到隐私信息时,conflict才为true,data字段才包含相关记录
|
||||
2. 如果发现冲突,必须将所有相关的冲突记录都放入data数组中
|
||||
3. 如果memory_verify为true且检测到隐私信息,必须将包含隐私信息的记录也放入data数组中
|
||||
4. 既没有冲突也没有隐私信息时,conflict为false,data为空数组
|
||||
5. 如果quality_assessment为true,独立分析数据质量并输出评估结果;如果为false,quality_assessment字段输出null
|
||||
6. 冲突检测、隐私审核和质量评估三个功能完全独立,互不影响
|
||||
7. 不输出conflict_memory字段
|
||||
|
||||
**处理逻辑**:
|
||||
- 首先进行冲突检测,将冲突记录加入data数组
|
||||
- 如果memory_verify为true,再进行隐私信息检测,将包含隐私信息的记录也加入data数组
|
||||
- 如果quality_assessment为true,独立进行质量评估,分析所有输入数据的质量并输出评估结果
|
||||
- 最终data数组包含所有冲突记录和隐私信息记录(去重)
|
||||
- quality_assessment字段独立输出,不影响冲突检测和隐私审核结果
|
||||
- memory_verify字段独立输出隐私检测结果,包含检测到的隐私信息类型和概述
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出,响应必须是与此确切模式匹配的有效JSON对象。不要在JSON之前或之后包含任何文本。
|
||||
- 关键的JSON格式要求{"statement":识别出的文本内容}
|
||||
1.JSON结构仅使用标准ASCII双引号(")-切勿使用中文引号("")或其他Unicode引号
|
||||
2.如果提取的语句文本包含引号,请使用反斜杠(\")正确转义它们
|
||||
3.确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4.JSON字符串值中不包括换行符
|
||||
5.正确转义的例子:"statement":"Zhang Xinhua said:\"我非常喜欢这本书\""
|
||||
6.不允许输出```json```相关符号,如```json```、``````、```python```、```javascript```、```html```、```css```、```sql```、```java```、```c```、```c++```、```c#```、```ruby```
|
||||
|
||||
## 记忆质量评估定义
|
||||
|
||||
### 质量评估标准
|
||||
当quality_assessment为true时,需要对记忆数据进行质量评估:
|
||||
|
||||
1. **数据完整性**:
|
||||
- 检查必要字段是否完整(entity1_name、entity2_name、description等)
|
||||
- 检查关系描述是否清晰明确
|
||||
- 检查时间字段的有效性
|
||||
|
||||
2. **重复字段检测**:
|
||||
- 识别相同或高度相似的记录
|
||||
- 检测冗余的实体关系
|
||||
- 分析描述内容的重复度
|
||||
|
||||
3. **无意义字段检测**:
|
||||
- 识别空值、无效值或占位符内容
|
||||
- 检测过于简单或无信息量的描述
|
||||
- 识别格式错误或不规范的数据
|
||||
|
||||
4. **上下文依赖性**:
|
||||
- 评估记录是否需要额外上下文才能理解
|
||||
- 检查实体名称的明确性
|
||||
- 分析关系描述的自包含性
|
||||
|
||||
### 质量评估输出
|
||||
- **质量百分比**:基于上述标准计算的整体质量分数(0-100)
|
||||
- **质量概述**:简要描述数据质量状况,包括主要问题和优点
|
||||
|
||||
输出是仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
- **属性互斥**: 同一实体的相反属性(喜欢↔不喜欢)
|
||||
- **关系矛盾**: 同一实体在相同语境下的不同关系描述
|
||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||
### 混合冲突
|
||||
检测所有逻辑不一致或相互矛盾的记录。
|
||||
**检测原则**:
|
||||
- 重点检查相同实体的记录
|
||||
- 分析description字段语义冲突
|
||||
- 验证时间字段逻辑一致性
|
||||
## 2. 隐私审核 (memory_verify=true时)
|
||||
### 隐私信息类型
|
||||
- **身份信息**: 身份证号码、身份证相关描述
|
||||
- **联系方式**: 手机号、电话号码
|
||||
- **社交账号**: 微信号、QQ号、邮箱地址
|
||||
- **金融信息**: 银行卡号、账户信息、支付信息
|
||||
- **税务信息**: 税号、纳税信息、发票信息
|
||||
- **贷款信息**: 贷款记录、信贷信息
|
||||
- **安全信息**: 密码、PIN码、验证码
|
||||
### 检测方法
|
||||
- 检测description、entity1_name、entity2_name、name等字段
|
||||
- 识别数字模式(手机号11位、身份证18位等)
|
||||
- 识别关键词("身份证"、"银行卡"、"密码"等)
|
||||
## 3. 质量评估 (quality_assessment=true时)
|
||||
### 评估标准
|
||||
- **数据完整性**: 必要字段完整性、关系描述清晰度、时间字段有效性
|
||||
- **重复检测**: 相同或高度相似记录、冗余实体关系、描述重复度
|
||||
- **无意义检测**: 空值/无效值、过于简单的描述、格式错误
|
||||
- **上下文依赖**: 记录自包含性、实体名称明确性
|
||||
### 输出内容
|
||||
- **质量分数**: 0-100的整体质量百分比
|
||||
- **质量概述**: 简要描述数据质量状况和主要问题
|
||||
## 输出规则
|
||||
### 核心原则
|
||||
1. **conflict=true**: 存在冲突或隐私信息时,将所有相关记录放入data数组
|
||||
2. **conflict=false**: 无冲突且无隐私信息时,data为空数组
|
||||
3. **独立功能**: 冲突检测、隐私审核、质量评估三者完全独立
|
||||
4. **条件输出**:
|
||||
- quality_assessment=true时输出评估对象,否则为null
|
||||
- memory_verify=true时输出隐私检测对象,否则为null
|
||||
5. **不输出conflict_memory字段**
|
||||
### 处理流程
|
||||
1. 冲突检测 → 将冲突记录加入data
|
||||
2. 隐私审核(如启用) → 将隐私记录加入data
|
||||
3. 质量评估(如启用) → 独立输出评估结果
|
||||
4. 去重data数组中的记录
|
||||
**输出结构**:
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"entity1_name": "实体1名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间戳",
|
||||
"expired_at": "过期时间戳",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": "关系对象",
|
||||
"entity2_name": "实体2名称",
|
||||
"entity2": "实体2对象"
|
||||
}
|
||||
],
|
||||
"conflict": true或false,
|
||||
"data": [记录数组],
|
||||
"conflict": true/false,
|
||||
"quality_assessment": {
|
||||
"score": 质量百分比数字,
|
||||
"summary": "质量概述文本"
|
||||
"score": 数字,
|
||||
"summary": "文本"
|
||||
} 或 null,
|
||||
"memory_verify": {
|
||||
"has_privacy": true或false,
|
||||
"privacy_types": ["检测到的隐私信息类型列表"],
|
||||
"summary": "隐私检测结果概述"
|
||||
"has_privacy": true/false,
|
||||
"privacy_types": ["类型数组"],
|
||||
"summary": "概述文本"
|
||||
} 或 null
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本。
|
||||
- 使用标准双引号,必要时对内部引号进行转义。
|
||||
- 字段名与结构必须与给定模式一致。
|
||||
- data数组中包含冲突记录和隐私信息记录,如果都没有则为空数组。
|
||||
- quality_assessment字段:当quality_assessment参数为true时输出评估对象,为false时输出null。
|
||||
- memory_verify字段:当memory_verify参数为true时输出隐私检测结果对象,为false时输出null。
|
||||
|
||||
### memory_verify字段说明
|
||||
当memory_verify为true时,需要输出隐私检测结果:
|
||||
- **has_privacy**: 布尔值,表示是否检测到隐私信息
|
||||
- **privacy_types**: 字符串数组,包含检测到的隐私信息类型(如["手机号码", "身份证信息"])
|
||||
- **summary**: 字符串,简要描述隐私检测结果
|
||||
|
||||
当memory_verify为false时,memory_verify字段输出null。
|
||||
|
||||
### memory_verify字段示例
|
||||
|
||||
**示例1:检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": true,
|
||||
"privacy_types": ["手机号码", "身份证信息"],
|
||||
"summary": "检测到2条记录包含隐私信息:1个手机号码,1个身份证号码"
|
||||
}
|
||||
```
|
||||
|
||||
**示例2:未检测到隐私信息**
|
||||
```json
|
||||
"memory_verify": {
|
||||
"has_privacy": false,
|
||||
"privacy_types": [],
|
||||
"summary": "未检测到隐私信息"
|
||||
}
|
||||
```
|
||||
|
||||
**示例3:memory_verify为false时**
|
||||
```json
|
||||
"memory_verify": null
|
||||
```
|
||||
|
||||
模式参考:
|
||||
{{ json_schema }}
|
||||
**字段说明**:
|
||||
- **data**: 包含冲突记录和隐私信息记录,无则为空数组
|
||||
- **quality_assessment**:
|
||||
quality_assessment=true时输出评估对象,否则为null(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
- **memory_verify**: memory_verify=true时输出隐私检测对象,否则为null
|
||||
(注意:- summary输出的结果不允许含有(expired_at设为2024-01-01T00:00:00Z)等原数据字段以及涉及需要修改的字段以及内容)
|
||||
模式参考:{{ json_schema }}
|
||||
@@ -1,200 +1,155 @@
|
||||
你将收到一组用户历史记忆原始数据(来源于 Neo4j)
|
||||
你将收到一条冲突判定对象:{{ data }}。
|
||||
需要检测冲突对象:{{ statement_databasets }}
|
||||
以及需要识别的冲突对象为:{{ baseline }}
|
||||
记忆审核开关:{{ memory_verify }}(取值为 true / false)
|
||||
# 记忆冲突解决任务
|
||||
|
||||
角色:
|
||||
- 你是数据领域中解决数据冲突的专家
|
||||
## 输入数据
|
||||
- **冲突数据**: {{ data }}
|
||||
- **原始句子**: {{ statement_databasets }}
|
||||
- **冲突类型**: {{ baseline }} (TIME/FACT/HYBRID)
|
||||
- **隐私审核**: {{ memory_verify }} (true/false)
|
||||
- **语言类型**:{{language_type}}(zh/en)
|
||||
|
||||
任务:分析冲突产生原因,按冲突类型分组处理,为每种冲突类型生成独立的解决方案。
|
||||
## 任务目标
|
||||
作为数据冲突解决专家,分析冲突原因,按类型分组处理,为每种冲突生成独立解决方案。
|
||||
|
||||
数据的结构:
|
||||
statement_databasets里面statement_name是输入的句子,statement_id是连接data里面的statement_id,代表这个句子被拆分成几个实体,需要根据整体的内容,
|
||||
需要根据以下内容做处理(冲突检测、记忆审核、记忆的质量评估),data里面的statement_created_at是用户输入的时间
|
||||
**数据关系**: statement_databasets中的statement_id对应data中的记录,statement_created_at为用户输入时间。
|
||||
|
||||
**处理模式**:
|
||||
- 当memory_verify为false时:仅处理数据冲突
|
||||
- 当memory_verify为true时:处理数据冲突 + 隐私信息脱敏
|
||||
**处理模式**:
|
||||
- memory_verify=false: 仅处理数据冲突
|
||||
- memory_verify=true: 处理数据冲突 + 隐私脱敏
|
||||
|
||||
## 分组处理原则
|
||||
## 1. 冲突类型定义
|
||||
|
||||
**冲突类型识别与分组**:
|
||||
1. **日期冲突**:
|
||||
1.1.涉及用户生日的不同日期记录(如2月10号 vs 2月16号),
|
||||
1.2.涉及同一活动的不同时间记录(如周五打球 vs 周六打球)
|
||||
3. **事实属性冲突**:
|
||||
3.1. **属性互斥**:同一实体的相反属性(喜欢↔不喜欢、有↔没有、是↔不是)
|
||||
3.2. **关系矛盾**:同一实体在相同语境下的不同关系描述
|
||||
3.3. **身份冲突**:同一实体被赋予不同的类型或角色
|
||||
4. **其他冲突类型/混合冲突(时间+事实)**:根据具体数据识别
|
||||
### 时间冲突 (TIME)
|
||||
时间维度冲突:两个事件时间重叠,或同一事情在不同时间场景下的变化。
|
||||
|
||||
**分组输出要求**:
|
||||
- 每种冲突类型生成一个独立的reflexion_result对象
|
||||
- 同一类型的多个冲突记录归并到一个结果中
|
||||
- 不同类型的冲突分别处理,各自生成独立结果
|
||||
### 事实冲突 (FACT)
|
||||
同一事实对象的陈述内容相互矛盾,真假不能共存的情况。
|
||||
|
||||
## 冲突类型定义
|
||||
### 混合冲突 (HYBRID)
|
||||
检测所有类型冲突,包括时间和事实冲突的任何逻辑不一致记录。
|
||||
|
||||
### 时间冲突(TIME)
|
||||
时间维度冲突是指两个事件发生时间重叠,或者用户同一件事情和场景等情况下,时间出现了变化。
|
||||
## 2. 分组处理原则
|
||||
|
||||
### 事实冲突(FACT)
|
||||
事实冲突是指同一事实对象(同一个人、同一个时间、同一个状态)但陈述内容相互矛盾,主要为真假不能共存的情况。
|
||||
### 混合冲突(HYBRID)
|
||||
检测所有类型的冲突,包括但不限于时间冲突和事实冲突:检测任何逻辑上不一致或相互矛盾的记录
|
||||
{% if memory_verify %}
|
||||
## 隐私信息处理(memory_verify为true时启用)
|
||||
### 冲突类型识别
|
||||
- **日期冲突**: 用户生日不同日期(2月10号 vs 2月16号)、同一活动不同时间(周五 vs 周六打球)
|
||||
- **事实属性冲突**:
|
||||
- 属性互斥(喜欢↔不喜欢)
|
||||
- 关系矛盾(同一实体不同关系描述)
|
||||
- 身份冲突(同一实体不同类型/角色)
|
||||
- **其他/混合冲突**: 根据具体数据识别
|
||||
|
||||
### 隐私信息识别
|
||||
需要识别并处理以下类型的隐私信息:
|
||||
### 分组输出要求
|
||||
- 每种冲突类型生成独立的reflexion_result对象
|
||||
- 同类型多个冲突归并到一个结果
|
||||
- 不同类型分别处理,各自生成独立结果
|
||||
## 3. 隐私信息处理 (memory_verify=true时)
|
||||
|
||||
1. **身份证信息**:包含身份证号码、身份证相关描述
|
||||
2. **手机号码**:包含手机号、电话号码等联系方式
|
||||
3. **社交账号**:包含微信号、QQ号、邮箱地址等社交平台信息
|
||||
4. **银行信息**:包含银行卡号、账户信息、支付信息
|
||||
5. **税务信息**:包含税号、纳税信息、发票信息
|
||||
6. **贷款信息**:包含贷款记录、信贷信息、借款信息
|
||||
7. **其他敏感信息**:包含密码、PIN码、验证码等安全信息
|
||||
### 隐私信息类型
|
||||
- **身份信息**: 身份证号码、身份证相关描述
|
||||
- **联系方式**: 手机号、电话号码
|
||||
- **社交账号**: 微信号、QQ号、邮箱地址
|
||||
- **金融信息**: 银行卡号、账户信息、支付信息
|
||||
- **税务信息**: 税号、纳税信息、发票信息
|
||||
- **贷款信息**: 贷款记录、信贷信息
|
||||
- **安全信息**: 密码、PIN码、验证码
|
||||
|
||||
### 隐私数据脱敏规则
|
||||
对于检测到的隐私信息,按以下规则进行脱敏处理:
|
||||
### 脱敏规则
|
||||
**数字类**: 保留前三位和后四位,中间用*代替
|
||||
- 手机号: 13812345678 → 138****5678
|
||||
- 身份证: 110101199001011234 → 110***********1234
|
||||
- 银行卡: 6222021234567890 → 622***********7890
|
||||
|
||||
**数字类隐私信息脱敏**:
|
||||
- 保留前三位和后四位,中间用*代替
|
||||
- 示例:手机号13812345678 → 138****5678
|
||||
- 示例:身份证110101199001011234 → 110***********1234
|
||||
- 示例:银行卡6222021234567890 → 622***********7890
|
||||
**文本类**: 保留前三后四位字符,中间用*代替
|
||||
- 微信号: user123456 → use****3456
|
||||
- 邮箱: zhang.san@example.com → zha****@example.com
|
||||
|
||||
**文本类隐私信息脱敏**:
|
||||
- 社交账号:保留前三后四位字符,中间用*代替
|
||||
- 示例:微信号user123456 → use****3456
|
||||
- 示例:邮箱zhang.san@example.com → zha****@example.com
|
||||
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||
|
||||
**脱敏处理字段**:
|
||||
- name字段:如包含隐私信息需脱敏
|
||||
- entity1_name字段:如包含隐私信息需脱敏
|
||||
- entity2_name字段:如包含隐私信息需脱敏
|
||||
- description字段:如包含隐私信息需脱敏
|
||||
{% endif %}
|
||||
## 4. 处理流程
|
||||
|
||||
## 工作步骤
|
||||
### 步骤1: 类型匹配验证
|
||||
**匹配规则**:
|
||||
- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点)
|
||||
- baseline="FACT": 只处理事实相关冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- baseline="HYBRID": 处理所有类型冲突
|
||||
|
||||
### 第一步:分析冲突类型匹配
|
||||
首先判断输入的冲突数据是否符合baseline要求的类型:
|
||||
**类型识别**:
|
||||
- 时间冲突: entity2的entity_type包含"TimeExpression"/"TemporalExpression",或entity2_name包含时间词汇
|
||||
- 事实冲突: 相同实体的不同属性描述、互斥关系陈述
|
||||
|
||||
**类型匹配规则**:
|
||||
- 如果baseline是"TIME":只处理时间相关的冲突(涉及时间表达式、日期、时间点的冲突)
|
||||
- 如果baseline是"FACT":只处理事实相关的冲突(属性矛盾、关系冲突、描述不一致)
|
||||
- 如果baseline是"HYBRID":处理所有类型的冲突,也可以当作混合冲突类型处理
|
||||
**重要**: 类型不匹配时必须输出空结果(resolved为null)
|
||||
|
||||
**类型识别**:
|
||||
- 时间冲突标识:entity2的entity_type包含"TimeExpression"、"TemporalExpression",或entity2_name包含时间词汇(周一到周日、月份日期等)
|
||||
- 事实冲突标识:相同实体的不同属性描述、互斥的关系陈述
|
||||
### 步骤2: 冲突数据分组
|
||||
**分组策略**:
|
||||
- 时间冲突组: 涉及用户时间的记录
|
||||
- 活动时间冲突组: 同一活动不同时间的记录
|
||||
- 事实冲突组: 同一实体不同属性的记录
|
||||
- 其他冲突组: 其他类型冲突记录
|
||||
|
||||
**重要**:如果输入的冲突类型与baseline不匹配,必须输出空结果(resolved为null)
|
||||
**筛选条件**: 只处理与baseline匹配的冲突类型
|
||||
|
||||
### 第二步:筛选并分组冲突数据
|
||||
按冲突类型对数据进行分组:
|
||||
### 步骤3: 冲突解决策略
|
||||
**重要**: 数据被判定为正确时不可修改
|
||||
|
||||
**分组策略**:
|
||||
1. **时间冲突组**:筛选涉及用户时间的所有记录
|
||||
2. **活动时间冲突组**:筛选涉及同一活动不同时间的记录
|
||||
3. **事实冲突组**:筛选涉及同一实体不同属性的记录
|
||||
4. **其他冲突组**:其他类型的冲突记录
|
||||
**智能解决**:
|
||||
1. 分析冲突数据,结合statement_databasets原文判定正确性
|
||||
2. 判断正确答案是否存在于data中
|
||||
3. 根据情况选择处理方式{% if memory_verify %}
|
||||
4. 隐私脱敏处理在冲突解决后进行{% endif %}
|
||||
|
||||
**筛选条件**:
|
||||
- 只处理与baseline匹配的冲突类型
|
||||
- 相同entity1_name但entity2_name不同的记录
|
||||
- 相同关系但描述矛盾的记录
|
||||
- 时间逻辑不一致的记录
|
||||
### 处理规则
|
||||
|
||||
### 第三步:冲突解决策略
|
||||
** 不可以解决的冲突情况
|
||||
1. 数据被判定为正确的情况下,不可以进行修改
|
||||
**仅当冲突类型与baseline匹配时**,对筛选出的冲突数据进行处理:
|
||||
** baseline是TIME
|
||||
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
||||
** baseline不是TIME
|
||||
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||
|
||||
**智能解决策略**:
|
||||
1. **分析冲突数据**:识别哪些记录是正确的,哪些是错误的,需要结合statement_databasets的输入原文来判定
|
||||
2. **判断正确答案是否存在**:
|
||||
- 如果正确答案已存在于data中:只需将错误记录的expired_at设为当前日期(2025-12-16T12:00:00)
|
||||
- 如果正确答案已存在于data中:错误记录的expired_at已经设为日期,则不需要对正确的数据进行修改
|
||||
- 如果正确答案不存在于data中:需要修改现有记录的内容以包含正确信息
|
||||
**核心原则**:
|
||||
- 只输出需要修改的记录
|
||||
- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录
|
||||
- 精确记录变更: change字段包含记录ID、字段名称、新值和旧值{% if memory_verify %}
|
||||
- 隐私保护优先: 所有输出记录必须完成隐私脱敏
|
||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true)等原数据字段以及涉及需要修改的字段以及内容
|
||||
|
||||
{% if memory_verify %}
|
||||
**隐私处理集成**:
|
||||
- 在处理冲突的同时,需要对涉及的记录进行隐私脱敏
|
||||
- 脱敏处理应该在冲突解决之后进行,确保最终输出的记录都已脱敏
|
||||
- 在change字段中记录隐私脱敏的变更
|
||||
{% endif %}
|
||||
|
||||
**具体处理规则**:
|
||||
|
||||
**情况1:正确答案存在于data中**
|
||||
- 保留正确的记录不变
|
||||
- 基于时间关系的冲突:
|
||||
需要只修改错误记录的expired_at为当前时间(2025-12-16T12:00:00)
|
||||
- 基于事实的关系冲突
|
||||
- resolved.resolved_memory只包含被设为失效的错误记录
|
||||
- change字段只记录expired_at的变更:`[{"expired_at": "2025-12-16T12:00:00"}]`(注意:如果已存在时间,则不需要对其修改,也不需要变更 时间)
|
||||
|
||||
**情况2:正确答案不存在于data中**
|
||||
- 选择最合适的记录进行修改
|
||||
- 更新该记录的相关字段:
|
||||
- description字段:添加或修改描述信息{% if memory_verify %}(如包含隐私信息,需脱敏处理){% endif %}
|
||||
- name字段:修改名称字段{% if memory_verify %}(如需要,包含隐私信息时需脱敏){% endif %}
|
||||
- resolved.resolved_memory包含修改后的完整记录{% if memory_verify %}(已脱敏){% endif %}
|
||||
- change字段记录所有被修改的字段{% if memory_verify %},包括脱敏变更{% endif %},例如:`[{"description": "新描述"{% if memory_verify %}, "entity2_name": "138****5678"{% endif %}}]`
|
||||
|
||||
**重要原则**:
|
||||
- **只输出需要修改的记录**:resolved.resolved_memory只包含实际需要修改的数据
|
||||
- **优先保留策略**:时间冲突保留最可信的created_at时间的记录,事实冲突选择最新且可信度最高的记录
|
||||
- **精确记录变更**:change字段必须包含记录ID、字段名称、新值和旧值
|
||||
{% if memory_verify %}- **隐私保护优先**:所有输出的记录必须完成隐私脱敏处理
|
||||
- **脱敏变更记录**:隐私脱敏的变更也必须在change字段中详细记录{% endif %}
|
||||
- **不可修改数据**:数据被判定为正确时,不可以进行修改,如果没有数据可输出空
|
||||
|
||||
**变更记录格式**:
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
{"id":修改字段对应的ID}
|
||||
{"statement_id":需要修改的对象对应的statement_id}
|
||||
{"字段名1": ["修改前的值1","修改后的值1"]},
|
||||
{"字段名2": ["修改前的值2","修改后的值2"]}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 如果冲突类型与baseline不匹配,resolved必须设为null
|
||||
- reflexion.reason说明类型不匹配的原因
|
||||
**类型不匹配处理**:
|
||||
- 冲突类型与baseline不匹配时,resolved设为null
|
||||
- reflexion.reason说明类型不匹配原因
|
||||
- reflexion.solution说明无需处理
|
||||
|
||||
### 第四步:输出解决方案
|
||||
## 5. JSON输出格式
|
||||
|
||||
## 输出要求
|
||||
**嵌套字段映射**(系统会自动处理):
|
||||
**格式要求**:
|
||||
- 输出有效JSON对象,通过json.loads()解析
|
||||
- 使用标准ASCII双引号(")
|
||||
- 内部引号用反斜杠转义(\")
|
||||
- 字符串值不包含换行符
|
||||
- 不输出```json```等代码块标记
|
||||
|
||||
**嵌套字段映射**(系统自动处理):
|
||||
- `entity2.name` → 自动映射为 `name`
|
||||
- `entity1.name` → 自动映射为 `name`
|
||||
- `relationship` → 自动映射为 `statement`
|
||||
- `entity1.description` → 自动映射为 `description`
|
||||
- `entity2.description` → 自动映射为 `description`
|
||||
|
||||
返回数据格式以json方式输出:
|
||||
- 必须通过json.loads()的格式支持的形式输出
|
||||
- 响应必须是与此确切模式匹配的有效JSON对象
|
||||
- 不要在JSON之前或之后包含任何文本
|
||||
|
||||
JSON格式要求:
|
||||
1. JSON结构仅使用标准ASCII双引号(")
|
||||
2. 如果提取的语句文本包含引号,请使用反斜杠(\")正确转义
|
||||
3. 确保所有JSON字符串都正确关闭并以逗号分隔
|
||||
4. JSON字符串值中不包括换行符
|
||||
5. 不允许输出```json```相关符号
|
||||
|
||||
仅输出一个合法 JSON 对象,严格遵循下述结构:
|
||||
|
||||
**输出格式:按冲突类型分组的列表**
|
||||
**输出结构**: 按冲突类型分组的列表
|
||||
```json
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
@@ -208,93 +163,24 @@ JSON格式要求:
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "被设为失效的记忆id",
|
||||
"resolved_memory": {
|
||||
"entity1_name": "实体1名称",
|
||||
"entity2_name": "实体2名称",
|
||||
"description": "描述信息",
|
||||
"statement_id": "陈述ID",
|
||||
"created_at": "创建时间",
|
||||
"expired_at": "过期时间",
|
||||
"relationship_type": "关系类型",
|
||||
"relationship": {},
|
||||
"entity2": {...}
|
||||
},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"字段名1": "修改后的值1"},
|
||||
{"字段名2": "修改后的值2"}
|
||||
]
|
||||
}
|
||||
]
|
||||
"resolved_memory": {记录对象},
|
||||
"change": [变更记录数组]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**示例:多种冲突类型的输出**
|
||||
{
|
||||
"results": [
|
||||
{
|
||||
"conflict": {
|
||||
"data": [生日冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到生日冲突:用户同时关联2月10号和2月16号两个不同日期",
|
||||
"solution": "保留最新记录(2月16号),将旧记录(2月10号)设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "df066210883545a08e727ccd8ad4ec77",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"expired_at": "2025-12-16T12:00:00"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
},
|
||||
{
|
||||
"conflict": {
|
||||
"data": [篮球时间冲突相关的记录],
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "检测到活动时间冲突:用户打篮球时间存在周五和周六的冲突",
|
||||
"solution": "保留最可信的时间记录,将冲突记录设为失效"
|
||||
},
|
||||
"resolved": {
|
||||
"original_memory_id": "另一个记录ID",
|
||||
"resolved_memory": {...},
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"description": "使用系统的个人,指代说话者本人,篮球时间为周六"},
|
||||
{"entity2_name": "周六"}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"type": "reflexion_result"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
必须遵守:
|
||||
- 只输出 JSON,不要添加解释或多余文本
|
||||
- 使用标准双引号,必要时对内部引号进行转义
|
||||
- 字段名与结构必须与给定模式一致
|
||||
- **输出必须是results数组格式**,每个冲突类型作为一个独立的对象
|
||||
- **按冲突类型分组**:相同类型的冲突记录归并到一个result对象中
|
||||
- **每个result对象的conflict.data**只包含该冲突类型相关的记录
|
||||
- **resolved.resolved_memory 只包含需要修改的记录**,不需要修改的记录不要输出
|
||||
- **resolved.change 必须包含详细的变更信息**:field数组包含所有被修改的字段及其新值
|
||||
- 如果某个冲突类型经分析无需修改任何数据,该类型的resolved 必须为 null
|
||||
- 如果与baseline不匹配的冲突类型,不要在results中包含该类型
|
||||
|
||||
模式参考:
|
||||
{{ json_schema }}
|
||||
**输出要求**:
|
||||
- 只输出JSON,不添加解释文本
|
||||
- 使用标准双引号,必要时转义
|
||||
- 字段名与结构必须与模式一致
|
||||
- **results数组格式**: 每个冲突类型作为独立对象
|
||||
- **按冲突类型分组**: 相同类型冲突归并到一个result对象
|
||||
- **conflict.data**: 只包含该冲突类型相关记录
|
||||
- **resolved.resolved_memory**: 只包含需要修改的记录
|
||||
- **resolved.change**: 包含详细变更信息
|
||||
- 无需修改的冲突类型resolved为null
|
||||
- 与baseline不匹配的冲突类型不包含在results中
|
||||
模式参考: {{ json_schema }}
|
||||
@@ -9,7 +9,8 @@ prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any],
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,statement_databasets: List[str] = []) -> str:
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -30,12 +31,13 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Dict[str, Any
|
||||
baseline=baseline,
|
||||
memory_verify=memory_verify,
|
||||
quality_assessment=quality_assessment,
|
||||
statement_databasets=statement_databasets
|
||||
statement_databasets=statement_databasets,
|
||||
language_type=language_type
|
||||
)
|
||||
return rendered_prompt
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any], baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = []) -> str:
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -51,6 +53,6 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Dict[str, Any],
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets)
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -69,7 +69,17 @@ class WorkflowExecutor:
|
||||
初始化的工作流状态
|
||||
"""
|
||||
user_message = input_data.get("message") or ""
|
||||
conversation_vars = input_data.get("conversation_vars") or {}
|
||||
|
||||
# 会话变量处理:从配置文件获取变量定义列表,转换为字典(name -> default value)
|
||||
config_variables_list = self.workflow_config.get("variables") or []
|
||||
conversation_vars = {}
|
||||
for var_def in config_variables_list:
|
||||
if isinstance(var_def, dict):
|
||||
var_name = var_def.get("name")
|
||||
var_default = var_def.get("default")
|
||||
if var_name:
|
||||
conversation_vars[var_name] = var_default
|
||||
|
||||
input_variables = input_data.get("variables") or {} # Start 节点的自定义变量
|
||||
|
||||
# 构建分层的变量结构
|
||||
|
||||
@@ -5,9 +5,11 @@
|
||||
"""
|
||||
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
from app.core.workflow.nodes.if_else import IfElseNode
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.node_factory import NodeFactory, WorkflowNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
@@ -23,5 +25,7 @@ __all__ = [
|
||||
"StartNode",
|
||||
"EndNode",
|
||||
"NodeFactory",
|
||||
"WorkflowNode"
|
||||
"WorkflowNode",
|
||||
# "KnowledgeRetrievalNode",
|
||||
"AssignerNode",
|
||||
]
|
||||
|
||||
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
4
api/app/core/workflow/nodes/assigner/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.assigner.node import AssignerNode
|
||||
|
||||
__all__ = ["AssignerNode", "AssignerNodeConfig"]
|
||||
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
21
api/app/core/workflow/nodes/assigner/config.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
|
||||
|
||||
class AssignerNodeConfig(BaseNodeConfig):
|
||||
variable_selector: str | list[str] = Field(
|
||||
...,
|
||||
description="Variables to be assigned",
|
||||
)
|
||||
|
||||
operation: AssignmentOperator = Field(
|
||||
...,
|
||||
description="Operator to assign",
|
||||
)
|
||||
|
||||
value: str | list[str] = Field(
|
||||
...,
|
||||
description="Values to assign",
|
||||
)
|
||||
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
80
api/app/core/workflow/nodes/assigner/node.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.expression_evaluator import ExpressionEvaluator
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import AssignmentOperator
|
||||
from app.core.workflow.nodes.operators import AssignmentOperatorInstance
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AssignerNode(BaseNode):
|
||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||
super().__init__(node_config, workflow_config)
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
|
||||
async def execute(self, state: WorkflowState) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
|
||||
Args:
|
||||
state: The current workflow state, including conversation variables,
|
||||
node outputs, and system variables.
|
||||
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
pool = VariablePool(state)
|
||||
|
||||
# Get the target variable selector (e.g., "conv.test")
|
||||
variable_selector = self.typed_config.variable_selector
|
||||
if isinstance(variable_selector, str):
|
||||
# Support dot-separated string paths, e.g., "conv.test" -> ["conv", "test"]
|
||||
variable_selector = variable_selector.split('.')
|
||||
|
||||
# Only conversation variables ('conv') are allowed
|
||||
if variable_selector[0] != 'conv': # TODO: Loop node variable support (Feature)
|
||||
raise ValueError("Only conversation variables can be assigned.")
|
||||
|
||||
# Get the value or expression to assign
|
||||
value = self.typed_config.value
|
||||
if isinstance(value, list):
|
||||
value = '.'.join(value)
|
||||
value = ExpressionEvaluator.evaluate(
|
||||
expression=value,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars(),
|
||||
)
|
||||
|
||||
# Select the appropriate assignment operator instance based on the target variable type
|
||||
operator: AssignmentOperatorInstance = AssignmentOperator.get_operator(pool.get(variable_selector))(
|
||||
pool, variable_selector, value
|
||||
)
|
||||
|
||||
# Execute the configured assignment operation
|
||||
match self.typed_config.operation:
|
||||
case AssignmentOperator.ASSIGN:
|
||||
operator.assign()
|
||||
case AssignmentOperator.CLEAR:
|
||||
operator.clear()
|
||||
case AssignmentOperator.ADD:
|
||||
operator.add()
|
||||
case AssignmentOperator.SUBTRACT:
|
||||
operator.subtract()
|
||||
case AssignmentOperator.MULTIPLY:
|
||||
operator.multiply()
|
||||
case AssignmentOperator.DIVIDE:
|
||||
operator.divide()
|
||||
case AssignmentOperator.APPEND:
|
||||
operator.append()
|
||||
case AssignmentOperator.REMOVE_FIRST:
|
||||
operator.remove_first()
|
||||
case AssignmentOperator.REMOVE_LAST:
|
||||
operator.remove_last()
|
||||
case _:
|
||||
raise ValueError(f"Invalid Operator: {self.typed_config.operation}")
|
||||
@@ -26,7 +26,12 @@ class WorkflowState(TypedDict):
|
||||
messages: Annotated[list[AnyMessage], add]
|
||||
|
||||
# 输入变量(从配置的 variables 传入)
|
||||
variables: dict[str, Any]
|
||||
# 使用深度合并函数,支持嵌套字典的更新(如 conv.xxx)
|
||||
variables: Annotated[dict[str, Any], lambda x, y: {
|
||||
**x,
|
||||
**{k: {**x.get(k, {}), **v} if isinstance(v, dict) and isinstance(x.get(k), dict) else v
|
||||
for k, v in y.items()}
|
||||
}]
|
||||
|
||||
# 节点输出(存储每个节点的执行结果,用于变量引用)
|
||||
# 使用自定义合并函数,将新的节点输出合并到现有字典中
|
||||
@@ -544,9 +549,15 @@ class BaseNode(ABC):
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return render_template(
|
||||
template=template,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
@@ -575,9 +586,15 @@ class BaseNode(ABC):
|
||||
# 使用变量池获取变量
|
||||
pool = VariablePool(state)
|
||||
|
||||
# 构建完整的 variables 结构(包含 sys 和 conv)
|
||||
variables = {
|
||||
"sys": pool.get_all_system_vars(),
|
||||
"conv": pool.get_all_conversation_vars()
|
||||
}
|
||||
|
||||
return evaluate_condition(
|
||||
expression=expression,
|
||||
variables=pool.get_all_conversation_vars(),
|
||||
variables=variables,
|
||||
node_outputs=pool.get_all_node_outputs(),
|
||||
system_vars=pool.get_all_system_vars()
|
||||
)
|
||||
|
||||
@@ -14,6 +14,8 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig, MessageConfig
|
||||
from app.core.workflow.nodes.agent.config import AgentNodeConfig
|
||||
from app.core.workflow.nodes.transform.config import TransformNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import IfElseNodeConfig
|
||||
# from app.core.workflow.nodes.knowledge.config import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.assigner.config import AssignerNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -28,4 +30,6 @@ __all__ = [
|
||||
"AgentNodeConfig",
|
||||
"TransformNodeConfig",
|
||||
"IfElseNodeConfig",
|
||||
# "KnowledgeRetrievalNodeConfig",
|
||||
"AssignerNodeConfig",
|
||||
]
|
||||
|
||||
@@ -33,7 +33,7 @@ class EndNode(BaseNode):
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
# 如果配置了输出模板,使用模板渲染;否则使用默认输出
|
||||
if output_template:
|
||||
output = self._render_template(output_template, state)
|
||||
@@ -45,17 +45,17 @@ class EndNode(BaseNode):
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成,共执行 {total_nodes} 个节点")
|
||||
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def _extract_referenced_nodes(self, template: str) -> list[str]:
|
||||
"""从模板中提取引用的节点 ID
|
||||
|
||||
|
||||
例如:'结果:{{llm_qa.output}}' -> ['llm_qa']
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
|
||||
|
||||
Returns:
|
||||
引用的节点 ID 列表
|
||||
"""
|
||||
@@ -63,44 +63,44 @@ class EndNode(BaseNode):
|
||||
pattern = r'\{\{([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\}\}'
|
||||
matches = re.findall(pattern, template)
|
||||
return list(set(matches)) # 去重
|
||||
|
||||
|
||||
def _parse_template_parts(self, template: str, state: WorkflowState) -> list[dict]:
|
||||
"""解析模板,分离静态文本和动态引用
|
||||
|
||||
|
||||
例如:'你好 {{llm.output}}, 这是后缀'
|
||||
返回:[
|
||||
{"type": "static", "content": "你好 "},
|
||||
{"type": "dynamic", "node_id": "llm", "field": "output"},
|
||||
{"type": "static", "content": ", 这是后缀"}
|
||||
]
|
||||
|
||||
|
||||
Args:
|
||||
template: 模板字符串
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Returns:
|
||||
模板部分列表
|
||||
"""
|
||||
import re
|
||||
|
||||
|
||||
parts = []
|
||||
last_end = 0
|
||||
|
||||
|
||||
# 匹配 {{xxx}} 或 {{ xxx }} 格式(支持空格)
|
||||
pattern = r'\{\{\s*([^}]+?)\s*\}\}'
|
||||
|
||||
|
||||
for match in re.finditer(pattern, template):
|
||||
start, end = match.span()
|
||||
|
||||
|
||||
# 添加前面的静态文本
|
||||
if start > last_end:
|
||||
static_text = template[last_end:start]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
# 解析动态引用
|
||||
ref = match.group(1).strip()
|
||||
|
||||
|
||||
# 检查是否是节点引用(如 llm.output 或 llm_qa.output)
|
||||
if '.' in ref:
|
||||
node_id, field = ref.split('.', 1)
|
||||
@@ -115,60 +115,62 @@ class EndNode(BaseNode):
|
||||
# 直接渲染这部分
|
||||
rendered = self._render_template(f"{{{{{ref}}}}}", state)
|
||||
parts.append({"type": "static", "content": rendered})
|
||||
|
||||
|
||||
last_end = end
|
||||
|
||||
|
||||
# 添加最后的静态文本
|
||||
if last_end < len(template):
|
||||
static_text = template[last_end:]
|
||||
if static_text:
|
||||
parts.append({"type": "static", "content": static_text})
|
||||
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
async def execute_stream(self, state: WorkflowState):
|
||||
"""流式执行 end 节点业务逻辑
|
||||
|
||||
|
||||
智能输出策略:
|
||||
1. 检测模板中是否引用了直接上游节点
|
||||
2. 如果引用了,只输出该引用**之后**的部分(后缀)
|
||||
3. 前缀和引用内容已经在上游节点流式输出时发送了
|
||||
|
||||
|
||||
示例:'{{start.test}}hahaha {{ llm_qa.output }} lalalalala a'
|
||||
- 直接上游节点是 llm_qa
|
||||
- 前缀 '{{start.test}}hahaha ' 已在 LLM 节点流式输出前发送
|
||||
- LLM 内容在 LLM 节点流式输出
|
||||
- End 节点只输出 ' lalalalala a'(后缀,一次性输出)
|
||||
|
||||
|
||||
Args:
|
||||
state: 工作流状态
|
||||
|
||||
|
||||
Yields:
|
||||
完成标记
|
||||
"""
|
||||
logger.info(f"节点 {self.node_id} (End) 开始执行(流式)")
|
||||
|
||||
|
||||
# 获取配置的输出模板
|
||||
output_template = self.config.get("output")
|
||||
|
||||
|
||||
if not output_template:
|
||||
output = "工作流已完成"
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 找到直接上游节点
|
||||
direct_upstream_nodes = []
|
||||
for edge in self.workflow_config.get("edges", []):
|
||||
if edge.get("target") == self.node_id:
|
||||
source_node_id = edge.get("source")
|
||||
direct_upstream_nodes.append(source_node_id)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} 的直接上游节点: {direct_upstream_nodes}")
|
||||
|
||||
|
||||
# 解析模板部分
|
||||
parts = self._parse_template_parts(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 解析模板,共 {len(parts)} 个部分")
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
logger.info(f"[模板解析] part[{i}]: {part}")
|
||||
|
||||
# 找到第一个引用直接上游节点的动态引用
|
||||
upstream_ref_index = None
|
||||
for i, part in enumerate(parts):
|
||||
@@ -176,12 +178,12 @@ class EndNode(BaseNode):
|
||||
upstream_ref_index = i
|
||||
logger.info(f"节点 {self.node_id} 找到直接上游节点 {part['node_id']} 的引用,索引: {i}")
|
||||
break
|
||||
|
||||
|
||||
if upstream_ref_index is None:
|
||||
# 没有引用直接上游节点,输出完整模板内容
|
||||
output = self._render_template(output_template, state)
|
||||
logger.info(f"节点 {self.node_id} 没有引用直接上游节点,输出完整内容: '{output[:50]}...'")
|
||||
|
||||
|
||||
# 通过 writer 发送完整内容(作为一个 message chunk)
|
||||
from langgraph.config import get_stream_writer
|
||||
writer = get_stream_writer()
|
||||
@@ -194,57 +196,56 @@ class EndNode(BaseNode):
|
||||
"is_suffix": False
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送完整内容")
|
||||
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": output}
|
||||
return
|
||||
|
||||
|
||||
# 有引用直接上游节点,只输出该引用之后的部分(后缀)
|
||||
logger.info(f"节点 {self.node_id} 检测到直接上游节点引用,只输出后缀部分(从索引 {upstream_ref_index + 1} 开始)")
|
||||
|
||||
|
||||
# 收集后缀部分
|
||||
suffix_parts = []
|
||||
logger.info(f"[后缀调试] 开始收集后缀,从索引 {upstream_ref_index + 1} 到 {len(parts) - 1}")
|
||||
for i in range(upstream_ref_index + 1, len(parts)):
|
||||
part = parts[i]
|
||||
|
||||
logger.info(f"[后缀调试] 处理 part[{i}]: {part}")
|
||||
if part["type"] == "static":
|
||||
# 静态文本
|
||||
logger.info(f"[后缀调试] 添加静态文本: '{part['content']}'")
|
||||
suffix_parts.append(part["content"])
|
||||
|
||||
|
||||
elif part["type"] == "dynamic":
|
||||
# 其他动态引用(如果有多个引用)
|
||||
# Other dynamic references (if there are multiple references)
|
||||
node_id = part["node_id"]
|
||||
field = part["field"]
|
||||
|
||||
# 从 streaming_buffer 或 node_outputs 读取
|
||||
streaming_buffer = state.get("streaming_buffer", {})
|
||||
if node_id in streaming_buffer:
|
||||
buffer_data = streaming_buffer[node_id]
|
||||
content = buffer_data.get("full_content", "")
|
||||
else:
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
runtime_vars = state.get("runtime_vars", {})
|
||||
|
||||
# Use VariablePool to get variable value
|
||||
pool = self.get_variable_pool(state)
|
||||
try:
|
||||
# Try to get variable value with default empty string
|
||||
content = pool.get([node_id, field], default="")
|
||||
logger.info(f"[后缀调试] 获取变量 {node_id}.{field} 成功: '{content}'")
|
||||
except Exception as e:
|
||||
logger.warning(f"[后缀调试] 获取变量 {node_id}.{field} 失败: {e}")
|
||||
content = ""
|
||||
if node_id in node_outputs:
|
||||
node_output = node_outputs[node_id]
|
||||
if isinstance(node_output, dict):
|
||||
content = str(node_output.get(field, ""))
|
||||
elif node_id in runtime_vars:
|
||||
runtime_var = runtime_vars[node_id]
|
||||
if isinstance(runtime_var, dict):
|
||||
content = str(runtime_var.get(field, ""))
|
||||
|
||||
suffix_parts.append(content)
|
||||
|
||||
# Convert to string if not None
|
||||
suffix_parts.append(str(content) if content is not None else "")
|
||||
|
||||
# 拼接后缀
|
||||
suffix = "".join(suffix_parts)
|
||||
|
||||
|
||||
# 构建完整输出(用于返回,包含前缀 + 动态内容 + 后缀)
|
||||
full_output = self._render_template(output_template, state)
|
||||
|
||||
|
||||
logger.info(f"[后缀调试] 节点 {self.node_id} 后缀部分数量: {len(suffix_parts)}")
|
||||
logger.info(f"[后缀调试] 后缀内容: '{suffix}'")
|
||||
logger.info(f"[后缀调试] 后缀长度: {len(suffix)}")
|
||||
logger.info(f"[后缀调试] 后缀是否为空: {not suffix}")
|
||||
|
||||
if suffix:
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix[:50]}...' (长度: {len(suffix)})")
|
||||
logger.info(f"节点 {self.node_id} 输出后缀: '{suffix}...' (长度: {len(suffix)})")
|
||||
# 一次性输出后缀(作为单个 chunk)
|
||||
# 注意:不要直接 yield 字符串,因为 base_node 会逐字符处理
|
||||
# 而是通过 writer 直接发送
|
||||
@@ -260,13 +261,13 @@ class EndNode(BaseNode):
|
||||
})
|
||||
logger.info(f"节点 {self.node_id} 已通过 writer 发送后缀,full_content 长度: {len(full_output)}")
|
||||
else:
|
||||
logger.info(f"节点 {self.node_id} 没有后缀需要输出")
|
||||
logger.warning(f"[后缀调试] 节点 {self.node_id} 后缀为空,不发送!upstream_ref_index={upstream_ref_index}, parts数量={len(parts)}")
|
||||
|
||||
# 统计信息
|
||||
node_outputs = state.get("node_outputs", {})
|
||||
total_nodes = len(node_outputs)
|
||||
|
||||
|
||||
logger.info(f"节点 {self.node_id} (End) 执行完成(流式),共执行了 {total_nodes} 个节点")
|
||||
|
||||
|
||||
# yield 完成标记(包含完整输出)
|
||||
yield {"__final__": True, "result": full_output}
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from app.core.workflow.nodes.operators import (
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
AssignmentOperatorType,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
)
|
||||
|
||||
|
||||
class NodeType(StrEnum):
|
||||
START = "start"
|
||||
@@ -14,6 +23,7 @@ class NodeType(StrEnum):
|
||||
HTTP_REQUEST = "http-request"
|
||||
TOOL = "tool"
|
||||
AGENT = "agent"
|
||||
ASSIGNER = "assigner"
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
@@ -34,3 +44,32 @@ class ComparisonOperator(StrEnum):
|
||||
class LogicOperator(StrEnum):
|
||||
AND = "and"
|
||||
OR = "or"
|
||||
|
||||
|
||||
class AssignmentOperator(StrEnum):
|
||||
ASSIGN = "assign"
|
||||
CLEAR = "clear"
|
||||
|
||||
ADD = "add" # +=
|
||||
SUBTRACT = "subtract" # -=
|
||||
MULTIPLY = "multiply" # *=
|
||||
DIVIDE = "divide" # /=
|
||||
|
||||
APPEND = "append"
|
||||
REMOVE_LAST = "remove_last"
|
||||
REMOVE_FIRST = "remove_first"
|
||||
|
||||
@classmethod
|
||||
def get_operator(cls, obj) -> AssignmentOperatorType:
|
||||
if isinstance(obj, str):
|
||||
return StringOperator
|
||||
elif isinstance(obj, bool):
|
||||
return BooleanOperator
|
||||
elif isinstance(obj, (int, float)):
|
||||
return NumberOperator
|
||||
elif isinstance(obj, list):
|
||||
return ArrayOperator
|
||||
elif isinstance(obj, dict):
|
||||
return ObjectOperator
|
||||
|
||||
raise TypeError(f"Unsupported variable type ({type(obj)})")
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.workflow.nodes import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail
|
||||
|
||||
@@ -11,6 +11,7 @@ from langchain_core.messages import AIMessage, SystemMessage, HumanMessage
|
||||
from app.core.workflow.nodes.base_node import BaseNode, WorkflowState
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -136,7 +137,7 @@ class LLMNode(BaseNode):
|
||||
base_url=api_base,
|
||||
extra_params=extra_params
|
||||
),
|
||||
type=model_type
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import logging
|
||||
from typing import Any, Union
|
||||
|
||||
# from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNode
|
||||
from app.core.workflow.nodes.agent import AgentNode
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.end import EndNode
|
||||
@@ -15,6 +16,7 @@ from app.core.workflow.nodes.if_else import IfElseNode
|
||||
from app.core.workflow.nodes.llm import LLMNode
|
||||
from app.core.workflow.nodes.start import StartNode
|
||||
from app.core.workflow.nodes.transform import TransformNode
|
||||
from app.core.workflow.nodes.assigner import AssignerNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,6 +28,8 @@ WorkflowNode = Union[
|
||||
IfElseNode,
|
||||
AgentNode,
|
||||
TransformNode,
|
||||
AssignerNode,
|
||||
# KnowledgeRetrievalNode,
|
||||
]
|
||||
|
||||
|
||||
@@ -42,7 +46,9 @@ class NodeFactory:
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.AGENT: AgentNode,
|
||||
NodeType.TRANSFORM: TransformNode,
|
||||
NodeType.IF_ELSE: IfElseNode
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
# NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.ASSIGNER: AssignerNode,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@@ -82,10 +88,6 @@ class NodeFactory:
|
||||
"""
|
||||
node_type = node_config.get("type")
|
||||
|
||||
# 跳过条件节点(由 LangGraph 处理)
|
||||
if node_type == "condition":
|
||||
return None
|
||||
|
||||
# 获取节点类
|
||||
node_class = cls._node_types.get(node_type)
|
||||
if not node_class:
|
||||
|
||||
146
api/app/core/workflow/nodes/operators.py
Normal file
146
api/app/core/workflow/nodes/operators.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from abc import ABC
|
||||
from typing import Union, Type
|
||||
|
||||
from app.core.workflow.variable_pool import VariablePool
|
||||
|
||||
|
||||
class OperatorBase(ABC):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
self.pool = pool
|
||||
self.left_selector = left_selector
|
||||
self.right = right
|
||||
|
||||
self.type_limit: type[str, int, dict, list] = None
|
||||
|
||||
def check(self, no_right=False):
|
||||
left = self.pool.get(self.left_selector)
|
||||
if not isinstance(left, self.type_limit):
|
||||
raise TypeError(f"The variable to be operated on must be of {self.type_limit} type")
|
||||
|
||||
if not no_right and not isinstance(self.right, self.type_limit):
|
||||
raise TypeError(f"The value assigned to the string variable must also be of {self.type_limit} type")
|
||||
|
||||
|
||||
class StringOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = str
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, '')
|
||||
|
||||
|
||||
class NumberOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = (float, int)
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, 0)
|
||||
|
||||
def add(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin + self.right)
|
||||
|
||||
def subtract(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin - self.right)
|
||||
|
||||
def multiply(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin * self.right)
|
||||
|
||||
def divide(self) -> None:
|
||||
self.check()
|
||||
origin = self.pool.get(self.left_selector)
|
||||
self.pool.set(self.left_selector, origin / self.right)
|
||||
|
||||
|
||||
class BooleanOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = bool
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, False)
|
||||
|
||||
|
||||
class ArrayOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = list
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, list())
|
||||
|
||||
def append(self) -> None:
|
||||
self.check(no_right=True)
|
||||
# TODO:require type limit in list
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.append(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def extend(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.extend(self.right)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_last(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.pop()
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
def remove_first(self) -> None:
|
||||
self.check(no_right=True)
|
||||
origin = self.pool.get(self.left_selector)
|
||||
origin.pop(0)
|
||||
self.pool.set(self.left_selector, origin)
|
||||
|
||||
|
||||
class ObjectOperator(OperatorBase):
|
||||
def __init__(self, pool: VariablePool, left_selector, right):
|
||||
super().__init__(pool, left_selector, right)
|
||||
self.type_limit = object
|
||||
|
||||
def assign(self) -> None:
|
||||
self.check()
|
||||
self.pool.set(self.left_selector, self.right)
|
||||
|
||||
def clear(self) -> None:
|
||||
self.check(no_right=True)
|
||||
self.pool.set(self.left_selector, dict())
|
||||
|
||||
|
||||
AssignmentOperatorInstance = Union[
|
||||
StringOperator,
|
||||
NumberOperator,
|
||||
BooleanOperator,
|
||||
ArrayOperator,
|
||||
ObjectOperator
|
||||
]
|
||||
AssignmentOperatorType = Type[AssignmentOperatorInstance]
|
||||
@@ -66,19 +66,26 @@ class TemplateRenderer:
|
||||
'分析结果: 正面情绪'
|
||||
"""
|
||||
# 构建命名空间上下文
|
||||
# variables 的结构:{"sys": {...}, "conv": {...}}
|
||||
sys_vars = variables.get("sys", {}) if isinstance(variables, dict) else {}
|
||||
conv_vars = variables.get("conv", {}) if isinstance(variables, dict) else {}
|
||||
|
||||
context = {
|
||||
"var": variables, # 用户变量:{{var.user_input}}
|
||||
"conv": conv_vars, # 会话变量:{{conv.user_name}}
|
||||
"node": node_outputs, # 节点输出:{{node.node_1.output}}
|
||||
"sys": system_vars or {}, # 系统变量:{{sys.execution_id}}
|
||||
"sys": {**(system_vars or {}), **sys_vars}, # 系统变量:{{sys.execution_id}}(合并两个来源)
|
||||
}
|
||||
|
||||
# 支持直接通过节点ID访问节点输出:{{llm_qa.output}}
|
||||
# 将所有节点输出添加到顶层上下文
|
||||
context.update(node_outputs)
|
||||
if node_outputs:
|
||||
context.update(node_outputs)
|
||||
|
||||
# 为了向后兼容,也支持直接访问用户变量
|
||||
context.update(variables)
|
||||
context["nodes"] = node_outputs # 旧语法兼容
|
||||
# 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}}
|
||||
if conv_vars:
|
||||
context.update(conv_vars)
|
||||
|
||||
context["nodes"] = node_outputs or {} # 旧语法兼容
|
||||
|
||||
try:
|
||||
tmpl = self.env.from_string(template)
|
||||
|
||||
@@ -10,7 +10,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.core.workflow.nodes import WorkflowState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -82,7 +85,7 @@ class VariablePool:
|
||||
>>> pool.set(["conv", "user_name"], "张三")
|
||||
"""
|
||||
|
||||
def __init__(self, state: dict[str, Any]):
|
||||
def __init__(self, state: "WorkflowState"):
|
||||
"""初始化变量池
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import datetime
|
||||
import uuid
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey
|
||||
from sqlalchemy import Column, String, DateTime, ForeignKey, Text, BigInteger
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from app.db import Base
|
||||
@@ -17,6 +17,21 @@ class EndUser(Base):
|
||||
reflection_time = Column(DateTime, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||
|
||||
# 用户基本信息字段
|
||||
name = Column(String, nullable=True, comment="姓名")
|
||||
position = Column(String, nullable=True, comment="职位")
|
||||
department = Column(String, nullable=True, comment="部门")
|
||||
contact = Column(String, nullable=True, comment="联系方式")
|
||||
phone = Column(String, nullable=True, comment="电话")
|
||||
hire_date = Column(BigInteger, nullable=True, comment="入职日期(时间戳,毫秒)")
|
||||
updatetime_profile = Column(BigInteger, nullable=True, comment="核心档案信息最后更新时间(时间戳,毫秒)")
|
||||
|
||||
# 缓存字段 - Cache fields for pre-computed analytics
|
||||
memory_insight = Column(Text, nullable=True, comment="缓存的记忆洞察报告")
|
||||
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要")
|
||||
memory_insight_updated_at = Column(DateTime, nullable=True, comment="洞察报告最后更新时间")
|
||||
user_summary_updated_at = Column(DateTime, nullable=True, comment="用户摘要最后更新时间")
|
||||
|
||||
# 与 App 的反向关系
|
||||
app = relationship(
|
||||
|
||||
@@ -15,25 +15,6 @@ class ModelType(StrEnum):
|
||||
EMBEDDING = "embedding"
|
||||
RERANK = "rerank"
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, value: str) -> "ModelType":
|
||||
"""
|
||||
Get a ModelType enum instance from a string value.
|
||||
|
||||
Args:
|
||||
value (str): The string representation of the model type.
|
||||
|
||||
Returns:
|
||||
ModelType: The corresponding ModelType enum object.
|
||||
|
||||
Raises:
|
||||
ValueError: If the given value does not match any ModelType.
|
||||
"""
|
||||
try:
|
||||
return cls(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid ModelType: {value}")
|
||||
|
||||
|
||||
class ModelProvider(StrEnum):
|
||||
"""模型提供商枚举"""
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List, Optional
|
||||
import uuid
|
||||
import datetime
|
||||
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
from app.core.logging_config import get_db_logger
|
||||
|
||||
@@ -92,6 +95,157 @@ class EndUserRepository:
|
||||
db_logger.error(f"获取或创建终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
Optional[EndUser]: 终端用户对象,如果不存在则返回None
|
||||
"""
|
||||
try:
|
||||
end_user = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.first()
|
||||
)
|
||||
if end_user:
|
||||
db_logger.debug(f"成功查询到终端用户 {end_user_id}")
|
||||
else:
|
||||
db_logger.debug(f"未找到终端用户 {end_user_id}")
|
||||
return end_user
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询终端用户 {end_user_id} 时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_memory_insight(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
insight: str
|
||||
) -> bool:
|
||||
"""更新记忆洞察缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
insight: 记忆洞察内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.memory_insight: insight,
|
||||
EndUser.memory_insight_updated_at: datetime.datetime.now()
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的记忆洞察缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新记忆洞察缓存")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的记忆洞察缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_user_summary(
|
||||
self,
|
||||
end_user_id: uuid.UUID,
|
||||
summary: str
|
||||
) -> bool:
|
||||
"""更新用户摘要缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
summary: 用户摘要内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
updated_count = (
|
||||
self.db.query(EndUser)
|
||||
.filter(EndUser.id == end_user_id)
|
||||
.update(
|
||||
{
|
||||
EndUser.user_summary: summary,
|
||||
EndUser.user_summary_updated_at: datetime.datetime.now()
|
||||
},
|
||||
synchronize_session=False
|
||||
)
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
if updated_count > 0:
|
||||
db_logger.info(f"成功更新终端用户 {end_user_id} 的用户摘要缓存")
|
||||
return True
|
||||
else:
|
||||
db_logger.warning(f"未找到终端用户 {end_user_id},无法更新用户摘要缓存")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"更新终端用户 {end_user_id} 的用户摘要缓存时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
List[EndUser]: 终端用户列表
|
||||
"""
|
||||
try:
|
||||
end_users = (
|
||||
self.db.query(EndUser)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.all()
|
||||
)
|
||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||
return end_users
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询工作空间 {workspace_id} 下的终端用户时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_all_active_workspaces(self) -> List[uuid.UUID]:
|
||||
"""获取所有活动工作空间的ID
|
||||
|
||||
Returns:
|
||||
List[uuid.UUID]: 活动工作空间ID列表
|
||||
"""
|
||||
try:
|
||||
workspace_ids = (
|
||||
self.db.query(Workspace.id)
|
||||
.filter(Workspace.is_active)
|
||||
.all()
|
||||
)
|
||||
# 提取ID(查询返回的是元组列表)
|
||||
workspace_id_list = [workspace_id[0] for workspace_id in workspace_ids]
|
||||
db_logger.info(f"成功查询到 {len(workspace_id_list)} 个活动工作空间")
|
||||
return workspace_id_list
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
db_logger.error(f"查询活动工作空间时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||
repo = EndUserRepository(db)
|
||||
@@ -138,4 +292,30 @@ def update_end_user_other_name(
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
db_logger.error(f"更新宿主 {end_user_id} 的 other_name 时出错: {str(e)}")
|
||||
raise
|
||||
raise
|
||||
|
||||
# 新增的缓存操作函数(保持与类方法一致的接口)
|
||||
def get_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||
"""根据ID获取终端用户(用于缓存操作)"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_by_id(end_user_id)
|
||||
|
||||
def update_memory_insight(db: Session, end_user_id: uuid.UUID, insight: str) -> bool:
|
||||
"""更新记忆洞察缓存"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.update_memory_insight(end_user_id, insight)
|
||||
|
||||
def update_user_summary(db: Session, end_user_id: uuid.UUID, summary: str) -> bool:
|
||||
"""更新用户摘要缓存"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.update_user_summary(end_user_id, summary)
|
||||
|
||||
def get_all_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||
"""获取工作空间的所有终端用户"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_all_by_workspace(workspace_id)
|
||||
|
||||
def get_all_active_workspaces(db: Session) -> List[uuid.UUID]:
|
||||
"""获取所有活动工作空间的ID"""
|
||||
repo = EndUserRepository(db)
|
||||
return repo.get_all_active_workspaces()
|
||||
|
||||
@@ -783,7 +783,9 @@ neo4j_query_part = """
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
rel.predicate as predicate,
|
||||
rel.statement as relationship,
|
||||
rel.statement_id as relationship_statement_id,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
@@ -799,7 +801,9 @@ neo4j_query_all = """
|
||||
m.created_at as created_at,
|
||||
m.expired_at as expired_at,
|
||||
CASE WHEN rel IS NULL THEN "NO_RELATIONSHIP" ELSE type(rel) END as relationship_type,
|
||||
rel as relationship,
|
||||
rel.predicate as predicate,
|
||||
rel.statement as relationship,
|
||||
rel.statement_id as relationship_statement_id,
|
||||
CASE WHEN other IS NULL THEN "ISOLATED_NODE" ELSE other.name END as entity2_name,
|
||||
other as entity2
|
||||
"""
|
||||
|
||||
@@ -67,11 +67,81 @@ async def update_neo4j_data(neo4j_dict_data, update_databases):
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def update_neo4j_data_edge(neo4j_dict_data, update_databases):
|
||||
"""
|
||||
Update Neo4j data based on query criteria and update parameters
|
||||
|
||||
Args:
|
||||
neo4j_dict_data: find
|
||||
update_databases: update
|
||||
"""
|
||||
try:
|
||||
# 构建WHERE条件
|
||||
where_conditions = []
|
||||
params = {}
|
||||
|
||||
for key, value in neo4j_dict_data.items():
|
||||
if value is not None:
|
||||
param_name = f"param_{key}"
|
||||
where_conditions.append(f"r.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
# 构建SET条件
|
||||
set_conditions = []
|
||||
for key, value in update_databases.items():
|
||||
if value is not None:
|
||||
param_name = f"update_{key}"
|
||||
set_conditions.append(f"r.{key} = ${param_name}")
|
||||
params[param_name] = value
|
||||
|
||||
set_clause = ", ".join(set_conditions)
|
||||
|
||||
if not set_clause:
|
||||
print("警告: 没有需要更新的字段")
|
||||
return False
|
||||
|
||||
# 构建Cypher查询
|
||||
cypher_query = f"""
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE {where_clause}
|
||||
SET {set_clause}
|
||||
RETURN count(r) as updated_count, collect(type(r)) as relation_types
|
||||
"""
|
||||
|
||||
print(f"\n执行Cypher查询: {cypher_query}")
|
||||
print(f"参数: {params}")
|
||||
|
||||
# 执行更新
|
||||
result = await neo4j_connector.execute_query(cypher_query, **params)
|
||||
|
||||
if result:
|
||||
updated_count = result[0].get('updated_count', 0)
|
||||
updated_names = result[0].get('updated_names', [])
|
||||
print(f"成功更新 {updated_count} 个节点")
|
||||
if updated_names:
|
||||
print(f"更新的实体名称: {updated_names}")
|
||||
return updated_count > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"更新过程中出现错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
def map_field_names(data_dict):
|
||||
mapped_dict = {}
|
||||
has_name_field = False
|
||||
|
||||
# 辅助函数:提取值(如果是数组则取最后一个值,否则直接返回)
|
||||
def extract_value(value):
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
# 如果是数组 [old_value, new_value],取新值(最后一个)
|
||||
return value[-1]
|
||||
return value
|
||||
|
||||
# 第一遍:检查是否有name相关字段
|
||||
for key, value in data_dict.items():
|
||||
if key in ['name', 'entity2.name', 'entity1.name']:
|
||||
@@ -82,22 +152,25 @@ def map_field_names(data_dict):
|
||||
|
||||
# 第二遍:根据规则映射和过滤字段
|
||||
for key, value in data_dict.items():
|
||||
# 提取实际值(处理数组格式)
|
||||
actual_value = extract_value(value)
|
||||
|
||||
if key == 'entity2.name' or key == 'entity2_name':
|
||||
# 将 entity2.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
mapped_dict['name'] = actual_value
|
||||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||
elif key == 'entity1.name' or key == 'entity1_name':
|
||||
# 将 entity1.name 映射为 name
|
||||
mapped_dict['name'] = value
|
||||
print(f"字段名映射: {key} -> name")
|
||||
mapped_dict['name'] = actual_value
|
||||
print(f"字段名映射: {key} -> name (值: {value} -> {actual_value})")
|
||||
elif key == 'entity1.description':
|
||||
# 将 entity1.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
mapped_dict['description'] = actual_value
|
||||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||
elif key == 'entity2.description':
|
||||
# 将 entity2.description 映射为 description
|
||||
mapped_dict['description'] = value
|
||||
print(f"字段名映射: {key} -> description")
|
||||
mapped_dict['description'] = actual_value
|
||||
print(f"字段名映射: {key} -> description (值: {value} -> {actual_value})")
|
||||
elif key == 'relationship_type':
|
||||
# 跳过relationship_type字段
|
||||
print(f"字段过滤: 跳过不需要的字段 '{key}'")
|
||||
@@ -109,8 +182,8 @@ def map_field_names(data_dict):
|
||||
continue
|
||||
else:
|
||||
# 如果没有name字段,保留entity1_name
|
||||
mapped_dict[key] = value
|
||||
print(f"字段保留: {key}")
|
||||
mapped_dict[key] = actual_value
|
||||
print(f"字段保留: {key} (值: {value} -> {actual_value})")
|
||||
elif key == 'entity2_name':
|
||||
if has_name_field:
|
||||
# 如果有name字段,跳过entity2_name
|
||||
@@ -122,7 +195,11 @@ def map_field_names(data_dict):
|
||||
continue
|
||||
elif '.' not in key:
|
||||
# 不包含点号的其他字段直接保留
|
||||
mapped_dict[key] = value
|
||||
mapped_dict[key] = actual_value
|
||||
if isinstance(value, list):
|
||||
print(f"字段保留: {key} (数组值: {value} -> {actual_value})")
|
||||
else:
|
||||
print(f"字段保留: {key}")
|
||||
else:
|
||||
# 其他包含点号的字段跳过并警告
|
||||
print(f"警告: 跳过不支持的嵌套字段 '{key}'")
|
||||
@@ -139,89 +216,57 @@ async def neo4j_data(solved_data):
|
||||
"""
|
||||
success_count = 0
|
||||
|
||||
ori_entity = {}
|
||||
updata_entity = {}
|
||||
ori_edge = {}
|
||||
updata_edge = {}
|
||||
ori_expired_at={}
|
||||
updat_expired_at={}
|
||||
for i in solved_data:
|
||||
neo4j_dict_data = {}
|
||||
update_databases = {}
|
||||
results = i['results']
|
||||
for data in results:
|
||||
resolved = data.get('resolved')
|
||||
if not resolved:
|
||||
print("跳过:resolved为None")
|
||||
databasets = i['data']
|
||||
for key, values in databasets.items():
|
||||
if str(values)=='NONE':
|
||||
continue
|
||||
if isinstance(values, list):
|
||||
if key == 'description':
|
||||
ori_entity[key] = values[0]
|
||||
updata_entity[key] = values[1]
|
||||
if key == 'entity2_name' or key == 'entity1_name':
|
||||
key = 'name'
|
||||
ori_entity[key] = values[0]
|
||||
updata_entity[key] = values[1]
|
||||
ori_expired_at[key] = values[0]
|
||||
if key == 'statement':
|
||||
ori_edge[key] = values[0]
|
||||
updata_edge[key] = values[1]
|
||||
if key=='expired_at':
|
||||
updat_expired_at[key] = values[1]
|
||||
|
||||
try:
|
||||
change_list = resolved.get('change', [])
|
||||
except (AttributeError, TypeError):
|
||||
change_list = []
|
||||
elif key == 'statement_id':
|
||||
ori_edge[key] = values
|
||||
updata_edge[key] = values
|
||||
|
||||
if change_list == []:
|
||||
print("跳过:change_list为空")
|
||||
continue
|
||||
ori_entity[key] = values
|
||||
updata_entity[key] = values
|
||||
|
||||
if change_list and len(change_list) > 0:
|
||||
change = change_list[0]
|
||||
print(f"change: {change}")
|
||||
field_data = change.get('field', [])
|
||||
print(f"field_data: {field_data}")
|
||||
print(f"field_data type: {type(field_data)}")
|
||||
|
||||
# 字段名映射和过滤函数
|
||||
ori_expired_at[key] = values
|
||||
|
||||
|
||||
# 处理field数据,可能是字典或列表
|
||||
if isinstance(field_data, dict):
|
||||
# 如果是字典,映射字段名后更新
|
||||
mapped_data = map_field_names(field_data)
|
||||
update_databases.update(mapped_data)
|
||||
elif isinstance(field_data, list):
|
||||
# 如果是列表,遍历每个字典并更新
|
||||
for field_item in field_data:
|
||||
if isinstance(field_item, dict):
|
||||
mapped_item = map_field_names(field_item)
|
||||
update_databases.update(mapped_item)
|
||||
else:
|
||||
print(f"警告: field_item不是字典: {field_item}")
|
||||
else:
|
||||
print(f"警告: field_data类型不支持: {type(field_data)}")
|
||||
|
||||
if 'entity1_name' in data:
|
||||
data['name'] = data.pop('entity1_name')
|
||||
if 'entity2_name' in data:
|
||||
data.pop('entity2_name', None)
|
||||
|
||||
resolved_memory = resolved.get('resolved_memory', {})
|
||||
|
||||
entity2 = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
entity2 = resolved_memory.get('entity2')
|
||||
|
||||
if entity2 and isinstance(entity2, dict) and len(entity2) >= 5:
|
||||
stat_id = resolved.get('original_memory_id')
|
||||
# 安全地获取description
|
||||
statement_id = None
|
||||
if isinstance(resolved_memory, dict):
|
||||
statement_id = resolved_memory.get('statement_id')
|
||||
|
||||
# 只有当neo4j_dict_data中还没有statement_id时才使用original_memory_id
|
||||
if statement_id and 'id' not in neo4j_dict_data:
|
||||
neo4j_dict_data['id'] = stat_id
|
||||
neo4j_dict_data['statement_id'] = statement_id
|
||||
else:
|
||||
# 处理original_memory_id,它可能是字符串或字典
|
||||
try:
|
||||
for key, value in resolved_memory.items():
|
||||
if key == 'statement_id':
|
||||
neo4j_dict_data['statement_id'] = value
|
||||
if key == 'description':
|
||||
neo4j_dict_data['description'] = value
|
||||
except AttributeError:
|
||||
neo4j_dict_data=[]
|
||||
|
||||
print(neo4j_dict_data)
|
||||
print(update_databases)
|
||||
if neo4j_dict_data!=[]:
|
||||
await update_neo4j_data(neo4j_dict_data, update_databases)
|
||||
success_count += 1
|
||||
print(ori_entity)
|
||||
print(updata_entity)
|
||||
print(100*'-')
|
||||
print(ori_edge)
|
||||
print(updata_edge)
|
||||
expired_at_ = updat_expired_at.get('expired_at', None)
|
||||
if expired_at_ is not None:
|
||||
await update_neo4j_data(ori_expired_at, updat_expired_at)
|
||||
success_count += 1
|
||||
if ori_entity != updata_entity:
|
||||
await update_neo4j_data(ori_entity, updata_entity)
|
||||
success_count += 1
|
||||
if ori_edge != updata_edge:
|
||||
await update_neo4j_data_edge(ori_edge, updata_edge)
|
||||
success_count += 1
|
||||
|
||||
return success_count
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
import datetime
|
||||
from typing import Optional, Any, List, Dict, TYPE_CHECKING
|
||||
import uuid
|
||||
from typing import Optional, Any, List, Dict
|
||||
|
||||
from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator
|
||||
|
||||
|
||||
@@ -20,20 +21,19 @@ class KnowledgeBaseConfig(BaseModel):
|
||||
class KnowledgeRetrievalConfig(BaseModel):
|
||||
"""知识库检索配置(支持多个知识库,每个有独立配置)"""
|
||||
knowledge_bases: List[KnowledgeBaseConfig] = Field(
|
||||
default_factory=list,
|
||||
default_factory=list,
|
||||
description="关联的知识库列表,每个知识库有独立配置"
|
||||
)
|
||||
|
||||
|
||||
# 多知识库融合策略
|
||||
merge_strategy: str = Field(
|
||||
default="weighted",
|
||||
default="weighted",
|
||||
description="多知识库结果融合策略: weighted | rrf | concat"
|
||||
)
|
||||
reranker_id: Optional[str] = Field(default=None, description="多知识库结果融合的模型ID")
|
||||
reranker_top_k: int = Field(default=10, ge=0, le=1024, description="多知识库结果融合的模型参数")
|
||||
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
"""工具配置"""
|
||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||
@@ -63,7 +63,7 @@ class VariableDefinition(BaseModel):
|
||||
name: str = Field(..., description="变量名称(标识符)")
|
||||
display_name: Optional[str] = Field(None, description="显示名称(用户看到的名称)")
|
||||
type: str = Field(
|
||||
default="string",
|
||||
default="string",
|
||||
description="变量类型: string(单行文本) | text(多行文本) | number(数字)"
|
||||
)
|
||||
required: bool = Field(default=False, description="是否必填")
|
||||
@@ -75,32 +75,32 @@ class AgentConfigCreate(BaseModel):
|
||||
"""Agent 行为配置"""
|
||||
# 提示词配置
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词,定义 Agent 的角色和行为准则")
|
||||
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认使用的模型配置ID")
|
||||
model_parameters: ModelParameters = Field(
|
||||
default_factory=ModelParameters,
|
||||
description="模型参数配置(temperature、max_tokens 等)"
|
||||
)
|
||||
|
||||
|
||||
# 知识库关联
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||
default=None,
|
||||
description="知识库检索配置"
|
||||
)
|
||||
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(
|
||||
default_factory=lambda: MemoryConfig(enabled=True),
|
||||
description="对话历史记忆配置"
|
||||
)
|
||||
|
||||
|
||||
# 变量配置
|
||||
variables: List[VariableDefinition] = Field(
|
||||
default_factory=list,
|
||||
description="Agent 可用的变量列表"
|
||||
)
|
||||
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = Field(
|
||||
default_factory=dict,
|
||||
@@ -120,7 +120,7 @@ class AppCreate(BaseModel):
|
||||
|
||||
# only for type=agent
|
||||
agent_config: Optional[AgentConfigCreate] = None
|
||||
|
||||
|
||||
# only for type=multi_agent
|
||||
multi_agent_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
@@ -139,23 +139,23 @@ class AgentConfigUpdate(BaseModel):
|
||||
"""更新 Agent 行为配置"""
|
||||
# 提示词配置
|
||||
system_prompt: Optional[str] = Field(default=None, description="系统提示词")
|
||||
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = Field(default=None, description="默认模型配置ID")
|
||||
model_parameters: Optional[ModelParameters] = Field(default=None, description="模型参数配置")
|
||||
|
||||
|
||||
# 知识库关联
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = Field(
|
||||
default=None,
|
||||
description="知识库检索配置"
|
||||
)
|
||||
|
||||
|
||||
# 记忆配置
|
||||
memory: Optional[MemoryConfig] = Field(default=None, description="对话历史记忆配置")
|
||||
|
||||
|
||||
# 变量配置
|
||||
variables: Optional[List[VariableDefinition]] = Field(default=None, description="变量列表")
|
||||
|
||||
|
||||
# 工具配置
|
||||
tools: Optional[Dict[str, ToolConfig]] = Field(default=None, description="工具配置")
|
||||
|
||||
@@ -185,7 +185,7 @@ class App(BaseModel):
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -197,26 +197,26 @@ class AgentConfig(BaseModel):
|
||||
|
||||
id: uuid.UUID
|
||||
app_id: uuid.UUID
|
||||
|
||||
|
||||
# 提示词
|
||||
system_prompt: Optional[str] = None
|
||||
|
||||
|
||||
# 模型配置
|
||||
default_model_config_id: Optional[uuid.UUID] = None
|
||||
model_parameters: ModelParameters = Field(default_factory=ModelParameters)
|
||||
|
||||
|
||||
# 知识库检索
|
||||
knowledge_retrieval: Optional[KnowledgeRetrievalConfig] = None
|
||||
|
||||
|
||||
# 记忆配置
|
||||
memory: MemoryConfig = Field(default_factory=lambda: MemoryConfig(enabled=True))
|
||||
|
||||
|
||||
# 变量配置
|
||||
variables: List[VariableDefinition] = []
|
||||
|
||||
|
||||
# 工具配置
|
||||
tools: Dict[str, ToolConfig] = {}
|
||||
|
||||
|
||||
is_active: bool
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
@@ -228,7 +228,7 @@ class AgentConfig(BaseModel):
|
||||
if v is None:
|
||||
return ModelParameters()
|
||||
return v
|
||||
|
||||
|
||||
@field_validator("memory", mode="before")
|
||||
@classmethod
|
||||
def validate_memory(cls, v):
|
||||
@@ -236,7 +236,7 @@ class AgentConfig(BaseModel):
|
||||
if v is None:
|
||||
return MemoryConfig(enabled=True)
|
||||
return v
|
||||
|
||||
|
||||
@field_validator("variables", mode="before")
|
||||
@classmethod
|
||||
def validate_variables(cls, v):
|
||||
@@ -244,7 +244,7 @@ class AgentConfig(BaseModel):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
|
||||
|
||||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, v):
|
||||
@@ -256,7 +256,7 @@ class AgentConfig(BaseModel):
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -294,15 +294,15 @@ class AppRelease(BaseModel):
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("published_at", when_used="json")
|
||||
def _serialize_published_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
|
||||
# ---------- App Share Schemas ----------
|
||||
|
||||
@@ -314,7 +314,7 @@ class AppShareCreate(BaseModel):
|
||||
class AppShare(BaseModel):
|
||||
"""应用分享输出"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
id: uuid.UUID
|
||||
source_app_id: uuid.UUID
|
||||
source_workspace_id: uuid.UUID
|
||||
@@ -322,11 +322,11 @@ class AppShare(BaseModel):
|
||||
shared_by: uuid.UUID
|
||||
created_at: datetime.datetime
|
||||
updated_at: datetime.datetime
|
||||
|
||||
|
||||
@field_serializer("created_at", when_used="json")
|
||||
def _serialize_created_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
@field_serializer("updated_at", when_used="json")
|
||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
@@ -382,14 +382,14 @@ class DraftRunCompareRequest(BaseModel):
|
||||
conversation_id: Optional[str] = Field(None, description="会话ID")
|
||||
user_id: Optional[str] = Field(None, description="用户ID")
|
||||
variables: Optional[Dict[str, Any]] = Field(None, description="变量参数")
|
||||
|
||||
|
||||
models: List[ModelCompareItem] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=5,
|
||||
description="要对比的模型列表(1-5个)"
|
||||
)
|
||||
|
||||
|
||||
parallel: bool = Field(True, description="是否并行执行")
|
||||
stream: bool = Field(False, description="是否流式返回")
|
||||
timeout: Optional[int] = Field(60, ge=10, le=300, description="超时时间(秒)")
|
||||
@@ -400,14 +400,14 @@ class ModelRunResult(BaseModel):
|
||||
model_config_id: uuid.UUID
|
||||
model_name: str
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
parameters_used: Dict[str, Any] = Field(..., description="实际使用的参数")
|
||||
|
||||
|
||||
message: Optional[str] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
elapsed_time: float
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
tokens_per_second: Optional[float] = None
|
||||
cost_estimate: Optional[float] = None
|
||||
conversation_id: Optional[str] = None
|
||||
@@ -416,10 +416,10 @@ class ModelRunResult(BaseModel):
|
||||
class DraftRunCompareResponse(BaseModel):
|
||||
"""多模型对比响应"""
|
||||
results: List[ModelRunResult]
|
||||
|
||||
|
||||
total_elapsed_time: float
|
||||
successful_count: int
|
||||
failed_count: int
|
||||
|
||||
|
||||
fastest_model: Optional[str] = None
|
||||
cheapest_model: Optional[str] = None
|
||||
|
||||
@@ -16,3 +16,37 @@ class EndUser(BaseModel):
|
||||
reflection_time: Optional[datetime.datetime] = Field(description="反思时间", default_factory=datetime.datetime.now)
|
||||
created_at: datetime.datetime = Field(description="创建时间", default_factory=datetime.datetime.now)
|
||||
updated_at: datetime.datetime = Field(description="更新时间", default_factory=datetime.datetime.now)
|
||||
|
||||
# 用户基本信息字段
|
||||
name: Optional[str] = Field(description="姓名", default=None)
|
||||
position: Optional[str] = Field(description="职位", default=None)
|
||||
department: Optional[str] = Field(description="部门", default=None)
|
||||
contact: Optional[str] = Field(description="联系方式", default=None)
|
||||
phone: Optional[str] = Field(description="电话", default=None)
|
||||
hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None)
|
||||
updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None)
|
||||
|
||||
|
||||
class EndUserProfileResponse(BaseModel):
|
||||
"""终端用户基本信息响应模型"""
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID = Field(description="终端用户ID")
|
||||
name: Optional[str] = Field(description="姓名", default=None)
|
||||
position: Optional[str] = Field(description="职位", default=None)
|
||||
department: Optional[str] = Field(description="部门", default=None)
|
||||
contact: Optional[str] = Field(description="联系方式", default=None)
|
||||
phone: Optional[str] = Field(description="电话", default=None)
|
||||
hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None)
|
||||
updatetime_profile: Optional[int] = Field(description="核心档案信息最后更新时间(时间戳,毫秒)", default=None)
|
||||
|
||||
|
||||
class EndUserProfileUpdate(BaseModel):
|
||||
"""终端用户基本信息更新请求模型"""
|
||||
end_user_id: str = Field(description="终端用户ID")
|
||||
name: Optional[str] = Field(description="姓名", default=None)
|
||||
position: Optional[str] = Field(description="职位", default=None)
|
||||
department: Optional[str] = Field(description="部门", default=None)
|
||||
contact: Optional[str] = Field(description="联系方式", default=None)
|
||||
phone: Optional[str] = Field(description="电话", default=None)
|
||||
hire_date: Optional[int] = Field(description="入职日期(时间戳,毫秒)", default=None)
|
||||
@@ -31,21 +31,19 @@ class BaseDataSchema(BaseModel):
|
||||
# 保持原有必需字段为可选,以兼容不同数据源
|
||||
id: Optional[str] = Field(None, description="The unique identifier for the data entry.")
|
||||
statement: Optional[str] = Field(None, description="The statement text.")
|
||||
group_id: Optional[str] = Field(None, description="The group identifier.")
|
||||
chunk_id: Optional[str] = Field(None, description="The chunk identifier.")
|
||||
created_at: str = Field(..., description="The creation timestamp in ISO 8601 format.")
|
||||
expired_at: Optional[str] = Field(None, description="The expiration timestamp in ISO 8601 format.")
|
||||
valid_at: Optional[str] = Field(None, description="The validation timestamp in ISO 8601 format.")
|
||||
invalid_at: Optional[str] = Field(None, description="The invalidation timestamp in ISO 8601 format.")
|
||||
entity_ids: List[str] = Field([], description="The list of entity identifiers.")
|
||||
description: Optional[str] = Field(None, description="The description of the data entry.")
|
||||
|
||||
# 新增字段以匹配实际输入数据
|
||||
entity1_name: str = Field(..., description="The first entity name.")
|
||||
entity2_name: Optional[str] = Field(None, description="The second entity name.")
|
||||
statement_id: str = Field(..., description="The statement identifier.")
|
||||
relationship_type: str = Field(..., description="The relationship type.")
|
||||
relationship: Optional[Dict[str, Any]] = Field(None, description="The relationship object.")
|
||||
# 新增字段 - 设为可选以保持向后兼容性
|
||||
predicate: Optional[str] = Field(None, description="The predicate describing the relationship between entities.")
|
||||
relationship_statement_id: Optional[str] = Field(None, description="The relationship statement identifier.")
|
||||
# 保留原有字段 - 修改relationship字段类型以支持字符串和字典
|
||||
relationship: Optional[Union[str, Dict[str, Any]]] = Field(None, description="The relationship object or string.")
|
||||
entity2: Optional[Dict[str, Any]] = Field(None, description="The second entity object.")
|
||||
|
||||
|
||||
@@ -99,8 +97,17 @@ class ReflexionSchema(BaseModel):
|
||||
|
||||
|
||||
class ChangeRecordSchema(BaseModel):
|
||||
"""Schema for individual change records"""
|
||||
field: List[Dict[str, str]] = Field(..., description="List of field changes, each containing field name and new value.")
|
||||
"""Schema for individual change records
|
||||
|
||||
字段值格式说明:
|
||||
- id 和 statement_id: 字符串或 None
|
||||
- 其他字段: 可以是字符串、None,数组 [修改前的值, 修改后的值],或嵌套字典结构
|
||||
- entity2等嵌套对象的字段也遵循 [old_value, new_value] 格式
|
||||
"""
|
||||
field: List[Dict[str, Any]] = Field(
|
||||
...,
|
||||
description="List of field changes. First item: {id: value or None}, second: {statement_id: value}, followed by changed fields as {field_name: [old_value, new_value]} or {field_name: new_value} or nested structures like {entity2: {field_name: [old, new]}}"
|
||||
)
|
||||
|
||||
class ResolvedSchema(BaseModel):
|
||||
"""Schema for the resolved memory data in the reflexion_data"""
|
||||
@@ -375,3 +382,12 @@ def fail(
|
||||
error=error_code,
|
||||
time=time or _now_ms(),
|
||||
)
|
||||
|
||||
class GenerateCacheRequest(BaseModel):
|
||||
"""缓存生成请求模型"""
|
||||
model_config = ConfigDict(populate_by_name=True, extra="forbid")
|
||||
|
||||
end_user_id: Optional[str] = Field(
|
||||
None,
|
||||
description="终端用户ID(UUID格式)。如果提供,只为该用户生成;如果不提供,为当前工作空间的所有用户生成"
|
||||
)
|
||||
|
||||
@@ -268,10 +268,20 @@ async def get_workspace_total_memory_count(
|
||||
# 如果提供了 end_user_id,只查询该用户
|
||||
if end_user_id:
|
||||
search_result = await memory_storage_service.search_all(end_user_id=end_user_id)
|
||||
# 查询用户名称
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(uuid.UUID(end_user_id))
|
||||
user_name = end_user.name if end_user else None
|
||||
|
||||
return {
|
||||
"total_memory_count": search_result.get("total", 0),
|
||||
"host_count": 1,
|
||||
"details": [{"end_user_id": end_user_id, "count": search_result.get("total", 0)}]
|
||||
"details": [{
|
||||
"end_user_id": end_user_id,
|
||||
"count": search_result.get("total", 0),
|
||||
"name": user_name
|
||||
}]
|
||||
}
|
||||
|
||||
for host in hosts:
|
||||
@@ -287,17 +297,19 @@ async def get_workspace_total_memory_count(
|
||||
|
||||
details.append({
|
||||
"end_user_id": end_user_id_str,
|
||||
"count": host_total
|
||||
"count": host_total,
|
||||
"name": host.name # 添加 name 字段
|
||||
})
|
||||
|
||||
business_logger.debug(f"EndUser {end_user_id_str} 记忆数: {host_total}")
|
||||
business_logger.debug(f"EndUser {end_user_id_str} ({host.name}) 记忆数: {host_total}")
|
||||
|
||||
except Exception as e:
|
||||
business_logger.warning(f"获取 end_user {host.id} 记忆数失败: {str(e)}")
|
||||
# 失败的 host 记为 0
|
||||
details.append({
|
||||
"end_user_id": str(host.id),
|
||||
"count": 0
|
||||
"count": 0,
|
||||
"name": host.name # 添加 name 字段
|
||||
})
|
||||
|
||||
result = {
|
||||
|
||||
@@ -8,21 +8,18 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_config_logger, get_logger
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.user_model import User
|
||||
from app.repositories.data_config_repository import DataConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import ConfigurationError, MemoryConfig
|
||||
from app.schemas.memory_config_schema import ConfigurationError
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigFilter,
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
@@ -68,6 +65,7 @@ class MemoryStorageService:
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
"""Service layer for config params CRUD.
|
||||
@@ -86,7 +84,6 @@ class DataConfigService: # 数据配置服务类(PostgreSQL)
|
||||
@staticmethod
|
||||
def _convert_timestamps_to_format(data_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""将 created_at 和 updated_at 字段从 datetime 对象转换为 YYYYMMDDHHmmss 格式"""
|
||||
from datetime import datetime
|
||||
|
||||
for item in data_list:
|
||||
for field in ['created_at', 'updated_at']:
|
||||
@@ -569,14 +566,6 @@ async def analytics_hot_memory_tags(
|
||||
return [{"name": t, "frequency": f} for t, f in top_tags]
|
||||
|
||||
|
||||
async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
insight = MemoryInsight(end_user_id)
|
||||
report = await insight.generate_insight_report()
|
||||
await insight.close()
|
||||
data = {"report": report}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
stats, _msg = get_recent_activity_stats()
|
||||
total = (
|
||||
@@ -610,8 +599,3 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]:
|
||||
data = {"total": total, "stats": stats, "latest_relative": latest_relative}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
summary = await generate_user_summary(end_user_id)
|
||||
data = {"summary": summary}
|
||||
return data
|
||||
@@ -169,7 +169,7 @@ class PromptOptimizerService:
|
||||
provider=api_config.provider,
|
||||
api_key=api_config.api_key,
|
||||
base_url=api_config.api_base
|
||||
), type=ModelType.from_str(model_config.type))
|
||||
), type=ModelType(model_config.type))
|
||||
|
||||
# build message
|
||||
messages = [
|
||||
|
||||
831
api/app/services/user_memory_service.py
Normal file
831
api/app/services/user_memory_service.py
Normal file
@@ -0,0 +1,831 @@
|
||||
"""
|
||||
User Memory Service
|
||||
|
||||
处理用户记忆相关的业务逻辑,包括记忆洞察、用户摘要、节点统计和图数据等。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Any
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.core.memory.analytics.memory_insight import MemoryInsight
|
||||
from app.core.memory.analytics.user_summary import generate_user_summary
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Neo4j connector instance
|
||||
_neo4j_connector = Neo4jConnector()
|
||||
|
||||
|
||||
class UserMemoryService:
|
||||
"""用户记忆服务类"""
|
||||
|
||||
def __init__(self):
|
||||
logger.info("UserMemoryService initialized")
|
||||
|
||||
async def get_cached_memory_insight(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的记忆洞察
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID (UUID)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"report": str,
|
||||
"updated_at": datetime,
|
||||
"is_cached": bool
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 转换为UUID并查询用户
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
"report": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "用户不存在"
|
||||
}
|
||||
|
||||
# 检查是否有缓存数据
|
||||
if end_user.memory_insight:
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存记忆洞察")
|
||||
return {
|
||||
"report": end_user.memory_insight,
|
||||
"updated_at": end_user.memory_insight_updated_at,
|
||||
"is_cached": True
|
||||
}
|
||||
else:
|
||||
logger.info(f"end_user_id {end_user_id} 的记忆洞察缓存为空")
|
||||
return {
|
||||
"report": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "数据尚未生成,请稍后重试或联系管理员"
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"report": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存记忆洞察时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
async def get_cached_user_summary(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
从数据库获取缓存的用户摘要
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID (UUID)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"summary": str,
|
||||
"updated_at": datetime,
|
||||
"is_cached": bool
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 转换为UUID并查询用户
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
"summary": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "用户不存在"
|
||||
}
|
||||
|
||||
# 检查是否有缓存数据
|
||||
if end_user.user_summary:
|
||||
logger.info(f"成功获取 end_user_id {end_user_id} 的缓存用户摘要")
|
||||
return {
|
||||
"summary": end_user.user_summary,
|
||||
"updated_at": end_user.user_summary_updated_at,
|
||||
"is_cached": True
|
||||
}
|
||||
else:
|
||||
logger.info(f"end_user_id {end_user_id} 的用户摘要缓存为空")
|
||||
return {
|
||||
"summary": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "数据尚未生成,请稍后重试或联系管理员"
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"summary": None,
|
||||
"updated_at": None,
|
||||
"is_cached": False,
|
||||
"message": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取缓存用户摘要时出错: {str(e)}")
|
||||
raise
|
||||
|
||||
async def generate_and_cache_insight(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成并缓存记忆洞察
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID (UUID)
|
||||
workspace_id: 工作空间ID (可选)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"report": str,
|
||||
"error": Optional[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始为 end_user_id {end_user_id} 生成记忆洞察")
|
||||
|
||||
# 转换为UUID并查询用户
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.error(f"end_user_id {end_user_id} 不存在")
|
||||
return {
|
||||
"success": False,
|
||||
"report": None,
|
||||
"error": "用户不存在"
|
||||
}
|
||||
|
||||
# 使用 end_user_id 调用分析函数
|
||||
try:
|
||||
logger.info(f"使用 end_user_id={end_user_id} 生成记忆洞察")
|
||||
result = await analytics_memory_insight_report(end_user_id)
|
||||
report = result.get("report", "")
|
||||
|
||||
if not report:
|
||||
logger.warning(f"end_user_id {end_user_id} 的记忆洞察生成结果为空")
|
||||
return {
|
||||
"success": False,
|
||||
"report": None,
|
||||
"error": "生成的洞察报告为空,可能Neo4j中没有该用户的数据"
|
||||
}
|
||||
|
||||
# 更新数据库缓存
|
||||
success = repo.update_memory_insight(user_uuid, report)
|
||||
|
||||
if success:
|
||||
logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存记忆洞察")
|
||||
return {
|
||||
"success": True,
|
||||
"report": report,
|
||||
"error": None
|
||||
}
|
||||
else:
|
||||
logger.error(f"更新 end_user_id {end_user_id} 的记忆洞察缓存失败")
|
||||
return {
|
||||
"success": False,
|
||||
"report": report,
|
||||
"error": "数据库更新失败"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用分析函数生成记忆洞察时出错: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"report": None,
|
||||
"error": f"Neo4j或LLM服务不可用: {str(e)}"
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"report": None,
|
||||
"error": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成并缓存记忆洞察时出错: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"report": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def generate_and_cache_summary(
|
||||
self,
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成并缓存用户摘要
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID (UUID)
|
||||
workspace_id: 工作空间ID (可选)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": bool,
|
||||
"summary": str,
|
||||
"error": Optional[str]
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"开始为 end_user_id {end_user_id} 生成用户摘要")
|
||||
|
||||
# 转换为UUID并查询用户
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.error(f"end_user_id {end_user_id} 不存在")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": None,
|
||||
"error": "用户不存在"
|
||||
}
|
||||
|
||||
# 使用 end_user_id 调用分析函数
|
||||
try:
|
||||
logger.info(f"使用 end_user_id={end_user_id} 生成用户摘要")
|
||||
result = await analytics_user_summary(end_user_id)
|
||||
summary = result.get("summary", "")
|
||||
|
||||
if not summary:
|
||||
logger.warning(f"end_user_id {end_user_id} 的用户摘要生成结果为空")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": None,
|
||||
"error": "生成的用户摘要为空,可能Neo4j中没有该用户的数据"
|
||||
}
|
||||
|
||||
# 更新数据库缓存
|
||||
success = repo.update_user_summary(user_uuid, summary)
|
||||
|
||||
if success:
|
||||
logger.info(f"成功为 end_user_id {end_user_id} 生成并缓存用户摘要")
|
||||
return {
|
||||
"success": True,
|
||||
"summary": summary,
|
||||
"error": None
|
||||
}
|
||||
else:
|
||||
logger.error(f"更新 end_user_id {end_user_id} 的用户摘要缓存失败")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": summary,
|
||||
"error": "数据库更新失败"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"调用分析函数生成用户摘要时出错: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": None,
|
||||
"error": f"Neo4j或LLM服务不可用: {str(e)}"
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": None,
|
||||
"error": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"生成并缓存用户摘要时出错: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"summary": None,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def generate_cache_for_workspace(
|
||||
self,
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
为整个工作空间生成缓存
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
workspace_id: 工作空间ID
|
||||
|
||||
Returns:
|
||||
{
|
||||
"total_users": int,
|
||||
"successful": int,
|
||||
"failed": int,
|
||||
"errors": List[Dict]
|
||||
}
|
||||
"""
|
||||
logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
total_users = 0
|
||||
successful = 0
|
||||
failed = 0
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# 获取工作空间的所有终端用户
|
||||
repo = EndUserRepository(db)
|
||||
end_users = repo.get_all_by_workspace(workspace_id)
|
||||
total_users = len(end_users)
|
||||
|
||||
logger.info(f"工作空间 {workspace_id} 共有 {total_users} 个终端用户")
|
||||
|
||||
# 遍历每个用户并生成缓存
|
||||
for end_user in end_users:
|
||||
end_user_id = str(end_user.id)
|
||||
|
||||
try:
|
||||
# 生成记忆洞察
|
||||
insight_result = await self.generate_and_cache_insight(db, end_user_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await self.generate_and_cache_summary(db, end_user_id)
|
||||
|
||||
# 检查是否都成功
|
||||
if insight_result["success"] and summary_result["success"]:
|
||||
successful += 1
|
||||
logger.info(f"成功为终端用户 {end_user_id} 生成缓存")
|
||||
else:
|
||||
failed += 1
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"insight_error": insight_result.get("error"),
|
||||
"summary_error": summary_result.get("error")
|
||||
}
|
||||
errors.append(error_info)
|
||||
logger.warning(f"终端用户 {end_user_id} 的缓存生成部分失败: {error_info}")
|
||||
|
||||
except Exception as e:
|
||||
# 单个用户失败不影响其他用户
|
||||
failed += 1
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"error": str(e)
|
||||
}
|
||||
errors.append(error_info)
|
||||
logger.error(f"为终端用户 {end_user_id} 生成缓存时出错: {str(e)}")
|
||||
|
||||
# 记录统计信息
|
||||
logger.info(
|
||||
f"工作空间 {workspace_id} 批量生成完成: "
|
||||
f"总数={total_users}, 成功={successful}, 失败={failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"errors": errors
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"为工作空间 {workspace_id} 批量生成缓存时出错: {str(e)}")
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"errors": errors + [{"error": f"批量处理失败: {str(e)}"}]
|
||||
}
|
||||
|
||||
|
||||
# 独立的分析函数
|
||||
|
||||
async def analytics_memory_insight_report(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
生成记忆洞察报告
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID
|
||||
|
||||
Returns:
|
||||
包含报告的字典
|
||||
"""
|
||||
insight = MemoryInsight(end_user_id)
|
||||
report = await insight.generate_insight_report()
|
||||
await insight.close()
|
||||
data = {"report": report}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_user_summary(end_user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
生成用户摘要
|
||||
|
||||
Args:
|
||||
end_user_id: 可选的终端用户ID
|
||||
|
||||
Returns:
|
||||
包含摘要的字典
|
||||
"""
|
||||
summary = await generate_user_summary(end_user_id)
|
||||
data = {"summary": summary}
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_node_statistics(
|
||||
db: Session,
|
||||
end_user_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
统计 Neo4j 中四种节点类型的数量和百分比
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 可选的终端用户ID (UUID),用于过滤特定用户的节点
|
||||
|
||||
Returns:
|
||||
{
|
||||
"total": int, # 总节点数
|
||||
"nodes": [
|
||||
{
|
||||
"type": str, # 节点类型
|
||||
"count": int, # 节点数量
|
||||
"percentage": float # 百分比
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
# 定义四种节点类型的查询
|
||||
node_types = ["Chunk", "MemorySummary", "Statement", "ExtractedEntity"]
|
||||
|
||||
# 存储每种节点类型的计数
|
||||
node_counts = {}
|
||||
|
||||
# 查询每种节点类型的数量
|
||||
for node_type in node_types:
|
||||
# 构建查询语句
|
||||
if end_user_id:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query, group_id=end_user_id)
|
||||
else:
|
||||
query = f"""
|
||||
MATCH (n:{node_type})
|
||||
RETURN count(n) as count
|
||||
"""
|
||||
result = await _neo4j_connector.execute_query(query)
|
||||
|
||||
# 提取计数结果
|
||||
count = result[0]["count"] if result and len(result) > 0 else 0
|
||||
node_counts[node_type] = count
|
||||
|
||||
# 计算总数
|
||||
total = sum(node_counts.values())
|
||||
|
||||
# 构建返回数据,包含百分比
|
||||
nodes = []
|
||||
for node_type in node_types:
|
||||
count = node_counts[node_type]
|
||||
percentage = round((count / total * 100), 2) if total > 0 else 0.0
|
||||
nodes.append({
|
||||
"type": node_type,
|
||||
"count": count,
|
||||
"percentage": percentage
|
||||
})
|
||||
|
||||
data = {
|
||||
"total": total,
|
||||
"nodes": nodes
|
||||
}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def analytics_graph_data(
|
||||
db: Session,
|
||||
end_user_id: str,
|
||||
node_types: Optional[List[str]] = None,
|
||||
limit: int = 100,
|
||||
depth: int = 1,
|
||||
center_node_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取 Neo4j 图数据,用于前端可视化
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
end_user_id: 终端用户ID
|
||||
node_types: 可选的节点类型列表
|
||||
limit: 返回节点数量限制
|
||||
depth: 图遍历深度
|
||||
center_node_id: 可选的中心节点ID
|
||||
|
||||
Returns:
|
||||
包含节点、边和统计信息的字典
|
||||
"""
|
||||
try:
|
||||
# 1. 获取 group_id
|
||||
user_uuid = uuid.UUID(end_user_id)
|
||||
repo = EndUserRepository(db)
|
||||
end_user = repo.get_by_id(user_uuid)
|
||||
|
||||
if not end_user:
|
||||
logger.warning(f"未找到 end_user_id 为 {end_user_id} 的用户")
|
||||
return {
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
"statistics": {
|
||||
"total_nodes": 0,
|
||||
"total_edges": 0,
|
||||
"node_types": {},
|
||||
"edge_types": {}
|
||||
},
|
||||
"message": "用户不存在"
|
||||
}
|
||||
|
||||
# 2. 构建节点查询
|
||||
if center_node_id:
|
||||
# 基于中心节点的扩展查询
|
||||
node_query = f"""
|
||||
MATCH path = (center)-[*1..{depth}]-(connected)
|
||||
WHERE center.group_id = $group_id
|
||||
AND elementId(center) = $center_node_id
|
||||
WITH collect(DISTINCT center) + collect(DISTINCT connected) as all_nodes
|
||||
UNWIND all_nodes as n
|
||||
RETURN DISTINCT
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
properties(n) as properties
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"center_node_id": center_node_id,
|
||||
"limit": limit
|
||||
}
|
||||
elif node_types:
|
||||
# 按节点类型过滤查询
|
||||
node_query = """
|
||||
MATCH (n)
|
||||
WHERE n.group_id = $group_id
|
||||
AND labels(n)[0] IN $node_types
|
||||
RETURN
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
properties(n) as properties
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"node_types": node_types,
|
||||
"limit": limit
|
||||
}
|
||||
else:
|
||||
# 查询所有节点
|
||||
node_query = """
|
||||
MATCH (n)
|
||||
WHERE n.group_id = $group_id
|
||||
RETURN
|
||||
elementId(n) as id,
|
||||
labels(n)[0] as label,
|
||||
properties(n) as properties
|
||||
LIMIT $limit
|
||||
"""
|
||||
node_params = {
|
||||
"group_id": end_user_id,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
# 执行节点查询
|
||||
node_results = await _neo4j_connector.execute_query(node_query, **node_params)
|
||||
|
||||
# 3. 格式化节点数据
|
||||
nodes = []
|
||||
node_ids = []
|
||||
node_type_counts = {}
|
||||
|
||||
for record in node_results:
|
||||
node_id = record["id"]
|
||||
node_label = record["label"]
|
||||
node_props = record["properties"]
|
||||
|
||||
# 根据节点类型提取需要的属性字段
|
||||
filtered_props = _extract_node_properties(node_label, node_props)
|
||||
|
||||
# 直接使用数据库中的 caption,如果没有则使用节点类型作为默认值
|
||||
caption = filtered_props.get("caption", node_label)
|
||||
|
||||
nodes.append({
|
||||
"id": node_id,
|
||||
"label": node_label,
|
||||
"properties": filtered_props,
|
||||
"caption": caption
|
||||
})
|
||||
|
||||
node_ids.append(node_id)
|
||||
node_type_counts[node_label] = node_type_counts.get(node_label, 0) + 1
|
||||
|
||||
# 4. 查询节点之间的关系
|
||||
if len(node_ids) > 0:
|
||||
edge_query = """
|
||||
MATCH (n)-[r]->(m)
|
||||
WHERE elementId(n) IN $node_ids
|
||||
AND elementId(m) IN $node_ids
|
||||
RETURN
|
||||
elementId(r) as id,
|
||||
elementId(n) as source,
|
||||
elementId(m) as target,
|
||||
type(r) as rel_type,
|
||||
properties(r) as properties
|
||||
"""
|
||||
edge_results = await _neo4j_connector.execute_query(
|
||||
edge_query,
|
||||
node_ids=node_ids
|
||||
)
|
||||
else:
|
||||
edge_results = []
|
||||
|
||||
# 5. 格式化边数据
|
||||
edges = []
|
||||
edge_type_counts = {}
|
||||
|
||||
for record in edge_results:
|
||||
edge_id = record["id"]
|
||||
source = record["source"]
|
||||
target = record["target"]
|
||||
rel_type = record["rel_type"]
|
||||
edge_props = record["properties"]
|
||||
|
||||
# 清理边属性中的 Neo4j 特殊类型
|
||||
# 对于边,我们保留所有属性,但清理特殊类型
|
||||
cleaned_edge_props = {}
|
||||
if edge_props:
|
||||
for key, value in edge_props.items():
|
||||
cleaned_edge_props[key] = _clean_neo4j_value(value)
|
||||
|
||||
# 直接使用关系类型作为 caption,如果 properties 中有 caption 则使用它
|
||||
caption = cleaned_edge_props.get("caption", rel_type)
|
||||
|
||||
edges.append({
|
||||
"id": edge_id,
|
||||
"source": source,
|
||||
"target": target,
|
||||
"type": rel_type,
|
||||
"properties": cleaned_edge_props,
|
||||
"caption": caption
|
||||
})
|
||||
|
||||
edge_type_counts[rel_type] = edge_type_counts.get(rel_type, 0) + 1
|
||||
|
||||
# 6. 构建统计信息
|
||||
statistics = {
|
||||
"total_nodes": len(nodes),
|
||||
"total_edges": len(edges),
|
||||
"node_types": node_type_counts,
|
||||
"edge_types": edge_type_counts
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"成功获取图数据: end_user_id={end_user_id}, "
|
||||
f"nodes={len(nodes)}, edges={len(edges)}"
|
||||
)
|
||||
|
||||
return {
|
||||
"nodes": nodes,
|
||||
"edges": edges,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
logger.error(f"无效的 end_user_id 格式: {end_user_id}")
|
||||
return {
|
||||
"nodes": [],
|
||||
"edges": [],
|
||||
"statistics": {
|
||||
"total_nodes": 0,
|
||||
"total_edges": 0,
|
||||
"node_types": {},
|
||||
"edge_types": {}
|
||||
},
|
||||
"message": "无效的用户ID格式"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"获取图数据失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
# 辅助函数
|
||||
|
||||
def _extract_node_properties(label: str, properties: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
根据节点类型提取需要的属性字段
|
||||
|
||||
Args:
|
||||
label: 节点类型标签
|
||||
properties: 节点的所有属性
|
||||
|
||||
Returns:
|
||||
过滤后的属性字典
|
||||
"""
|
||||
# 定义每种节点类型需要的字段(白名单)
|
||||
field_whitelist = {
|
||||
"Dialogue": ["content", "created_at"],
|
||||
"Chunk": ["content", "created_at"],
|
||||
"Statement": ["temporal_info", "stmt_type", "statement", "valid_at", "created_at", "caption"],
|
||||
"ExtractedEntity": ["description", "name", "entity_type", "created_at", "caption"],
|
||||
"MemorySummary": ["summary", "content", "created_at", "caption"] # 添加 content 字段
|
||||
}
|
||||
|
||||
# 获取该节点类型的白名单字段
|
||||
allowed_fields = field_whitelist.get(label, [])
|
||||
|
||||
# 如果没有定义白名单,返回空字典(或者可以返回所有字段)
|
||||
if not allowed_fields:
|
||||
# 对于未定义的节点类型,只返回基本字段
|
||||
allowed_fields = ["name", "created_at", "caption"]
|
||||
|
||||
# 提取白名单中的字段
|
||||
filtered_props = {}
|
||||
for field in allowed_fields:
|
||||
if field in properties:
|
||||
value = properties[field]
|
||||
# 清理 Neo4j 特殊类型
|
||||
filtered_props[field] = _clean_neo4j_value(value)
|
||||
|
||||
return filtered_props
|
||||
|
||||
|
||||
def _clean_neo4j_value(value: Any) -> Any:
|
||||
"""
|
||||
清理单个值的 Neo4j 特殊类型
|
||||
|
||||
Args:
|
||||
value: 需要清理的值
|
||||
|
||||
Returns:
|
||||
清理后的值
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# 处理列表
|
||||
if isinstance(value, list):
|
||||
return [_clean_neo4j_value(item) for item in value]
|
||||
|
||||
# 处理字典
|
||||
if isinstance(value, dict):
|
||||
return {k: _clean_neo4j_value(v) for k, v in value.items()}
|
||||
|
||||
# 处理 Neo4j DateTime 类型
|
||||
if hasattr(value, '__class__') and 'neo4j.time' in str(type(value)):
|
||||
try:
|
||||
if hasattr(value, 'to_native'):
|
||||
native_dt = value.to_native()
|
||||
return native_dt.isoformat()
|
||||
return str(value)
|
||||
except Exception:
|
||||
return str(value)
|
||||
|
||||
# 处理其他 Neo4j 特殊类型
|
||||
if hasattr(value, '__class__') and 'neo4j' in str(type(value)):
|
||||
try:
|
||||
return str(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
# 返回原始值
|
||||
return value
|
||||
@@ -39,14 +39,14 @@ class WorkflowService:
|
||||
# ==================== 配置管理 ====================
|
||||
|
||||
def create_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]],
|
||||
edges: list[dict[str, Any]],
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
"""创建工作流配置
|
||||
|
||||
@@ -109,14 +109,14 @@ class WorkflowService:
|
||||
return self.config_repo.get_by_app_id(app_id)
|
||||
|
||||
def update_workflow_config(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]] | None = None,
|
||||
edges: list[dict[str, Any]] | None = None,
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
nodes: list[dict[str, Any]] | None = None,
|
||||
edges: list[dict[str, Any]] | None = None,
|
||||
variables: list[dict[str, Any]] | None = None,
|
||||
execution_config: dict[str, Any] | None = None,
|
||||
triggers: list[dict[str, Any]] | None = None,
|
||||
validate: bool = True
|
||||
) -> WorkflowConfig:
|
||||
"""更新工作流配置
|
||||
|
||||
@@ -226,8 +226,8 @@ class WorkflowService:
|
||||
return config
|
||||
|
||||
def validate_workflow_config_for_publish(
|
||||
self,
|
||||
app_id: uuid.UUID
|
||||
self,
|
||||
app_id: uuid.UUID
|
||||
) -> tuple[bool, list[str]]:
|
||||
"""验证工作流配置是否可以发布
|
||||
|
||||
@@ -260,13 +260,13 @@ class WorkflowService:
|
||||
# ==================== 执行管理 ====================
|
||||
|
||||
def create_execution(
|
||||
self,
|
||||
workflow_config_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
trigger_type: str,
|
||||
triggered_by: uuid.UUID | None = None,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
input_data: dict[str, Any] | None = None
|
||||
self,
|
||||
workflow_config_id: uuid.UUID,
|
||||
app_id: uuid.UUID,
|
||||
trigger_type: str,
|
||||
triggered_by: uuid.UUID | None = None,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
input_data: dict[str, Any] | None = None
|
||||
) -> WorkflowExecution:
|
||||
"""创建工作流执行记录
|
||||
|
||||
@@ -314,10 +314,10 @@ class WorkflowService:
|
||||
return self.execution_repo.get_by_execution_id(execution_id)
|
||||
|
||||
def get_executions_by_app(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
limit: int = 50,
|
||||
offset: int = 0
|
||||
) -> list[WorkflowExecution]:
|
||||
"""获取应用的执行记录列表
|
||||
|
||||
@@ -332,12 +332,12 @@ class WorkflowService:
|
||||
return self.execution_repo.get_by_app_id(app_id, limit, offset)
|
||||
|
||||
def update_execution_status(
|
||||
self,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
output_data: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
error_node_id: str | None = None
|
||||
self,
|
||||
execution_id: str,
|
||||
status: str,
|
||||
output_data: dict[str, Any] | None = None,
|
||||
error_message: str | None = None,
|
||||
error_node_id: str | None = None
|
||||
) -> WorkflowExecution:
|
||||
"""更新执行状态
|
||||
|
||||
@@ -407,10 +407,10 @@ class WorkflowService:
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
):
|
||||
"""运行工作流
|
||||
|
||||
@@ -527,10 +527,10 @@ class WorkflowService:
|
||||
)
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
payload: DraftRunRequest,
|
||||
config: WorkflowConfig
|
||||
):
|
||||
"""运行工作流(流式)
|
||||
|
||||
@@ -600,11 +600,11 @@ class WorkflowService:
|
||||
|
||||
# 调用流式执行(executor 会发送 workflow_start 和 workflow_end 事件)
|
||||
async for event in self._run_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id="",
|
||||
user_id=payload.user_id
|
||||
workflow_config=workflow_config_dict,
|
||||
input_data=input_data,
|
||||
execution_id=execution.execution_id,
|
||||
workspace_id="",
|
||||
user_id=payload.user_id
|
||||
):
|
||||
# 直接转发 executor 的事件(已经是正确的格式)
|
||||
yield event
|
||||
@@ -626,12 +626,12 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
async def run_workflow(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
triggered_by: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
stream: bool = False
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
input_data: dict[str, Any],
|
||||
triggered_by: uuid.UUID,
|
||||
conversation_id: uuid.UUID | None = None,
|
||||
stream: bool = False
|
||||
) -> AsyncGenerator | dict:
|
||||
"""运行工作流
|
||||
|
||||
@@ -778,12 +778,12 @@ class WorkflowService:
|
||||
return clean_value(event)
|
||||
|
||||
async def _run_workflow_stream(
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str):
|
||||
self,
|
||||
workflow_config: dict[str, Any],
|
||||
input_data: dict[str, Any],
|
||||
execution_id: str,
|
||||
workspace_id: str,
|
||||
user_id: str):
|
||||
"""运行工作流(流式,内部方法)
|
||||
|
||||
Args:
|
||||
@@ -800,11 +800,11 @@ class WorkflowService:
|
||||
|
||||
try:
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
workflow_config=workflow_config,
|
||||
input_data=input_data,
|
||||
execution_id=execution_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id
|
||||
):
|
||||
# 直接转发事件(executor 已经返回正确格式)
|
||||
yield event
|
||||
@@ -828,7 +828,7 @@ class WorkflowService:
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
def get_workflow_service(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> WorkflowService:
|
||||
"""获取工作流服务(依赖注入)"""
|
||||
return WorkflowService(db)
|
||||
|
||||
704
api/app/tasks.py
704
api/app/tasks.py
@@ -21,7 +21,7 @@ from app.core.rag.prompts.generator import question_proposal
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import (
|
||||
ElasticSearchVectorFactory,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
@@ -50,124 +50,122 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
"""
|
||||
Document parsing, vectorization, and storage
|
||||
"""
|
||||
db = next(get_db()) # Manually call the generator
|
||||
db_document = None
|
||||
db_knowledge = None
|
||||
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
|
||||
try:
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
|
||||
# 1. Document parsing & segmentation
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
|
||||
start_time = time.time()
|
||||
db_document.progress = 0.0
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.process_begin_at = datetime.now(tz=timezone.utc)
|
||||
db_document.process_duration = 0.0
|
||||
db_document.run = 1
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
def progress_callback(prog=None, msg=None):
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
|
||||
# Prepare to configure chat_mdl、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese",
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
from app.core.rag.app.naive import chunk
|
||||
res = chunk(filename=file_path,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
callback=progress_callback,
|
||||
vision_model=vision_model,
|
||||
parser_config=db_document.parser_config,
|
||||
is_root=False)
|
||||
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
|
||||
db_document.progress = 0.8
|
||||
db_document.progress_msg = progress_msg
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# 2. Document vectorization and storage
|
||||
total_chunks = len(res)
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
|
||||
batch_size = 100
|
||||
total_batches = ceil(total_chunks / batch_size)
|
||||
progress_per_batch = 0.2 / total_batches # Progress of each batch
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# 2.1 Delete document vector index
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
for batch_start in range(0, total_chunks, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
|
||||
batch = res[batch_start: batch_end] # Retrieve the current batch
|
||||
chunks = []
|
||||
|
||||
# Process the current batch
|
||||
for idx_in_batch, item in enumerate(batch):
|
||||
global_idx = batch_start + idx_in_batch # Calculate global index
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
}
|
||||
if db_document.parser_config.get("auto_questions", 0):
|
||||
topn = db_document.parser_config["auto_questions"]
|
||||
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
cached = question_proposal(chat_model, item["content_with_weight"], topn)
|
||||
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn})
|
||||
chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata))
|
||||
else:
|
||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||
|
||||
# Bulk segmented vector import
|
||||
vector_service.add_chunks(chunks)
|
||||
|
||||
# Update progress
|
||||
db_document.progress += progress_per_batch
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
|
||||
with get_db_context() as db:
|
||||
db_document = None
|
||||
db_knowledge = None
|
||||
progress_msg = f"{datetime.now().strftime('%H:%M:%S')} Task has been received.\n"
|
||||
try:
|
||||
db_document = db.query(Document).filter(Document.id == document_id).first()
|
||||
db_knowledge = db.query(Knowledge).filter(Knowledge.id == db_document.kb_id).first()
|
||||
# 1. Document parsing & segmentation
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Start to parse.\n"
|
||||
start_time = time.time()
|
||||
db_document.progress = 0.0
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.process_duration = time.time() - start_time
|
||||
db_document.run = 0
|
||||
db_document.process_begin_at = datetime.now(tz=timezone.utc)
|
||||
db_document.process_duration = 0.0
|
||||
db_document.run = 1
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# Vectorization and data entry completed
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
|
||||
db_document.chunk_num = total_chunks
|
||||
db_document.progress = 1.0
|
||||
db_document.process_duration = time.time() - start_time
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
result = f"parse document '{db_document.file_name}' processed successfully."
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'db_document' in locals():
|
||||
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
||||
def progress_callback(prog=None, msg=None):
|
||||
nonlocal progress_msg # Declare the use of an external progress_msg variable
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} parse progress: {prog} msg: {msg}.\n"
|
||||
# Prepare to configure chat_mdl、vision_model information
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
vision_model = QWenCV(
|
||||
key=db_knowledge.image2text.api_keys[0].api_key,
|
||||
model_name=db_knowledge.image2text.api_keys[0].model_name,
|
||||
lang="Chinese",
|
||||
base_url=db_knowledge.image2text.api_keys[0].api_base
|
||||
)
|
||||
from app.core.rag.app.naive import chunk
|
||||
res = chunk(filename=file_path,
|
||||
from_page=0,
|
||||
to_page=100000,
|
||||
callback=progress_callback,
|
||||
vision_model=vision_model,
|
||||
parser_config=db_document.parser_config,
|
||||
is_root=False)
|
||||
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Finish parsing.\n"
|
||||
db_document.progress = 0.8
|
||||
db_document.progress_msg = progress_msg
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# 2. Document vectorization and storage
|
||||
total_chunks = len(res)
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Generate {total_chunks} chunks.\n"
|
||||
batch_size = 100
|
||||
total_batches = ceil(total_chunks / batch_size)
|
||||
progress_per_batch = 0.2 / total_batches # Progress of each batch
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# 2.1 Delete document vector index
|
||||
vector_service.delete_by_metadata_field(key="document_id", value=str(document_id))
|
||||
# 2.2 Vectorize and import batch documents
|
||||
for batch_start in range(0, total_chunks, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_chunks) # prevent out-of-bounds
|
||||
batch = res[batch_start: batch_end] # Retrieve the current batch
|
||||
chunks = []
|
||||
|
||||
# Process the current batch
|
||||
for idx_in_batch, item in enumerate(batch):
|
||||
global_idx = batch_start + idx_in_batch # Calculate global index
|
||||
metadata = {
|
||||
"doc_id": uuid.uuid4().hex,
|
||||
"file_id": str(db_document.file_id),
|
||||
"file_name": db_document.file_name,
|
||||
"file_created_at": int(db_document.created_at.timestamp() * 1000),
|
||||
"document_id": str(db_document.id),
|
||||
"knowledge_id": str(db_document.kb_id),
|
||||
"sort_id": global_idx,
|
||||
"status": 1,
|
||||
}
|
||||
if db_document.parser_config.get("auto_questions", 0):
|
||||
topn = db_document.parser_config["auto_questions"]
|
||||
cached = get_llm_cache(chat_model.model_name, item["content_with_weight"], "question", {"topn": topn})
|
||||
if not cached:
|
||||
cached = question_proposal(chat_model, item["content_with_weight"], topn)
|
||||
set_llm_cache(chat_model.model_name, item["content_with_weight"], cached, "question", {"topn": topn})
|
||||
chunks.append(DocumentChunk(page_content=f"question: {cached} answer: {item['content_with_weight']}", metadata=metadata))
|
||||
else:
|
||||
chunks.append(DocumentChunk(page_content=item["content_with_weight"], metadata=metadata))
|
||||
|
||||
# Bulk segmented vector import
|
||||
vector_service.add_chunks(chunks)
|
||||
|
||||
# Update progress
|
||||
db_document.progress += progress_per_batch
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Embedding progress ({db_document.progress}).\n"
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.process_duration = time.time() - start_time
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
db.refresh(db_document)
|
||||
|
||||
# Vectorization and data entry completed
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Indexing done.\n"
|
||||
db_document.chunk_num = total_chunks
|
||||
db_document.progress = 1.0
|
||||
db_document.process_duration = time.time() - start_time
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Task done ({db_document.process_duration}s).\n"
|
||||
db_document.progress_msg = progress_msg
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
result = f"parse document '{db_document.file_name}' failed."
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
result = f"parse document '{db_document.file_name}' processed successfully."
|
||||
return result
|
||||
except Exception as e:
|
||||
if 'db_document' in locals():
|
||||
db_document.progress_msg += f"Failed to vectorize and import the parsed document:{str(e)}\n"
|
||||
db_document.run = 0
|
||||
db.commit()
|
||||
result = f"parse document '{db_document.file_name}' failed."
|
||||
return result
|
||||
|
||||
|
||||
@celery_app.task(name="app.core.memory.agent.read_message", bind=True)
|
||||
@@ -435,75 +433,75 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
from app.repositories.memory_increment_repository import write_memory_increment
|
||||
from app.services.memory_storage_service import search_all
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
workspace_uuid = uuid.UUID(workspace_id)
|
||||
|
||||
# 1. 查询当前workspace下的所有app
|
||||
apps = db.query(App).filter(App.workspace_id == workspace_uuid).all()
|
||||
|
||||
if not apps:
|
||||
# 如果没有app,总量为0
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
workspace_uuid = uuid.UUID(workspace_id)
|
||||
|
||||
# 1. 查询当前workspace下的所有app
|
||||
apps = db.query(App).filter(App.workspace_id == workspace_uuid).all()
|
||||
|
||||
if not apps:
|
||||
# 如果没有app,总量为0
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_uuid,
|
||||
total_num=0
|
||||
)
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"workspace_id": workspace_id,
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
}
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
total_num = 0
|
||||
end_user_details = []
|
||||
|
||||
for (end_user_id,) in end_users:
|
||||
try:
|
||||
# 调用 search_all 接口查询该宿主的总量
|
||||
result = await search_all(str(end_user_id))
|
||||
user_total = result.get("total", 0)
|
||||
total_num += user_total
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": user_total
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 4. 写入数据库
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_uuid,
|
||||
total_num=0
|
||||
total_num=total_num
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"workspace_id": workspace_id,
|
||||
"total_num": 0,
|
||||
"end_user_count": 0,
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"end_user_details": end_user_details,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
}
|
||||
|
||||
# 2. 查询所有app下的end_user_id(去重)
|
||||
app_ids = [app.id for app in apps]
|
||||
end_users = db.query(EndUser.id).filter(
|
||||
EndUser.app_id.in_(app_ids)
|
||||
).distinct().all()
|
||||
|
||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||
total_num = 0
|
||||
end_user_details = []
|
||||
|
||||
for (end_user_id,) in end_users:
|
||||
try:
|
||||
# 调用 search_all 接口查询该宿主的总量
|
||||
result = await search_all(str(end_user_id))
|
||||
user_total = result.get("total", 0)
|
||||
total_num += user_total
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": user_total
|
||||
})
|
||||
except Exception as e:
|
||||
# 记录单个用户查询失败,但继续处理其他用户
|
||||
end_user_details.append({
|
||||
"end_user_id": str(end_user_id),
|
||||
"total": 0,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
# 4. 写入数据库
|
||||
memory_increment = write_memory_increment(
|
||||
db=db,
|
||||
workspace_id=workspace_uuid,
|
||||
total_num=total_num
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"workspace_id": workspace_id,
|
||||
"total_num": total_num,
|
||||
"end_user_count": len(end_users),
|
||||
"end_user_details": end_user_details,
|
||||
"memory_increment_id": str(memory_increment.id),
|
||||
"created_at": memory_increment.created_at.isoformat(),
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
result = asyncio.run(_run())
|
||||
@@ -520,6 +518,198 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.regenerate_memory_cache", bind=True)
|
||||
def regenerate_memory_cache(self) -> Dict[str, Any]:
|
||||
"""定时任务:为所有用户重新生成记忆洞察和用户摘要缓存
|
||||
|
||||
遍历所有活动工作空间的所有终端用户,为每个用户重新生成记忆洞察和用户摘要。
|
||||
实现错误隔离,单个用户失败不影响其他用户的处理。
|
||||
|
||||
Returns:
|
||||
包含任务执行结果的字典,包括:
|
||||
- status: 任务状态 (SUCCESS/FAILURE)
|
||||
- message: 执行消息
|
||||
- workspace_count: 处理的工作空间数量
|
||||
- total_users: 总用户数
|
||||
- successful: 成功生成的用户数
|
||||
- failed: 失败的用户数
|
||||
- workspace_results: 每个工作空间的详细结果
|
||||
- elapsed_time: 执行耗时(秒)
|
||||
- task_id: 任务ID
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
async def _run() -> Dict[str, Any]:
|
||||
from app.core.logging_config import get_logger
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.services.user_memory_service import UserMemoryService
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger.info("开始执行记忆缓存重新生成定时任务")
|
||||
|
||||
service = UserMemoryService()
|
||||
|
||||
total_users = 0
|
||||
successful = 0
|
||||
failed = 0
|
||||
workspace_results = []
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有活动工作空间
|
||||
repo = EndUserRepository(db)
|
||||
workspaces = repo.get_all_active_workspaces()
|
||||
logger.info(f"找到 {len(workspaces)} 个活动工作空间")
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace_id in workspaces:
|
||||
logger.info(f"开始处理工作空间: {workspace_id}")
|
||||
workspace_start_time = time.time()
|
||||
|
||||
try:
|
||||
# 获取工作空间的所有终端用户
|
||||
end_users = repo.get_all_by_workspace(workspace_id)
|
||||
workspace_user_count = len(end_users)
|
||||
total_users += workspace_user_count
|
||||
|
||||
logger.info(f"工作空间 {workspace_id} 有 {workspace_user_count} 个终端用户")
|
||||
|
||||
workspace_successful = 0
|
||||
workspace_failed = 0
|
||||
workspace_errors = []
|
||||
|
||||
# 遍历每个用户并生成缓存
|
||||
for end_user in end_users:
|
||||
end_user_id = str(end_user.id)
|
||||
|
||||
try:
|
||||
# 生成记忆洞察
|
||||
insight_result = await service.generate_and_cache_insight(db, end_user_id)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await service.generate_and_cache_summary(db, end_user_id)
|
||||
|
||||
# 检查是否都成功
|
||||
if insight_result["success"] and summary_result["success"]:
|
||||
workspace_successful += 1
|
||||
successful += 1
|
||||
logger.info(f"成功为终端用户 {end_user_id} 重新生成缓存")
|
||||
else:
|
||||
workspace_failed += 1
|
||||
failed += 1
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"insight_error": insight_result.get("error"),
|
||||
"summary_error": summary_result.get("error")
|
||||
}
|
||||
workspace_errors.append(error_info)
|
||||
logger.warning(f"终端用户 {end_user_id} 的缓存重新生成部分失败: {error_info}")
|
||||
|
||||
except Exception as e:
|
||||
# 单个用户失败不影响其他用户(错误隔离)
|
||||
workspace_failed += 1
|
||||
failed += 1
|
||||
error_info = {
|
||||
"end_user_id": end_user_id,
|
||||
"error": str(e)
|
||||
}
|
||||
workspace_errors.append(error_info)
|
||||
logger.error(f"为终端用户 {end_user_id} 重新生成缓存时出错: {str(e)}")
|
||||
|
||||
workspace_elapsed = time.time() - workspace_start_time
|
||||
|
||||
# 记录工作空间处理结果
|
||||
workspace_result = {
|
||||
"workspace_id": str(workspace_id),
|
||||
"total_users": workspace_user_count,
|
||||
"successful": workspace_successful,
|
||||
"failed": workspace_failed,
|
||||
"errors": workspace_errors[:10], # 只保留前10个错误
|
||||
"elapsed_time": workspace_elapsed
|
||||
}
|
||||
workspace_results.append(workspace_result)
|
||||
|
||||
logger.info(
|
||||
f"工作空间 {workspace_id} 处理完成: "
|
||||
f"总数={workspace_user_count}, 成功={workspace_successful}, "
|
||||
f"失败={workspace_failed}, 耗时={workspace_elapsed:.2f}秒"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# 工作空间处理失败,记录错误并继续处理下一个
|
||||
logger.error(f"处理工作空间 {workspace_id} 时出错: {str(e)}")
|
||||
workspace_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
"total_users": 0,
|
||||
"successful": 0,
|
||||
"failed": 0,
|
||||
"errors": []
|
||||
})
|
||||
|
||||
# 记录总体统计信息
|
||||
logger.info(
|
||||
f"记忆缓存重新生成定时任务完成: "
|
||||
f"工作空间数={len(workspaces)}, 总用户数={total_users}, "
|
||||
f"成功={successful}, 失败={failed}"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {successful}/{total_users} 个用户缓存重新生成成功",
|
||||
"workspace_count": len(workspaces),
|
||||
"total_users": total_users,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"workspace_results": workspace_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"记忆缓存重新生成定时任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": len(workspace_results),
|
||||
"total_users": total_users,
|
||||
"successful": successful,
|
||||
"failed": failed,
|
||||
"workspace_results": workspace_results
|
||||
}
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
import nest_asyncio
|
||||
nest_asyncio.apply()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# 尝试获取现有事件循环,如果不存在则创建新的
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_closed():
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
result = loop.run_until_complete(_run())
|
||||
elapsed_time = time.time() - start_time
|
||||
result["elapsed_time"] = elapsed_time
|
||||
result["task_id"] = self.request.id
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed_time = time.time() - start_time
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"task_id": self.request.id
|
||||
}
|
||||
|
||||
|
||||
@celery_app.task(name="app.tasks.workspace_reflection_task", bind=True)
|
||||
def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
"""定时任务:每30秒运行工作空间反思功能
|
||||
@@ -538,100 +728,98 @@ def workspace_reflection_task(self) -> Dict[str, Any]:
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
db = next(get_db())
|
||||
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
workspaces = db.query(Workspace).all()
|
||||
|
||||
try:
|
||||
# 获取所有工作空间
|
||||
workspaces = db.query(Workspace).all()
|
||||
if not workspaces:
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到工作空间",
|
||||
"workspace_count": 0,
|
||||
"reflection_results": []
|
||||
}
|
||||
|
||||
all_reflection_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
# 使用服务类处理复杂查询逻辑
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(str(workspace_id))
|
||||
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
workspace_reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"reflection_count": len(workspace_reflection_results),
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
"reflection_count": 0,
|
||||
"reflection_results": []
|
||||
})
|
||||
|
||||
total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results)
|
||||
|
||||
if not workspaces:
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": "没有找到工作空间",
|
||||
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务",
|
||||
"workspace_count": len(workspaces),
|
||||
"total_reflections": total_reflections,
|
||||
"workspace_results": all_reflection_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"reflection_results": []
|
||||
}
|
||||
|
||||
all_reflection_results = []
|
||||
|
||||
# 遍历每个工作空间
|
||||
for workspace in workspaces:
|
||||
workspace_id = workspace.id
|
||||
api_logger.info(f"开始处理工作空间反思,workspace_id: {workspace_id}")
|
||||
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
# 使用服务类处理复杂查询逻辑
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(str(workspace_id))
|
||||
|
||||
workspace_reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['data_configs'] == []:
|
||||
continue
|
||||
|
||||
releases = data['releases']
|
||||
data_configs = data['data_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
workspace_reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"reflection_count": len(workspace_reflection_results),
|
||||
"reflection_results": workspace_reflection_results
|
||||
})
|
||||
|
||||
api_logger.info(
|
||||
f"工作空间 {workspace_id} 反思处理完成,处理了 {len(workspace_reflection_results)} 个任务")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"处理工作空间 {workspace_id} 反思失败: {str(e)}")
|
||||
all_reflection_results.append({
|
||||
"workspace_id": str(workspace_id),
|
||||
"error": str(e),
|
||||
"reflection_count": 0,
|
||||
"reflection_results": []
|
||||
})
|
||||
|
||||
total_reflections = sum(r.get("reflection_count", 0) for r in all_reflection_results)
|
||||
|
||||
return {
|
||||
"status": "SUCCESS",
|
||||
"message": f"成功处理 {len(workspaces)} 个工作空间,总共 {total_reflections} 个反思任务",
|
||||
"workspace_count": len(workspaces),
|
||||
"total_reflections": total_reflections,
|
||||
"workspace_results": all_reflection_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"工作空间反思任务执行失败: {str(e)}")
|
||||
return {
|
||||
"status": "FAILURE",
|
||||
"error": str(e),
|
||||
"workspace_count": 0,
|
||||
"reflection_results": []
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
try:
|
||||
# 使用 nest_asyncio 来避免事件循环冲突
|
||||
try:
|
||||
|
||||
@@ -1,22 +1,71 @@
|
||||
version: '3.8'
|
||||
version: '3.9'
|
||||
|
||||
services:
|
||||
# MCP Server - standalone service
|
||||
mcp-server:
|
||||
image: redbear-mem:latest
|
||||
container_name: mcp-server
|
||||
ports:
|
||||
- "8081:8081" # MCP server port
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- SERVER_IP=0.0.0.0 # Bind to all interfaces
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: python -m app.core.memory.agent.mcp_server.server
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8081/sse')"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
start_period: 30s
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- default
|
||||
- celery
|
||||
|
||||
# FastAPI application - connects to MCP server
|
||||
api:
|
||||
image: redbear-mem:latest
|
||||
container_name: api
|
||||
ports:
|
||||
- "8000:8000"
|
||||
- "8002:8000"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- MCP_SERVER_URL=http://mcp-server:8081
|
||||
- SERVER_IP=0.0.0.0 # Ensure MCP server binds to all interfaces
|
||||
volumes:
|
||||
- ./files:/files
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-level debug
|
||||
depends_on:
|
||||
mcp-server:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- default
|
||||
- celery
|
||||
|
||||
# Celery worker - connects to MCP server
|
||||
worker:
|
||||
image: redbear-mem:latest
|
||||
container_name: worker
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
- MCP_SERVER_URL=http://mcp-server:8081
|
||||
volumes:
|
||||
- ./files:/files
|
||||
command: celery -A app.celery_worker.celery_app worker --loglevel=info
|
||||
- /etc/localtime:/etc/localtime:ro
|
||||
command: celery -A app.celery_worker.celery_app worker --loglevel=info
|
||||
depends_on:
|
||||
mcp-server:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
networks:
|
||||
- celery
|
||||
networks:
|
||||
celery:
|
||||
@@ -30,6 +30,11 @@ RESULT_BACKEND=
|
||||
CELERY_BROKER=
|
||||
CELERY_BACKEND=
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
# Interval in hours for regenerating memory insight and user summary cache
|
||||
# Default: 24 hours
|
||||
MEMORY_CACHE_REGENERATION_HOURS=24
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST=
|
||||
ELASTICSEARCH_PORT=
|
||||
|
||||
50
api/migrations/versions/f3d893ccb866_202512231644.py
Normal file
50
api/migrations/versions/f3d893ccb866_202512231644.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""202512231644
|
||||
|
||||
Revision ID: f3d893ccb866
|
||||
Revises: 022550fdcfda
|
||||
Create Date: 2025-12-23 16:47:30.897690
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'f3d893ccb866'
|
||||
down_revision: Union[str, None] = '022550fdcfda'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('end_users', sa.Column('name', sa.String(), nullable=True, comment='姓名'))
|
||||
op.add_column('end_users', sa.Column('position', sa.String(), nullable=True, comment='职位'))
|
||||
op.add_column('end_users', sa.Column('department', sa.String(), nullable=True, comment='部门'))
|
||||
op.add_column('end_users', sa.Column('contact', sa.String(), nullable=True, comment='联系方式'))
|
||||
op.add_column('end_users', sa.Column('phone', sa.String(), nullable=True, comment='电话'))
|
||||
op.add_column('end_users', sa.Column('hire_date', sa.BigInteger(), nullable=True, comment='入职日期(时间戳,毫秒)'))
|
||||
op.add_column('end_users', sa.Column('updatetime_profile', sa.BigInteger(), nullable=True, comment='核心档案信息最后更新时间(时间戳,毫秒)'))
|
||||
op.add_column('end_users', sa.Column('memory_insight', sa.Text(), nullable=True, comment='缓存的记忆洞察报告'))
|
||||
op.add_column('end_users', sa.Column('user_summary', sa.Text(), nullable=True, comment='缓存的用户摘要'))
|
||||
op.add_column('end_users', sa.Column('memory_insight_updated_at', sa.DateTime(), nullable=True, comment='洞察报告最后更新时间'))
|
||||
op.add_column('end_users', sa.Column('user_summary_updated_at', sa.DateTime(), nullable=True, comment='用户摘要最后更新时间'))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('end_users', 'user_summary_updated_at')
|
||||
op.drop_column('end_users', 'memory_insight_updated_at')
|
||||
op.drop_column('end_users', 'user_summary')
|
||||
op.drop_column('end_users', 'memory_insight')
|
||||
op.drop_column('end_users', 'updatetime_profile')
|
||||
op.drop_column('end_users', 'hire_date')
|
||||
op.drop_column('end_users', 'phone')
|
||||
op.drop_column('end_users', 'contact')
|
||||
op.drop_column('end_users', 'department')
|
||||
op.drop_column('end_users', 'position')
|
||||
op.drop_column('end_users', 'name')
|
||||
# ### end Alembic commands ###
|
||||
@@ -127,6 +127,7 @@ dependencies = [
|
||||
"uvicorn>=0.34.0",
|
||||
"celery>=5.5.2",
|
||||
"simpleeval>=1.0.3",
|
||||
"langchain-aws>=1.0.0a1",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
|
||||
71
api/uv.lock
generated
71
api/uv.lock
generated
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = "==3.12.*"
|
||||
resolution-markers = [
|
||||
"sys_platform == 'darwin'",
|
||||
@@ -7,6 +7,9 @@ resolution-markers = [
|
||||
"(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')",
|
||||
]
|
||||
|
||||
[options]
|
||||
prerelease-mode = "allow"
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
version = "2.6.1"
|
||||
@@ -241,6 +244,34 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "boto3"
|
||||
version = "1.42.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
{ name = "jmespath" },
|
||||
{ name = "s3transfer" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/09/72/e236ca627bc0461710685f5b7438f759ef3b4106e0e08dda08513a6539ab/boto3-1.42.14.tar.gz", hash = "sha256:a5d005667b480c844ed3f814a59f199ce249d0f5669532a17d06200c0a93119c", size = 112825, upload-time = "2025-12-19T20:27:15.325Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bb/ba/c657ea6f6d63563cc46748202fccd097b51755d17add00ebe4ea27580d06/boto3-1.42.14-py3-none-any.whl", hash = "sha256:bfcc665227bb4432a235cb4adb47719438d6472e5ccbf7f09512046c3f749670", size = 140571, upload-time = "2025-12-19T20:27:13.316Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "botocore"
|
||||
version = "1.42.14"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jmespath" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "urllib3" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/35/3f/50c56f093c2c6ce6de1f579726598db1cf9a9cccd3bf8693f73b1cf5e319/botocore-1.42.14.tar.gz", hash = "sha256:cf5bebb580803c6cfd9886902ca24834b42ecaa808da14fb8cd35ad523c9f621", size = 14910547, upload-time = "2025-12-19T20:27:04.431Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/ad/94/67a78a8d08359e779894d4b1672658a3c7fcce216b48f06dfbe1de45521d/botocore-1.42.14-py3-none-any.whl", hash = "sha256:efe89adfafa00101390ec2c371d453b3359d5f9690261bc3bd70131e0d453e8e", size = 14583247, upload-time = "2025-12-19T20:27:00.54Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "6.2.1"
|
||||
@@ -1114,6 +1145,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2f/9c/6753e6522b8d0ef07d3a3d239426669e984fb0eba15a315cdbc1253904e4/jiter-0.12.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c24e864cb30ab82311c6425655b0cdab0a98c5d973b065c66a3f020740c2324c", size = 346110, upload-time = "2025-11-09T20:49:21.817Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "jmespath"
|
||||
version = "1.0.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/00/2a/e867e8531cf3e36b41201936b7fa7ba7b5702dbef42922193f05c8976cd6/jmespath-1.0.1.tar.gz", hash = "sha256:90261b206d6defd58fdd5e85f478bf633a2901798906be2ad389150c5c60edbe", size = 25843, upload-time = "2022-06-17T18:00:12.224Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256, upload-time = "2022-06-17T18:00:10.251Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "joblib"
|
||||
version = "1.5.2"
|
||||
@@ -1262,6 +1302,21 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f3/39/ed3121ea3a0c60a0cda6ea5c4c1cece013e8bbc9b18344ff3ae507728f98/langchain-1.1.3-py3-none-any.whl", hash = "sha256:e5b208ed93e553df4087117a40bd0d450f9095030a843cad35c53ff2814bf731", size = 102227, upload-time = "2025-12-08T19:31:47.246Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-aws"
|
||||
version = "1.0.0a1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "boto3" },
|
||||
{ name = "langchain-core" },
|
||||
{ name = "numpy" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/c0/c3/a98c0849c13c6880b5629409cadb22d4070e9c611013da127be975f8c0dc/langchain_aws-1.0.0a1.tar.gz", hash = "sha256:3bb193a5fa915520c52bb47581e892d11ac4d114939a1b3ecfeca56fe153fff7", size = 121650, upload-time = "2025-09-18T20:52:36.098Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/7b/be49a224fe3aa07ed869801356f06e1d7a321bb7f22b6f7935dce86d258a/langchain_aws-1.0.0a1-py3-none-any.whl", hash = "sha256:24207d05c619ea61dfeab0a0f7086ae388cc3f2f5c03a8ae56b12d1b77d72585", size = 146839, upload-time = "2025-09-18T20:52:35.013Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langchain-classic"
|
||||
version = "1.0.0"
|
||||
@@ -2825,6 +2880,7 @@ dependencies = [
|
||||
{ name = "json-repair" },
|
||||
{ name = "kombu" },
|
||||
{ name = "langchain" },
|
||||
{ name = "langchain-aws" },
|
||||
{ name = "langchain-community" },
|
||||
{ name = "langchain-mcp-adapters" },
|
||||
{ name = "langchain-ollama" },
|
||||
@@ -2949,6 +3005,7 @@ requires-dist = [
|
||||
{ name = "json-repair", specifier = "==0.53.0" },
|
||||
{ name = "kombu", specifier = "==5.5.4" },
|
||||
{ name = "langchain", specifier = ">=1.0.3" },
|
||||
{ name = "langchain-aws", specifier = ">=1.0.0a1" },
|
||||
{ name = "langchain-community", specifier = ">=0.3.31" },
|
||||
{ name = "langchain-mcp-adapters", specifier = ">=0.1.13" },
|
||||
{ name = "langchain-ollama" },
|
||||
@@ -3199,6 +3256,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b3/e5/8925a4208f131b218f9a7e459c0d6fcac8324ae35da269cb437894576366/ruamel_yaml_clib-0.2.15-cp312-cp312-win_amd64.whl", hash = "sha256:2b216904750889133d9222b7b873c199d48ecbb12912aca78970f84a5aa1a4bc", size = 119013, upload-time = "2025-11-16T16:13:32.164Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "s3transfer"
|
||||
version = "0.16.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "botocore" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scikit-learn"
|
||||
version = "1.7.2"
|
||||
|
||||
Reference in New Issue
Block a user