Compare commits
337 Commits
v0.2.6
...
feature/to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
035cc17264 | ||
|
|
cf26c9f39c | ||
|
|
fabc8936ab | ||
|
|
06de54ebfd | ||
|
|
7c6e48b04e | ||
|
|
fcc81ac025 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
19d149c129 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
e7e136036c | ||
|
|
ca84fc6c9d | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
32d612fbeb | ||
|
|
9ce3a881f3 | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
1635f9dbef | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
75e36173cd | ||
|
|
8097f227ca | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
fd1debe681 | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
39f3a85bb1 | ||
|
|
098a2e54ae | ||
|
|
d575478b53 | ||
|
|
aab54ca1a8 | ||
|
|
d4f2094ee0 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
01a1e8eab1 | ||
|
|
b71bc1f875 | ||
|
|
6a0ee22d81 | ||
|
|
cbc8714414 | ||
|
|
065f8db2f7 | ||
|
|
0ac7f83726 | ||
|
|
d03473da10 | ||
|
|
dac1c01a2c | ||
|
|
f6d929ab7a | ||
|
|
a7a2dabc5a | ||
|
|
83015a3404 | ||
|
|
b88e9c5f5e | ||
|
|
8380a8a811 | ||
|
|
6c69181290 | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
7b8f101824 | ||
|
|
0f36c5c872 | ||
|
|
6a67f028ce | ||
|
|
5d82786c20 | ||
|
|
e368f1c1d6 | ||
|
|
572ce7f9ec | ||
|
|
a4c942a21f | ||
|
|
4859ab3ba7 | ||
|
|
983b5f5087 | ||
|
|
75b87955dd | ||
|
|
110de0afbc | ||
|
|
2c074cd5c1 | ||
|
|
73e51a9b0b | ||
|
|
3a47039919 | ||
|
|
2961ea4e44 | ||
|
|
af2ffc9737 | ||
|
|
d7911244fc | ||
|
|
2a66775e45 | ||
|
|
6029a5a9a8 | ||
|
|
71d9ae15a1 | ||
|
|
f0c3d5f308 | ||
|
|
4706ea59fe | ||
|
|
5774a95f61 | ||
|
|
d660521c5c | ||
|
|
5db2c5092e | ||
|
|
59618457df | ||
|
|
c612dfbc1f | ||
|
|
fc58ac0408 | ||
|
|
8d053c97a7 | ||
|
|
a3e6f67ff7 | ||
|
|
01da2e3eee | ||
|
|
168cce1678 | ||
|
|
7240dfe793 | ||
|
|
b9340ba02d | ||
|
|
4f5ee24bc5 | ||
|
|
6a1b8d3ee3 | ||
|
|
f1207dc8b9 | ||
|
|
86c51559bb | ||
|
|
8b0f806079 | ||
|
|
99e94b3567 | ||
|
|
cfd5c1bc93 | ||
|
|
45d9e45346 | ||
|
|
fcb3845543 | ||
|
|
97eabc0c36 | ||
|
|
5328163973 | ||
|
|
7ff9dfee8c | ||
|
|
5b431400be | ||
|
|
1e1675ec12 | ||
|
|
f941541304 | ||
|
|
3f7083c5b3 | ||
|
|
e81faebf69 | ||
|
|
8a4d58c520 | ||
|
|
2ac29ee89c | ||
|
|
252cdcd6f5 | ||
|
|
16e2c95965 | ||
|
|
10560fb34c | ||
|
|
58aa60ca0e | ||
|
|
d24b186d3e | ||
|
|
b4e81615b1 | ||
|
|
424d2033ea | ||
|
|
fd556f9b00 | ||
|
|
e2f5fa87b1 | ||
|
|
e4a2bd3b9b | ||
|
|
e3ada17a78 | ||
|
|
3e5a7adfe4 | ||
|
|
3237f4cd6e | ||
|
|
beea826377 | ||
|
|
7cdbbefc64 | ||
|
|
18780622b3 | ||
|
|
f405ac4d84 | ||
|
|
9fe47e2fb2 | ||
|
|
e4aaa18f61 | ||
|
|
5c3d9717dd | ||
|
|
ac86bbd60c | ||
|
|
33d12c43b2 | ||
|
|
107c676185 | ||
|
|
0f221b7ee6 | ||
|
|
e1939ef472 | ||
|
|
5438d35f17 | ||
|
|
9c26d1f4c8 | ||
|
|
4c2b31f31f | ||
|
|
4f88a13256 | ||
|
|
21ae448ed7 | ||
|
|
50466124c8 | ||
|
|
ece88a3879 | ||
|
|
cedc4a92cc | ||
|
|
c8065b0c60 | ||
|
|
476632294f | ||
|
|
349d46e043 | ||
|
|
00e0201bf9 | ||
|
|
389dd8d402 | ||
|
|
966bd8528d | ||
|
|
8f789d47a2 | ||
|
|
94a40e49a0 | ||
|
|
8429279eea | ||
|
|
cef14cda9e | ||
|
|
c14f067afb | ||
|
|
2612abc9d0 | ||
|
|
d080b44ac3 | ||
|
|
c01ad5a19e | ||
|
|
f56bc0f85a | ||
|
|
9600d687fa | ||
|
|
41a0036bf6 | ||
|
|
08e4ad6a7c | ||
|
|
314e6e29d5 | ||
|
|
391cd602a2 | ||
|
|
247db844a4 | ||
|
|
5495d32822 | ||
|
|
a496991400 | ||
|
|
0ea83b4364 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -25,6 +25,8 @@ examples/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
redbear-mem-metrics/
|
||||
pitch-deck/
|
||||
|
||||
api/migrations/versions
|
||||
tmp
|
||||
|
||||
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||
apt install -y ghostscript
|
||||
apt install -y ghostscript && \
|
||||
apt install -y libmagic1
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
|
||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
||||
# Database connection URL - DO NOT hardcode credentials here!
|
||||
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||
# Example: postgresql://user:password@localhost:5432/dbname
|
||||
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import redis.asyncio as redis
|
||||
from redis.asyncio import ConnectionPool
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# 设置日志记录器
|
||||
|
||||
@@ -63,56 +63,60 @@ celery_app.conf.update(
|
||||
accept_content=['json'],
|
||||
result_serializer='json',
|
||||
|
||||
# 时区
|
||||
timezone='Asia/Shanghai',
|
||||
enable_utc=True,
|
||||
# # 时区
|
||||
# timezone='Asia/Shanghai',
|
||||
# enable_utc=False,
|
||||
|
||||
# 任务追踪
|
||||
task_track_started=True,
|
||||
task_ignore_result=False,
|
||||
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -129,7 +133,7 @@ implicit_emotions_update_schedule = crontab(
|
||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||
)
|
||||
|
||||
#构建定时任务配置
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
|
||||
@@ -13,4 +13,4 @@ logger.info("Celery worker logging initialized")
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
__all__ = ['celery_app']
|
||||
__all__ = ['celery_app']
|
||||
|
||||
@@ -13,9 +13,11 @@ from . import (
|
||||
document_controller,
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
end_user_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
i18n_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
knowledgeshare_controller,
|
||||
@@ -94,5 +96,7 @@ manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
manager_router.include_router(i18n_controller.router)
|
||||
manager_router.include_router(end_user_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import uuid
|
||||
import io
|
||||
from typing import Optional, Annotated
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.services.workflow_import_service import WorkflowImportService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_dsl_service import AppDslService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -50,6 +53,7 @@ def list_apps(
|
||||
status: str | None = None,
|
||||
search: str | None = None,
|
||||
include_shared: bool = True,
|
||||
shared_only: bool = False,
|
||||
page: int = 1,
|
||||
pagesize: int = 10,
|
||||
ids: Optional[str] = None,
|
||||
@@ -81,6 +85,7 @@ def list_apps(
|
||||
status=status,
|
||||
search=search,
|
||||
include_shared=include_shared,
|
||||
shared_only=shared_only,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
)
|
||||
@@ -90,6 +95,37 @@ def list_apps(
|
||||
return success(data=PageData(page=meta, items=items))
|
||||
|
||||
|
||||
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_my_shared_out(
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||
@cur_workspace_access_guard()
|
||||
def unshare_all_apps_to_workspace(
|
||||
target_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Cancel all app shares from current workspace to a target workspace."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.unshare_all_apps_to_workspace(
|
||||
target_workspace_id=target_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||
|
||||
|
||||
@router.get("/{app_id}", summary="获取应用详情")
|
||||
@cur_workspace_access_guard()
|
||||
def get_app(
|
||||
@@ -158,6 +194,7 @@ def delete_app(
|
||||
def copy_app(
|
||||
app_id: uuid.UUID,
|
||||
new_name: Optional[str] = None,
|
||||
payload: app_schema.CopyAppRequest = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
@@ -169,6 +206,8 @@ def copy_app(
|
||||
- 不影响原应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# body takes precedence over query param for backward compatibility
|
||||
new_name = (payload.new_name if payload else None) or new_name
|
||||
logger.info(
|
||||
"用户请求复制应用",
|
||||
extra={
|
||||
@@ -218,6 +257,27 @@ def get_agent_config(
|
||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||
|
||||
|
||||
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||
@cur_workspace_access_guard()
|
||||
def get_opening(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||
features = cfg.features or {}
|
||||
if hasattr(features, "model_dump"):
|
||||
features = features.model_dump()
|
||||
opening = features.get("opening_statement", {})
|
||||
return success(data=app_schema.OpeningResponse(
|
||||
enabled=opening.get("enabled", False),
|
||||
statement=opening.get("statement"),
|
||||
suggested_questions=opening.get("suggested_questions", []),
|
||||
))
|
||||
|
||||
|
||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||
@cur_workspace_access_guard()
|
||||
def publish_app(
|
||||
@@ -299,7 +359,8 @@ def share_app(
|
||||
app_id=app_id,
|
||||
target_workspace_ids=payload.target_workspace_ids,
|
||||
user_id=current_user.id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
permission=payload.permission
|
||||
)
|
||||
|
||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||
@@ -330,6 +391,32 @@ def unshare_app(
|
||||
return success(msg="应用分享已取消")
|
||||
|
||||
|
||||
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||
@cur_workspace_access_guard()
|
||||
def update_share_permission(
|
||||
app_id: uuid.UUID,
|
||||
target_workspace_id: uuid.UUID,
|
||||
payload: app_schema.UpdateSharePermissionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""更新共享权限(readonly <-> editable)
|
||||
|
||||
- 只能修改自己工作空间应用的共享权限
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
share = service.update_share_permission(
|
||||
app_id=app_id,
|
||||
target_workspace_id=target_workspace_id,
|
||||
permission=payload.permission,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(data=app_schema.AppShare.model_validate(share))
|
||||
|
||||
|
||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||
@cur_workspace_access_guard()
|
||||
def list_app_shares(
|
||||
@@ -353,6 +440,46 @@ def list_app_shares(
|
||||
return success(data=data)
|
||||
|
||||
|
||||
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
service = app_service.AppService(db)
|
||||
count = service.remove_all_shared_apps_from_workspace(
|
||||
source_workspace_id=source_workspace_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||
|
||||
|
||||
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||
@cur_workspace_access_guard()
|
||||
def remove_shared_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""被共享者从自己的工作空间移除共享应用
|
||||
|
||||
- 不会删除源应用,只删除共享记录
|
||||
- 只能移除共享给自己工作空间的应用
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
service = app_service.AppService(db)
|
||||
service.remove_shared_app(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return success(msg="已移除共享应用")
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||
@cur_workspace_access_guard()
|
||||
async def draft_run(
|
||||
@@ -393,7 +520,7 @@ async def draft_run(
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
from app.services.app_service import AppService
|
||||
from app.services.multi_agent_service import MultiAgentService
|
||||
from app.models import AgentConfig, ModelConfig
|
||||
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||
from sqlalchemy import select
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.services.draft_run_service import AgentRunService
|
||||
@@ -410,11 +537,12 @@ async def draft_run(
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
@@ -431,18 +559,29 @@ async def draft_run(
|
||||
service._check_agent_config(app_id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
agent_cfg = service._agent_config_from_release(release)
|
||||
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||
else:
|
||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||
agent_cfg = db.scalars(stmt).first()
|
||||
if not agent_cfg:
|
||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
# 3. 获取模型配置
|
||||
model_config = None
|
||||
if agent_cfg.default_model_config_id:
|
||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||
if not model_config:
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
@@ -598,7 +737,17 @@ async def draft_run(
|
||||
msg="多 Agent 任务执行成功"
|
||||
)
|
||||
elif app.type == AppType.WORKFLOW: # 工作流
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||
is_shared = app.workspace_id != workspace_id
|
||||
if is_shared:
|
||||
if not app.current_release_id:
|
||||
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||
release = db.get(AppRelease, app.current_release_id)
|
||||
if not release:
|
||||
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||
config = service._workflow_config_from_release(release)
|
||||
else:
|
||||
config = workflow_service.check_config(app_id)
|
||||
# 3. 流式返回
|
||||
if payload.stream:
|
||||
logger.debug(
|
||||
@@ -741,6 +890,16 @@ async def draft_run_compare(
|
||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
service._validate_app_accessible(app, workspace_id)
|
||||
|
||||
if payload.user_id is None:
|
||||
# 先获取 app 的 workspace_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=str(current_user.id),
|
||||
)
|
||||
payload.user_id = str(new_end_user.id)
|
||||
|
||||
# 2. 获取 Agent 配置
|
||||
from sqlalchemy import select
|
||||
from app.models import AgentConfig
|
||||
@@ -786,6 +945,13 @@ async def draft_run_compare(
|
||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||
})
|
||||
|
||||
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||
features_config: dict = agent_cfg.features or {}
|
||||
if hasattr(features_config, 'model_dump'):
|
||||
features_config = features_config.model_dump()
|
||||
web_search_feature = features_config.get("web_search", {})
|
||||
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||
|
||||
# 流式返回
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
@@ -797,11 +963,11 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -828,11 +994,11 @@ async def draft_run_compare(
|
||||
message=payload.message,
|
||||
workspace_id=workspace_id,
|
||||
conversation_id=payload.conversation_id,
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
user_id=payload.user_id,
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
web_search=True,
|
||||
web_search=web_search,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60,
|
||||
@@ -1010,3 +1176,57 @@ def get_workspace_api_statistics(
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||
@cur_workspace_access_guard()
|
||||
async def export_app(
|
||||
app_id: uuid.UUID,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
release_id: Optional[uuid.UUID] = None
|
||||
):
|
||||
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||
"""
|
||||
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||
encoded = quote(filename, safe=".")
|
||||
yaml_bytes = yaml_str.encode("utf-8")
|
||||
file_stream = io.BytesIO(yaml_bytes)
|
||||
file_stream.seek(0)
|
||||
return StreamingResponse(
|
||||
file_stream,
|
||||
media_type="application/octet-stream; charset=utf-8",
|
||||
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||
"Content-Length": str(len(yaml_bytes))}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||
@cur_workspace_access_guard()
|
||||
async def import_app(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||
"""
|
||||
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||
|
||||
raw = (await file.read()).decode("utf-8")
|
||||
dsl = yaml.safe_load(raw)
|
||||
if not dsl or "app" not in dsl:
|
||||
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||
|
||||
new_app, warnings = AppDslService(db).import_dsl(
|
||||
dsl=dsl,
|
||||
workspace_id=current_user.current_workspace_id,
|
||||
tenant_id=current_user.tenant_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
return success(
|
||||
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Callable
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
from app.dependencies import get_current_user, oauth2_scheme
|
||||
from app.models.user_model import User
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取专用日志器
|
||||
auth_logger = get_auth_logger()
|
||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
||||
@router.post("/token", response_model=ApiResponse)
|
||||
async def login_for_access_token(
|
||||
form_data: TokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""用户登录获取token"""
|
||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||
@@ -40,10 +43,10 @@ async def login_for_access_token(
|
||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||
|
||||
if not invite_info.is_valid:
|
||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||
|
||||
if invite_info.email != form_data.email:
|
||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
||||
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||
try:
|
||||
# 尝试认证用户
|
||||
@@ -69,7 +72,7 @@ async def login_for_access_token(
|
||||
elif e.code == BizCode.PASSWORD_ERROR:
|
||||
# 用户存在但密码错误
|
||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||
else:
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise
|
||||
@@ -82,7 +85,7 @@ async def login_for_access_token(
|
||||
except BusinessException as e:
|
||||
|
||||
# 其他认证失败情况,直接抛出
|
||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
||||
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||
|
||||
# 创建 tokens
|
||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -110,14 +113,15 @@ async def login_for_access_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="登录成功"
|
||||
msg=t("auth.login.success")
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=ApiResponse)
|
||||
async def refresh_token(
|
||||
refresh_request: RefreshTokenRequest,
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""刷新token"""
|
||||
auth_logger.info("收到token刷新请求")
|
||||
@@ -125,18 +129,18 @@ async def refresh_token(
|
||||
# 验证 refresh token
|
||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||
if not userId:
|
||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 检查用户是否存在
|
||||
user = auth_service.get_user_by_id(db, userId)
|
||||
if not user:
|
||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
||||
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||
|
||||
# 检查 refresh token 黑名单
|
||||
if settings.ENABLE_SINGLE_SESSION:
|
||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
||||
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||
|
||||
# 生成新 tokens
|
||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||
@@ -167,7 +171,7 @@ async def refresh_token(
|
||||
expires_at=access_expires_at,
|
||||
refresh_expires_at=refresh_expires_at
|
||||
),
|
||||
msg="token刷新成功"
|
||||
msg=t("auth.token.refresh_success")
|
||||
)
|
||||
|
||||
|
||||
@@ -175,14 +179,15 @@ async def refresh_token(
|
||||
async def logout(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db)
|
||||
db: Session = Depends(get_db),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""登出当前用户:加入token黑名单并清理会话"""
|
||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||
|
||||
token_id = security.get_token_id(token)
|
||||
if not token_id:
|
||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
||||
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||
|
||||
# 加入黑名单
|
||||
await SessionService.blacklist_token(token_id)
|
||||
@@ -192,5 +197,5 @@ async def logout(
|
||||
await SessionService.clear_user_session(current_user.username)
|
||||
|
||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||
return success(msg="登出成功")
|
||||
return success(msg=t("auth.logout.success"))
|
||||
|
||||
|
||||
48
api/app/controllers/end_user_controller.py
Normal file
48
api/app/controllers/end_user_controller.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""End User 管理接口 - 无需认证"""
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
from app.schemas.memory_api_schema import (
|
||||
CreateEndUserRequest,
|
||||
CreateEndUserResponse,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def create_end_user(
|
||||
data: CreateEndUserRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Create an end user.
|
||||
|
||||
Creates a new end user for the given workspace.
|
||||
If an end user with the same other_id already exists in the workspace,
|
||||
returns the existing one.
|
||||
"""
|
||||
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||
|
||||
end_user_repo = EndUserRepository(db)
|
||||
end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=None,
|
||||
workspace_id=data.workspace_id,
|
||||
other_id=data.other_id,
|
||||
)
|
||||
|
||||
logger.info(f"End user ready: {end_user.id}")
|
||||
|
||||
result = {
|
||||
"id": str(end_user.id),
|
||||
"other_id": end_user.other_id or "",
|
||||
"other_name": end_user.other_name or "",
|
||||
"workspace_id": str(end_user.workspace_id),
|
||||
}
|
||||
|
||||
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||
@@ -15,7 +15,7 @@ import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -47,6 +47,19 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
def _match_scheme(request: Request, url: str) -> str:
|
||||
"""
|
||||
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||
"""
|
||||
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||
if url.startswith("http://") and incoming_scheme == "https":
|
||||
return "https://" + url[7:]
|
||||
if url.startswith("https://") and incoming_scheme == "http":
|
||||
return "http://" + url[8:]
|
||||
return url
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
@@ -280,6 +293,7 @@ async def upload_file_with_share_token(
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
@@ -327,6 +341,7 @@ async def download_file(
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
@@ -400,6 +415,7 @@ async def delete_file(
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
@@ -463,6 +479,7 @@ async def get_file_url(
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
url = _match_scheme(request, url)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
@@ -484,6 +501,7 @@ async def get_file_url(
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
@@ -555,6 +573,7 @@ async def public_download_file(
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
@@ -566,6 +585,7 @@ async def public_download_file(
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
request: Request,
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
@@ -625,6 +645,7 @@ async def permanent_download_file(
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
presigned_url = _match_scheme(request, presigned_url)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
|
||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
||||
"""
|
||||
I18n Management API Controller
|
||||
|
||||
This module provides management APIs for:
|
||||
- Language management (list, get, add, update languages)
|
||||
- Translation management (get, update, reload translations)
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Callable, Optional
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, get_current_superuser
|
||||
from app.i18n.dependencies import get_translator
|
||||
from app.i18n.service import get_translation_service
|
||||
from app.models.user_model import User
|
||||
from app.schemas.i18n_schema import (
|
||||
LanguageInfo,
|
||||
LanguageListResponse,
|
||||
LanguageCreateRequest,
|
||||
LanguageUpdateRequest,
|
||||
TranslationResponse,
|
||||
TranslationUpdateRequest,
|
||||
MissingTranslationsResponse,
|
||||
ReloadResponse
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/i18n",
|
||||
tags=["I18n Management"],
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Language Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/languages", response_model=ApiResponse)
|
||||
def get_languages(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of all supported languages.
|
||||
|
||||
Returns:
|
||||
List of language information including code, name, and status
|
||||
"""
|
||||
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Get available locales from translation service
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Build language info list
|
||||
languages = []
|
||||
for locale in available_locales:
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
# Get native names
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
languages.append(language_info)
|
||||
|
||||
response = LanguageListResponse(languages=languages)
|
||||
|
||||
api_logger.info(f"Returning {len(languages)} languages")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||
def get_language(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get information about a specific language.
|
||||
|
||||
Args:
|
||||
locale: Language code (e.g., 'zh', 'en')
|
||||
|
||||
Returns:
|
||||
Language information
|
||||
"""
|
||||
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Build language info
|
||||
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||
|
||||
native_names = {
|
||||
"zh": "中文(简体)",
|
||||
"en": "English",
|
||||
"ja": "日本語",
|
||||
"ko": "한국어",
|
||||
"fr": "Français",
|
||||
"de": "Deutsch",
|
||||
"es": "Español"
|
||||
}
|
||||
|
||||
language_info = LanguageInfo(
|
||||
code=locale,
|
||||
name=f"{locale.upper()}",
|
||||
native_name=native_names.get(locale, locale),
|
||||
is_enabled=is_enabled,
|
||||
is_default=is_default
|
||||
)
|
||||
|
||||
api_logger.info(f"Returning language info for: {locale}")
|
||||
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/languages", response_model=ApiResponse)
|
||||
def add_language(
|
||||
request: LanguageCreateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Add a new language (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual language addition
|
||||
requires creating translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
request: Language creation request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language already exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if request.code in available_locales:
|
||||
api_logger.warning(f"Language already exists: {request.code}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.language.already_exists", locale=request.code)
|
||||
)
|
||||
|
||||
# Note: Actual language addition requires creating translation files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language addition validated: {request.code}. "
|
||||
"Translation files need to be created manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t(
|
||||
"i18n.language.add_instructions",
|
||||
locale=request.code,
|
||||
dir=settings.I18N_CORE_LOCALES_DIR
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||
def update_language(
|
||||
locale: str,
|
||||
request: LanguageUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update language configuration (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual configuration
|
||||
changes require updating environment variables or config files.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
request: Language update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if language exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Note: Actual configuration changes require updating settings
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Language update validated: {locale}. "
|
||||
"Configuration changes require environment variable updates."
|
||||
)
|
||||
|
||||
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Translation Management APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/translations", response_model=ApiResponse)
|
||||
def get_all_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for all or specific locale.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
All translations organized by locale and namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
if locale:
|
||||
# Get translations for specific locale
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = {
|
||||
locale: translation_service._cache.get(locale, {})
|
||||
}
|
||||
else:
|
||||
# Get all translations
|
||||
translations = translation_service._cache
|
||||
|
||||
response = TranslationResponse(translations=translations)
|
||||
|
||||
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||
def get_locale_translations(
|
||||
locale: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get all translations for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
|
||||
Returns:
|
||||
All translations for the locale organized by namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
translations = translation_service._cache.get(locale, {})
|
||||
|
||||
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||
def get_namespace_translations(
|
||||
locale: str,
|
||||
namespace: str,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get translations for a specific namespace in a locale.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||
|
||||
Returns:
|
||||
Translations for the specified namespace
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get namespace translations request: locale={locale}, "
|
||||
f"namespace={namespace}, user={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Get namespace translations
|
||||
locale_translations = translation_service._cache.get(locale, {})
|
||||
namespace_translations = locale_translations.get(namespace, {})
|
||||
|
||||
if not namespace_translations:
|
||||
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||
)
|
||||
return success(
|
||||
data={
|
||||
"locale": locale,
|
||||
"namespace": namespace,
|
||||
"translations": namespace_translations
|
||||
},
|
||||
msg=t("common.success.retrieved")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||
def update_translation(
|
||||
locale: str,
|
||||
key: str,
|
||||
request: TranslationUpdateRequest,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Update a single translation (admin only).
|
||||
|
||||
Note: This endpoint validates the request but actual translation updates
|
||||
require modifying translation files in the locales directory.
|
||||
|
||||
Args:
|
||||
locale: Language code
|
||||
key: Translation key (format: "namespace.key.subkey")
|
||||
request: Translation update request
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Update translation request: locale={locale}, key={key}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
# Check if locale exists
|
||||
available_locales = translation_service.get_available_locales()
|
||||
if locale not in available_locales:
|
||||
api_logger.warning(f"Language not found: {locale}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=t("i18n.language.not_found", locale=locale)
|
||||
)
|
||||
|
||||
# Validate key format
|
||||
if "." not in key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||
)
|
||||
|
||||
# Note: Actual translation updates require modifying JSON files
|
||||
# This endpoint serves as a validation and documentation point
|
||||
|
||||
api_logger.info(
|
||||
f"Translation update validated: {locale}/{key}. "
|
||||
"Translation files need to be updated manually."
|
||||
)
|
||||
|
||||
return success(
|
||||
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/translations/missing", response_model=ApiResponse)
|
||||
def get_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Get list of missing translations.
|
||||
|
||||
Compares translations across locales to find missing keys.
|
||||
|
||||
Args:
|
||||
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||
|
||||
Returns:
|
||||
List of missing translation keys
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
translation_service = get_translation_service()
|
||||
|
||||
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||
available_locales = translation_service.get_available_locales()
|
||||
|
||||
# Get default locale translations as reference
|
||||
default_translations = translation_service._cache.get(default_locale, {})
|
||||
|
||||
# Collect all keys from default locale
|
||||
def collect_keys(data, prefix=""):
|
||||
keys = []
|
||||
for key, value in data.items():
|
||||
full_key = f"{prefix}.{key}" if prefix else key
|
||||
if isinstance(value, dict):
|
||||
keys.extend(collect_keys(value, full_key))
|
||||
else:
|
||||
keys.append(full_key)
|
||||
return keys
|
||||
|
||||
default_keys = set()
|
||||
for namespace, translations in default_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
default_keys.update(namespace_keys)
|
||||
|
||||
# Find missing keys in target locale(s)
|
||||
missing_by_locale = {}
|
||||
|
||||
target_locales = [locale] if locale else [
|
||||
loc for loc in available_locales if loc != default_locale
|
||||
]
|
||||
|
||||
for target_locale in target_locales:
|
||||
if target_locale not in available_locales:
|
||||
continue
|
||||
|
||||
target_translations = translation_service._cache.get(target_locale, {})
|
||||
target_keys = set()
|
||||
|
||||
for namespace, translations in target_translations.items():
|
||||
namespace_keys = collect_keys(translations, namespace)
|
||||
target_keys.update(namespace_keys)
|
||||
|
||||
missing_keys = default_keys - target_keys
|
||||
if missing_keys:
|
||||
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||
|
||||
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||
|
||||
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||
|
||||
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/reload", response_model=ApiResponse)
|
||||
def reload_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Trigger hot reload of translation files (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to reload (defaults to reloading all locales)
|
||||
|
||||
Returns:
|
||||
Reload status and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Reload translations request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||
api_logger.warning("Hot reload is disabled in configuration")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=t("i18n.reload.disabled")
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
|
||||
try:
|
||||
# Reload translations
|
||||
translation_service.reload(locale)
|
||||
|
||||
# Get statistics
|
||||
available_locales = translation_service.get_available_locales()
|
||||
reloaded_locales = [locale] if locale else available_locales
|
||||
|
||||
response = ReloadResponse(
|
||||
success=True,
|
||||
reloaded_locales=reloaded_locales,
|
||||
total_locales=len(available_locales)
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||
)
|
||||
|
||||
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=t("i18n.reload.failed", error=str(e))
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Performance Monitoring APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/metrics", response_model=ApiResponse)
|
||||
def get_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get i18n performance metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Performance metrics including:
|
||||
- Request counts
|
||||
- Missing translations
|
||||
- Timing statistics
|
||||
- Locale usage
|
||||
- Error counts
|
||||
"""
|
||||
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
metrics = translation_service.get_metrics_summary()
|
||||
|
||||
api_logger.info("Returning i18n metrics")
|
||||
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||
def get_cache_stats(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get cache statistics (admin only).
|
||||
|
||||
Returns:
|
||||
Cache statistics including:
|
||||
- Hit/miss rates
|
||||
- LRU cache performance
|
||||
- Loaded locales
|
||||
- Memory usage
|
||||
"""
|
||||
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||
|
||||
translation_service = get_translation_service()
|
||||
cache_stats = translation_service.get_cache_stats()
|
||||
memory_usage = translation_service.get_memory_usage()
|
||||
|
||||
data = {
|
||||
"cache": cache_stats,
|
||||
"memory": memory_usage
|
||||
}
|
||||
|
||||
api_logger.info("Returning cache statistics")
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/metrics/prometheus")
|
||||
def get_prometheus_metrics(
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get metrics in Prometheus format (admin only).
|
||||
|
||||
Returns:
|
||||
Prometheus-formatted metrics as plain text
|
||||
"""
|
||||
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
prometheus_output = metrics.export_prometheus()
|
||||
|
||||
from fastapi.responses import PlainTextResponse
|
||||
return PlainTextResponse(content=prometheus_output)
|
||||
|
||||
|
||||
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||
def reset_metrics(
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Reset all metrics (admin only).
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||
|
||||
from app.i18n.metrics import get_metrics
|
||||
metrics = get_metrics()
|
||||
metrics.reset()
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_service.cache.reset_stats()
|
||||
|
||||
api_logger.info("Metrics reset completed")
|
||||
return success(msg=t("i18n.metrics.reset_success"))
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Missing Translation Logging and Reporting APIs
|
||||
# ============================================================================
|
||||
|
||||
@router.get("/logs/missing", response_model=ApiResponse)
|
||||
def get_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Get missing translation logs (admin only).
|
||||
|
||||
Returns logged missing translations with context information.
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
limit: Maximum number of entries to return (default: 100)
|
||||
|
||||
Returns:
|
||||
Missing translation logs with context
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Get missing translation logs request: locale={locale}, "
|
||||
f"limit={limit}, admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Get missing translations
|
||||
missing_translations = translation_logger.get_missing_translations(locale)
|
||||
|
||||
# Get missing with context
|
||||
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||
|
||||
# Get statistics
|
||||
statistics = translation_logger.get_statistics()
|
||||
|
||||
data = {
|
||||
"missing_translations": missing_translations,
|
||||
"recent_context": missing_with_context,
|
||||
"statistics": statistics
|
||||
}
|
||||
|
||||
api_logger.info(
|
||||
f"Returning {statistics['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=data, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||
def generate_missing_translation_report(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Generate a comprehensive missing translation report (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Comprehensive report with missing translations and statistics
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Generate missing translation report request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate report
|
||||
report = translation_logger.generate_report(locale)
|
||||
|
||||
api_logger.info(
|
||||
f"Generated report with {report['total_missing']} missing translations"
|
||||
)
|
||||
return success(data=report, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||
def export_missing_translations(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Export missing translations to JSON file (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale filter
|
||||
|
||||
Returns:
|
||||
Export status and file path
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Export missing translations request: locale={locale}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
from datetime import datetime
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Generate filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
locale_suffix = f"_{locale}" if locale else "_all"
|
||||
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||
|
||||
# Export to file
|
||||
translation_logger.export_to_json(output_file)
|
||||
|
||||
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||
return success(
|
||||
data={"file_path": output_file},
|
||||
msg=t("i18n.logs.export_success", file=output_file)
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||
def clear_missing_translation_logs(
|
||||
locale: Optional[str] = None,
|
||||
t: Callable = Depends(get_translator),
|
||||
current_user: User = Depends(get_current_superuser)
|
||||
):
|
||||
"""
|
||||
Clear missing translation logs (admin only).
|
||||
|
||||
Args:
|
||||
locale: Optional locale to clear (clears all if not specified)
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||
f"admin={current_user.username}"
|
||||
)
|
||||
|
||||
translation_service = get_translation_service()
|
||||
translation_logger = translation_service.translation_logger
|
||||
|
||||
# Clear logs
|
||||
translation_logger.clear(locale)
|
||||
|
||||
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||
return success(msg=t("i18n.logs.clear_success"))
|
||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
||||
from app.models.user_model import User
|
||||
from app.schemas import mcp_market_config_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import mcp_market_config_service
|
||||
from app.services import mcp_market_config_service, mcp_market_service
|
||||
|
||||
# Obtain a dedicated API logger
|
||||
api_logger = get_api_logger()
|
||||
@@ -55,6 +55,12 @@ async def get_mcp_servers(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The paging parameter must be greater than 0"
|
||||
)
|
||||
if page * pagesize > 100:
|
||||
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
||||
)
|
||||
|
||||
# 2. Query mcp market config information from the database
|
||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||
@@ -64,14 +70,16 @@ async def get_mcp_servers(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 3. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
@@ -115,6 +123,17 @@ async def get_mcp_servers(
|
||||
"has_next": True if page * pagesize < total else False
|
||||
}
|
||||
}
|
||||
# 5. Update mck_market.mcp_count
|
||||
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market:
|
||||
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market does not exist or access is denied"
|
||||
)
|
||||
db_mcp_market.mcp_count = total
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market)
|
||||
return success(data=result, msg="Query of mcp servers list successful")
|
||||
|
||||
|
||||
@@ -140,14 +159,16 @@ async def get_operational_mcp_servers(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Execute paged query
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
url = f'{api.mcp_base_url}/operational'
|
||||
@@ -198,14 +219,16 @@ async def get_mcp_server(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Get detailed information for a specific MCP Server
|
||||
api = MCPApi()
|
||||
token = db_mcp_market_config.token
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="MCP market config token is not configured"
|
||||
)
|
||||
api = MCPApi()
|
||||
api.login(token)
|
||||
|
||||
result = api.get_mcp_server(server_id=server_id)
|
||||
@@ -226,7 +249,26 @@ async def create_mcp_market_config(
|
||||
|
||||
try:
|
||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||
# 1. Check if the mcp market name already exists
|
||||
# 1. Validate token can access ModelScope MCP market
|
||||
if not create_data.token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Token is required to access ModelScope MCP market"
|
||||
)
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(create_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(create_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
# 2. Check if the mcp market name already exists
|
||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||
if db_mcp_market_config_exist:
|
||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||
@@ -234,6 +276,30 @@ async def create_mcp_market_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||
)
|
||||
# 2. verify token
|
||||
create_data.status = 1
|
||||
try:
|
||||
api = MCPApi()
|
||||
token = create_data.token
|
||||
api.login(token)
|
||||
|
||||
body = {
|
||||
'filter': {},
|
||||
'page_number': 1,
|
||||
'page_size': 20,
|
||||
'search': ""
|
||||
}
|
||||
cookies = api.get_cookies(token)
|
||||
r = api.session.put(
|
||||
url=api.mcp_base_url,
|
||||
headers=api.builder_headers(api.headers),
|
||||
json=body,
|
||||
cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except requests.exceptions.RequestException as e:
|
||||
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||
create_data.status = 0
|
||||
# 3. create mcp_market_config
|
||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||
api_logger.info(
|
||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||
@@ -262,10 +328,7 @@ async def get_mcp_market_config(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -295,10 +358,7 @@ async def get_mcp_market_config_by_mcp_market_id(
|
||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or access is denied"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
@@ -324,12 +384,25 @@ async def update_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Update fields (only update non-null fields)
|
||||
# 2. Validate new token if provided
|
||||
if update_data.token is not None:
|
||||
try:
|
||||
api = MCPApi()
|
||||
api.login(update_data.token)
|
||||
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||
cookies = api.get_cookies(update_data.token)
|
||||
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||
raise_for_http_status(r)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||
)
|
||||
|
||||
# 3. Update fields (only update non-null fields)
|
||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||
update_dict = update_data.dict(exclude_unset=True)
|
||||
updated_fields = []
|
||||
@@ -344,7 +417,7 @@ async def update_mcp_market_config(
|
||||
if updated_fields:
|
||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||
|
||||
# 3. Save to database
|
||||
# 4. Save to database
|
||||
try:
|
||||
db.commit()
|
||||
db.refresh(db_mcp_market_config)
|
||||
@@ -357,7 +430,7 @@ async def update_mcp_market_config(
|
||||
detail=f"The mcp market config update failed: {str(e)}"
|
||||
)
|
||||
|
||||
# 4. Return the updated mcp market config
|
||||
# 5. Return the updated mcp market config
|
||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||
msg="The mcp market config information updated successfully")
|
||||
|
||||
@@ -381,10 +454,7 @@ async def delete_mcp_market_config(
|
||||
if not db_mcp_market_config:
|
||||
api_logger.warning(
|
||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
||||
)
|
||||
return success(msg='The mcp market config does not exist or access is denied')
|
||||
|
||||
# 2. Deleting mcp market config
|
||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from app.core.response_utils import success
|
||||
@@ -149,6 +150,21 @@ async def get_workspace_end_users(
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_implicit_emotions_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
_celery_app.send_task(
|
||||
"app.tasks.init_interest_distribution_for_users",
|
||||
kwargs={"end_user_ids": end_user_ids},
|
||||
)
|
||||
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
@@ -177,7 +193,16 @@ async def get_workspace_end_users(
|
||||
await aio_redis_set(cache_key, json.dumps(result), expire=30)
|
||||
except Exception as e:
|
||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
|
||||
try:
|
||||
from app.tasks import init_community_clustering_for_users
|
||||
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
|
||||
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||
|
||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||
return success(data=result, msg="宿主列表获取成功")
|
||||
|
||||
@@ -387,14 +412,15 @@ def get_current_user_rag_total_num(
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
limit: int = Query(15, description="返回记录数"),
|
||||
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取当前宿主知识库中的chunk内容
|
||||
获取当前宿主知识库中的chunk内容(分页)
|
||||
"""
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
||||
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||
|
||||
|
||||
@@ -407,26 +433,18 @@ async def get_chunk_summary_tag(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk总结、提取的标签和人物形象
|
||||
|
||||
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"summary": "chunk内容的总结",
|
||||
"tags": [
|
||||
{"tag": "标签1", "frequency": 5},
|
||||
{"tag": "标签2", "frequency": 3},
|
||||
...
|
||||
],
|
||||
"personas": [
|
||||
"产品设计师",
|
||||
"旅行爱好者",
|
||||
"摄影发烧友",
|
||||
...
|
||||
]
|
||||
"summary": "用户摘要",
|
||||
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||
"personas": ["产品设计师", ...],
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
@@ -434,9 +452,8 @@ async def get_chunk_summary_tag(
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||
@@ -447,24 +464,57 @@ async def get_chunk_insight(
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
获取chunk的洞察内容
|
||||
|
||||
读取RAG洞察报告(纯读库,不触发生成)。
|
||||
|
||||
返回格式:
|
||||
{
|
||||
"insight": "对chunk内容的深度洞察分析"
|
||||
"insight": "总体概述",
|
||||
"behavior_pattern": "行为模式",
|
||||
"key_findings": "关键发现",
|
||||
"growth_trajectory": "成长轨迹",
|
||||
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||
}
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||
|
||||
data = await memory_dashboard_service.get_chunk_insight(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
db=db,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info("成功获取chunk洞察")
|
||||
return success(data=data, msg="chunk洞察获取成功")
|
||||
|
||||
return success(data=data, msg="获取成功")
|
||||
|
||||
|
||||
class GenerateRagProfileRequest(BaseModel):
|
||||
end_user_id: str = Field(..., description="宿主ID")
|
||||
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||
max_tags: int = Field(10, description="最大标签数量")
|
||||
|
||||
|
||||
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||
async def generate_rag_profile(
|
||||
body: GenerateRagProfileRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||
每次请求都会重新生成,覆盖已有数据。
|
||||
"""
|
||||
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||
|
||||
data = await memory_dashboard_service.generate_rag_profile(
|
||||
end_user_id=body.end_user_id,
|
||||
limit=body.limit,
|
||||
max_tags=body.max_tags,
|
||||
db=db,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
api_logger.info(f"RAG画像生产完成: {data}")
|
||||
return success(data=data, msg="RAG画像生产完成")
|
||||
|
||||
|
||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||
@@ -553,9 +603,12 @@ async def dashboard_data(
|
||||
)
|
||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
from app.repositories import app_repository
|
||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
||||
neo4j_data["total_app"] = len(apps_orm)
|
||||
# 包含自有app + 被分享给本工作空间的app
|
||||
from app.services import app_service as _app_svc
|
||||
_, total_app = _app_svc.AppService(db).list_apps(
|
||||
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||
)
|
||||
neo4j_data["total_app"] = total_app
|
||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||
|
||||
@@ -1,3 +1,19 @@
|
||||
"""
|
||||
Memory Reflection Controller
|
||||
|
||||
This module provides REST API endpoints for managing memory reflection configurations
|
||||
and operations. It handles reflection engine setup, configuration management, and
|
||||
execution of self-reflection processes across memory systems.
|
||||
|
||||
Key Features:
|
||||
- Reflection configuration management (save, retrieve, update)
|
||||
- Workspace-wide reflection execution across multiple applications
|
||||
- Individual configuration-based reflection runs
|
||||
- Multi-language support for reflection outputs
|
||||
- Integration with Neo4j memory storage and LLM models
|
||||
- Comprehensive error handling and logging
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory",
|
||||
tags=["Memory"],
|
||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
"""
|
||||
Save reflection configuration to memory config table
|
||||
|
||||
Persists reflection engine configuration settings to the data_config table,
|
||||
including reflection parameters, model settings, and evaluation criteria.
|
||||
Validates configuration parameters and ensures data consistency.
|
||||
|
||||
Args:
|
||||
request: Memory reflection configuration data including:
|
||||
- config_id: Configuration identifier to update
|
||||
- reflection_enabled: Whether reflection is enabled
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model for reflection operations
|
||||
- memory_verify: Enable memory verification checks
|
||||
- quality_assessment: Enable quality assessment evaluation
|
||||
current_user: Authenticated user saving the configuration
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with saved reflection configuration data
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If config_id is missing or parameters are invalid
|
||||
HTTPException 500: If configuration save operation fails
|
||||
|
||||
Database Operations:
|
||||
- Updates memory_config table with reflection settings
|
||||
- Commits transaction and refreshes entity
|
||||
- Maintains configuration consistency
|
||||
"""
|
||||
try:
|
||||
config_id = request.config_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
||||
)
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
# Update reflection configuration in database
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
# Commit transaction and refresh entity
|
||||
db.commit()
|
||||
db.refresh(memory_config)
|
||||
|
||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
"""
|
||||
Start reflection functionality for all matching applications in workspace
|
||||
|
||||
Initiates reflection processes across all applications within the user's current
|
||||
workspace that have valid memory configurations. Processes each application's
|
||||
configurations and associated end users, executing reflection operations
|
||||
with proper error isolation and transaction management.
|
||||
|
||||
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||
that reflection failures for individual users don't affect other operations.
|
||||
|
||||
Args:
|
||||
current_user: Authenticated user initiating workspace reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection results for all processed applications:
|
||||
- app_id: Application identifier
|
||||
- config_id: Memory configuration identifier
|
||||
- end_user_id: End user identifier
|
||||
- reflection_result: Individual reflection operation result
|
||||
|
||||
Processing Logic:
|
||||
1. Retrieve all applications in the current workspace
|
||||
2. Filter applications with valid memory configurations
|
||||
3. For each configuration, find matching releases
|
||||
4. Execute reflection for each end user with isolated transactions
|
||||
5. Aggregate results with error handling per user
|
||||
|
||||
Error Handling:
|
||||
- Individual user reflection failures are isolated
|
||||
- Failed operations are logged and included in results
|
||||
- Database transactions are isolated per user to prevent cascading failures
|
||||
- Comprehensive error reporting for debugging
|
||||
|
||||
Raises:
|
||||
HTTPException 500: If workspace reflection initialization fails
|
||||
|
||||
Performance Notes:
|
||||
- Uses independent database sessions for each user operation
|
||||
- Prevents transaction failures from affecting other users
|
||||
- Comprehensive logging for operation tracking
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
||||
|
||||
reflection_results = []
|
||||
|
||||
# Process each application in the workspace
|
||||
for data in result['apps_detailed_info']:
|
||||
# 跳过没有配置的应用
|
||||
# Skip applications without configurations
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
# Execute reflection for each configuration and user combination
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
# Find all releases matching this configuration
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
# Execute reflection for each user - using independent database sessions
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
# Create independent database session for each user to avoid transaction failure impact
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
||||
"""
|
||||
Query reflection configuration information by config_id
|
||||
|
||||
Retrieves detailed reflection configuration settings from the memory_config
|
||||
table for a specific configuration ID. Provides comprehensive reflection
|
||||
parameters including model settings, evaluation criteria, and operational flags.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) to query
|
||||
current_user: Authenticated user making the request
|
||||
db: Database session for data operations
|
||||
|
||||
Returns:
|
||||
dict: Success response with detailed reflection configuration:
|
||||
- config_id: Resolved configuration identifier
|
||||
- reflection_enabled: Whether reflection is enabled for this config
|
||||
- reflection_period_in_hours: Reflection execution interval
|
||||
- reflexion_range: Scope of reflection operations (partial/all)
|
||||
- baseline: Reflection strategy (time/fact/hybrid)
|
||||
- reflection_model_id: LLM model identifier for reflection
|
||||
- memory_verify: Memory verification flag
|
||||
- quality_assessment: Quality assessment flag
|
||||
|
||||
Database Operations:
|
||||
- Queries memory_config table by resolved config_id
|
||||
- Retrieves all reflection-related configuration fields
|
||||
- Resolves configuration ID for consistent formatting
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration with specified ID is not found
|
||||
HTTPException 500: If configuration query operation fails
|
||||
|
||||
ID Resolution:
|
||||
- Supports both UUID and integer config_id formats
|
||||
- Automatically resolves to appropriate internal format
|
||||
- Maintains consistency across different ID representations
|
||||
"""
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
try:
|
||||
config_id=resolve_config_id(config_id,db)
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
memory_config_id = resolve_config_id(result.config_id, db)
|
||||
# 构建返回数据
|
||||
|
||||
# Build response data with comprehensive configuration details
|
||||
reflection_config = {
|
||||
"config_id": memory_config_id,
|
||||
"reflection_enabled": result.enable_self_reflexion,
|
||||
@@ -204,10 +337,12 @@ async def start_reflection_configs(
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
|
||||
|
||||
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||
|
||||
except HTTPException:
|
||||
# 重新抛出HTTP异常
|
||||
# Re-raise HTTP exceptions without modification
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Execute reflection engine with specified configuration
|
||||
|
||||
Runs the reflection engine using configuration parameters from the database.
|
||||
Validates model availability, sets up the reflection engine with proper
|
||||
configuration, and executes the reflection process with multi-language support.
|
||||
|
||||
This endpoint provides a test run capability for reflection configurations,
|
||||
allowing users to validate their reflection settings and see results before
|
||||
deploying to production environments.
|
||||
|
||||
Args:
|
||||
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||
language_type: Language preference header for output localization (optional)
|
||||
current_user: Authenticated user executing the reflection
|
||||
db: Database session for configuration queries
|
||||
|
||||
Returns:
|
||||
dict: Success response with reflection execution results including:
|
||||
- baseline: Reflection strategy used
|
||||
- source_data: Input data processed
|
||||
- memory_verifies: Memory verification results (if enabled)
|
||||
- quality_assessments: Quality assessment results (if enabled)
|
||||
- reflexion_data: Generated reflection insights and solutions
|
||||
|
||||
Configuration Validation:
|
||||
- Verifies configuration exists in database
|
||||
- Validates LLM model availability
|
||||
- Falls back to default model if specified model is unavailable
|
||||
- Ensures all required parameters are properly set
|
||||
|
||||
Reflection Engine Setup:
|
||||
- Creates ReflectionConfig with database parameters
|
||||
- Initializes Neo4j connector for memory access
|
||||
- Sets up ReflectionEngine with validated model
|
||||
- Configures language preferences for output
|
||||
|
||||
Error Handling:
|
||||
- Model validation with fallback to default
|
||||
- Configuration validation and error reporting
|
||||
- Comprehensive logging for debugging
|
||||
- Graceful handling of missing configurations
|
||||
|
||||
Raises:
|
||||
HTTPException 404: If configuration is not found
|
||||
HTTPException 500: If reflection execution fails
|
||||
|
||||
Performance Notes:
|
||||
- Direct database query for configuration retrieval
|
||||
- Model validation to prevent runtime failures
|
||||
- Efficient reflection engine initialization
|
||||
- Language-aware output processing
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 使用MemoryConfigRepository查询反思配置
|
||||
|
||||
# Query reflection configuration using MemoryConfigRepository
|
||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
||||
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 验证模型ID是否存在
|
||||
# Validate model ID existence
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
||||
# 可以设置为None,让反思引擎使用默认模型
|
||||
model_id = None
|
||||
|
||||
# Create reflection configuration with database parameters
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
||||
model_id=model_id,
|
||||
language_type=language_type
|
||||
)
|
||||
|
||||
# Initialize Neo4j connector and reflection engine
|
||||
connector = Neo4jConnector()
|
||||
engine = ReflectionEngine(
|
||||
config=config,
|
||||
neo4j_connector=connector,
|
||||
llm_client=model_id # 传入验证后的 model_id
|
||||
llm_client=model_id # Pass validated model_id
|
||||
)
|
||||
|
||||
result=await (engine.reflection_run())
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
"""
|
||||
Memory Short Term Controller
|
||||
|
||||
This module provides REST API endpoints for managing short-term and long-term memory
|
||||
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||
and provides comprehensive memory insights for end users.
|
||||
|
||||
Key Features:
|
||||
- Short-term memory data retrieval and statistics
|
||||
- Long-term memory data aggregation
|
||||
- Entity count integration
|
||||
- Multi-language response support
|
||||
- Memory system analytics and reporting
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
||||
from app.services.memory_short_service import LongService, ShortService
|
||||
from app.services.memory_storage_service import search_entity
|
||||
|
||||
# Load environment variables for configuration
|
||||
load_dotenv()
|
||||
|
||||
# Initialize API logger for request tracking and debugging
|
||||
api_logger = get_api_logger()
|
||||
|
||||
# Configure router with prefix and tags for API organization
|
||||
router = APIRouter(
|
||||
prefix="/memory/short",
|
||||
tags=["Memory"],
|
||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
"""
|
||||
Retrieve comprehensive short-term and long-term memory statistics
|
||||
|
||||
Provides a comprehensive overview of memory system data for a specific end user,
|
||||
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||
|
||||
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||
|
||||
Args:
|
||||
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||
language_type: Language preference header for response localization (optional)
|
||||
current_user: Authenticated user making the request (injected by dependency)
|
||||
db: Database session for data operations (injected by dependency)
|
||||
|
||||
Returns:
|
||||
dict: Success response containing comprehensive memory statistics:
|
||||
- short_term: List of short-term memory entries with detailed data
|
||||
- long_term: List of long-term memory aggregations and summaries
|
||||
- entity: Count of entities associated with the end user
|
||||
- retrieval_number: Total count of short-term memory retrievals
|
||||
- long_term_number: Total count of long-term memory entries
|
||||
|
||||
Response Structure:
|
||||
{
|
||||
"code": 200,
|
||||
"msg": "Short-term memory system data retrieved successfully",
|
||||
"data": {
|
||||
"short_term": [...], # Short-term memory entries
|
||||
"long_term": [...], # Long-term memory data
|
||||
"entity": 42, # Entity count
|
||||
"retrieval_number": 156, # Short-term retrieval count
|
||||
"long_term_number": 23 # Long-term memory count
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||
|
||||
Performance Notes:
|
||||
- Combines multiple service calls for comprehensive data
|
||||
- Entity search is performed asynchronously for better performance
|
||||
- Response time depends on memory data volume for the specified user
|
||||
"""
|
||||
# Use centralized language validation for consistent localization
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id, db)
|
||||
short_result=short_term.get_short_databasets()
|
||||
short_count=short_term.get_short_count()
|
||||
# Retrieve short-term memory data and statistics
|
||||
short_term = ShortService(end_user_id, db)
|
||||
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||
|
||||
long_term=LongService(end_user_id, db)
|
||||
long_result=long_term.get_long_databasets()
|
||||
# Retrieve long-term memory data and aggregations
|
||||
long_term = LongService(end_user_id, db)
|
||||
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||
|
||||
# Get entity count for the specified end user
|
||||
entity_result = await search_entity(end_user_id)
|
||||
|
||||
# Compile comprehensive memory statistics response
|
||||
result = {
|
||||
'short_term': short_result,
|
||||
'long_term': long_result,
|
||||
'entity': entity_result.get('num', 0),
|
||||
"retrieval_number":short_count,
|
||||
"long_term_number":len(long_result)
|
||||
'short_term': short_result, # Short-term memory entries
|
||||
'long_term': long_result, # Long-term memory data
|
||||
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||
"long_term_number": len(long_result) # Long-term memory entry count
|
||||
}
|
||||
|
||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import conversation_schema
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.conversation_service import ConversationService
|
||||
|
||||
@@ -90,11 +91,7 @@ def get_messages(
|
||||
conversation_id,
|
||||
)
|
||||
messages = [
|
||||
{
|
||||
"role": message.role,
|
||||
"content": message.content,
|
||||
"created_at": int(message.created_at.timestamp() * 1000),
|
||||
}
|
||||
conversation_schema.Message.model_validate(message)
|
||||
for message in messages_obj
|
||||
]
|
||||
return success(data=messages, msg="get conversation history success")
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
@@ -22,6 +21,7 @@ from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.app_service import AppService
|
||||
from app.services.auth_service import create_access_token
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
@@ -215,8 +215,11 @@ def list_conversations(
|
||||
service = SharedChatService(db)
|
||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=app.workspace_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
@@ -308,25 +311,29 @@ async def chat(
|
||||
|
||||
# Store end_user_id in database with original user_id
|
||||
end_user_repo = EndUserRepository(db)
|
||||
app_service = AppService(db)
|
||||
app = app_service._get_app_or_404(share.app_id)
|
||||
workspace_id = app.workspace_id
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=user_id # Save original user_id to other_id
|
||||
original_user_id=user_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid = share.app_id
|
||||
# appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||
app = db.query(App).filter(
|
||||
App.id == appid,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
if not app:
|
||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
# app = db.query(App).filter(
|
||||
# App.id == appid,
|
||||
# App.is_active.is_(True)
|
||||
# ).first()
|
||||
# if not app:
|
||||
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||
|
||||
workspace_id = app.workspace_id
|
||||
# workspace_id = app.workspace_id
|
||||
|
||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
@@ -610,11 +617,11 @@ async def chat(
|
||||
|
||||
# 多 Agent 非流式返回
|
||||
result = await app_chat_service.workflow_chat(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -654,17 +661,21 @@ async def config_query(
|
||||
workflow_service = WorkflowService(db)
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": workflow_service.get_start_node_variables(release.config)
|
||||
"variables": workflow_service.get_start_node_variables(release.config),
|
||||
"memory": workflow_service.is_memory_enable(release.config),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": release.config.get("variables")
|
||||
"variables": release.config.get("variables"),
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
elif release.app.type == AppType.MULTI_AGENT:
|
||||
content = {
|
||||
"app_type": release.app.type,
|
||||
"variables": []
|
||||
"variables": [],
|
||||
"features": release.config.get("features")
|
||||
}
|
||||
else:
|
||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
@@ -95,8 +95,8 @@ async def chat(
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
other_id=other_id,
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search = True
|
||||
@@ -280,6 +280,7 @@ async def chat(
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
|
||||
@@ -6,6 +6,7 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.schemas.memory_api_schema import (
|
||||
ListConfigsResponse,
|
||||
MemoryReadRequest,
|
||||
MemoryReadResponse,
|
||||
MemoryWriteRequest,
|
||||
@@ -31,14 +32,15 @@ async def write_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
||||
|
||||
message: str = Body(..., description="Message content"),
|
||||
):
|
||||
"""
|
||||
Write memory to storage.
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryWriteRequest(**body)
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
@@ -62,13 +64,15 @@ async def read_memory_api_service(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
payload: MemoryReadRequest = Body(..., embed=False),
|
||||
message: str = Body(..., description="Query message"),
|
||||
):
|
||||
"""
|
||||
Read memory from storage.
|
||||
|
||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||
"""
|
||||
body = await request.json()
|
||||
payload = MemoryReadRequest(**body)
|
||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
@@ -85,3 +89,27 @@ async def read_memory_api_service(
|
||||
|
||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||
|
||||
|
||||
@router.get("/configs")
|
||||
@require_api_key(scopes=["memory"])
|
||||
async def list_memory_configs(
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List all memory configs for the workspace.
|
||||
|
||||
Returns all available memory configurations associated with the authorized workspace.
|
||||
"""
|
||||
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
result = memory_api_service.list_memory_configs(
|
||||
workspace_id=api_key_auth.workspace_id,
|
||||
)
|
||||
|
||||
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||
|
||||
@@ -3,8 +3,11 @@ from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.schemas.tool_schema import (
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||
CustomToolTestRequest, ToolActiveUpdate
|
||||
)
|
||||
|
||||
from app.core.response_utils import success
|
||||
@@ -14,6 +17,7 @@ from app.models import User
|
||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||
from app.services.tool_service import ToolService
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.core.exceptions import BusinessException
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||
|
||||
@@ -97,7 +101,13 @@ async def create_tool(
|
||||
):
|
||||
"""创建工具"""
|
||||
try:
|
||||
tool_id = service.create_tool(
|
||||
# 将 MCP 来源字段合并进 config
|
||||
if request.tool_type == ToolType.MCP:
|
||||
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||
val = getattr(request, key, None)
|
||||
if val is not None:
|
||||
request.config[key] = val
|
||||
tool_id = await service.create_tool(
|
||||
name=request.name,
|
||||
tool_type=request.tool_type,
|
||||
tenant_id=current_user.tenant_id,
|
||||
@@ -107,6 +117,8 @@ async def create_tool(
|
||||
tags=request.tags
|
||||
)
|
||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||
except BusinessException as e:
|
||||
raise HTTPException(status_code=400, detail=e.message)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
@@ -147,7 +159,7 @@ async def delete_tool(
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""删除工具"""
|
||||
"""删除工具(逻辑删除,is_active=False)"""
|
||||
try:
|
||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||
if not success_flag:
|
||||
@@ -159,6 +171,30 @@ async def delete_tool(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||
async def set_tool_active(
|
||||
tool_id: str,
|
||||
request: ToolActiveUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
service: ToolService = Depends(get_tool_service)
|
||||
):
|
||||
"""设置工具可用状态(启用/禁用)
|
||||
|
||||
- is_active=true: 启用工具
|
||||
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||
"""
|
||||
try:
|
||||
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||
if not success_flag:
|
||||
raise HTTPException(status_code=404, detail="工具不存在")
|
||||
action = "启用" if request.is_active else "禁用"
|
||||
return success(msg=f"工具已{action}")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/execution/execute", response_model=ApiResponse)
|
||||
async def execute_tool(
|
||||
request: ToolExecuteRequest,
|
||||
@@ -216,8 +252,10 @@ async def sync_mcp_tools(
|
||||
try:
|
||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||
if not result.get("success", False):
|
||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
||||
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||
return success(data=result, msg="MCP工具列表同步完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -240,8 +278,10 @@ async def test_tool_connection(
|
||||
# 普通连接测试
|
||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||
if result["success"] is False:
|
||||
raise HTTPException(status_code=400, detail=result["message"])
|
||||
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||
return success(data=result, msg="连接测试完成")
|
||||
except BusinessException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.orm import Session
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.security import verify_password
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
# 获取API专用日志器
|
||||
api_logger = get_api_logger()
|
||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
||||
def create_superuser(
|
||||
user: user_schema.UserCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_superuser: User = Depends(get_current_superuser)
|
||||
current_superuser: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""创建超级管理员(仅超级管理员可访问)"""
|
||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||
@@ -42,7 +45,7 @@ def create_superuser(
|
||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="超级管理员创建成功")
|
||||
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||
|
||||
|
||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||
@@ -50,6 +53,7 @@ def delete_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""停用用户(软删除)"""
|
||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -57,13 +61,14 @@ def delete_user(
|
||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||
return success(msg="用户停用成功")
|
||||
return success(msg=t("users.delete.deactivate_success"))
|
||||
|
||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||
def activate_user(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""激活用户"""
|
||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -74,13 +79,14 @@ def activate_user(
|
||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户激活成功")
|
||||
return success(data=result_schema, msg=t("users.activate.success"))
|
||||
|
||||
|
||||
@router.get("", response_model=ApiResponse)
|
||||
def get_current_user_info(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户信息"""
|
||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||
@@ -105,7 +111,7 @@ def get_current_user_info(
|
||||
break
|
||||
|
||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.get("/superusers", response_model=ApiResponse)
|
||||
@@ -113,6 +119,7 @@ def get_tenant_superusers(
|
||||
include_inactive: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||
@@ -125,7 +132,7 @@ def get_tenant_superusers(
|
||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||
|
||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
||||
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +141,7 @@ def get_user_info_by_id(
|
||||
user_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""根据用户ID获取用户信息"""
|
||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||
@@ -144,7 +152,7 @@ def get_user_info_by_id(
|
||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||
|
||||
result_schema = user_schema.User.model_validate(result)
|
||||
return success(data=result_schema, msg="用户信息获取成功")
|
||||
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||
|
||||
|
||||
@router.put("/change-password", response_model=ApiResponse)
|
||||
@@ -152,6 +160,7 @@ async def change_password(
|
||||
request: ChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""修改当前用户密码"""
|
||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||
@@ -164,7 +173,7 @@ async def change_password(
|
||||
current_user=current_user
|
||||
)
|
||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
|
||||
|
||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||
@@ -172,6 +181,7 @@ async def admin_change_password(
|
||||
request: AdminChangePasswordRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""超级管理员修改指定用户的密码"""
|
||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||
@@ -186,16 +196,17 @@ async def admin_change_password(
|
||||
# 根据是否生成了随机密码来构造响应
|
||||
if request.new_password:
|
||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||
return success(msg="密码修改成功")
|
||||
return success(msg=t("auth.password.change_success"))
|
||||
else:
|
||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||
return success(data=generated_password, msg="密码重置成功")
|
||||
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||
|
||||
|
||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||
def verify_pwd(
|
||||
request: VerifyPasswordRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证当前用户密码"""
|
||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||
@@ -203,8 +214,8 @@ def verify_pwd(
|
||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||
if not is_valid:
|
||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg="验证完成")
|
||||
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||
|
||||
|
||||
@router.post("/send-email-code", response_model=ApiResponse)
|
||||
@@ -212,6 +223,7 @@ async def send_email_code(
|
||||
request: SendEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""发送邮箱验证码"""
|
||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||
@@ -219,7 +231,7 @@ async def send_email_code(
|
||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||
|
||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
||||
return success(msg=t("users.email.code_sent"))
|
||||
|
||||
|
||||
@router.put("/change-email", response_model=ApiResponse)
|
||||
@@ -227,6 +239,7 @@ async def change_email(
|
||||
request: VerifyEmailCodeRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""验证验证码并修改邮箱"""
|
||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||
@@ -239,4 +252,51 @@ async def change_email(
|
||||
)
|
||||
|
||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||
return success(msg="邮箱修改成功")
|
||||
return success(msg=t("users.email.change_success"))
|
||||
|
||||
|
||||
|
||||
@router.get("/me/language", response_model=ApiResponse)
|
||||
def get_current_user_language(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前用户的语言偏好"""
|
||||
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||
|
||||
language = user_service.get_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||
msg=t("users.language.get_success")
|
||||
)
|
||||
|
||||
|
||||
@router.put("/me/language", response_model=ApiResponse)
|
||||
def update_current_user_language(
|
||||
request: user_schema.LanguagePreferenceRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
"""设置当前用户的语言偏好"""
|
||||
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||
|
||||
updated_user = user_service.update_user_language_preference(
|
||||
db=db,
|
||||
user_id=current_user.id,
|
||||
language=request.language,
|
||||
current_user=current_user
|
||||
)
|
||||
|
||||
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||
return success(
|
||||
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||
msg=t("users.language.update_success")
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
analytics_graph_data,
|
||||
analytics_community_graph_data,
|
||||
)
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
@@ -295,6 +296,42 @@ async def get_graph_data_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||
async def get_community_graph_data_api(
|
||||
end_user_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(
|
||||
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||
return success(data=result, msg=result.get("message", "查询成功"))
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||
f"nodes={result['statistics']['total_nodes']}, "
|
||||
f"edges={result['statistics']['total_edges']}"
|
||||
)
|
||||
return success(data=result, msg="查询成功")
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||
|
||||
|
||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||
async def get_end_user_profile(
|
||||
end_user_id: str,
|
||||
|
||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
||||
get_current_user,
|
||||
workspace_access_guard,
|
||||
)
|
||||
from app.i18n.dependencies import get_current_language, get_translator
|
||||
from app.i18n.serializers import (
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer
|
||||
)
|
||||
from app.models.tenant_model import Tenants
|
||||
from app.models.user_model import User
|
||||
from app.models.workspace_model import InviteStatus
|
||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
current_tenant: Tenants = Depends(get_current_tenant)
|
||||
current_tenant: Tenants = Depends(get_current_tenant),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前租户下用户参与的所有工作空间
|
||||
|
||||
@@ -88,8 +96,13 @@ def get_workspaces(
|
||||
)
|
||||
|
||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||
|
||||
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||
|
||||
|
||||
@router.post("", response_model=ApiResponse)
|
||||
@@ -98,6 +111,8 @@ def create_workspace(
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_superuser),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建新的工作空间"""
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -118,8 +133,13 @@ def create_workspace(
|
||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||
f"创建者: {current_user.username}, language={language}"
|
||||
)
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.created"))
|
||||
|
||||
@router.put("", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
@@ -127,6 +147,8 @@ def update_workspace(
|
||||
workspace: WorkspaceUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新工作空间"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -139,14 +161,21 @@ def update_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||
result_schema = WorkspaceResponse.model_validate(result)
|
||||
return success(data=result_schema, msg="工作空间更新成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceSerializer()
|
||||
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||
result_i18n = serializer.serialize(result_data, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||
|
||||
@router.get("/members", response_model=ApiResponse)
|
||||
@cur_workspace_access_guard()
|
||||
def get_cur_workspace_members(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间成员列表(关系序列化)"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||
|
||||
# 转换为表格项并使用序列化器添加国际化字段
|
||||
table_items = _convert_members_to_table_items(members)
|
||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
||||
serializer = WorkspaceMemberSerializer()
|
||||
members_data = [item.model_dump() for item in table_items]
|
||||
members_i18n = serializer.serialize_list(members_data, language)
|
||||
|
||||
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||
|
||||
|
||||
@router.put("/members", response_model=ApiResponse)
|
||||
@@ -168,6 +203,7 @@ def update_workspace_members(
|
||||
updates: List[WorkspaceMemberUpdate],
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||
@@ -178,7 +214,7 @@ def update_workspace_members(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||
return success(msg="成员角色更新成功")
|
||||
return success(msg=t("workspace.members.role_updated"))
|
||||
|
||||
|
||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||
@@ -187,6 +223,7 @@ def delete_workspace_member(
|
||||
member_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||
@@ -198,7 +235,7 @@ def delete_workspace_member(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||
return success(msg="成员删除成功")
|
||||
return success(msg=t("workspace.members.deleted"))
|
||||
|
||||
|
||||
# 创建空间协作邀请
|
||||
@@ -208,6 +245,8 @@ def create_workspace_invite(
|
||||
invite_data: WorkspaceInviteCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""创建工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -220,7 +259,12 @@ def create_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||
return success(data=result, msg="邀请创建成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||
|
||||
|
||||
@router.get("/invites", response_model=ApiResponse)
|
||||
@@ -232,6 +276,8 @@ def get_workspace_invites(
|
||||
offset: int = Query(0, ge=0),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请列表"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -246,18 +292,30 @@ def get_workspace_invites(
|
||||
offset=offset
|
||||
)
|
||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||
return success(data=invites, msg="邀请列表获取成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
invites_i18n = serializer.serialize_list(invites, language)
|
||||
|
||||
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||
|
||||
|
||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||
def get_workspace_invite_info(
|
||||
token: str,
|
||||
db: Session = Depends(get_db),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取工作空间邀请用户信息(无需认证)"""
|
||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||
return success(data=result, msg="邀请验证成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||
|
||||
|
||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
|
||||
invite_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""撤销工作空间邀请"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||
return success(data=result, msg="邀请撤销成功")
|
||||
|
||||
# 使用序列化器添加国际化字段
|
||||
serializer = WorkspaceInviteSerializer()
|
||||
result_i18n = serializer.serialize(result, language)
|
||||
|
||||
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||
|
||||
# ==================== 公开邀请接口(无需认证) ====================
|
||||
|
||||
@@ -302,6 +367,7 @@ def switch_workspace(
|
||||
workspace_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""切换工作空间"""
|
||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||
@@ -312,7 +378,7 @@ def switch_workspace(
|
||||
user=current_user,
|
||||
)
|
||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||
return success(msg="工作空间切换成功")
|
||||
return success(msg=t("workspace.switched"))
|
||||
|
||||
|
||||
@router.get("/storage", response_model=ApiResponse)
|
||||
@@ -320,6 +386,7 @@ def switch_workspace(
|
||||
def get_workspace_storage_type(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的存储类型"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
|
||||
user=current_user
|
||||
)
|
||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
||||
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||
|
||||
|
||||
@router.get("/workspace_models", response_model=ApiResponse)
|
||||
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
|
||||
def workspace_models_configs(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
language: str = Depends(get_current_language),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -354,14 +423,14 @@ def workspace_models_configs(
|
||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="工作空间不存在或无权访问"
|
||||
detail=t("workspace.not_found")
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||
|
||||
|
||||
@router.put("/workspace_models", response_model=ApiResponse)
|
||||
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
|
||||
models_update: WorkspaceModelsUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
t: callable = Depends(get_translator)
|
||||
):
|
||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
|
||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||
)
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
@@ -115,6 +114,7 @@ class Settings:
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
@@ -162,6 +162,44 @@ class Settings:
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# ========================================================================
|
||||
# Internationalization (i18n) Configuration
|
||||
# ========================================================================
|
||||
# Default language for API responses
|
||||
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Supported languages (comma-separated)
|
||||
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||
lang.strip()
|
||||
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||
if lang.strip()
|
||||
]
|
||||
|
||||
# Core locales directory (community edition)
|
||||
# Use absolute path to work from any working directory
|
||||
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||
"I18N_CORE_LOCALES_DIR",
|
||||
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||
)
|
||||
|
||||
# Premium locales directory (enterprise edition, optional)
|
||||
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||
|
||||
# Enable translation cache
|
||||
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||
|
||||
# LRU cache size for hot translations
|
||||
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||
|
||||
# Enable hot reload of translation files
|
||||
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||
|
||||
# Fallback language when translation is missing
|
||||
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||
|
||||
# Log missing translations
|
||||
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information
|
||||
|
||||
Extracts the content from the first message in the state and returns it
|
||||
as the data field while preserving all other state information.
|
||||
|
||||
Args:
|
||||
state: ReadState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
"""
|
||||
Start node - Extract content and maintain state information for write operations
|
||||
|
||||
Extracts the content from the first message in the state for write operations.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages and other state data
|
||||
|
||||
Returns:
|
||||
WriteState: Updated state with extracted content in data field
|
||||
"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
# Return content and maintain all state information
|
||||
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||
|
||||
return {"data": content}
|
||||
|
||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
"""
|
||||
Problem processing node service class
|
||||
|
||||
Handles problem decomposition and extension operations using LLM services.
|
||||
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
"""
|
||||
Problem decomposition node
|
||||
|
||||
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||
Uses LLM to analyze the input and generate structured problem decomposition
|
||||
with question types and reasoning.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user input and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with problem decomposition results
|
||||
"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
# Create default empty result
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
# Return updated state including spit_context field
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
"""
|
||||
Problem extension node
|
||||
|
||||
Extends the decomposed problems from Split_The_Problem node by generating
|
||||
additional related questions and organizing them by original question.
|
||||
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||
|
||||
Args:
|
||||
state: ReadState containing decomposed problems and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with extended problem results
|
||||
"""
|
||||
# Get original data and decomposition results
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
# Provide more detailed error information
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
|
||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Get LLM configuration information from state
|
||||
|
||||
Retrieves model configuration details including model ID and tenant ID
|
||||
from the memory configuration in the current state.
|
||||
|
||||
Args:
|
||||
state: ReadState containing memory configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Model configuration as Pydantic model
|
||||
"""
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
# Use existing memory_config instead of re-querying database
|
||||
# or use thread-safe database access
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
Simplified database search result cleaning function
|
||||
|
||||
Processes and cleans search results from various sources including
|
||||
reranked results and time-based search results. Extracts text content
|
||||
from structured data and returns as formatted string.
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
data: Search result data (can be string, dict, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
str: Cleaned content string
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
# Parse JSON string
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# Get result data
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
# Collect all content
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
# Process reranked results
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
# Process time search results
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
@@ -128,11 +169,18 @@ async def clean_databases(data) -> str:
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||
text_parts = []
|
||||
seen_community_names = set()
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
# community 节点用 name 去重
|
||||
if 'member_count' in item or 'core_entities' in item:
|
||||
community_name = item.get('name') or item.get('id', '')
|
||||
if community_name in seen_community_names:
|
||||
continue
|
||||
seen_community_names.add(community_name)
|
||||
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
@@ -146,10 +194,19 @@ async def clean_databases(data) -> str:
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
"""
|
||||
Retrieve information using simplified search approach
|
||||
|
||||
Processes extended problems from previous nodes and performs retrieval
|
||||
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||
processing of multiple questions and deduplicates results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
@@ -163,7 +220,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
@@ -209,7 +266,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
@@ -257,7 +314,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取end_user_id
|
||||
"""
|
||||
Advanced retrieve function using LangChain agents and tools
|
||||
|
||||
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||
memory storage approaches with concurrent processing and result deduplication.
|
||||
|
||||
Args:
|
||||
state: ReadState containing problem extensions and configuration
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state with retrieval results and intermediate outputs
|
||||
"""
|
||||
# Get end_user_id from state
|
||||
import time
|
||||
start = time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
@@ -291,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
||||
search_params = {
|
||||
"end_user_id": end_user_id,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||
}
|
||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
@@ -299,21 +373,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
# Create async task to process individual questions
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
# Define semaphore at module level to limit maximum concurrency
|
||||
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
async with SEMAPHORE: # Limit concurrency
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||
question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
@@ -327,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||
_expanded_stmts_to_write = []
|
||||
try:
|
||||
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||
reranked = results_dict.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
if not community_hits:
|
||||
community_hits = results_dict.get('communities', [])
|
||||
if community_hits:
|
||||
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_hits,
|
||||
end_user_id=end_user_id,
|
||||
existing_content=clean_content,
|
||||
)
|
||||
if new_texts:
|
||||
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||
except Exception as parse_err:
|
||||
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
@@ -362,7 +460,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
# Process all questions concurrently
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
"""
|
||||
Summary node service class
|
||||
|
||||
Handles summary generation operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
generating summaries from retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
"""
|
||||
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||
|
||||
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||
weights, and reranker settings specifically for summary generation.
|
||||
|
||||
Args:
|
||||
state: Current state containing user_rag_memory_id
|
||||
|
||||
Returns:
|
||||
dict: RAG configuration dictionary with knowledge base settings
|
||||
"""
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
@@ -54,6 +75,23 @@ async def rag_config(state):
|
||||
|
||||
|
||||
async def rag_knowledge(state, question):
|
||||
"""
|
||||
Retrieve knowledge using RAG approach for summary generation
|
||||
|
||||
Performs knowledge retrieval from configured knowledge bases using the
|
||||
provided question and returns formatted results for summary processing.
|
||||
|
||||
Args:
|
||||
state: Current state containing configuration
|
||||
question: Question to search for in knowledge base
|
||||
|
||||
Returns:
|
||||
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||
- clean_content: Formatted content string
|
||||
- cleaned_query: Processed query string
|
||||
- raw_results: Raw retrieval results
|
||||
"""
|
||||
kb_config = await rag_config(state)
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
|
||||
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Retrieve conversation history for summary context
|
||||
|
||||
Gets the conversation history for the current user to provide context
|
||||
for summary generation operations.
|
||||
|
||||
Args:
|
||||
state: ReadState containing end_user_id
|
||||
|
||||
Returns:
|
||||
ReadState: Conversation history data
|
||||
"""
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||
return history
|
||||
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||
search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
Enhanced summary_llm function with better error handling and data validation
|
||||
|
||||
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||
for handling LLM failures and provides robust error recovery.
|
||||
|
||||
Args:
|
||||
state: ReadState containing current context
|
||||
history: Conversation history for context
|
||||
retrieve_info: Retrieved information to summarize
|
||||
template_name: Jinja2 template name for prompt generation
|
||||
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||
response_model: Pydantic model for structured output
|
||||
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||
|
||||
Returns:
|
||||
str: Generated summary text or fallback message
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
# Build system prompt
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
# Use optimized LLM service for structured output
|
||||
with get_db_context() as db_session:
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
# Validate structured response
|
||||
if structured is None:
|
||||
logger.warning("LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
# Extract answer based on operation type
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
# Handle RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning("结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
# Validate answer is not empty
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
# Try unstructured output as fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
# Simple response cleaning
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
# Remove possible JSON markers
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
||||
|
||||
|
||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
"""
|
||||
Save summary results to Redis session storage
|
||||
|
||||
Stores the generated summary and user query in Redis for session management
|
||||
and conversation history tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user and query information
|
||||
aimessages: Generated summary message to save
|
||||
|
||||
Returns:
|
||||
ReadState: Updated state after saving to Redis
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
end_user_id = state.get("end_user_id", '')
|
||||
await SessionService(store).save_session(
|
||||
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||
|
||||
|
||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||
"""
|
||||
Format summary results for different output types
|
||||
|
||||
Creates structured output formats for both input summary and retrieval summary
|
||||
operations, including metadata and intermediate results for frontend display.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user information
|
||||
aimessages: Generated summary message
|
||||
raw_results: Raw search/retrieval results
|
||||
|
||||
Returns:
|
||||
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
data = state.get("data", '')
|
||||
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
||||
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate quick input summary from retrieved information
|
||||
|
||||
Performs fast retrieval and generates a quick summary response for user queries.
|
||||
This function prioritizes speed by only searching summary nodes and provides
|
||||
immediate feedback to users.
|
||||
|
||||
Args:
|
||||
state: ReadState containing user query, storage configuration, and context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing summary results with status and metadata
|
||||
"""
|
||||
start = time.time()
|
||||
storage_type = state.get("storage_type", '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
@@ -229,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
"end_user_id": end_user_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
||||
}
|
||||
|
||||
try:
|
||||
if storage_type != "rag":
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
||||
memory_config=memory_config)
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||
**search_params,
|
||||
memory_config=memory_config,
|
||||
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||
)
|
||||
# 调试:打印 community 检索结果数量
|
||||
if raw_results and isinstance(raw_results, dict):
|
||||
reranked = raw_results.get('reranked_results', {})
|
||||
community_hits = reranked.get('communities', [])
|
||||
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||
else:
|
||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||
except Exception as e:
|
||||
@@ -266,6 +380,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
||||
|
||||
|
||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate comprehensive summary from retrieved expansion issues
|
||||
|
||||
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||
This function handles complex retrieval results and provides comprehensive answers
|
||||
based on expanded query results.
|
||||
|
||||
Args:
|
||||
state: ReadState containing retrieve data with expansion issues
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing comprehensive summary results
|
||||
"""
|
||||
retrieve = state.get("retrieve", '')
|
||||
history = await summary_history(state)
|
||||
import json
|
||||
@@ -299,13 +426,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate final comprehensive summary from verified data
|
||||
|
||||
Creates the final summary using verified expansion issues and conversation history.
|
||||
This function processes verified data to generate the most comprehensive and
|
||||
accurate response to user queries.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and query information
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing final summary results
|
||||
"""
|
||||
start = time.time()
|
||||
query = state.get("data", '')
|
||||
verify = state.get("verify", '')
|
||||
@@ -336,13 +476,26 @@ async def Summary(state: ReadState) -> ReadState:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
# Fixed coroutine call - await first, then access return value
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary": summary}
|
||||
|
||||
|
||||
async def Summary_fails(state: ReadState) -> ReadState:
|
||||
"""
|
||||
Generate fallback summary when normal summary process fails
|
||||
|
||||
Provides a fallback summary generation mechanism when the standard summary
|
||||
process encounters errors or fails to produce satisfactory results. Uses
|
||||
a specialized failure template to handle edge cases.
|
||||
|
||||
Args:
|
||||
state: ReadState containing verified data and failure context
|
||||
|
||||
Returns:
|
||||
ReadState: Dictionary containing fallback summary results
|
||||
"""
|
||||
storage_type = state.get("storage_type", '')
|
||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||
history = await summary_history(state)
|
||||
|
||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
"""
|
||||
Verification node service class
|
||||
|
||||
Handles data verification operations using LLM services. Inherits from
|
||||
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||
verifying and validating retrieved information.
|
||||
|
||||
Attributes:
|
||||
template_service: Service for rendering Jinja2 templates
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
# Create global service instance
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
"""
|
||||
Process verification results and generate output format
|
||||
|
||||
Transforms VerificationResult objects into structured output format suitable
|
||||
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||
dictionary format and adds metadata for tracking.
|
||||
|
||||
Args:
|
||||
state: ReadState containing storage and user configuration
|
||||
messages_deal: VerificationResult containing verification outcomes
|
||||
|
||||
Returns:
|
||||
dict: Formatted verification result with status and metadata
|
||||
"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
# Convert VerificationItem objects to dictionary list
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
# Generate JSON schema to guide LLM output format
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||
|
||||
with get_db_context() as db_session:
|
||||
structured = await asyncio.wait_for(
|
||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
timeout=150.0 # 150 second timeout
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
"""
|
||||
Create and return a LangGraph workflow for memory reading operations
|
||||
|
||||
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||
verification, and summarization. The workflow includes nodes for content input,
|
||||
problem splitting, retrieval, verification, and various summary operations.
|
||||
|
||||
Yields:
|
||||
StateGraph: Compiled LangGraph workflow for memory reading
|
||||
|
||||
Raises:
|
||||
Exception: If workflow creation fails
|
||||
"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
@@ -48,7 +60,7 @@ async def make_read_graph():
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
# Add edges to define workflow flow
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
@@ -63,7 +75,7 @@ async def make_read_graph():
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
# Compile workflow
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
@@ -72,108 +84,3 @@ async def make_read_graph():
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
end_user_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start = time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
||||
"end_user_id": end_user_id
|
||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
||||
"memory_config": memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
end = time.time()
|
||||
print(100 * 'y')
|
||||
print(f"总耗时: {end - start}s")
|
||||
print(100 * 'y')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
|
||||
|
||||
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
status=state.get('verify', '')['status']
|
||||
status = state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
|
||||
@@ -2,77 +2,104 @@ import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.db import get_db_context
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
"""
|
||||
Write messages to RAG storage system
|
||||
|
||||
Combines user and AI messages into a single string format and stores them
|
||||
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for the conversation
|
||||
user_message: User's input message content
|
||||
ai_message: AI's response message content
|
||||
user_rag_memory_id: RAG memory identifier for storage location
|
||||
"""
|
||||
# RAG mode: combine messages into string format (maintain original logic)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
|
||||
|
||||
async def write(
|
||||
storage_type,
|
||||
end_user_id,
|
||||
user_message,
|
||||
ai_message,
|
||||
user_rag_memory_id,
|
||||
actual_end_user_id,
|
||||
actual_config_id,
|
||||
long_term_messages=None
|
||||
):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
Write memory with structured message support
|
||||
|
||||
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||
Supports both individual message pairs and batch long-term message processing.
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||
end_user_id: Terminal user identifier
|
||||
user_message: User message content
|
||||
ai_message: AI response content
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_end_user_id: Actual user identifier for storage
|
||||
actual_config_id: Configuration identifier
|
||||
long_term_messages: Optional list of structured messages for batch processing
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
Logic explanation:
|
||||
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||
- Neo4j mode: Uses structured message lists
|
||||
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||
3. Each message is converted to independent Chunk, preserving speaker field
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
if long_term_messages is None:
|
||||
long_term_messages = []
|
||||
with get_db_context() as db:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
# Neo4j mode: Use structured message lists
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
# Always add user message (if not empty)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
# Only add assistant message when AI reply is not empty
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
# If long_term_messages provided, use it to replace structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
# If it's a JSON string, parse it first
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
# If no messages, return directly
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
actual_end_user_id, # end_user_id: User ID
|
||||
structured_messages, # message: JSON string format message list
|
||||
str(actual_config_id), # config_id: Configuration ID string
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
|
||||
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||
"""
|
||||
Save long-term memory data to database
|
||||
|
||||
Handles the storage of long-term memory data based on different strategies
|
||||
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||
to long-term memory storage.
|
||||
|
||||
Args:
|
||||
long_term_messages: Long-term message data to be saved
|
||||
actual_config_id: Configuration identifier for memory settings
|
||||
end_user_id: User identifier for memory association
|
||||
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||
scope: Scope/window size for memory processing
|
||||
"""
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
if len(chunk_data) == scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
"""Window-based dialogue processing"""
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
|
||||
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||
"""
|
||||
Process dialogue based on window size and write to Neo4j
|
||||
|
||||
Manages conversation data based on a sliding window approach. When the window
|
||||
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
langchain_messages: Original message data list
|
||||
scope: Window size determining when to trigger long-term storage
|
||||
"""
|
||||
scope = scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
|
||||
await write(
|
||||
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||
end_user_id,
|
||||
"",
|
||||
"",
|
||||
None,
|
||||
end_user_id,
|
||||
config_id,
|
||||
formatted_messages
|
||||
)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
"""Time-based memory processing"""
|
||||
|
||||
|
||||
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||
"""
|
||||
Process memory storage based on time intervals and write to Neo4j
|
||||
|
||||
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||
long-term storage. This function handles time-based memory consolidation.
|
||||
|
||||
Args:
|
||||
end_user_id: Terminal user identifier
|
||||
memory_config: Memory configuration object containing settings
|
||||
time: Time interval for data retrieval
|
||||
"""
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
format_messages = long_time_data
|
||||
messages = []
|
||||
memory_config = memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
message = json.loads(i['Query'])
|
||||
messages += message
|
||||
if format_messages:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
|
||||
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||
|
||||
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
end_user_id: Terminal user identifier
|
||||
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: Memory configuration object containing LLM settings
|
||||
|
||||
Returns:
|
||||
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||
"""
|
||||
history = None
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
# 1. Get historical session data (using new method)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
@@ -210,7 +276,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
@@ -223,16 +289,16 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,41 +2,53 @@ import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
"""
|
||||
Extract ToolMessage content and tool names from agent response
|
||||
|
||||
Parses agent response messages to extract tool execution results and metadata.
|
||||
Handles JSON parsing and provides structured access to tool output data.
|
||||
|
||||
Args:
|
||||
response: Agent response dictionary containing messages
|
||||
|
||||
Returns:
|
||||
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||
- tool_name: Name of the executed tool
|
||||
- content: Parsed tool execution result (JSON or raw text)
|
||||
"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
# This is a ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
# Try to get tool name
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
# Parse JSON content
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
# If not JSON format, return content directly
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
@@ -46,38 +58,61 @@ def extract_tool_message_content(response):
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
"""
|
||||
Input schema for time retrieval tool
|
||||
|
||||
Defines the expected input parameters for time-based retrieval operations.
|
||||
Used for validation and documentation of tool parameters.
|
||||
|
||||
Attributes:
|
||||
context: User input query content for search
|
||||
end_user_id: Group ID for filtering search results, defaults to test user
|
||||
"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
|
||||
def create_time_retrieval_tool(end_user_id: str):
|
||||
"""
|
||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||
|
||||
Creates a specialized time-based retrieval tool that searches for statements within
|
||||
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||
metadata from search results.
|
||||
|
||||
Args:
|
||||
end_user_id: User identifier for scoping search results
|
||||
|
||||
Returns:
|
||||
function: Configured TimeRetrievalWithGroupId tool function
|
||||
"""
|
||||
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
Clean unnecessary fields from temporal search results and modify structure
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption and
|
||||
restructures the response format for better usability.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
data: Data to be cleaned (dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# List of fields to filter out
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
# Further change internal statements to time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
@@ -91,26 +126,35 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs time-based search operations with automatic metadata filtering. Supports
|
||||
flexible date range specification and provides clean, user-friendly output.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query context content
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results with temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
# Use passed parameters or default values
|
||||
actual_end_user_id = end_user_id_param or end_user_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
# Basic time search
|
||||
results = await search_by_temporal(
|
||||
end_user_id=actual_end_user_id,
|
||||
start_date=actual_start_date,
|
||||
@@ -118,33 +162,43 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||
clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Performs combined keyword and temporal search operations with automatic metadata
|
||||
filtering. Provides more targeted search results by combining content relevance
|
||||
with time-based filtering.
|
||||
|
||||
Explicit parameters:
|
||||
- context: Query content for keyword matching
|
||||
- days_back: Number of days to search backwards, default 7 days
|
||||
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||
- clean_output: Whether to clean metadata fields from output
|
||||
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results combining keyword and temporal data
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
# Keyword time search
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
end_user_id=end_user_id,
|
||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
@@ -162,51 +216,61 @@ def create_time_retrieval_tool(end_user_id: str):
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
||||
memory_config: Memory configuration object containing LLM and search settings
|
||||
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearch tool function with async capabilities
|
||||
"""
|
||||
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
Recursively clean unnecessary fields from results
|
||||
|
||||
Removes metadata fields that are not needed for end-user consumption,
|
||||
improving readability and reducing response size.
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
data: Data to be cleaned (can be dict, list, or other types)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
Cleaned data with unnecessary fields removed
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# List of fields to filter out
|
||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||
}
|
||||
|
||||
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
# Clean dictionary
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
# Clean each element in list
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
# Return other types directly
|
||||
return data
|
||||
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
@@ -216,57 +280,63 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
clean_output: bool = True # New: whether to clean output fields
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||
|
||||
Provides comprehensive search capabilities combining multiple search strategies
|
||||
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
rerank_alpha: Reranking weight parameter for result scoring
|
||||
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||
use_llm_rerank: Whether to use LLM-based reranking
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted comprehensive search results
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
# Import run_hybrid_search function
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
# Merge parameters, prioritize passed parameters
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||
"output_path": None, # Don't save to file
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
# Execute hybrid retrieval
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
# Clean unnecessary fields from results
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
# Format return results
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
@@ -275,38 +345,52 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||
|
||||
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||
making it compatible with synchronous tool execution environments.
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
memory_config: Memory configuration object containing search settings
|
||||
**search_params: Search parameters for configuration
|
||||
|
||||
Returns:
|
||||
function: Configured HybridSearchSync tool function
|
||||
"""
|
||||
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
end_user_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||
|
||||
Provides the same hybrid search capabilities as the async version but in a
|
||||
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
end_user_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
context: Query content for search
|
||||
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||
limit: Result quantity limit
|
||||
end_user_id: Group ID for filtering search results
|
||||
clean_output: Whether to clean metadata fields from output
|
||||
|
||||
Returns:
|
||||
str: JSON formatted search results
|
||||
"""
|
||||
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
# Create async tool and execute
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
@@ -315,7 +399,7 @@ def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"end_user_id": end_user_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
return HybridSearchSync
|
||||
|
||||
@@ -1,20 +1,28 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
|
||||
|
||||
async def format_parsing(messages: list, type: str = 'string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
Format and parse message lists into different output types
|
||||
|
||||
Processes message lists from storage and converts them into either string format
|
||||
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||
and role-based message organization.
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
messages: List of message objects from storage containing message data
|
||||
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
list: Formatted message list in the specified format
|
||||
- 'string': List of formatted text messages with role prefixes
|
||||
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
user = []
|
||||
ai = []
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
if role == 'human' or role == "user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
if type == "dict":
|
||||
if role == 'human' or role == "user":
|
||||
user.append(content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
for key, values in zip(user, ai):
|
||||
result.append({key: values})
|
||||
return result
|
||||
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
"""
|
||||
Parse messages from storage format into user-AI conversation pairs
|
||||
|
||||
Extracts and organizes conversation data from stored message format,
|
||||
separating user and AI messages and pairing them for database storage.
|
||||
|
||||
Args:
|
||||
messages: List or dictionary containing stored message data with Query fields
|
||||
|
||||
Returns:
|
||||
list: List of dictionaries containing user-AI message pairs for database storage
|
||||
"""
|
||||
user = []
|
||||
ai = []
|
||||
database = []
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
async def agent_chat_messages(user_content, ai_content):
|
||||
"""
|
||||
Create structured chat message format for agent conversations
|
||||
|
||||
Formats user and AI content into a standardized message structure suitable
|
||||
for agent processing and storage. Creates role-based message objects.
|
||||
|
||||
Args:
|
||||
user_content: User's message content string
|
||||
ai_content: AI's response content string
|
||||
|
||||
Returns:
|
||||
list: List of structured message dictionaries with role and content fields
|
||||
"""
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
@@ -42,10 +41,26 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
|
||||
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||
end_user_id: str = '', scope: int = 6):
|
||||
"""
|
||||
Handle long-term memory storage with different strategies
|
||||
|
||||
Supports multiple storage strategies including chunk-based, time-based,
|
||||
and aggregate judgment approaches for long-term memory persistence.
|
||||
|
||||
Args:
|
||||
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||
langchain_messages: List of messages to store
|
||||
memory_config: Memory configuration identifier
|
||||
end_user_id: User group identifier
|
||||
scope: Scope parameter for chunk-based storage (default: 6)
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||
aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
write_store.save_session_write(end_user_id, langchain_messages)
|
||||
# 获取数据库会话
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||
"""Time-based strategy"""
|
||||
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
"""Strategy 3: Aggregate judgment"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||
"""
|
||||
Write long-term memory with different storage types
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
Handles both RAG-based storage and traditional memory storage approaches.
|
||||
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||
|
||||
Args:
|
||||
storage_type: Type of storage (RAG or traditional)
|
||||
end_user_id: User group identifier
|
||||
message_chat: User message content
|
||||
aimessages: AI response messages
|
||||
user_rag_memory_id: RAG memory identifier
|
||||
actual_config_id: Actual configuration ID
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
@@ -101,4 +129,4 @@ async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
# asyncio.run(main())
|
||||
|
||||
@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||
_EXPAND_FIELDS_TO_REMOVE = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||
}
|
||||
|
||||
|
||||
def _clean_expand_fields(obj):
|
||||
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||
if isinstance(obj, list):
|
||||
return [_clean_expand_fields(i) for i in obj]
|
||||
return obj
|
||||
|
||||
|
||||
async def expand_communities_to_statements(
|
||||
community_results: List[dict],
|
||||
end_user_id: str,
|
||||
existing_content: str = "",
|
||||
limit: int = 10,
|
||||
) -> Tuple[List[dict], List[str]]:
|
||||
"""
|
||||
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||
|
||||
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||
- 过滤不可序列化字段
|
||||
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||
"""
|
||||
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||
if not community_ids or not end_user_id:
|
||||
return [], []
|
||||
|
||||
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
result = await search_graph_community_expand(
|
||||
connector=connector,
|
||||
community_ids=community_ids,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||
return [], []
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
expanded_stmts = result.get("expanded_statements", [])
|
||||
if not expanded_stmts:
|
||||
return [], []
|
||||
|
||||
existing_lines = set(existing_content.splitlines())
|
||||
new_texts = [
|
||||
s["statement"] for s in expanded_stmts
|
||||
if s.get("statement") and s["statement"] not in existing_lines
|
||||
]
|
||||
cleaned = _clean_expand_fields(expanded_stmts)
|
||||
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||
return cleaned, new_texts
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
@@ -21,7 +87,7 @@ class SearchService:
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||
"""
|
||||
Extract only meaningful content from search results, dropping all metadata.
|
||||
|
||||
@@ -30,9 +96,11 @@ class SearchService:
|
||||
- Entities: extract 'name' and 'fact_summary' fields
|
||||
- Summaries: extract 'content' field
|
||||
- Chunks: extract 'content' field
|
||||
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||
|
||||
Args:
|
||||
result: Search result dictionary
|
||||
node_type: Hint for node type ("community", "summary", etc.)
|
||||
|
||||
Returns:
|
||||
Clean content string without metadata
|
||||
@@ -46,8 +114,21 @@ class SearchService:
|
||||
if 'statement' in result and result['statement']:
|
||||
content_parts.append(result['statement'])
|
||||
|
||||
# Summaries/Chunks: extract content field
|
||||
if 'content' in result and result['content']:
|
||||
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||
is_community = (
|
||||
node_type == "community"
|
||||
or 'member_count' in result
|
||||
or 'core_entities' in result
|
||||
)
|
||||
if is_community:
|
||||
name = result.get('name', '')
|
||||
content = result.get('content', '')
|
||||
if content:
|
||||
prefix = f"[主题:{name}] " if name else ""
|
||||
content_parts.append(f"{prefix}{content}")
|
||||
elif 'content' in result and result['content']:
|
||||
# Summaries / Chunks
|
||||
content_parts.append(result['content'])
|
||||
|
||||
# Entities: extract name and fact_summary (commented out in original)
|
||||
@@ -99,7 +180,8 @@ class SearchService:
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config = None
|
||||
memory_config = None,
|
||||
expand_communities: bool = True,
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search and return clean content.
|
||||
@@ -114,13 +196,15 @@ class SearchService:
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
expand_communities: If True, expand community hits to member statements (default: True).
|
||||
Set to False for quick-summary paths that only need community-level text.
|
||||
|
||||
Returns:
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
@@ -146,8 +230,8 @@ class SearchService:
|
||||
if search_type == "hybrid":
|
||||
reranked_results = answer.get('reranked_results', {})
|
||||
|
||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in reranked_results:
|
||||
@@ -157,19 +241,33 @@ class SearchService:
|
||||
else:
|
||||
# For keyword or embedding search, results are directly in answer dict
|
||||
# Apply same priority order
|
||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
||||
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||
|
||||
for category in priority_order:
|
||||
if category in include and category in answer:
|
||||
category_results = answer[category]
|
||||
if isinstance(category_results, list):
|
||||
answer_list.extend(category_results)
|
||||
|
||||
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||
if expand_communities and "communities" in include:
|
||||
community_results = (
|
||||
answer.get('reranked_results', {}).get('communities', [])
|
||||
if search_type == "hybrid"
|
||||
else answer.get('communities', [])
|
||||
)
|
||||
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||
community_results=community_results,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
answer_list.extend(cleaned_stmts)
|
||||
|
||||
# Extract clean content from all results
|
||||
content_list = [
|
||||
self.extract_content_from_result(ans)
|
||||
for ans in answer_list
|
||||
]
|
||||
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||
content_list = []
|
||||
for ans in answer_list:
|
||||
# community 节点有 member_count 或 core_entities 字段
|
||||
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||
|
||||
|
||||
# Filter out empty strings and join with newlines
|
||||
|
||||
@@ -84,7 +84,7 @@ async def get_chunked_dialogs(
|
||||
pruning_scene=memory_config.pruning_scene or "education",
|
||||
pruning_threshold=memory_config.pruning_threshold,
|
||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||
ontology_classes=memory_config.ontology_classes,
|
||||
ontology_class_infos=memory_config.ontology_class_infos,
|
||||
)
|
||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||
|
||||
|
||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
||||
|
||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||
|
||||
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
"""
|
||||
Langgrapg Writing TypedDict
|
||||
'''
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages]
|
||||
end_user_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
||||
config_id: str
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
problem_extension: dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
retrieve: dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
|
||||
|
||||
class COUNTState:
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
@@ -99,6 +103,7 @@ class COUNTState:
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
@@ -142,4 +148,4 @@ def convert_extended_question_to_question(data):
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
return data
|
||||
|
||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
@@ -165,10 +165,16 @@ async def write(
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
connector=neo4j_connector,
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||
schedule_clustering_after_write(
|
||||
all_entity_nodes,
|
||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
|
||||
@@ -6,6 +6,7 @@ of the memory system including LLM, chunking, pruning, and search.
|
||||
Classes:
|
||||
LLMConfig: Configuration for LLM client
|
||||
ChunkerConfig: Configuration for dialogue chunking
|
||||
OntologyClassInfo: Single ontology class with name and description
|
||||
PruningConfig: Configuration for semantic pruning
|
||||
TemporalSearchParams: Parameters for temporal search queries
|
||||
"""
|
||||
@@ -50,30 +51,41 @@ class ChunkerConfig(BaseModel):
|
||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||
|
||||
|
||||
class OntologyClassInfo(BaseModel):
|
||||
"""本体类型的名称与语义描述,用于剪枝提示词注入。
|
||||
|
||||
Attributes:
|
||||
class_name: 本体类型名称(如"患者"、"课程")
|
||||
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
|
||||
"""
|
||||
class_name: str = Field(..., description="本体类型名称")
|
||||
class_description: str = Field(default="", description="本体类型语义描述")
|
||||
|
||||
|
||||
class PruningConfig(BaseModel):
|
||||
"""Configuration for semantic pruning of dialogue content.
|
||||
|
||||
Attributes:
|
||||
pruning_switch: Enable or disable semantic pruning
|
||||
pruning_scene: Scene name for pruning, either a built-in key
|
||||
('education', 'online_service', 'outbound') or a custom scene_name
|
||||
from ontology_scene table
|
||||
pruning_scene: Scene name for pruning from ontology_scene table
|
||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
||||
ontology_classes: List of class_name strings from ontology_class table,
|
||||
injected into the prompt when pruning_scene is not a built-in scene
|
||||
scene_id: Optional ontology scene UUID
|
||||
ontology_class_infos: Full ontology class info (name + description) from
|
||||
ontology_class table, injected into the pruning prompt to drive
|
||||
scene-aware preservation decisions
|
||||
"""
|
||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||
pruning_scene: str = Field(
|
||||
"education",
|
||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
||||
description="Scene name from ontology_scene table.",
|
||||
)
|
||||
pruning_threshold: float = Field(
|
||||
0.5, ge=0.0, le=0.9,
|
||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||
ontology_classes: Optional[List[str]] = Field(
|
||||
None, description="Class names from ontology_class table for custom scenes."
|
||||
ontology_class_infos: List[OntologyClassInfo] = Field(
|
||||
default_factory=list,
|
||||
description="Full ontology class info (name + description) injected into pruning prompt."
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -238,7 +238,7 @@ def rerank_with_activation(
|
||||
|
||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
|
||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
||||
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||
keyword_items = keyword_results.get(category, [])
|
||||
embedding_items = embedding_results.get(category, [])
|
||||
|
||||
@@ -281,21 +281,23 @@ def rerank_with_activation(
|
||||
for item in items_list:
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
if item_id and item_id in combined_items:
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||
|
||||
# 步骤 4: 计算基础分数和最终分数
|
||||
for item_id, item in combined_items.items():
|
||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
||||
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||
raw_act_norm = item.get("normalized_activation_value")
|
||||
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||
|
||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||
base_score = content_score # 第一阶段用内容分数
|
||||
|
||||
# 存储激活度分数供第二阶段使用
|
||||
item["activation_score"] = act_norm
|
||||
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||
item["activation_score"] = act_norm # 可能为 None
|
||||
item["content_score"] = content_score
|
||||
item["base_score"] = base_score
|
||||
|
||||
@@ -724,6 +726,8 @@ async def run_hybrid_search(
|
||||
try:
|
||||
keyword_task = None
|
||||
embedding_task = None
|
||||
keyword_results: Dict[str, List] = {}
|
||||
embedding_results: Dict[str, List] = {}
|
||||
|
||||
if search_type in ["keyword", "hybrid"]:
|
||||
# Keyword-based search
|
||||
@@ -746,35 +750,42 @@ async def run_hybrid_search(
|
||||
|
||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||
config_load_start = time.time()
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
config_service = MemoryConfigService(db)
|
||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||
rb_config = RedBearModelConfig(
|
||||
model_name=embedder_config_dict["model_name"],
|
||||
provider=embedder_config_dict["provider"],
|
||||
api_key=embedder_config_dict["api_key"],
|
||||
base_url=embedder_config_dict["base_url"],
|
||||
type="llm"
|
||||
)
|
||||
)
|
||||
config_load_time = time.time() - config_load_start
|
||||
logger.info(f"[PERF] Config loading took {config_load_time:.4f}s")
|
||||
|
||||
# Init embedder
|
||||
embedder_init_start = time.time()
|
||||
embedder = OpenAIEmbedderClient(model_config=rb_config)
|
||||
embedder_init_time = time.time() - embedder_init_start
|
||||
logger.info(f"[PERF] Embedder init took {embedder_init_time:.4f}s")
|
||||
|
||||
embedding_task = asyncio.create_task(
|
||||
search_graph_by_embedding(
|
||||
connector=connector,
|
||||
embedder_client=embedder,
|
||||
query_text=query_text,
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
include=include,
|
||||
)
|
||||
)
|
||||
except Exception as emb_init_err:
|
||||
logger.warning(
|
||||
f"[PERF] Embedding search skipped due to init error "
|
||||
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
|
||||
)
|
||||
embedding_task = None
|
||||
|
||||
if keyword_task:
|
||||
keyword_results = await keyword_task
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||
|
||||
__all__ = ["LabelPropagationEngine"]
|
||||
@@ -0,0 +1,572 @@
|
||||
"""标签传播聚类引擎
|
||||
|
||||
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||
|
||||
支持两种模式:
|
||||
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from math import sqrt
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全量迭代最大轮数,防止不收敛
|
||||
MAX_ITERATIONS = 10
|
||||
|
||||
# 社区核心实体取 top-N 数量
|
||||
CORE_ENTITY_LIMIT = 10
|
||||
|
||||
|
||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||
if not v1 or not v2 or len(v1) != len(v2):
|
||||
return 0.0
|
||||
dot = sum(a * b for a, b in zip(v1, v2))
|
||||
norm1 = sqrt(sum(a * a for a in v1))
|
||||
norm2 = sqrt(sum(b * b for b in v2))
|
||||
if norm1 == 0 or norm2 == 0:
|
||||
return 0.0
|
||||
return dot / (norm1 * norm2)
|
||||
|
||||
|
||||
def _weighted_vote(
|
||||
neighbors: List[Dict],
|
||||
self_embedding: Optional[List[float]],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
加权多数投票,选出得票最高的社区。
|
||||
|
||||
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||
没有 community_id 的邻居不参与投票。
|
||||
"""
|
||||
votes: Dict[str, float] = {}
|
||||
for nb in neighbors:
|
||||
cid = nb.get("community_id")
|
||||
if not cid:
|
||||
continue
|
||||
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||
act = nb.get("activation_value") or 0.5
|
||||
# 语义相似度权重 0.6,激活值权重 0.4
|
||||
weight = 0.6 * sem + 0.4 * act
|
||||
votes[cid] = votes.get(cid, 0.0) + weight
|
||||
|
||||
if not votes:
|
||||
return None
|
||||
return max(votes, key=votes.__getitem__)
|
||||
|
||||
|
||||
class LabelPropagationEngine:
|
||||
"""标签传播聚类引擎"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
connector: Neo4jConnector,
|
||||
llm_model_id: Optional[str] = None,
|
||||
embedding_model_id: Optional[str] = None,
|
||||
):
|
||||
self.connector = connector
|
||||
self.repo = CommunityRepository(connector)
|
||||
self.llm_model_id = llm_model_id
|
||||
self.embedding_model_id = embedding_model_id
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 公开接口
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def run(
|
||||
self,
|
||||
end_user_id: str,
|
||||
new_entity_ids: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
统一入口:自动判断全量还是增量。
|
||||
|
||||
- 若该用户尚无 Community 节点 → 全量初始化
|
||||
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||
"""
|
||||
has_communities = await self.repo.has_communities(end_user_id)
|
||||
if not has_communities:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||
await self.full_clustering(end_user_id)
|
||||
else:
|
||||
if new_entity_ids:
|
||||
logger.info(
|
||||
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||
)
|
||||
await self.incremental_update(new_entity_ids, end_user_id)
|
||||
|
||||
async def full_clustering(self, end_user_id: str) -> None:
|
||||
"""
|
||||
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||
|
||||
策略:
|
||||
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||
- 所有批次完成后统一 flush 和 merge
|
||||
"""
|
||||
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||
|
||||
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||
total_count = await self.repo.get_entity_count(end_user_id)
|
||||
if not total_count:
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||
return
|
||||
|
||||
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
|
||||
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||
|
||||
# labels 跨批次共享:只存 id→community_id,内存极小
|
||||
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
|
||||
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
|
||||
|
||||
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||
batch_entities = await self.repo.get_entities_page(
|
||||
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||
)
|
||||
if not batch_entities:
|
||||
break
|
||||
|
||||
batch_ids = [e["id"] for e in batch_entities]
|
||||
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||
)
|
||||
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||
batch_ids, end_user_id
|
||||
)
|
||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
changed = 0
|
||||
for entity in batch_entities:
|
||||
eid = entity["id"]
|
||||
neighbors = neighbors_cache.get(eid, [])
|
||||
|
||||
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||
enriched = []
|
||||
for nb in neighbors:
|
||||
nb_copy = dict(nb)
|
||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||
enriched.append(nb_copy)
|
||||
|
||||
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||
if new_label and new_label != labels[eid]:
|
||||
labels[eid] = new_label
|
||||
changed += 1
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||
)
|
||||
if changed == 0:
|
||||
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||
break
|
||||
|
||||
# 释放本批次的大对象
|
||||
del neighbors_cache, batch_embeddings, batch_entities
|
||||
|
||||
# 所有批次完成,统一写入 Neo4j
|
||||
await self._flush_labels(labels, end_user_id)
|
||||
pre_merge_count = len(set(labels.values()))
|
||||
logger.info(
|
||||
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体,开始后处理合并"
|
||||
)
|
||||
|
||||
all_community_ids = list(set(labels.values()))
|
||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||
|
||||
logger.info(
|
||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||
f"{len(labels)} 个实体"
|
||||
)
|
||||
|
||||
# 查询存活社区并生成元数据
|
||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||
surviving_community_ids = list({
|
||||
e.get("community_id") for e in surviving_communities
|
||||
if e.get("community_id")
|
||||
})
|
||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||
|
||||
async def incremental_update(
|
||||
self, new_entity_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||
|
||||
1. 对每个新实体查询邻居
|
||||
2. 加权多数投票决定社区归属
|
||||
3. 若邻居无社区 → 创建新社区
|
||||
4. 若邻居分属多个社区 → 评估是否合并
|
||||
"""
|
||||
for entity_id in new_entity_ids:
|
||||
await self._process_single_entity(entity_id, end_user_id)
|
||||
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
# 内部方法
|
||||
# ──────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def _process_single_entity(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> None:
|
||||
"""处理单个新实体的社区分配。"""
|
||||
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||
|
||||
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||
|
||||
if not neighbors:
|
||||
# 孤立实体:创建单成员社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||
return
|
||||
|
||||
# 统计邻居社区分布
|
||||
community_ids_in_neighbors = set(
|
||||
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||
)
|
||||
|
||||
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||
|
||||
if target_cid is None:
|
||||
# 邻居都没有社区,连同新实体一起创建新社区
|
||||
new_cid = self._new_community_id()
|
||||
await self.repo.upsert_community(new_cid, end_user_id)
|
||||
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||
for nb in neighbors:
|
||||
await self.repo.assign_entity_to_community(
|
||||
nb["id"], new_cid, end_user_id
|
||||
)
|
||||
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||
logger.debug(
|
||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||
)
|
||||
await self._generate_community_metadata([new_cid], end_user_id)
|
||||
else:
|
||||
# 加入得票最多的社区
|
||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||
|
||||
# 若邻居分属多个社区,评估合并
|
||||
if len(community_ids_in_neighbors) > 1:
|
||||
await self._evaluate_merge(
|
||||
list(community_ids_in_neighbors), end_user_id
|
||||
)
|
||||
await self._generate_community_metadata([target_cid], end_user_id)
|
||||
|
||||
async def _evaluate_merge(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
评估多个社区是否应合并。
|
||||
|
||||
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||
|
||||
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||
"""
|
||||
MERGE_THRESHOLD = 0.85
|
||||
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||
|
||||
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||
community_sizes: Dict[str, int] = {}
|
||||
|
||||
if len(community_ids) > BATCH_THRESHOLD:
|
||||
# 批量查询:一次拉取所有社区成员
|
||||
all_members = await self.repo.get_all_community_members_batch(
|
||||
community_ids, end_user_id
|
||||
)
|
||||
for cid in community_ids:
|
||||
members = all_members.get(cid, [])
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
else:
|
||||
# 增量场景:逐个查询
|
||||
for cid in community_ids:
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
community_sizes[cid] = len(members)
|
||||
valid_embeddings = [
|
||||
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||
]
|
||||
if valid_embeddings:
|
||||
dim = len(valid_embeddings[0])
|
||||
community_embeddings[cid] = [
|
||||
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||
for i in range(dim)
|
||||
]
|
||||
else:
|
||||
community_embeddings[cid] = None
|
||||
|
||||
# 找出应合并的社区对
|
||||
to_merge: List[tuple] = []
|
||||
cids = list(community_ids)
|
||||
for i in range(len(cids)):
|
||||
for j in range(i + 1, len(cids)):
|
||||
sim = _cosine_similarity(
|
||||
community_embeddings[cids[i]],
|
||||
community_embeddings[cids[j]],
|
||||
)
|
||||
if sim > MERGE_THRESHOLD:
|
||||
to_merge.append((cids[i], cids[j]))
|
||||
|
||||
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||
|
||||
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||
|
||||
def get_root(x: str) -> str:
|
||||
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||
while x in merged_into:
|
||||
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||
x = merged_into[x]
|
||||
return x
|
||||
|
||||
for c1, c2 in to_merge:
|
||||
root1, root2 = get_root(c1), get_root(c2)
|
||||
if root1 == root2:
|
||||
continue
|
||||
|
||||
# 用合并后的最新平均向量重新验证相似度
|
||||
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||
current_sim = _cosine_similarity(
|
||||
community_embeddings.get(root1),
|
||||
community_embeddings.get(root2),
|
||||
)
|
||||
if current_sim <= MERGE_THRESHOLD:
|
||||
# 合并后向量已漂移,不再满足阈值,跳过
|
||||
logger.debug(
|
||||
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||
dissolve = root2 if keep == root1 else root1
|
||||
merged_into[dissolve] = keep
|
||||
|
||||
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||
for m in members:
|
||||
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||
|
||||
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||
keep_emb = community_embeddings.get(keep)
|
||||
dissolve_emb = community_embeddings.get(dissolve)
|
||||
keep_size = community_sizes.get(keep, 0)
|
||||
dissolve_size = community_sizes.get(dissolve, 0)
|
||||
total_size = keep_size + dissolve_size
|
||||
if keep_emb and dissolve_emb and total_size > 0:
|
||||
dim = len(keep_emb)
|
||||
community_embeddings[keep] = [
|
||||
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||
for i in range(dim)
|
||||
]
|
||||
community_embeddings[dissolve] = None
|
||||
|
||||
community_sizes[keep] = total_size
|
||||
community_sizes[dissolve] = 0
|
||||
await self.repo.refresh_member_count(keep, end_user_id)
|
||||
logger.info(
|
||||
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||
)
|
||||
|
||||
async def _flush_labels(
|
||||
self, labels: Dict[str, str], end_user_id: str
|
||||
) -> None:
|
||||
"""将内存中的标签批量写入 Neo4j。"""
|
||||
# 先创建所有唯一社区节点
|
||||
unique_communities = set(labels.values())
|
||||
for cid in unique_communities:
|
||||
await self.repo.upsert_community(cid, end_user_id)
|
||||
|
||||
# 再批量分配实体
|
||||
for entity_id, community_id in labels.items():
|
||||
await self.repo.assign_entity_to_community(
|
||||
entity_id, community_id, end_user_id
|
||||
)
|
||||
|
||||
# 刷新成员数
|
||||
for cid in unique_communities:
|
||||
await self.repo.refresh_member_count(cid, end_user_id)
|
||||
|
||||
async def _get_entity_embedding(
|
||||
self, entity_id: str, end_user_id: str
|
||||
) -> Optional[List[float]]:
|
||||
"""查询单个实体的 name_embedding。"""
|
||||
try:
|
||||
result = await self.connector.execute_query(
|
||||
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||
"RETURN e.name_embedding AS name_embedding",
|
||||
eid=entity_id,
|
||||
uid=end_user_id,
|
||||
)
|
||||
return result[0]["name_embedding"] if result else None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
|
||||
lines = []
|
||||
for m in members:
|
||||
m_name = m.get("name", "")
|
||||
aliases = m.get("aliases") or []
|
||||
description = m.get("description") or ""
|
||||
example = m.get("example") or ""
|
||||
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||
desc_str = f":{description}" if description else ""
|
||||
example_str = f"(示例:{example})" if example else ""
|
||||
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
|
||||
return lines
|
||||
|
||||
async def _generate_community_metadata(
|
||||
self, community_ids: List[str], end_user_id: str
|
||||
) -> None:
|
||||
"""
|
||||
为一个或多个社区生成并写入元数据。
|
||||
|
||||
流程:
|
||||
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||
2. 收集所有 summary,一次性批量 embed
|
||||
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||
"""
|
||||
if not community_ids:
|
||||
return
|
||||
|
||||
from app.db import get_db_context
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
|
||||
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||
async def _build_one(cid: str):
|
||||
members = await self.repo.get_community_members(cid, end_user_id)
|
||||
if not members:
|
||||
return None
|
||||
|
||||
sorted_members = sorted(
|
||||
members,
|
||||
key=lambda m: m.get("activation_value") or 0,
|
||||
reverse=True,
|
||||
)
|
||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||
|
||||
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||
|
||||
# 方案四:注入社区内实体间关系三元组
|
||||
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||
rel_lines = [
|
||||
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||
for r in relationships
|
||||
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||
]
|
||||
rel_section = (
|
||||
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||
if rel_lines else ""
|
||||
)
|
||||
|
||||
prompt = (
|
||||
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||
f"请为这组实体所代表的主题:\n"
|
||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||
f"严格按以下格式输出,不要有其他内容:\n"
|
||||
f"名称:<名称>\n摘要:<摘要>"
|
||||
)
|
||||
with get_db_context() as db:
|
||||
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
|
||||
name, summary = "", ""
|
||||
for line in text.strip().splitlines():
|
||||
if line.startswith("名称:"):
|
||||
name = line[3:].strip()
|
||||
elif line.startswith("摘要:"):
|
||||
summary = line[3:].strip()
|
||||
|
||||
return {
|
||||
"community_id": cid,
|
||||
"end_user_id": end_user_id,
|
||||
"name": name,
|
||||
"summary": summary,
|
||||
"core_entities": core_entities,
|
||||
"summary_embedding": None,
|
||||
}
|
||||
|
||||
results = await asyncio.gather(
|
||||
*[_build_one(cid) for cid in community_ids],
|
||||
return_exceptions=True,
|
||||
)
|
||||
metadata_list = []
|
||||
for cid, res in zip(community_ids, results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||
elif res is not None:
|
||||
metadata_list.append(res)
|
||||
|
||||
if not metadata_list:
|
||||
return
|
||||
|
||||
# --- 阶段2:批量生成 summary_embedding ---
|
||||
summaries = [m["summary"] for m in metadata_list]
|
||||
with get_db_context() as db:
|
||||
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||
embeddings = await embedder.response(summaries)
|
||||
for i, meta in enumerate(metadata_list):
|
||||
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||
|
||||
# --- 阶段3:写入(单个 or 批量)---
|
||||
if len(metadata_list) == 1:
|
||||
m = metadata_list[0]
|
||||
result = await self.repo.update_community_metadata(
|
||||
community_id=m["community_id"],
|
||||
end_user_id=m["end_user_id"],
|
||||
name=m["name"],
|
||||
summary=m["summary"],
|
||||
core_entities=m["core_entities"],
|
||||
summary_embedding=m["summary_embedding"],
|
||||
)
|
||||
if result:
|
||||
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||
else:
|
||||
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||
if ok:
|
||||
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||
else:
|
||||
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||
|
||||
@staticmethod
|
||||
def _new_community_id() -> str:
|
||||
return str(uuid.uuid4())
|
||||
@@ -20,7 +20,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||
from app.core.memory.models.config_models import PruningConfig
|
||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||
SceneConfigRegistry,
|
||||
@@ -33,6 +32,9 @@ class DialogExtractionResponse(BaseModel):
|
||||
|
||||
- is_related:对话与场景的相关性判定。
|
||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||
- scene_unrelated_snippets:与当前场景无关且无语义关联的消息片段(原文截取),
|
||||
用于高阈值阶段精准删除跨场景内容。
|
||||
"""
|
||||
is_related: bool = Field(...)
|
||||
times: List[str] = Field(default_factory=list)
|
||||
@@ -41,6 +43,8 @@ class DialogExtractionResponse(BaseModel):
|
||||
contacts: List[str] = Field(default_factory=list)
|
||||
addresses: List[str] = Field(default_factory=list)
|
||||
keywords: List[str] = Field(default_factory=list)
|
||||
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||
scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容")
|
||||
|
||||
|
||||
class MessageImportanceResponse(BaseModel):
|
||||
@@ -86,26 +90,19 @@ class SemanticPruner:
|
||||
self._detailed_prune_logging = True # 是否启用详细日志
|
||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||
|
||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
||||
self.config.pruning_scene,
|
||||
fallback_to_generic=True
|
||||
)
|
||||
# 加载统一填充词库
|
||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||
|
||||
# 判断是否为内置专门场景
|
||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
||||
# 本体类型列表:直接使用 ontology_class_infos(name + description)
|
||||
self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or []
|
||||
# _ontology_classes 仅用于日志统计
|
||||
self._ontology_classes = [info.class_name for info in self._ontology_class_infos]
|
||||
|
||||
# 自定义场景的本体类型列表(用于注入提示词)
|
||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||
|
||||
if self._is_builtin_scene:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||
if self._ontology_class_infos:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
||||
if self._ontology_classes:
|
||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||
else:
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||
|
||||
# Load Jinja2 template
|
||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||
@@ -117,107 +114,28 @@ class SemanticPruner:
|
||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||
self.run_logs: List[str] = []
|
||||
|
||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
||||
"""基于启发式规则识别重要信息消息,优先保留。
|
||||
|
||||
改进版:使用场景特定的模式进行识别
|
||||
- 根据 pruning_scene 动态加载对应的识别规则
|
||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
if not text:
|
||||
return False
|
||||
|
||||
# 使用场景特定的模式
|
||||
all_patterns = (
|
||||
self.scene_config.high_priority_patterns +
|
||||
self.scene_config.medium_priority_patterns +
|
||||
self.scene_config.low_priority_patterns
|
||||
)
|
||||
|
||||
for pattern, _ in all_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
return True
|
||||
|
||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
return True
|
||||
|
||||
# 检查是否包含问句关键词
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
return True
|
||||
|
||||
# 检查是否包含决策性关键词
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _importance_score(self, message: ConversationMessage) -> int:
|
||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
||||
|
||||
改进版:使用场景特定的权重体系(0-10分)
|
||||
- 根据场景动态调整不同信息类型的权重
|
||||
- 高优先级模式:4-6分
|
||||
- 中优先级模式:2-3分
|
||||
- 低优先级模式:1分
|
||||
"""
|
||||
text = message.msg.strip()
|
||||
score = 0
|
||||
|
||||
# 使用场景特定的权重
|
||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
||||
score += weight
|
||||
|
||||
# 问句加分
|
||||
if text.endswith("?") or text.endswith("?"):
|
||||
score += 2
|
||||
|
||||
# 包含问句关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
||||
score += 1
|
||||
|
||||
# 包含决策性关键词加分
|
||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
||||
score += 2
|
||||
|
||||
# 长度加分(较长的消息通常包含更多信息)
|
||||
if len(text) > 50:
|
||||
score += 1
|
||||
if len(text) > 100:
|
||||
score += 1
|
||||
|
||||
return min(score, 10) # 最高10分
|
||||
# _is_important_message 和 _importance_score 已移除:
|
||||
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
|
||||
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
|
||||
|
||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||
"""检测典型寒暄/口头禅/确认类短消息。
|
||||
|
||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
||||
满足以下之一视为填充消息:
|
||||
- 纯标点或空白
|
||||
- 在场景特定填充词库中(精确匹配)
|
||||
- 纯表情符号
|
||||
- 常见寒暄(精确匹配短语)
|
||||
|
||||
注意:不再使用长度判断,避免误删短但重要的消息
|
||||
判断顺序:
|
||||
1. 空消息
|
||||
2. 场景特定填充词库精确匹配
|
||||
3. 常见寒暄精确匹配
|
||||
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||
5. 纯表情/标点
|
||||
"""
|
||||
t = message.msg.strip()
|
||||
if not t:
|
||||
return True
|
||||
|
||||
|
||||
# 检查是否在场景特定填充词库中(精确匹配)
|
||||
if t in self.scene_config.filler_phrases:
|
||||
return True
|
||||
|
||||
|
||||
# 常见寒暄和问候(精确匹配,避免误删)
|
||||
common_greetings = {
|
||||
"在吗", "在不在", "在呢", "在的",
|
||||
@@ -229,25 +147,60 @@ class SemanticPruner:
|
||||
}
|
||||
if t in common_greetings:
|
||||
return True
|
||||
|
||||
|
||||
# 组合寒暄模式:短消息(≤15字)且完全由寒暄成分构成
|
||||
# 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充
|
||||
if len(t) <= 15:
|
||||
# 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢"
|
||||
_confirm_prefixes = {"好的", "好", "嗯", "嗯嗯", "哦", "明白", "明白了", "知道了", "了解", "收到", "没问题"}
|
||||
_thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"}
|
||||
_greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"}
|
||||
_greeting_prefixes = {"同学", "老师", "您好", "你好"}
|
||||
_close_patterns = {
|
||||
"没有了", "没事了", "没问题了", "好了", "行了", "可以了",
|
||||
"不用了", "不需要了", "就这样", "就这样吧", "那就这样",
|
||||
}
|
||||
_polite_responses = {
|
||||
"不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的",
|
||||
}
|
||||
|
||||
# 规则1:确认词 + 感谢词(如"好的谢谢"、"嗯谢谢")
|
||||
for cp in _confirm_prefixes:
|
||||
for ts in _thanks_suffixes:
|
||||
if t == cp + ts or t == cp + "," + ts or t == cp + "," + ts:
|
||||
return True
|
||||
|
||||
# 规则2:称呼前缀 + 问候(如"同学你好"、"老师好")
|
||||
for gp in _greeting_prefixes:
|
||||
for gs in _greeting_suffixes:
|
||||
if t == gp + gs or t.startswith(gp) and t.endswith("好"):
|
||||
return True
|
||||
|
||||
# 规则3:结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢")
|
||||
for cp in _close_patterns:
|
||||
if t.startswith(cp):
|
||||
remainder = t[len(cp):].lstrip(",,、 ")
|
||||
if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes):
|
||||
return True
|
||||
|
||||
# 规则4:礼貌回应(如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话)
|
||||
for pr in _polite_responses:
|
||||
if t.startswith(pr):
|
||||
remainder = t[len(pr):].lstrip(",,、 ")
|
||||
# 后半是祝福/套话(不含实质信息)
|
||||
if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder):
|
||||
return True
|
||||
|
||||
# 规则5:纯确认词加"了"后缀(如"明白了"、"知道了"、"好了")
|
||||
_confirm_base = {"明白", "知道", "了解", "收到", "好", "行", "可以", "没问题"}
|
||||
for cb in _confirm_base:
|
||||
if t == cb + "了" or t == cb + "了。" or t == cb + "了!":
|
||||
return True
|
||||
|
||||
# 检查是否为纯表情符号(方括号包裹)
|
||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||
return True
|
||||
|
||||
# 检查是否为纯emoji(Unicode表情)
|
||||
emoji_pattern = re.compile(
|
||||
"["
|
||||
"\U0001F600-\U0001F64F" # 表情符号
|
||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
||||
"\U00002702-\U000027B0"
|
||||
"\U000024C2-\U0001F251"
|
||||
"]+", flags=re.UNICODE
|
||||
)
|
||||
if emoji_pattern.fullmatch(t):
|
||||
return True
|
||||
|
||||
# 纯标点符号
|
||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||
return True
|
||||
@@ -432,15 +385,13 @@ class SemanticPruner:
|
||||
|
||||
rendered = self.template.render(
|
||||
pruning_scene=self.config.pruning_scene,
|
||||
is_builtin_scene=self._is_builtin_scene,
|
||||
ontology_classes=self._ontology_classes,
|
||||
ontology_class_infos=self._ontology_class_infos,
|
||||
dialog_text=dialog_text,
|
||||
language=self.language
|
||||
)
|
||||
log_template_rendering("extracat_Pruning.jinja2", {
|
||||
"pruning_scene": self.config.pruning_scene,
|
||||
"is_builtin_scene": self._is_builtin_scene,
|
||||
"ontology_classes_count": len(self._ontology_classes),
|
||||
"ontology_class_infos_count": len(self._ontology_class_infos),
|
||||
"language": self.language
|
||||
})
|
||||
log_prompt_rendering("pruning-extract", rendered)
|
||||
@@ -480,6 +431,183 @@ class SemanticPruner:
|
||||
)
|
||||
return fallback_response
|
||||
|
||||
def _get_pruning_mode(self) -> str:
|
||||
"""根据 pruning_threshold 返回当前剪枝阶段。
|
||||
|
||||
- 低阈值 [0.0, 0.3):conservative 只删填充,保留所有实质内容
|
||||
- 中阈值 [0.3, 0.6):semantic 保留场景相关 + 有语义关联的内容,删除无关联内容
|
||||
- 高阈值 [0.6, 0.9]:strict 只保留场景相关内容,跨场景内容可被删除
|
||||
"""
|
||||
t = float(self.config.pruning_threshold)
|
||||
if t < 0.3:
|
||||
return "conservative"
|
||||
elif t < 0.6:
|
||||
return "semantic"
|
||||
else:
|
||||
return "strict"
|
||||
|
||||
def _apply_related_dialog_pruning(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
extraction: "DialogExtractionResponse",
|
||||
dialog_label: str,
|
||||
pruning_mode: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。
|
||||
|
||||
- conservative:只删填充
|
||||
- semantic / strict:场景感知剪枝
|
||||
"""
|
||||
if pruning_mode == "conservative":
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
return self._prune_fillers_only(msgs, preserve_tokens, dialog_label)
|
||||
else:
|
||||
return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode)
|
||||
|
||||
def _prune_fillers_only(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
preserve_tokens: List[str],
|
||||
dialog_label: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""相关对话专用:只删填充消息,LLM 保护消息和实质内容一律保留。
|
||||
|
||||
不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。
|
||||
至少保留 1 条消息。
|
||||
注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值,
|
||||
即使 LLM 误将其关键词放入 preserve_tokens 也应删除。
|
||||
"""
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [填充] '{m.msg[:40]}' → 删除")
|
||||
continue
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
self._log(f" [保护] '{m.msg[:40]}' → LLM保护,跳过")
|
||||
|
||||
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
|
||||
deleted = len(msgs) - len(kept)
|
||||
self._log(
|
||||
f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} "
|
||||
f"填充删除={deleted} 保留={len(kept)}"
|
||||
)
|
||||
return kept
|
||||
|
||||
def _prune_with_scene_filter(
|
||||
self,
|
||||
msgs: List[ConversationMessage],
|
||||
extraction: "DialogExtractionResponse",
|
||||
dialog_label: str,
|
||||
mode: str,
|
||||
) -> List[ConversationMessage]:
|
||||
"""场景感知剪枝,供 semantic / strict 两个阈值档位调用。
|
||||
|
||||
本函数体现剪枝系统的三层递进逻辑:
|
||||
|
||||
第一层(conservative,阈值 < 0.3):
|
||||
不进入本函数,由 _prune_fillers_only 处理。
|
||||
保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。
|
||||
|
||||
第二层(semantic,阈值 [0.3, 0.6)):
|
||||
保留标准:内容价值优先,场景相关性是参考而非唯一标准。
|
||||
- 填充消息 → 删除(最高优先级)
|
||||
- 场景相关消息 → 保留
|
||||
- 场景无关消息 → 有两次豁免机会:
|
||||
1. 命中 scene_preserve_tokens(LLM 标记的关键词/时间/金额等)→ 保留
|
||||
2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值)
|
||||
3. 两次豁免均未命中 → 删除
|
||||
|
||||
第三层(strict,阈值 [0.6, 0.9]):
|
||||
保留标准:场景相关性优先,无任何豁免。
|
||||
- 填充消息 → 删除(最高优先级)
|
||||
- 场景相关消息 → 保留
|
||||
- 场景无关消息 → 直接删除,preserve_keywords 和情感词在此模式下均不生效
|
||||
|
||||
至少保留 1 条消息(兜底取第一条)。
|
||||
"""
|
||||
# strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址),
|
||||
# 不保护 keywords / preserve_keywords,让场景过滤能删掉更多内容。
|
||||
# semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords)。
|
||||
if mode == "strict":
|
||||
scene_preserve_tokens = (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses
|
||||
)
|
||||
else:
|
||||
scene_preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
|
||||
unrelated_snippets = extraction.scene_unrelated_snippets or []
|
||||
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||
if self._is_filler_message(m):
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [填充] '{msg_text[:40]}' → 删除")
|
||||
continue
|
||||
|
||||
# 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况
|
||||
is_scene_unrelated = any(
|
||||
snip and (snip in msg_text or msg_text in snip)
|
||||
for snip in unrelated_snippets
|
||||
)
|
||||
|
||||
if is_scene_unrelated:
|
||||
if mode == "strict":
|
||||
# strict:场景无关直接删除,不做任何豁免
|
||||
# 场景相关性是唯一裁决标准,preserve_keywords 在此模式下不生效
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除")
|
||||
elif mode == "semantic":
|
||||
# semantic:场景无关但有内容价值 → 保留
|
||||
# 豁免第一层:命中 scene_preserve_tokens(关键词/结构化信息保护)
|
||||
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||
self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留")
|
||||
else:
|
||||
# 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留
|
||||
has_contextual_emotion = any(
|
||||
word in msg_text
|
||||
for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧",
|
||||
"喜欢", "讨厌", "爱", "恨", "担心", "害怕", "兴奋",
|
||||
"压力", "累", "疲惫", "烦", "焦虑", "委屈", "感动"]
|
||||
)
|
||||
if not has_contextual_emotion:
|
||||
to_delete_ids.add(id(m))
|
||||
self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)")
|
||||
else:
|
||||
self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留")
|
||||
else:
|
||||
# 不在 scene_unrelated_snippets 中 → 场景相关,直接保留
|
||||
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||
self._log(f" [保护] '{msg_text[:40]}' → LLM保护,跳过")
|
||||
# else: 普通场景相关消息,保留,不输出日志
|
||||
|
||||
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept and msgs:
|
||||
kept = [msgs[0]]
|
||||
|
||||
deleted = len(msgs) - len(kept)
|
||||
self._log(
|
||||
f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} "
|
||||
f"删除={deleted} 保留={len(kept)}"
|
||||
)
|
||||
return kept
|
||||
|
||||
def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]:
|
||||
"""统一构建 preserve_tokens,合并 LLM 抽取的所有重要片段。"""
|
||||
return (
|
||||
extraction.times + extraction.ids + extraction.amounts +
|
||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||
extraction.preserve_keywords
|
||||
)
|
||||
|
||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||
if not tokens:
|
||||
@@ -500,66 +628,62 @@ class SemanticPruner:
|
||||
|
||||
proportion = float(self.config.pruning_threshold)
|
||||
extraction = await self._extract_dialog_important(dialog.content)
|
||||
pruning_mode = self._get_pruning_mode()
|
||||
self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}")
|
||||
|
||||
if extraction.is_related:
|
||||
# 相关对话不剪枝
|
||||
kept = self._apply_related_dialog_pruning(
|
||||
dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode
|
||||
)
|
||||
dialog.context = ConversationContext(msgs=kept)
|
||||
return dialog
|
||||
|
||||
# 在不相关对话中,识别重要/不重要消息
|
||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
||||
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
msgs = dialog.context.msgs
|
||||
imp_unrel_msgs: List[ConversationMessage] = []
|
||||
unimp_unrel_msgs: List[ConversationMessage] = []
|
||||
|
||||
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||
filler_ids: set = set()
|
||||
deletable: List[ConversationMessage] = []
|
||||
|
||||
for m in msgs:
|
||||
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
||||
imp_unrel_msgs.append(m)
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
pass # 保护消息:不加入任何桶,不会被删除
|
||||
elif self._is_filler_message(m):
|
||||
filler_ids.add(id(m))
|
||||
else:
|
||||
unimp_unrel_msgs.append(m)
|
||||
# 计算总删除目标数量
|
||||
deletable.append(m)
|
||||
|
||||
# 计算删除目标
|
||||
total_unrel = len(msgs)
|
||||
delete_target = int(total_unrel * proportion)
|
||||
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
|
||||
unimp_del_cap = len(unimp_unrel_msgs)
|
||||
max_capacity = max(0, len(msgs) - 1)
|
||||
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
|
||||
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
# 删除配额分配
|
||||
del_unimp = min(delete_target, unimp_del_cap)
|
||||
rem = delete_target - del_unimp
|
||||
del_imp = min(rem, imp_del_cap)
|
||||
|
||||
# 选取删除集合
|
||||
unimp_delete_ids = []
|
||||
imp_delete_ids = []
|
||||
if del_unimp > 0:
|
||||
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
|
||||
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
|
||||
if del_imp > 0:
|
||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
|
||||
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
|
||||
|
||||
# 统计实际删除数量(重要/不重要)
|
||||
actual_unimp_deleted = 0
|
||||
actual_imp_deleted = 0
|
||||
kept_msgs = []
|
||||
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
|
||||
# 优先删填充,再删其他可删消息(按出现顺序)
|
||||
to_delete_ids: set = set()
|
||||
for m in msgs:
|
||||
mid = id(m)
|
||||
if mid in delete_targets:
|
||||
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
|
||||
actual_unimp_deleted += 1
|
||||
continue
|
||||
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
|
||||
actual_imp_deleted += 1
|
||||
continue
|
||||
kept_msgs.append(m)
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
if id(m) in filler_ids:
|
||||
to_delete_ids.add(id(m))
|
||||
for m in deletable:
|
||||
if len(to_delete_ids) >= delete_target:
|
||||
break
|
||||
to_delete_ids.add(id(m))
|
||||
|
||||
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
deleted_total = actual_unimp_deleted + actual_imp_deleted
|
||||
deleted_total = len(msgs) - len(kept_msgs)
|
||||
protected_count = len(msgs) - len(filler_ids) - len(deletable)
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
|
||||
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
|
||||
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
dialog.context = ConversationContext(msgs=kept_msgs)
|
||||
@@ -590,140 +714,192 @@ class SemanticPruner:
|
||||
self._log(
|
||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||
)
|
||||
|
||||
|
||||
pruning_mode = self._get_pruning_mode()
|
||||
self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}")
|
||||
|
||||
result: List[DialogData] = []
|
||||
total_original_msgs = 0
|
||||
total_deleted_msgs = 0
|
||||
|
||||
for d_idx, dd in enumerate(dialogs):
|
||||
|
||||
# 统计对象:直接收集结构化数据,无需事后正则解析
|
||||
stats = {
|
||||
"scene": self.config.pruning_scene,
|
||||
"dialog_total": len(dialogs),
|
||||
"deletion_ratio": proportion,
|
||||
"enabled": self.config.pruning_switch,
|
||||
"pruning_mode": pruning_mode,
|
||||
"related_count": 0,
|
||||
"unrelated_count": 0,
|
||||
"related_indices": [],
|
||||
"unrelated_indices": [],
|
||||
"total_deleted_messages": 0,
|
||||
"remaining_dialogs": 0,
|
||||
"dialogs": [],
|
||||
}
|
||||
|
||||
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||
|
||||
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
|
||||
async with semaphore:
|
||||
try:
|
||||
return await self._extract_dialog_important(dd.content)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
|
||||
return DialogExtractionResponse(is_related=True)
|
||||
|
||||
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
|
||||
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
|
||||
|
||||
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
|
||||
msgs = dd.context.msgs
|
||||
original_count = len(msgs)
|
||||
total_original_msgs += original_count
|
||||
|
||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
||||
# ========================================================
|
||||
|
||||
# 消息级分类:每条消息独立判断
|
||||
important_msgs = [] # 重要消息(保留)
|
||||
unimportant_msgs = [] # 不重要消息(可删除)
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
|
||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
||||
|
||||
# 相关对话:根据阶段决定处理力度
|
||||
if extraction.is_related:
|
||||
stats["related_count"] += 1
|
||||
stats["related_indices"].append(d_idx + 1)
|
||||
kept = self._apply_related_dialog_pruning(
|
||||
msgs, extraction, f"对话 {d_idx+1}", pruning_mode
|
||||
)
|
||||
deleted_count = original_count - len(kept)
|
||||
total_deleted_msgs += deleted_count
|
||||
dd.context.msgs = kept
|
||||
result.append(dd)
|
||||
stats["dialogs"].append({
|
||||
"index": d_idx + 1,
|
||||
"is_related": True,
|
||||
"total_messages": original_count,
|
||||
"deleted": deleted_count,
|
||||
"kept": len(kept),
|
||||
})
|
||||
continue
|
||||
|
||||
stats["unrelated_count"] += 1
|
||||
stats["unrelated_indices"].append(d_idx + 1)
|
||||
|
||||
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||
|
||||
# 判断是否需要详细日志
|
||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||
|
||||
|
||||
if extraction.preserve_keywords:
|
||||
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
|
||||
|
||||
# 消息级分类:LLM保护 / 填充 / 其他可删
|
||||
llm_protected_msgs = [] # LLM 保护消息(preserve_tokens 命中):绝对不可删除
|
||||
filler_msgs = [] # 填充消息(优先删除)
|
||||
deletable_msgs = [] # 其余消息(按比例删除)
|
||||
|
||||
for idx, m in enumerate(msgs):
|
||||
msg_text = m.msg.strip()
|
||||
|
||||
# ========== 问答对保护判断(已注释) ==========
|
||||
# if idx in protected_indices:
|
||||
# important_msgs.append((idx, m))
|
||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
||||
# ==========================================
|
||||
|
||||
# 填充消息(寒暄、表情等)
|
||||
if self._is_filler_message(m):
|
||||
|
||||
if self._msg_matches_tokens(m, preserve_tokens):
|
||||
llm_protected_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护(LLM,不可删)")
|
||||
elif self._is_filler_message(m):
|
||||
filler_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||
# 重要信息(学号、成绩、时间、金额等)
|
||||
elif self._is_important_message(m):
|
||||
important_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
||||
# 其他消息
|
||||
else:
|
||||
unimportant_msgs.append((idx, m))
|
||||
deletable_msgs.append((idx, m))
|
||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
||||
|
||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 可删")
|
||||
|
||||
# important_msgs 仅用于日志统计
|
||||
important_msgs = llm_protected_msgs
|
||||
|
||||
# 计算删除配额
|
||||
delete_target = int(original_count * proportion)
|
||||
if proportion > 0 and original_count > 0 and delete_target == 0:
|
||||
delete_target = 1
|
||||
|
||||
|
||||
# 确保至少保留1条消息
|
||||
max_deletable = max(0, original_count - 1)
|
||||
delete_target = min(delete_target, max_deletable)
|
||||
|
||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
||||
|
||||
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
|
||||
to_delete_indices = set()
|
||||
deleted_details = [] # 记录删除的消息详情
|
||||
|
||||
deleted_details = []
|
||||
|
||||
# 第一步:删除填充消息
|
||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
||||
for i in range(filler_to_delete):
|
||||
idx, msg = filler_msgs[i]
|
||||
for idx, msg in filler_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||
|
||||
# 第二步:如果还需要删除,删除不重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0:
|
||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
||||
for i in range(unimp_to_delete):
|
||||
idx, msg = unimportant_msgs[i]
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
||||
|
||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
||||
remaining_quota = delete_target - len(to_delete_indices)
|
||||
if remaining_quota > 0 and important_msgs:
|
||||
# 按重要性分数排序(分数低的优先删除)
|
||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
||||
for i in range(imp_to_delete):
|
||||
idx, msg = imp_sorted[i]
|
||||
to_delete_indices.add(idx)
|
||||
score = self._importance_score(msg)
|
||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
||||
|
||||
|
||||
# 第二步:如果还需要删除,按出现顺序删可删消息
|
||||
for idx, msg in deletable_msgs:
|
||||
if len(to_delete_indices) >= delete_target:
|
||||
break
|
||||
to_delete_indices.add(idx)
|
||||
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
|
||||
|
||||
# 执行删除
|
||||
kept_msgs = []
|
||||
for idx, m in enumerate(msgs):
|
||||
if idx not in to_delete_indices:
|
||||
kept_msgs.append(m)
|
||||
|
||||
|
||||
# 确保至少保留1条
|
||||
if not kept_msgs and msgs:
|
||||
kept_msgs = [msgs[0]]
|
||||
|
||||
|
||||
dd.context.msgs = kept_msgs
|
||||
deleted_count = original_count - len(kept_msgs)
|
||||
total_deleted_msgs += deleted_count
|
||||
|
||||
|
||||
# 输出删除详情
|
||||
if deleted_details:
|
||||
self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:")
|
||||
for detail in deleted_details:
|
||||
self._log(f" {detail}")
|
||||
|
||||
|
||||
# ========== 问答对统计(已注释) ==========
|
||||
# qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else ""
|
||||
# ========================================
|
||||
|
||||
|
||||
self._log(
|
||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
||||
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
|
||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||
)
|
||||
|
||||
result.append(dd)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
|
||||
# 保存日志
|
||||
stats["dialogs"].append({
|
||||
"index": d_idx + 1,
|
||||
"is_related": False,
|
||||
"total_messages": original_count,
|
||||
"protected": len(important_msgs),
|
||||
"fillers": len(filler_msgs),
|
||||
"deletable": len(deletable_msgs),
|
||||
"deleted": deleted_count,
|
||||
"kept": len(kept_msgs),
|
||||
})
|
||||
|
||||
result.append(dd)
|
||||
|
||||
# 补全统计对象
|
||||
stats["total_deleted_messages"] = total_deleted_msgs
|
||||
stats["remaining_dialogs"] = len(result)
|
||||
|
||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||
self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}")
|
||||
self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs} 条")
|
||||
|
||||
# 直接序列化统计对象,无需正则解析
|
||||
try:
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
||||
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
|
||||
|
||||
@@ -731,7 +907,7 @@ class SemanticPruner:
|
||||
if not result:
|
||||
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
|
||||
return dialogs
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def _log(self, msg: str) -> None:
|
||||
@@ -743,114 +919,4 @@ class SemanticPruner:
|
||||
pass
|
||||
print(msg)
|
||||
|
||||
def _sanitize_log_line(self, line: str) -> str:
|
||||
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
|
||||
try:
|
||||
return re.sub(r"^\[[^\]]+\]\s*", "", line)
|
||||
except Exception:
|
||||
return line
|
||||
|
||||
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
|
||||
"""将已去前缀的日志列表解析为结构化 JSON,便于数据对接。"""
|
||||
summary = {
|
||||
"scene": self.config.pruning_scene,
|
||||
"dialog_total": None,
|
||||
"deletion_ratio": None,
|
||||
"enabled": None,
|
||||
"related_count": None,
|
||||
"unrelated_count": None,
|
||||
"related_indices": [],
|
||||
"unrelated_indices": [],
|
||||
"total_deleted_messages": None,
|
||||
"remaining_dialogs": None,
|
||||
}
|
||||
dialogs = []
|
||||
|
||||
# 解析函数
|
||||
def parse_int(value: str) -> Optional[int]:
|
||||
try:
|
||||
return int(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_float(value: str) -> Optional[float]:
|
||||
try:
|
||||
return float(value)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def parse_indices(s: str) -> List[int]:
|
||||
s = s.strip()
|
||||
if not s:
|
||||
return []
|
||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
||||
out: List[int] = []
|
||||
for p in parts:
|
||||
try:
|
||||
out.append(int(p))
|
||||
except Exception:
|
||||
pass
|
||||
return out
|
||||
|
||||
# 正则
|
||||
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
|
||||
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
|
||||
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
|
||||
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
|
||||
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
|
||||
re_remaining = re.compile(r"剩余对话数=(\d+)")
|
||||
|
||||
for line in logs:
|
||||
# 第一行:总览
|
||||
m = re_header.search(line)
|
||||
if m:
|
||||
summary["dialog_total"] = parse_int(m.group(1))
|
||||
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
|
||||
summary["deletion_ratio"] = parse_float(m.group(3))
|
||||
summary["enabled"] = True if m.group(4) == "True" else False
|
||||
continue
|
||||
|
||||
# 第二行:相关/不相关数量
|
||||
m = re_counts.search(line)
|
||||
if m:
|
||||
summary["related_count"] = parse_int(m.group(1))
|
||||
summary["unrelated_count"] = parse_int(m.group(2))
|
||||
continue
|
||||
|
||||
# 第三行:相关/不相关索引
|
||||
m = re_indices.search(line)
|
||||
if m:
|
||||
summary["related_indices"] = parse_indices(m.group(1))
|
||||
summary["unrelated_indices"] = parse_indices(m.group(2))
|
||||
continue
|
||||
|
||||
# 对话级统计
|
||||
m = re_dialog.search(line)
|
||||
if m:
|
||||
dialogs.append({
|
||||
"index": parse_int(m.group(1)),
|
||||
"total_messages": parse_int(m.group(2)),
|
||||
"quota_delete": parse_int(m.group(3)),
|
||||
"actual_deleted": parse_int(m.group(4)),
|
||||
"kept": parse_int(m.group(5)),
|
||||
})
|
||||
continue
|
||||
|
||||
# 全局删除总数
|
||||
m = re_total_del.search(line)
|
||||
if m:
|
||||
summary["total_deleted_messages"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
# 剩余对话数
|
||||
m = re_remaining.search(line)
|
||||
if m:
|
||||
summary["remaining_dialogs"] = parse_int(m.group(1))
|
||||
continue
|
||||
|
||||
return {
|
||||
"scene": summary["scene"],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"summary": {k: v for k, v in summary.items() if k != "scene"},
|
||||
"dialogs": dialogs,
|
||||
}
|
||||
|
||||
@@ -1,66 +1,25 @@
|
||||
"""
|
||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
||||
场景特定配置 - 统一填充词库
|
||||
|
||||
功能:
|
||||
- 场景特定的重要信息识别模式
|
||||
- 场景特定的重要性评分权重
|
||||
- 场景特定的填充词库
|
||||
- 场景特定的问答对识别规则
|
||||
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
|
||||
本模块仅保留统一填充词库(filler_phrases),用于识别无意义寒暄/表情/口头禅。
|
||||
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Set, Tuple
|
||||
from typing import List, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScenePatterns:
|
||||
"""场景特定的识别模式"""
|
||||
|
||||
# 重要信息的正则模式(优先级从高到低)
|
||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
||||
|
||||
# 填充词库(无意义对话)
|
||||
"""场景特定的识别模式(仅保留填充词库)"""
|
||||
filler_phrases: Set[str] = field(default_factory=set)
|
||||
|
||||
# 问句关键词(用于识别问答对)
|
||||
question_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
# 决策性/承诺性关键词
|
||||
decision_keywords: Set[str] = field(default_factory=set)
|
||||
|
||||
|
||||
class SceneConfigRegistry:
|
||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
||||
|
||||
# 基础通用模式(所有场景共享)
|
||||
BASE_HIGH_PRIORITY = [
|
||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
||||
(r"\d{11}", 4), # 手机号
|
||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
||||
]
|
||||
|
||||
BASE_MEDIUM_PRIORITY = [
|
||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
||||
(r"地址|地点|位置", 2),
|
||||
(r"时间|日期|有效期|截止", 2),
|
||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
||||
(r"今年|去年|明年", 3),
|
||||
]
|
||||
|
||||
BASE_LOW_PRIORITY = [
|
||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
||||
(r"AM|PM|am|pm", 1),
|
||||
]
|
||||
|
||||
BASE_FILLERS = {
|
||||
"""场景配置注册表 - 所有场景共用统一填充词库"""
|
||||
|
||||
BASE_FILLERS: Set[str] = {
|
||||
# 基础寒暄
|
||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
|
||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||
# 确认词
|
||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
"是的", "对", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
|
||||
# 服务类套话
|
||||
"请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "欢迎下次咨询",
|
||||
# 外呼套话
|
||||
"喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
|
||||
# 教育场景套话
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
# 标点和符号
|
||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||
# 表情符号
|
||||
@@ -81,246 +59,8 @@ class SceneConfigRegistry:
|
||||
"hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok",
|
||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||
}
|
||||
|
||||
BASE_QUESTION_KEYWORDS = {
|
||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
||||
}
|
||||
|
||||
BASE_DECISION_KEYWORDS = {
|
||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_education_config(cls) -> ScenePatterns:
|
||||
"""教育场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 成绩相关(最高优先级)
|
||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
||||
(r"GPA|绩点|学分|平均分", 6),
|
||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
||||
|
||||
# 学籍信息
|
||||
(r"学号|学生证|教师工号|工号", 5),
|
||||
(r"班级|年级|专业|院系", 4),
|
||||
|
||||
# 课程相关
|
||||
(r"课程|科目|学科|必修|选修", 4),
|
||||
(r"教材|课本|教科书|参考书", 4),
|
||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
||||
|
||||
# 学科内容(新增)
|
||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
||||
(r"代数|几何|三角|概率|统计", 4),
|
||||
(r"物理|化学|生物|历史|地理", 4),
|
||||
(r"英语|语文|数学|政治|哲学", 4),
|
||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
||||
(r"例题|解题|证明|推导|计算", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 教学活动
|
||||
(r"作业|练习|习题|题目", 3),
|
||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
||||
(r"上课|下课|课堂|讲课", 2),
|
||||
(r"提问|回答|发言|讨论", 2),
|
||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
||||
|
||||
# 时间安排
|
||||
(r"课表|课程表|时间表", 3),
|
||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"老师|教师|同学|学生", 1),
|
||||
(r"教室|实验室|图书馆", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||
"举手", "请坐", "很好", "不错", "继续",
|
||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"必考", "重点", "考点", "难点", "关键",
|
||||
"记住", "背诵", "掌握", "理解", "复习",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_online_service_config(cls) -> ScenePatterns:
|
||||
"""在线服务场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 工单相关(最高优先级)
|
||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
||||
|
||||
# 产品信息
|
||||
(r"产品型号|型号|SKU|产品编号", 5),
|
||||
(r"序列号|SN|设备号", 5),
|
||||
(r"版本号|软件版本|固件版本", 4),
|
||||
|
||||
# 问题描述
|
||||
(r"故障|错误|异常|bug|问题", 4),
|
||||
(r"错误代码|故障代码|error code", 5),
|
||||
(r"无法|不能|失败|报错", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 服务相关
|
||||
(r"退款|退货|换货|补发", 4),
|
||||
(r"发票|收据|凭证", 3),
|
||||
(r"物流|快递|运单号", 3),
|
||||
(r"保修|质保|售后", 3),
|
||||
|
||||
# 时效相关
|
||||
(r"SLA|响应时间|处理时长", 4),
|
||||
(r"超时|延迟|等待", 2),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"客服|工程师|技术支持", 1),
|
||||
(r"用户|客户|会员", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 在线服务特有填充词
|
||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||
"感谢您的耐心等待", "抱歉让您久等了",
|
||||
"已记录", "已反馈", "已转接", "已升级",
|
||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"能否", "可否", "是否", "有没有", "能不能",
|
||||
"怎么办", "如何处理", "怎么解决",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"立即处理", "马上解决", "尽快", "优先",
|
||||
"升级", "转接", "派单", "跟进",
|
||||
"补偿", "赔偿", "退款", "换货",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_outbound_config(cls) -> ScenePatterns:
|
||||
"""外呼场景配置"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
||||
# 意向相关(最高优先级)
|
||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
||||
(r"成交|签约|下单|购买|确认", 6),
|
||||
|
||||
# 联系信息(外呼场景中更重要)
|
||||
(r"预约|约定|安排|确定时间", 5),
|
||||
(r"下次联系|回访|跟进", 5),
|
||||
(r"方便|有空|可以|时间", 4),
|
||||
|
||||
# 通话状态
|
||||
(r"接通|未接通|占线|关机|停机", 4),
|
||||
(r"通话时长|通话时间", 3),
|
||||
],
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
||||
# 客户信息
|
||||
(r"姓名|称呼|先生|女士", 3),
|
||||
(r"公司|单位|职位|职务", 3),
|
||||
(r"需求|要求|期望", 3),
|
||||
|
||||
# 跟进状态
|
||||
(r"跟进状态|进展|进度", 3),
|
||||
(r"已联系|待联系|联系中", 2),
|
||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
||||
],
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
||||
(r"销售|客户经理|业务员", 1),
|
||||
(r"产品|服务|方案", 1),
|
||||
],
|
||||
filler_phrases=cls.BASE_FILLERS | {
|
||||
# 外呼场景特有填充词
|
||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||
"我是", "我们是", "我们公司", "我们这边",
|
||||
"了解一下", "介绍一下", "简单说一下",
|
||||
"考虑考虑", "想一想", "再说", "再看看",
|
||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||
"好的", "行", "可以", "没问题", "那就这样",
|
||||
"再联系", "回头聊", "有需要再说",
|
||||
},
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
||||
"了解吗", "知道吗", "听说过吗",
|
||||
"方便吗", "有空吗", "在吗",
|
||||
},
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
||||
"确定", "决定", "选择", "购买", "下单",
|
||||
"预约", "安排", "约定", "确认",
|
||||
"跟进", "回访", "联系", "沟通",
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
||||
"""根据场景名称获取配置
|
||||
|
||||
Args:
|
||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
||||
|
||||
Returns:
|
||||
对应场景的配置,如果场景不存在:
|
||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
||||
- fallback_to_generic=False: 抛出异常
|
||||
"""
|
||||
scene_map = {
|
||||
'education': cls.get_education_config,
|
||||
'online_service': cls.get_online_service_config,
|
||||
'outbound': cls.get_outbound_config,
|
||||
}
|
||||
|
||||
if scene in scene_map:
|
||||
return scene_map[scene]()
|
||||
|
||||
if fallback_to_generic:
|
||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
||||
return cls.get_generic_config()
|
||||
else:
|
||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
||||
|
||||
@classmethod
|
||||
def get_generic_config(cls) -> ScenePatterns:
|
||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
||||
|
||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
||||
"""
|
||||
return ScenePatterns(
|
||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
||||
filler_phrases=cls.BASE_FILLERS,
|
||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_scenes(cls) -> List[str]:
|
||||
"""获取所有预定义场景的列表"""
|
||||
return ['education', 'online_service', 'outbound']
|
||||
|
||||
@classmethod
|
||||
def is_scene_supported(cls, scene: str) -> bool:
|
||||
"""检查场景是否有专门的配置支持
|
||||
|
||||
Args:
|
||||
scene: 场景名称
|
||||
|
||||
Returns:
|
||||
True: 有专门配置
|
||||
False: 将使用通用配置
|
||||
"""
|
||||
return scene in cls.get_all_scenes()
|
||||
def get_config(cls, scene: str = "") -> ScenePatterns:
|
||||
"""所有场景统一返回同一份填充词库"""
|
||||
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
|
||||
|
||||
@@ -384,6 +384,14 @@ class ExtractionOrchestrator:
|
||||
|
||||
logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句")
|
||||
|
||||
# 试运行模式下,所有分块提取完成后发送完成事件
|
||||
if self.progress_callback and self.is_pilot_run:
|
||||
await self.progress_callback(
|
||||
"knowledge_extraction_complete",
|
||||
f"陈述句提取完成,共提取 {len(all_statements)} 条",
|
||||
{"total_statements": len(all_statements), "total_chunks": total_chunks}
|
||||
)
|
||||
|
||||
return dialog_data_list
|
||||
|
||||
async def _extract_triplets(
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
logger = get_memory_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"
|
||||
):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "system",
|
||||
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||
{"role": "user", "content": prompt_content}
|
||||
]
|
||||
|
||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||
return TripletExtractionResponse(triplets=[], entities=[])
|
||||
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||
str, TripletExtractionResponse]:
|
||||
"""Extract triplets and entities from statements
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""
|
||||
自我反思引擎实现
|
||||
Self-Reflection Engine Implementation
|
||||
|
||||
该模块实现了记忆系统的自我反思功能,包括:
|
||||
1. 基于时间的反思 - 根据时间周期触发反思
|
||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
||||
3. 综合反思 - 整合多种反思策略
|
||||
4. 反思结果应用 - 更新记忆库
|
||||
This module implements the self-reflection functionality of the memory system, including:
|
||||
1. Time-based reflection - Triggers reflection based on time cycles
|
||||
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||
4. Reflection result application - Updates memory database
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
# 配置日志
|
||||
# Configure logging
|
||||
_root_logger = logging.getLogger()
|
||||
if not _root_logger.handlers:
|
||||
logging.basicConfig(
|
||||
@@ -49,35 +49,62 @@ else:
|
||||
_root_logger.setLevel(logging.INFO)
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
"""翻译响应模型"""
|
||||
"""Translation response model for language conversion"""
|
||||
data: str
|
||||
|
||||
class ReflectionRange(str, Enum):
|
||||
"""反思范围枚举"""
|
||||
PARTIAL = "partial" # 从检索结果中反思
|
||||
ALL = "all" # 从整个数据库中反思
|
||||
"""
|
||||
Reflection range enumeration
|
||||
|
||||
Defines the scope of data to be included in reflection operations.
|
||||
"""
|
||||
PARTIAL = "partial" # Reflect from retrieval results
|
||||
ALL = "all" # Reflect from entire database
|
||||
|
||||
|
||||
class ReflectionBaseline(str, Enum):
|
||||
"""反思基线枚举"""
|
||||
TIME = "TIME" # 基于时间的反思
|
||||
FACT = "FACT" # 基于事实的反思
|
||||
HYBRID = "HYBRID" # 混合反思
|
||||
"""
|
||||
Reflection baseline enumeration
|
||||
|
||||
Defines the strategy or approach used for reflection operations.
|
||||
"""
|
||||
TIME = "TIME" # Time-based reflection
|
||||
FACT = "FACT" # Fact-based reflection
|
||||
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||
|
||||
|
||||
class ReflectionConfig(BaseModel):
|
||||
"""反思引擎配置"""
|
||||
"""
|
||||
Reflection engine configuration
|
||||
|
||||
Defines all configuration parameters for the reflection engine including
|
||||
operation modes, model settings, and evaluation criteria.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether reflection engine is enabled
|
||||
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||
model_id: LLM model identifier for reflection operations
|
||||
end_user_id: User identifier for scoped operations
|
||||
output_example: Example output format for guidance
|
||||
memory_verify: Enable memory verification checks
|
||||
quality_assessment: Enable quality assessment evaluation
|
||||
violation_handling_strategy: Strategy for handling violations
|
||||
language_type: Language type for output ("zh" or "en")
|
||||
"""
|
||||
enabled: bool = False
|
||||
iteration_period: str = "3" # 反思周期
|
||||
iteration_period: str = "3" # Reflection cycle period
|
||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||
model_id: Optional[str] = None # 模型ID
|
||||
model_id: Optional[str] = None # Model ID
|
||||
end_user_id: Optional[str] = None
|
||||
output_example: Optional[str] = None # 输出示例
|
||||
output_example: Optional[str] = None # Output example
|
||||
|
||||
# 评估相关字段
|
||||
memory_verify: bool = True # 记忆验证
|
||||
quality_assessment: bool = True # 质量评估
|
||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
||||
# Evaluation related fields
|
||||
memory_verify: bool = True # Memory verification
|
||||
quality_assessment: bool = True # Quality assessment
|
||||
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||
language_type: str = "zh"
|
||||
|
||||
class Config:
|
||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
||||
|
||||
|
||||
class ReflectionResult(BaseModel):
|
||||
"""反思结果"""
|
||||
"""
|
||||
Reflection operation result
|
||||
|
||||
Contains comprehensive information about the outcome of a reflection operation
|
||||
including success status, metrics, and execution details.
|
||||
|
||||
Attributes:
|
||||
success: Whether the reflection operation succeeded
|
||||
message: Descriptive message about the operation result
|
||||
conflicts_found: Number of conflicts detected during reflection
|
||||
conflicts_resolved: Number of conflicts successfully resolved
|
||||
memories_updated: Number of memory entries updated in database
|
||||
execution_time: Total time taken for the reflection operation
|
||||
details: Additional details about the operation (optional)
|
||||
"""
|
||||
success: bool
|
||||
message: str
|
||||
conflicts_found: int = 0
|
||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
||||
|
||||
class ReflectionEngine:
|
||||
"""
|
||||
自我反思引擎
|
||||
|
||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
||||
Self-Reflection Engine
|
||||
|
||||
Responsible for executing memory system self-reflection operations including
|
||||
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||
reflection strategies and provides comprehensive result tracking.
|
||||
|
||||
The engine can operate in different modes:
|
||||
- Time-based: Reflects on memories within specific time periods
|
||||
- Fact-based: Detects and resolves factual conflicts in memories
|
||||
- Hybrid: Combines multiple reflection strategies
|
||||
|
||||
Attributes:
|
||||
config: Reflection engine configuration
|
||||
neo4j_connector: Neo4j database connector
|
||||
llm_client: Language model client for analysis
|
||||
Various function handlers for data processing and prompt rendering
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
||||
update_query: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
初始化反思引擎
|
||||
Initialize reflection engine
|
||||
|
||||
Sets up the reflection engine with configuration and optional dependencies.
|
||||
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||
|
||||
Args:
|
||||
config: 反思引擎配置
|
||||
neo4j_connector: Neo4j 连接器(可选)
|
||||
llm_client: LLM 客户端(可选)
|
||||
get_data_func: 获取数据的函数(可选)
|
||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
||||
conflict_schema: 冲突结果 Schema(可选)
|
||||
reflexion_schema: 反思结果 Schema(可选)
|
||||
update_query: 更新查询语句(可选)
|
||||
config: Reflection engine configuration object
|
||||
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||
llm_client: LLM client instance (optional, will be created if not provided)
|
||||
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||
conflict_schema: Schema for conflict result validation (optional)
|
||||
reflexion_schema: Schema for reflection result validation (optional)
|
||||
update_query: Query string for database updates (optional)
|
||||
"""
|
||||
self.config = config
|
||||
self.neo4j_connector = neo4j_connector
|
||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
||||
self.conflict_schema = conflict_schema
|
||||
self.reflexion_schema = reflexion_schema
|
||||
self.update_query = update_query
|
||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
||||
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||
|
||||
|
||||
# 延迟导入以避免循环依赖
|
||||
# Lazy import to avoid circular dependencies
|
||||
self._lazy_init_done = False
|
||||
|
||||
def _lazy_init(self):
|
||||
"""延迟初始化,避免循环导入"""
|
||||
"""
|
||||
Lazy initialization to avoid circular imports
|
||||
|
||||
Initializes dependencies only when needed, preventing circular import issues
|
||||
and optimizing startup performance. Sets up default implementations for
|
||||
any components not provided during construction.
|
||||
"""
|
||||
if self._lazy_init_done:
|
||||
return
|
||||
|
||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
# If llm_client is a string (model_id), use it to initialize the client
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
||||
model_config = config_service.get_model_config(model_id)
|
||||
|
||||
extra_params={
|
||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
||||
"max_tokens": 600, # 限制最大token数
|
||||
"top_p": 0.8, # 优化采样参数
|
||||
"stream": False, # 确保非流式输出以获得最快响应
|
||||
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||
"max_tokens": 600, # Limit maximum token count
|
||||
"top_p": 0.8, # Optimize sampling parameters
|
||||
"stream": False, # Ensure non-streaming output for fastest response
|
||||
}
|
||||
|
||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
||||
if self.get_data_func is None:
|
||||
self.get_data_func = get_data
|
||||
|
||||
# 导入get_data_statement函数
|
||||
# Import get_data_statement function
|
||||
if not hasattr(self, 'get_data_statement'):
|
||||
self.get_data_statement = get_data_statement
|
||||
|
||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
||||
|
||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||
"""
|
||||
执行完整的反思流程
|
||||
Execute complete reflection workflow
|
||||
|
||||
Performs the full reflection process including data retrieval, conflict detection,
|
||||
conflict resolution, and memory updates. This is the main entry point for
|
||||
reflection operations.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host identifier for scoping reflection operations
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive result of the reflection operation including
|
||||
success status, conflict metrics, and execution time
|
||||
"""
|
||||
# 延迟初始化
|
||||
# Lazy initialization
|
||||
self._lazy_init()
|
||||
|
||||
if not self.config.enabled:
|
||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
||||
|
||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||
try:
|
||||
# 1. 获取反思数据
|
||||
# 1. Get reflection data
|
||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||
if not reflexion_data:
|
||||
return ReflectionResult(
|
||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
# 2. Detect conflicts (fact-based reflection)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
|
||||
|
||||
# 4. 应用反思结果(更新记忆库)
|
||||
# 4. Apply reflection results (update memory database)
|
||||
memories_updated=await self._apply_reflection_results(solved_data)
|
||||
|
||||
execution_time = asyncio.get_event_loop().time() - start_time
|
||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
||||
)
|
||||
|
||||
async def Translate(self, text):
|
||||
# 翻译中文为英文
|
||||
"""
|
||||
Translate Chinese text to English
|
||||
|
||||
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||
Provides consistent translation format for reflection results.
|
||||
|
||||
Args:
|
||||
text: Chinese text to be translated
|
||||
|
||||
Returns:
|
||||
str: Translated English text
|
||||
"""
|
||||
# Translate Chinese to English
|
||||
translation_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
||||
)
|
||||
return response.data
|
||||
async def extract_translation(self,data):
|
||||
"""
|
||||
Extract and translate reflection data to English
|
||||
|
||||
Processes reflection data structure and translates all Chinese content to English.
|
||||
Handles nested data structures including memory verifications, quality assessments,
|
||||
and reflection data while preserving the original structure.
|
||||
|
||||
Args:
|
||||
data: Dictionary containing reflection data with Chinese content
|
||||
|
||||
Returns:
|
||||
dict: Translated data structure with English content
|
||||
"""
|
||||
end_datas={}
|
||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||
quality_assessments = []
|
||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
||||
return end_datas
|
||||
|
||||
async def reflection_run(self):
|
||||
"""
|
||||
Execute reflection workflow with comprehensive data processing
|
||||
|
||||
Performs a complete reflection operation including conflict detection, resolution,
|
||||
and result formatting. Supports both Chinese and English output based on
|
||||
configuration settings.
|
||||
|
||||
Returns:
|
||||
dict: Comprehensive reflection results including source data, memory verifications,
|
||||
quality assessments, and reflection data. Results are translated to English
|
||||
if language_type is set to 'en'.
|
||||
"""
|
||||
self._lazy_init()
|
||||
start_time = time.time()
|
||||
memory_verifies_flag = self.config.memory_verify
|
||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||
# 遍历数据提取字段
|
||||
# Traverse data to extract fields
|
||||
quality_assessments = []
|
||||
memory_verifies = []
|
||||
for item in conflict_data:
|
||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
# Clean conflict_data, and memory_verify and quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
for item in conflict_data:
|
||||
cleaned_item = {
|
||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
||||
for item in conflict_data:
|
||||
cleaned_data = []
|
||||
for row in item.get("data", []):
|
||||
# 删除 created_at / expired_at
|
||||
# Remove created_at / expired_at
|
||||
cleaned_row = {
|
||||
k: v
|
||||
for k, v in row.items()
|
||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
||||
}
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
# 3. Resolve conflicts
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
||||
)
|
||||
reflexion_data = []
|
||||
|
||||
# 遍历数据提取reflexion字段
|
||||
# Traverse data to extract reflexion fields
|
||||
for item in solved_data:
|
||||
if 'results' in item:
|
||||
for result in item['results']:
|
||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
async def extract_fields_from_json(self):
|
||||
"""从example.json中提取source_data和databasets字段"""
|
||||
"""
|
||||
Extract source_data and databasets fields from example.json
|
||||
|
||||
Reads reflection example data from the example.json file and extracts
|
||||
the source data and database statements for testing and demonstration purposes.
|
||||
|
||||
Returns:
|
||||
tuple: (source_data, databasets) extracted from the example file
|
||||
Returns empty lists if file reading fails
|
||||
"""
|
||||
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||
try:
|
||||
# 读取JSON文件
|
||||
# Read JSON file
|
||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||
data = json.loads(f.read())
|
||||
|
||||
# 提取memory_verify下的字段
|
||||
# Extract fields under memory_verify
|
||||
memory_verify = data.get("memory_verify", {})
|
||||
source_data = memory_verify.get("source_data", [])
|
||||
databasets = memory_verify.get("databasets", [])
|
||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||
"""
|
||||
获取反思数据
|
||||
|
||||
根据配置的反思范围获取需要反思的记忆数据。
|
||||
Get reflection data from database
|
||||
|
||||
Retrieves memory data for reflection based on the configured reflection range.
|
||||
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping data retrieval
|
||||
|
||||
Returns:
|
||||
List[Any]: 反思数据列表
|
||||
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||
Returns empty lists if query fails
|
||||
"""
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
||||
|
||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
检测冲突(基于事实的反思)
|
||||
|
||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
||||
Detect conflicts (fact-based reflection)
|
||||
|
||||
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||
Performs comprehensive conflict detection including memory verification and
|
||||
quality assessment based on configuration settings.
|
||||
|
||||
Args:
|
||||
data: 待检测的记忆数据
|
||||
data: Memory data to be analyzed for conflicts
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 冲突记忆列表
|
||||
List[Any]: List of detected conflicts with detailed analysis
|
||||
"""
|
||||
if not data:
|
||||
return []
|
||||
|
||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
||||
# Data preprocessing: if data is too small, return no conflicts directly
|
||||
if len(data) < 2:
|
||||
logging.info("数据量不足,无需检测冲突")
|
||||
return []
|
||||
|
||||
# 使用转换后的数据
|
||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
||||
# Use converted data
|
||||
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
logging.info("====== 冲突检测开始 ======")
|
||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
||||
language_type=self.config.language_type
|
||||
|
||||
try:
|
||||
# 渲染冲突检测提示词
|
||||
# Render conflict detection prompt
|
||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||
data,
|
||||
self.conflict_schema,
|
||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||
|
||||
# 调用 LLM 进行冲突检测
|
||||
# Call LLM for conflict detection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.conflict_schema
|
||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
||||
logging.error("LLM 冲突检测输出解析失败")
|
||||
return []
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return [response.model_dump()]
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
||||
|
||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||
"""
|
||||
解决冲突
|
||||
|
||||
使用 LLM 对检测到的冲突进行反思和解决。
|
||||
Resolve detected conflicts
|
||||
|
||||
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||
|
||||
Args:
|
||||
conflicts: 冲突列表
|
||||
conflicts: List of conflicts to be resolved
|
||||
statement_databasets: Statement database records for context
|
||||
|
||||
Returns:
|
||||
List[Any]: 解决方案列表
|
||||
List[Any]: List of resolution solutions with reflection results
|
||||
"""
|
||||
if not conflicts:
|
||||
return []
|
||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
||||
baseline = self.config.baseline
|
||||
memory_verify = self.config.memory_verify
|
||||
|
||||
# 并行处理每个冲突
|
||||
# Process each conflict in parallel
|
||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||
"""解决单个冲突"""
|
||||
"""Resolve a single conflict"""
|
||||
async with self._semaphore:
|
||||
try:
|
||||
# 渲染反思提示词
|
||||
# Render reflection prompt
|
||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||
[conflict],
|
||||
self.reflexion_schema,
|
||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
# 调用 LLM 进行反思
|
||||
# Call LLM for reflection
|
||||
response = await self.llm_client.response_structured(
|
||||
messages,
|
||||
self.reflexion_schema
|
||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
||||
if not response:
|
||||
return None
|
||||
|
||||
# 标准化返回格式
|
||||
# Standardize return format
|
||||
if isinstance(response, BaseModel):
|
||||
return response.model_dump()
|
||||
elif hasattr(response, 'dict'):
|
||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
||||
logging.warning(f"解决单个冲突失败: {e}")
|
||||
return None
|
||||
|
||||
# 并发执行所有冲突解决任务
|
||||
# Execute all conflict resolution tasks concurrently
|
||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
# 过滤掉失败的结果
|
||||
# Filter out failed results
|
||||
solved = [r for r in results if r is not None]
|
||||
|
||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
||||
solved_data: List[Dict[str, Any]]
|
||||
) -> int:
|
||||
"""
|
||||
应用反思结果(更新记忆库)
|
||||
|
||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
||||
Apply reflection results (update memory database)
|
||||
|
||||
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||
Processes the solved data and applies changes to the memory storage system.
|
||||
|
||||
Args:
|
||||
solved_data: 解决方案列表
|
||||
solved_data: List of resolved conflict solutions with reflection data
|
||||
|
||||
Returns:
|
||||
int: 成功更新的记忆数量
|
||||
int: Number of successfully updated memory entries
|
||||
"""
|
||||
changes = extract_and_process_changes(solved_data)
|
||||
success_count = await neo4j_data(changes)
|
||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
||||
|
||||
|
||||
|
||||
# 基于时间的反思方法
|
||||
# Time-based reflection methods
|
||||
async def time_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID,
|
||||
time_period: Optional[str] = None
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于时间的反思
|
||||
|
||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
||||
Time-based reflection
|
||||
|
||||
Triggers reflection based on time cycles, checking memories within
|
||||
specified time periods. Uses the configured iteration period if
|
||||
no specific time period is provided.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
period = time_period or self.config.iteration_period
|
||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 基于事实的反思方法
|
||||
# Fact-based reflection methods
|
||||
async def fact_based_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
基于事实的反思
|
||||
|
||||
检测记忆中的事实冲突并解决。
|
||||
Fact-based reflection
|
||||
|
||||
Detects and resolves factual conflicts within memories. Analyzes
|
||||
memory data for inconsistencies and contradictions that need resolution.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result
|
||||
"""
|
||||
logging.info("执行基于事实的反思")
|
||||
|
||||
# 使用标准反思流程
|
||||
# Use standard reflection workflow
|
||||
return await self.execute_reflection(host_id)
|
||||
|
||||
# 综合反思方法
|
||||
# Comprehensive reflection methods
|
||||
async def comprehensive_reflection(
|
||||
self,
|
||||
host_id: uuid.UUID
|
||||
) -> ReflectionResult:
|
||||
"""
|
||||
综合反思
|
||||
|
||||
整合基于时间和基于事实的反思策略。
|
||||
Comprehensive reflection
|
||||
|
||||
Integrates time-based and fact-based reflection strategies based on
|
||||
the configured baseline. Supports hybrid approaches that combine
|
||||
multiple reflection methodologies.
|
||||
|
||||
Args:
|
||||
host_id: 主机ID
|
||||
host_id: Host UUID identifier for scoping reflection
|
||||
|
||||
Returns:
|
||||
ReflectionResult: 反思结果
|
||||
ReflectionResult: Comprehensive reflection operation result combining
|
||||
multiple strategies if using hybrid baseline
|
||||
"""
|
||||
logging.info("执行综合反思")
|
||||
|
||||
# 根据配置的基线选择反思策略
|
||||
# Choose reflection strategy based on configured baseline
|
||||
if self.config.baseline == ReflectionBaseline.TIME:
|
||||
return await self.time_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||
return await self.fact_based_reflection(host_id)
|
||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
||||
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||
time_result = await self.time_based_reflection(host_id)
|
||||
fact_result = await self.fact_based_reflection(host_id)
|
||||
|
||||
# 合并结果
|
||||
# Merge results
|
||||
return ReflectionResult(
|
||||
success=time_result.success and fact_result.success,
|
||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||
|
||||
@@ -2,9 +2,17 @@ import json
|
||||
|
||||
|
||||
def escape_lucene_query(query: str) -> str:
|
||||
"""Escape Lucene special characters in a free-text query.
|
||||
|
||||
This prevents ParseException when using Neo4j full-text procedures.
|
||||
"""
|
||||
Escape special characters in Lucene queries
|
||||
|
||||
Prevents ParseException when using Neo4j full-text search procedures.
|
||||
Escapes all Lucene reserved special characters and operators.
|
||||
|
||||
Args:
|
||||
query: Original query string
|
||||
|
||||
Returns:
|
||||
str: Escaped query string safe for Lucene search
|
||||
"""
|
||||
if query is None:
|
||||
return ""
|
||||
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
|
||||
return s
|
||||
|
||||
def extract_plain_query(query_input: str) -> str:
|
||||
"""Extract clean, plain-text query from various input forms.
|
||||
|
||||
"""
|
||||
Extract clean plain-text query from various input forms
|
||||
|
||||
Handles the following cases:
|
||||
- Strips surrounding quotes and whitespace
|
||||
- If input looks like JSON, prefers the 'original' field
|
||||
- Fallbacks to the raw string when parsing fails
|
||||
- Falls back to raw string when parsing fails
|
||||
- Handles dictionary-type input
|
||||
- Best-effort unescape common escape characters
|
||||
|
||||
Args:
|
||||
query_input: Query input in various forms (string, dict, etc.)
|
||||
|
||||
Returns:
|
||||
str: Extracted plain-text query string
|
||||
"""
|
||||
if query_input is None:
|
||||
return ""
|
||||
|
||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
||||
|
||||
def validate_date_format(date_str: str) -> bool:
|
||||
"""
|
||||
Validate if the date string is in the format YYYY-MM-DD.
|
||||
Validate if date string conforms to YYYY-MM-DD format
|
||||
|
||||
Args:
|
||||
date_str: Date string to validate
|
||||
|
||||
Returns:
|
||||
bool: True if format is correct, False otherwise
|
||||
"""
|
||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||
return bool(re.match(pattern, date_str))
|
||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
||||
|
||||
|
||||
def preprocess_date_string(date_str: str) -> str:
|
||||
"""预处理日期字符串,处理特殊格式"""
|
||||
"""
|
||||
预处理日期字符串,处理特殊格式
|
||||
|
||||
处理以下特殊格式:
|
||||
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||
- 混合分隔符,统一为 "-"
|
||||
|
||||
Args:
|
||||
date_str: 原始日期字符串
|
||||
|
||||
Returns:
|
||||
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||
"""
|
||||
|
||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
||||
|
||||
|
||||
def fallback_parse(date_str: str) -> str:
|
||||
"""备选解析方案"""
|
||||
"""
|
||||
备选日期解析方案
|
||||
|
||||
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||
支持多种常见的日期格式,包括:
|
||||
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||
- YYYYMMDD, YYMMDD
|
||||
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||
|
||||
Args:
|
||||
date_str: 待解析的日期字符串
|
||||
|
||||
Returns:
|
||||
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||
"""
|
||||
|
||||
# 尝试常见的日期格式[citation:4][citation:5]
|
||||
formats_to_try = [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{#
|
||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
||||
输入:pruning_scene, ontology_class_infos, dialog_text, language
|
||||
- ontology_class_infos: List[{class_name: str, class_description: str}]
|
||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||
- is_related: bool,是否与所选场景相关
|
||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||
@@ -9,64 +10,103 @@
|
||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||
- addresses: [string],地址/地点相关文本
|
||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
|
||||
|
||||
要求:
|
||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||
- 仅输出上述键;避免多余解释或字段。
|
||||
#}
|
||||
|
||||
{# ── 内置场景的固定说明 ── #}
|
||||
{% set builtin_scene_instructions = {
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{# ── 确定最终使用的场景说明 ── #}
|
||||
{% if is_builtin_scene %}
|
||||
{# 内置专门场景:使用固定说明 #}
|
||||
{% set scene_key = pruning_scene %}
|
||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% else %}
|
||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||
{% else %}
|
||||
{% set custom_types_str = ontology_classes | join('、') %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||
{% endif %}
|
||||
{# ── 确定场景说明 ── #}
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %}
|
||||
{% else %}
|
||||
{# 无本体类型时退化为通用说明 #}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% set custom_types_str = '' %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %}
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if language == 'en' %}
|
||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||
{% else %}
|
||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务:
|
||||
1. 判断对话是否与指定场景相关;
|
||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||
|
||||
场景说明:{{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
【本场景实体类型定义】
|
||||
以下实体类型定义了本场景中哪些内容是重要的。
|
||||
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}:{{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
重要提示:只要对话中出现与上述任意实体类型相关的内容,即判定为相关(is_related=true)。
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
【必须保留的内容(不可删除)】
|
||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段)
|
||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||
- 地址信息:地点、地址、位置 → addresses 字段
|
||||
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段)
|
||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了")→ preserve_keywords 字段
|
||||
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
|
||||
|
||||
【场景无关内容标记】
|
||||
请从对话中识别出与当前场景({{ pruning_scene }})**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
|
||||
判断标准:
|
||||
- 与场景实体类型完全无关
|
||||
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
|
||||
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
|
||||
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
|
||||
|
||||
**重要:scene_unrelated_snippets 必须认真填写,不能为空数组。**
|
||||
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
|
||||
|
||||
示例(场景=在线教育):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||
|
||||
示例(场景=情感陪伴):
|
||||
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
|
||||
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
|
||||
- "期末考试12月25日,3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
|
||||
- "我上次高数作业87分,这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
|
||||
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
|
||||
|
||||
【可以删除的内容】
|
||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
|
||||
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
|
||||
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
|
||||
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
|
||||
|
||||
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
|
||||
例如:
|
||||
- "我好开心呀" → 包含情绪(开心),必须保留,preserve_keywords 中加入"开心"
|
||||
- "好喜欢打羽毛球呀" → 包含兴趣爱好(喜欢打羽毛球),必须保留,preserve_keywords 中加入"喜欢打羽毛球"
|
||||
- "我好难过" → 包含情绪(难过),必须保留,preserve_keywords 中加入"难过"
|
||||
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪,必须保留,preserve_keywords 中加入"开心"
|
||||
|
||||
---
|
||||
对话全文:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -80,15 +120,65 @@
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||
1. Determine whether the dialogue is relevant to the specified scene;
|
||||
2. Extract all important information fragments that must be preserved.
|
||||
|
||||
Scenario Description: {{ instruction }}
|
||||
{% if not is_builtin_scene and custom_types_str %}
|
||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||
|
||||
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||
[Scene Entity Type Definitions]
|
||||
The following entity types define what content is important in this scene.
|
||||
Content related to ANY of these types must be preserved and extracted into the keywords field:
|
||||
|
||||
{% for info in ontology_class_infos %}
|
||||
- {{ info.class_name }}: {{ info.class_description }}
|
||||
{% endfor %}
|
||||
|
||||
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
|
||||
{% endif %}
|
||||
|
||||
---
|
||||
[MUST PRESERVE (cannot be deleted)]
|
||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||
- Time information: dates, time points, durations, expiry dates → times field
|
||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here)
|
||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||
- Address information: locations, addresses, places → addresses field
|
||||
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here)
|
||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field
|
||||
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
|
||||
|
||||
[Scene-Unrelated Content Marking]
|
||||
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
|
||||
Criteria:
|
||||
- Completely unrelated to the scene's entity types
|
||||
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
|
||||
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
|
||||
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
|
||||
|
||||
[CAN BE DELETED]
|
||||
The following types of content are low-value and can be removed during pruning:
|
||||
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
|
||||
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
|
||||
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
|
||||
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
|
||||
|
||||
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
|
||||
Examples:
|
||||
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
|
||||
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
|
||||
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
|
||||
|
||||
---
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
@@ -102,6 +192,8 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
"keywords": [<string>...],
|
||||
"preserve_keywords": [<string>...],
|
||||
"scene_unrelated_snippets": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -2,15 +2,15 @@ import os
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
# Setup Jinja2 environment
|
||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
|
||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
baseline: str = "TIME",
|
||||
memory_verify: bool = False,quality_assessment:bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
memory_verify: bool = False, quality_assessment: bool = False,
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||
|
||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("evaluate.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||
|
||||
|
||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
||||
statement_databasets=None, language_type: str = "zh") -> str:
|
||||
"""
|
||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||
|
||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
"""
|
||||
if statement_databasets is None:
|
||||
statement_databasets = []
|
||||
template = prompt_env.get_template("reflexion.jinja2")
|
||||
|
||||
# Convert Pydantic model to JSON schema if needed
|
||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
||||
json_schema = schema
|
||||
|
||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||
baseline=baseline,memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets,language_type=language_type)
|
||||
baseline=baseline, memory_verify=memory_verify,
|
||||
statement_databasets=statement_databasets, language_type=language_type)
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
||||
from typing import Any, Dict, Optional, TypeVar
|
||||
|
||||
from langchain_aws import ChatBedrock
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLLM
|
||||
from langchain_ollama import OllamaLLM
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import httpx
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.models.models_model import ModelProvider, ModelType
|
||||
from langchain_community.document_compressors import JinaRerank
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
||||
from langchain_core.outputs import Generation, LLMResult
|
||||
from langchain_core.retrievers import BaseRetriever
|
||||
from langchain_core.runnables import RunnableSerializable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
||||
|
||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||
if type == ModelType.LLM:
|
||||
from langchain_openai import OpenAI
|
||||
return OpenAI
|
||||
elif type == ModelType.CHAT:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI
|
||||
elif provider == ModelProvider.DASHSCOPE:
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
return ChatTongyi
|
||||
elif provider == ModelProvider.OLLAMA:
|
||||
from langchain_ollama import OllamaLLM
|
||||
return OllamaLLM
|
||||
elif provider == ModelProvider.BEDROCK:
|
||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
||||
|
||||
return ChatBedrock
|
||||
else:
|
||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||
|
||||
@@ -94,72 +94,16 @@ def knowledge_retrieval(
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||
# Process shared knowledge base
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
||||
knowledgeshare_id=db_knowledge.id)
|
||||
if knowledgeshare:
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
||||
knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
continue
|
||||
else:
|
||||
continue
|
||||
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
||||
)
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
# Retrieve according to the configured retrieval type
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
case _: # hybrid
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=query,
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter
|
||||
)
|
||||
|
||||
# Deduplication of merge results
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
rs, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||
db=db,
|
||||
db_knowledge=db_knowledge,
|
||||
kb_config={**kb_config, "query": query}, # 或改为单独参数
|
||||
file_names_filter=file_names_filter,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
kb_ids=kb_ids,
|
||||
workspace_ids=workspace_ids,
|
||||
)
|
||||
|
||||
all_results.extend(rs)
|
||||
except Exception as e:
|
||||
@@ -199,6 +143,115 @@ def knowledge_retrieval(
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def _retrieve_for_knowledge(
|
||||
db: Session,
|
||||
db_knowledge,
|
||||
kb_config: Dict[str, Any],
|
||||
file_names_filter: list[str],
|
||||
chat_model: Base | None,
|
||||
embedding_model: OpenAIEmbed | None,
|
||||
kb_ids: list[str],
|
||||
workspace_ids: list[str],
|
||||
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
|
||||
"""
|
||||
对单个知识库进行检索。
|
||||
- 处理共享知识库
|
||||
- 如果是 Folder,则递归检索其子知识库
|
||||
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
|
||||
"""
|
||||
results: list[DocumentChunk] = []
|
||||
|
||||
# 处理共享知识库
|
||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
|
||||
if not knowledgeshare:
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
|
||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
# Folder 类型:递归处理子知识库
|
||||
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||
for child in children:
|
||||
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||
continue
|
||||
# 递归处理子知识库(子库如果还是 Folder,会继续往下)
|
||||
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||
db=db,
|
||||
db_knowledge=child,
|
||||
kb_config=kb_config,
|
||||
file_names_filter=file_names_filter,
|
||||
chat_model=chat_model,
|
||||
embedding_model=embedding_model,
|
||||
kb_ids=kb_ids,
|
||||
workspace_ids=workspace_ids,
|
||||
)
|
||||
results.extend(child_results)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
# 普通知识库,执行一次检索
|
||||
if str(db_knowledge.id) not in kb_ids:
|
||||
kb_ids.append(str(db_knowledge.id))
|
||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||
|
||||
if not chat_model:
|
||||
chat_model = Base(
|
||||
key=db_knowledge.llm.api_keys[0].api_key,
|
||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||
)
|
||||
if not embedding_model:
|
||||
embedding_model = OpenAIEmbed(
|
||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||
)
|
||||
|
||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||
|
||||
match kb_config["retrieve_type"]:
|
||||
case "participle":
|
||||
rs = vector_service.search_by_full_text(
|
||||
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
case "semantic":
|
||||
rs = vector_service.search_by_vector(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["vector_similarity_weight"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
rs2 = vector_service.search_by_full_text(
|
||||
query=kb_config["query"],
|
||||
top_k=kb_config["top_k"],
|
||||
score_threshold=kb_config["similarity_threshold"],
|
||||
file_names_filter=file_names_filter,
|
||||
)
|
||||
# 合并去重
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
for doc in rs1 + rs2:
|
||||
if doc.metadata["doc_id"] not in seen_ids:
|
||||
seen_ids.add(doc.metadata["doc_id"])
|
||||
unique_rs.append(doc)
|
||||
rs = unique_rs
|
||||
|
||||
results.extend(rs)
|
||||
return results, chat_model, embedding_model
|
||||
|
||||
|
||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||
"""
|
||||
|
||||
@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
|
||||
|
||||
from .chunk_summary import generate_chunk_summary
|
||||
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
|
||||
from .chunk_insight import generate_chunk_insight
|
||||
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
|
||||
|
||||
__all__ = [
|
||||
"generate_chunk_summary",
|
||||
"extract_chunk_tags",
|
||||
"extract_chunk_persona",
|
||||
"generate_chunk_insight",
|
||||
"generate_chunk_insight_sections",
|
||||
]
|
||||
|
||||
@@ -1,213 +1,207 @@
|
||||
"""
|
||||
Generate insights from RAG chunks.
|
||||
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
|
||||
|
||||
This module provides functionality to analyze chunk content and generate insights using LLM.
|
||||
The memory_insight.jinja2 template produces a four-section report:
|
||||
【总体概述】 → memory_insight
|
||||
【行为模式】 → behavior_pattern
|
||||
【关键发现】 → key_findings
|
||||
【成长轨迹】 → growth_trajectory
|
||||
|
||||
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
|
||||
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
"""Pydantic model for chunk insight."""
|
||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
||||
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
|
||||
|
||||
async def _classify_domain(chunk: str, llm_client) -> str:
|
||||
"""Classify a single chunk into a domain category."""
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class DomainClassification(BaseModel):
|
||||
"""Pydantic model for domain classification."""
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="内容所属的领域分类",
|
||||
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
|
||||
)
|
||||
class _Domain(BaseModel):
|
||||
domain: str = Field(..., description="领域分类")
|
||||
|
||||
|
||||
async def classify_chunk_domain(chunk: str) -> str:
|
||||
"""
|
||||
Classify a chunk into a specific domain.
|
||||
|
||||
Args:
|
||||
chunk: Chunk content string
|
||||
|
||||
Returns:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = _get_llm_client()
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
可选领域及其关键词:
|
||||
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
|
||||
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
|
||||
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
|
||||
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
|
||||
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
|
||||
- 健康:医疗、养生、运动、心理、保健、疾病等
|
||||
- 其他:无法归入以上类别的内容
|
||||
|
||||
文本内容: {chunk[:500]}...
|
||||
|
||||
请直接返回最合适的领域名称。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
classification = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=DomainClassification
|
||||
prompt = (
|
||||
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
|
||||
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
|
||||
)
|
||||
|
||||
return classification.domain if classification else "其他"
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分类chunk领域失败: {str(e)}")
|
||||
result = await llm_client.response_structured(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=_Domain,
|
||||
)
|
||||
return result.domain if result else "其他"
|
||||
except Exception:
|
||||
return "其他"
|
||||
|
||||
|
||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
||||
async def _build_insight_inputs(
|
||||
chunks: List[str],
|
||||
max_chunks: int,
|
||||
end_user_id: Optional[str],
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Analyze the domain distribution of chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary of domain -> percentage
|
||||
Derive domain_distribution, active_periods, social_connections strings
|
||||
to feed into the memory_insight.jinja2 template.
|
||||
"""
|
||||
if not chunks:
|
||||
return {}
|
||||
|
||||
try:
|
||||
# 限制分析的chunk数量
|
||||
chunks_to_analyze = chunks[:max_chunks]
|
||||
|
||||
# 为每个chunk分类
|
||||
domain_counts = Counter()
|
||||
for chunk in chunks_to_analyze:
|
||||
domain = await classify_chunk_domain(chunk)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
# 计算百分比
|
||||
total = sum(domain_counts.values())
|
||||
domain_distribution = {
|
||||
domain: count / total
|
||||
for domain, count in domain_counts.items()
|
||||
}
|
||||
|
||||
# 按百分比降序排序
|
||||
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"分析领域分布失败: {str(e)}")
|
||||
return {}
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
chunks_sample = chunks[:max_chunks]
|
||||
|
||||
# Domain distribution
|
||||
domain_counts: Counter = Counter()
|
||||
for chunk in chunks_sample:
|
||||
domain = await _classify_domain(chunk, llm_client)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
total = sum(domain_counts.values()) or 1
|
||||
domain_distribution = ", ".join(
|
||||
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
|
||||
)
|
||||
|
||||
return {
|
||||
"domain_distribution": domain_distribution,
|
||||
"active_periods": None, # RAG模式暂无时间维度数据
|
||||
"social_connections": None, # RAG模式暂无社交关联数据
|
||||
}
|
||||
|
||||
|
||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
||||
# ── Section parser ────────────────────────────────────────────────────────────
|
||||
|
||||
_ZH_SECTIONS = {
|
||||
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
|
||||
"key_findings": r"【关键发现】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
_EN_SECTIONS = {
|
||||
"memory_insight": r"【Overview】(.*?)(?=【|$)",
|
||||
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
|
||||
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
|
||||
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
|
||||
}
|
||||
|
||||
|
||||
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
|
||||
"""Extract the four sections from the LLM output."""
|
||||
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
|
||||
result = {}
|
||||
for key, pattern in patterns.items():
|
||||
match = re.search(pattern, text, re.DOTALL)
|
||||
result[key] = match.group(1).strip() if match else ""
|
||||
return result
|
||||
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_insight(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate insights from the given chunks.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to analyze
|
||||
|
||||
Returns:
|
||||
A comprehensive insight report
|
||||
Generate a memory insight report from RAG chunks.
|
||||
|
||||
Returns the full raw report text (suitable for end_user.memory_insight).
|
||||
Use generate_chunk_insight_sections() when you need all four dimensions.
|
||||
"""
|
||||
sections = await generate_chunk_insight_sections(
|
||||
chunks=chunks,
|
||||
max_chunks=max_chunks,
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
)
|
||||
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
|
||||
|
||||
|
||||
async def generate_chunk_insight_sections(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 15,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a four-section memory insight report from RAG chunks.
|
||||
|
||||
Returns a dict with keys:
|
||||
memory_insight, behavior_pattern, key_findings, growth_trajectory
|
||||
(plus '_raw' containing the full LLM output for debugging)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成洞察")
|
||||
return "暂无足够数据生成洞察报告"
|
||||
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "暂无足够数据生成洞察报告"
|
||||
return empty
|
||||
|
||||
try:
|
||||
# 1. 分析领域分布
|
||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
||||
|
||||
# 2. 统计基本信息
|
||||
total_chunks = len(chunks)
|
||||
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
|
||||
|
||||
# 3. 构建洞察prompt
|
||||
prompt_parts = []
|
||||
|
||||
if domain_dist:
|
||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
||||
prompt_parts.append(f"- 内容领域分布: {top_domains}")
|
||||
|
||||
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}字")
|
||||
|
||||
# 添加部分chunk内容作为参考
|
||||
sample_chunks = chunks[:5]
|
||||
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
|
||||
prompt_parts.append(f"\n内容示例:\n{sample_content}")
|
||||
|
||||
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||
|
||||
重要规则:
|
||||
1. 报告需要将所有要点流畅地串联成一个段落
|
||||
2. 语言风格要专业、客观,同时易于理解
|
||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
||||
4. 基于提供的数据和示例内容进行分析,不要编造信息
|
||||
5. 重点关注内容的主题、特点和价值
|
||||
6. 报告长度控制在150-200字
|
||||
# Build template inputs from chunk analysis
|
||||
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
|
||||
|
||||
例如,如果输入是:
|
||||
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
|
||||
- 内容规模: 共50个知识片段,平均长度320字
|
||||
内容示例: [示例内容...]
|
||||
rendered_prompt = await render_memory_insight_prompt(
|
||||
domain_distribution=inputs["domain_distribution"],
|
||||
active_periods=inputs["active_periods"],
|
||||
social_connections=inputs["social_connections"],
|
||||
language=language,
|
||||
)
|
||||
|
||||
你的输出应该类似:
|
||||
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段,平均每个片段约320字,内容详实。从示例来看,内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
|
||||
"""
|
||||
|
||||
user_prompt = "\n".join(prompt_parts)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt}
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = _get_llm_client()
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk洞察,分析了 {min(len(chunks), max_chunks)} 个片段")
|
||||
|
||||
return insight
|
||||
|
||||
raw_text = response.content.strip() if response and response.content else ""
|
||||
|
||||
sections = _parse_sections(raw_text, language=language)
|
||||
sections["_raw"] = raw_text
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk洞察四维度,分析了 {min(len(chunks), max_chunks)} 个片段"
|
||||
)
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk洞察失败: {str(e)}")
|
||||
return "洞察生成失败"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"Python是一种高级编程语言,以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
|
||||
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
|
||||
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
|
||||
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
|
||||
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
|
||||
]
|
||||
|
||||
print("开始生成chunk洞察...")
|
||||
insight = asyncio.run(generate_chunk_insight(test_chunks))
|
||||
print(f"\n生成的洞察:\n{insight}")
|
||||
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||
empty["_raw"] = "洞察生成失败"
|
||||
return empty
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
"""
|
||||
Generate summary for RAG chunks.
|
||||
|
||||
This module provides functionality to summarize chunk content using LLM.
|
||||
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
"""Pydantic model for chunk summary."""
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
# ── Schema ──────────────────────────────────────────────────────────────────
|
||||
|
||||
class MemorySummaryStatement(BaseModel):
|
||||
"""Single labelled statement extracted by memory_summary.jinja2."""
|
||||
statement: str = Field(..., description="提取的陈述内容")
|
||||
label: Optional[str] = Field(None, description="陈述标签")
|
||||
|
||||
|
||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
||||
class MemorySummaryResponse(BaseModel):
|
||||
"""
|
||||
Generate a summary for the given chunks.
|
||||
|
||||
Structured output expected from memory_summary.jinja2.
|
||||
The template asks for a JSON array of labelled statements;
|
||||
we wrap it in an object so response_structured can parse it.
|
||||
"""
|
||||
statements: List[MemorySummaryStatement] = Field(
|
||||
default_factory=list,
|
||||
description="从chunk中提取的陈述列表"
|
||||
)
|
||||
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
|
||||
|
||||
|
||||
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
# ── Core function ─────────────────────────────────────────────────────────────
|
||||
|
||||
async def generate_chunk_summary(
|
||||
chunks: List[str],
|
||||
max_chunks: int = 10,
|
||||
end_user_id: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
|
||||
|
||||
The template extracts labelled statements from the chunks; we then join them
|
||||
into a coherent summary string that can be stored in end_user.user_summary.
|
||||
|
||||
Args:
|
||||
chunks: List of chunk content strings
|
||||
max_chunks: Maximum number of chunks to process (default: 10)
|
||||
|
||||
max_chunks: Maximum number of chunks to process
|
||||
end_user_id: Optional end-user ID for model selection
|
||||
language: Output language ("zh" or "en")
|
||||
|
||||
Returns:
|
||||
A concise summary of the chunks
|
||||
Summary string (joined statements or fallback text)
|
||||
"""
|
||||
if not chunks:
|
||||
business_logger.warning("没有提供chunk内容用于生成摘要")
|
||||
return "暂无内容"
|
||||
|
||||
|
||||
try:
|
||||
# 限制处理的chunk数量,避免token过多
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
|
||||
chunks_to_process = chunks[:max_chunks]
|
||||
|
||||
# 合并chunk内容
|
||||
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
|
||||
|
||||
# 构建prompt
|
||||
system_prompt = (
|
||||
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
|
||||
"- 摘要长度控制在100-150字;\n"
|
||||
"- 提取核心信息和关键要点;\n"
|
||||
"- 使用客观、清晰的语言;\n"
|
||||
"- 避免冗余和重复;\n"
|
||||
"- 如果内容涉及多个主题,按重要性排序呈现。"
|
||||
chunk_texts = "\n\n".join(
|
||||
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
|
||||
)
|
||||
|
||||
json_schema = MemorySummaryResponse.model_json_schema()
|
||||
|
||||
rendered_prompt = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# Try structured output; fall back to plain chat only for LLMClientException
|
||||
# (indicates the model/provider doesn't support structured output).
|
||||
# All other exceptions are re-raised so config/schema errors stay visible.
|
||||
try:
|
||||
response: MemorySummaryResponse = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=MemorySummaryResponse,
|
||||
)
|
||||
if response.summary:
|
||||
summary = response.summary.strip()
|
||||
elif response.statements:
|
||||
summary = ";".join(s.statement for s in response.statements)
|
||||
else:
|
||||
summary = "暂无内容"
|
||||
except Exception as e:
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
if isinstance(e, LLMClientException):
|
||||
business_logger.warning(
|
||||
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
|
||||
)
|
||||
raw = await llm_client.chat(messages=messages)
|
||||
summary = raw.content.strip() if raw and raw.content else "暂无内容"
|
||||
else:
|
||||
business_logger.error(f"生成摘要时发生非预期异常: {e}")
|
||||
raise
|
||||
|
||||
business_logger.info(
|
||||
f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段"
|
||||
)
|
||||
|
||||
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = _get_llm_client()
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
business_logger.info(f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段")
|
||||
|
||||
return summary
|
||||
|
||||
|
||||
except Exception as e:
|
||||
business_logger.error(f"生成chunk摘要失败: {str(e)}")
|
||||
return "摘要生成失败"
|
||||
|
||||
|
||||
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
|
||||
"""
|
||||
Generate summaries for multiple chunk lists in batch.
|
||||
|
||||
Args:
|
||||
chunks_list: List of chunk lists
|
||||
|
||||
Returns:
|
||||
List of summaries
|
||||
"""
|
||||
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_chunks = [
|
||||
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
|
||||
"第二段内容介绍了深度学习的应用场景和发展历史。",
|
||||
"第三段讨论了自然语言处理技术的最新进展。"
|
||||
]
|
||||
|
||||
print("开始生成chunk摘要...")
|
||||
summary = asyncio.run(generate_chunk_summary(test_chunks))
|
||||
print(f"\n生成的摘要:\n{summary}")
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||
|
||||
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Extract meaningful tags from the given chunks.
|
||||
|
||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||
|
||||
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract persona (人物形象) from the given chunks.
|
||||
|
||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -7,7 +7,7 @@ file operations across different storage backends.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
|
||||
class StorageBackend(ABC):
|
||||
@@ -42,6 +42,26 @@ class StorageBackend(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Upload a file from an async byte stream.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file.
|
||||
stream: Async iterator yielding bytes chunks.
|
||||
content_type: Optional MIME type of the file.
|
||||
|
||||
Returns:
|
||||
Total bytes written.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
|
||||
@@ -85,6 +85,7 @@ class StorageFactory:
|
||||
access_key_id=settings.S3_ACCESS_KEY_ID,
|
||||
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
|
||||
bucket_name=settings.S3_BUCKET_NAME,
|
||||
endpoint_url=settings.S3_ENDPOINT_URL,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional
|
||||
|
||||
import aiofiles
|
||||
import aiofiles.os
|
||||
from typing import AsyncIterator
|
||||
|
||||
from app.core.storage.base import StorageBackend
|
||||
from app.core.storage_exceptions import (
|
||||
@@ -179,6 +180,36 @@ class LocalStorage(StorageBackend):
|
||||
full_path = self._get_full_path(file_key)
|
||||
return full_path.exists()
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Upload a file from an async byte stream to the local file system.
|
||||
|
||||
Returns:
|
||||
Total bytes written.
|
||||
"""
|
||||
full_path = self._get_full_path(file_key)
|
||||
try:
|
||||
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
total = 0
|
||||
async with aiofiles.open(full_path, "wb") as f:
|
||||
async for chunk in stream:
|
||||
await f.write(chunk)
|
||||
total += len(chunk)
|
||||
logger.info(f"File stream uploaded successfully: {file_key}")
|
||||
return total
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stream upload file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||
"""
|
||||
Get an access URL for the file.
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on Aliyun Object
|
||||
Storage Service (OSS) using the oss2 SDK.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import oss2
|
||||
from oss2.exceptions import NoSuchKey, OssError
|
||||
@@ -125,10 +126,39 @@ class OSSStorage(StorageBackend):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||
buf = io.BytesIO()
|
||||
try:
|
||||
async for chunk in stream:
|
||||
buf.write(chunk)
|
||||
content = buf.getvalue()
|
||||
headers = {"Content-Type": content_type} if content_type else None
|
||||
self.bucket.put_object(file_key, content, headers=headers)
|
||||
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||
return len(content)
|
||||
except OssError as e:
|
||||
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to OSS: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
Download a file from OSS.
|
||||
|
||||
Args:
|
||||
file_key: Unique identifier for the file in the storage system.
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on AWS S3
|
||||
using the boto3 SDK.
|
||||
"""
|
||||
|
||||
import io
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import AsyncIterator, Optional
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
||||
@@ -35,6 +36,19 @@ class S3Storage(StorageBackend):
|
||||
bucket_name: The name of the S3 bucket.
|
||||
region: The AWS region.
|
||||
"""
|
||||
AMAZON_S3_ENDPOINT_MAP = {
|
||||
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
|
||||
"us-east-2": "https://s3.us-east-2.amazonaws.com",
|
||||
"us-west-1": "https://s3.us-west-1.amazonaws.com",
|
||||
"us-west-2": "https://s3.us-west-2.amazonaws.com",
|
||||
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
|
||||
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
|
||||
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
|
||||
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
|
||||
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
|
||||
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
|
||||
# 可根据需要扩展其他地域
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -42,6 +56,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: str,
|
||||
secret_access_key: str,
|
||||
bucket_name: str,
|
||||
endpoint_url: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Initialize the S3Storage backend.
|
||||
@@ -51,6 +66,7 @@ class S3Storage(StorageBackend):
|
||||
access_key_id: The AWS access key ID.
|
||||
secret_access_key: The AWS secret access key.
|
||||
bucket_name: The name of the S3 bucket.
|
||||
endpoint_url: The complete URL to use for the constructed client.
|
||||
|
||||
Raises:
|
||||
StorageConfigError: If any required configuration is missing.
|
||||
@@ -69,10 +85,19 @@ class S3Storage(StorageBackend):
|
||||
self.region = region
|
||||
self.bucket_name = bucket_name
|
||||
|
||||
if not endpoint_url:
|
||||
# 优先匹配内置映射表(解决特殊地域)
|
||||
if region in self.AMAZON_S3_ENDPOINT_MAP:
|
||||
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
|
||||
# 兜底:通用拼接(适配未配置的新地域)
|
||||
else:
|
||||
endpoint_url = f"https://s3.{region}.amazonaws.com"
|
||||
|
||||
try:
|
||||
self.client = boto3.client(
|
||||
"s3",
|
||||
region_name=region,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
)
|
||||
@@ -150,6 +175,62 @@ class S3Storage(StorageBackend):
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def upload_stream(
|
||||
self,
|
||||
file_key: str,
|
||||
stream: AsyncIterator[bytes],
|
||||
content_type: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Upload from async stream to S3 via multipart upload. Returns total bytes written."""
|
||||
extra_args = {"ContentType": content_type} if content_type else {}
|
||||
mpu = self.client.create_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key, **extra_args
|
||||
)
|
||||
upload_id = mpu["UploadId"]
|
||||
parts = []
|
||||
part_number = 1
|
||||
buf = io.BytesIO()
|
||||
total = 0
|
||||
min_part_size = 5 * 1024 * 1024 # S3 最小分片 5MB
|
||||
try:
|
||||
async for chunk in stream:
|
||||
buf.write(chunk)
|
||||
total += len(chunk)
|
||||
if buf.tell() >= min_part_size:
|
||||
buf.seek(0)
|
||||
resp = self.client.upload_part(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id, PartNumber=part_number, Body=buf.read()
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||
part_number += 1
|
||||
buf = io.BytesIO()
|
||||
# 上传剩余数据(最后一片可小于 5MB)
|
||||
remaining = buf.getvalue()
|
||||
if remaining:
|
||||
resp = self.client.upload_part(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id, PartNumber=part_number, Body=remaining
|
||||
)
|
||||
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||
self.client.complete_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key,
|
||||
UploadId=upload_id,
|
||||
MultipartUpload={"Parts": parts}
|
||||
)
|
||||
logger.info(f"File stream uploaded to S3 successfully: {file_key}")
|
||||
return total
|
||||
except Exception as e:
|
||||
self.client.abort_multipart_upload(
|
||||
Bucket=self.bucket_name, Key=file_key, UploadId=upload_id
|
||||
)
|
||||
logger.error(f"Failed to stream upload file to S3 {file_key}: {e}")
|
||||
raise StorageUploadError(
|
||||
message=f"Failed to stream upload file to S3: {e}",
|
||||
file_key=file_key,
|
||||
cause=e,
|
||||
)
|
||||
|
||||
async def download(self, file_key: str) -> bytes:
|
||||
"""
|
||||
Download a file from S3.
|
||||
|
||||
@@ -195,6 +195,6 @@ class MCPToolManager:
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "连接失败"
|
||||
"error": "连接失败",
|
||||
"message": str(e)
|
||||
}
|
||||
@@ -23,7 +23,7 @@ class SimpleMCPClient:
|
||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||
self.server_url = server_url
|
||||
self.connection_config = connection_config or {}
|
||||
self.timeout = self.connection_config.get("timeout", 30)
|
||||
self.timeout = self.connection_config.get("timeout", 10)
|
||||
|
||||
# 确定连接类型
|
||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||
@@ -53,6 +53,7 @@ class SimpleMCPClient:
|
||||
else:
|
||||
await self._connect_http()
|
||||
except Exception as e:
|
||||
await self.disconnect()
|
||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||
raise MCPConnectionError(f"连接失败: {e}")
|
||||
|
||||
|
||||
@@ -8,34 +8,60 @@ from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
||||
from app.core.workflow.adapters.errors import (
|
||||
UnsupportVariableType,
|
||||
UnknowModelWarning,
|
||||
ExceptionDefineition,
|
||||
ExceptionType
|
||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
EndNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
NoteNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
VariableAggregatorNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.cycle_graph.config import (
|
||||
ConditionDetail as LoopConditionDetail,
|
||||
ConditionsConfig,
|
||||
CycleVariable
|
||||
from app.core.workflow.nodes.end import EndNodeConfig
|
||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
||||
HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
)
|
||||
from app.core.workflow.nodes.enums import (
|
||||
ValueInputType,
|
||||
ComparisonOperator,
|
||||
AssignmentOperator,
|
||||
HttpAuthType,
|
||||
HttpContentType,
|
||||
HttpErrorHandle,
|
||||
NodeType
|
||||
)
|
||||
from app.core.workflow.nodes.http_request.config import (
|
||||
HttpAuthConfig,
|
||||
HttpContentTypeConfig,
|
||||
HttpFormData,
|
||||
HttpTimeOutConfig,
|
||||
HttpRetryConfig,
|
||||
HttpErrorDefaultTamplete,
|
||||
HttpErrorHandleConfig
|
||||
)
|
||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||
|
||||
|
||||
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
|
||||
|
||||
def __init__(self):
|
||||
self.CONFIG_CONVERT_MAP = {
|
||||
"start": self.convert_start_node_config,
|
||||
"llm": self.convert_llm_node_config,
|
||||
"answer": self.convert_end_node_config,
|
||||
"if-else": self.convert_if_else_node_config,
|
||||
"loop": self.convert_loop_node_config,
|
||||
"iteration": self.convert_iteration_node_config,
|
||||
"assigner": self.convert_assigner_node_config,
|
||||
"code": self.convert_code_node_config,
|
||||
"http-request": self.convert_http_node_config,
|
||||
"template-transform": self.convert_jinja_render_node_config,
|
||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
||||
"question-classifier": self.convert_question_classifier_node_config,
|
||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
||||
"tool": self.convert_tool_node_config,
|
||||
"loop-start": lambda x: {},
|
||||
"iteration-start": lambda x: {},
|
||||
"loop-end": lambda x: {},
|
||||
NodeType.START: self.convert_start_node_config,
|
||||
NodeType.LLM: self.convert_llm_node_config,
|
||||
NodeType.END: self.convert_end_node_config,
|
||||
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||
NodeType.LOOP: self.convert_loop_node_config,
|
||||
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||
NodeType.ASSIGNER: self.convert_assigner_node_config,
|
||||
NodeType.CODE: self.convert_code_node_config,
|
||||
NodeType.HTTP_REQUEST: self.convert_http_node_config,
|
||||
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
|
||||
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
|
||||
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
|
||||
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||
NodeType.TOOL: self.convert_tool_node_config,
|
||||
NodeType.NOTES: self.convert_notes_config,
|
||||
NodeType.CYCLE_START: lambda x: {},
|
||||
NodeType.BREAK: lambda x: {},
|
||||
}
|
||||
|
||||
def get_node_convert(self, node_type):
|
||||
@@ -732,3 +758,16 @@ class DifyConverter(BaseConverter):
|
||||
detail=f"Please reconfigure the tool node.",
|
||||
))
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def convert_notes_config(node: dict):
|
||||
node_data = node["data"]
|
||||
result = NoteNodeConfig.model_construct(
|
||||
author=node_data.get("author", ""),
|
||||
text=node_data.get("text", ""),
|
||||
width=node_data.get("width", 80),
|
||||
height=node_data.get("height", 80),
|
||||
theme=node_data.get("theme", "blue"),
|
||||
show_author=node_data.get("showAuthor", True)
|
||||
).model_dump()
|
||||
return result
|
||||
|
||||
@@ -50,7 +50,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
DifyConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
def map_node_type(self, platform_node_type) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@property
|
||||
@@ -84,7 +84,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
if self.config.get("app",{}).get("mode") == "workflow":
|
||||
if self.config.get("app", {}).get("mode") == "workflow":
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.PLATFORM,
|
||||
detail="workflow mode is not supported"
|
||||
@@ -163,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_data = node["data"]
|
||||
try:
|
||||
node_type = self.map_node_type(node_data["type"])
|
||||
return NodeDefinition(
|
||||
id=node["id"],
|
||||
type=self.map_node_type(node_data["type"]),
|
||||
type=node_type,
|
||||
name=node_data.get("title") or "notes",
|
||||
cycle=node.get("parentId"),
|
||||
description=None,
|
||||
config=self._convert_node_config(node),
|
||||
config=self._convert_node_config(node_type, node),
|
||||
position={
|
||||
"x": node["position"]["x"],
|
||||
"y": node["position"]["y"]
|
||||
@@ -183,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
except Exception as e:
|
||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||
|
||||
def _convert_node_config(self, node: dict):
|
||||
node_data = node["data"]
|
||||
node_type = node_data["type"]
|
||||
def _convert_node_config(self, node_type: NodeType, node: dict):
|
||||
try:
|
||||
node_data = node["data"]
|
||||
converter = self.get_node_convert(node_type)
|
||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node["id"],
|
||||
node_name=node["data"]["title"],
|
||||
detail=f"node type {node_type if node_type else 'notes'} is unsupported",
|
||||
detail=f"node type {node_data.get('type')} is unsupported",
|
||||
))
|
||||
return converter(node)
|
||||
except Exception as e:
|
||||
@@ -214,7 +214,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
if source in self.branch_node_cache:
|
||||
case_id = edge["sourceHandle"]
|
||||
if case_id == "false":
|
||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
||||
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
|
||||
else:
|
||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||
if source in self.error_branch_node_cache:
|
||||
@@ -257,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
||||
|
||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||
return ExecutionConfig()
|
||||
|
||||
|
||||
|
||||
@@ -4,65 +4,145 @@
|
||||
# @Time : 2026/2/25 14:11
|
||||
from typing import Any
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.adapters.base_adapter import (
|
||||
PlatformMetadata,
|
||||
PlatformType,
|
||||
BasePlatformAdapter,
|
||||
WorkflowParserResult
|
||||
)
|
||||
from app.schemas.workflow_schema import ExecutionConfig
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
|
||||
|
||||
|
||||
class MemoryBearAdapter(BasePlatformAdapter):
|
||||
NODE_TYPE_MAPPING = {}
|
||||
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
|
||||
|
||||
def __init__(self, config: dict[str, Any]):
|
||||
MemoryBearConverter.__init__(self)
|
||||
BasePlatformAdapter.__init__(self, config)
|
||||
|
||||
@property
|
||||
def origin_nodes(self):
|
||||
return self.config.get("workflow").get("nodes")
|
||||
return self.config.get("workflow").get("nodes") or []
|
||||
|
||||
@property
|
||||
def origin_edges(self):
|
||||
return self.config.get("workflow").get("edges")
|
||||
return self.config.get("workflow").get("edges") or []
|
||||
|
||||
@property
|
||||
def origin_variables(self):
|
||||
return self.config.get("workflow").get("variables")
|
||||
return self.config.get("workflow").get("variables") or []
|
||||
|
||||
def get_metadata(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
platform_name=PlatformType.MEMORY_BEAR,
|
||||
version="0.2.5",
|
||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||
support_node_types=list(VALID_NODE_TYPES)
|
||||
)
|
||||
|
||||
def map_node_type(self, platform_node_type) -> str:
|
||||
return platform_node_type
|
||||
def map_node_type(self, platform_node_type: str) -> NodeType:
|
||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||
|
||||
@staticmethod
|
||||
def _valid_nodes(node: dict[str, Any]):
|
||||
if "type" not in node["data"]:
|
||||
return False
|
||||
def _valid_node(node: dict[str, Any]) -> bool:
|
||||
if "id" not in node or "type" not in node:
|
||||
return False
|
||||
if not isinstance(node.get("config"), dict):
|
||||
return False
|
||||
return True
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
require_fields = frozenset({'app', 'workflow'})
|
||||
if not all(field in self.config for field in require_fields):
|
||||
return False
|
||||
|
||||
for node in self.origin_nodes:
|
||||
if not self._valid_nodes(node):
|
||||
if not self._valid_node(node):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||
node_id = node.get("id")
|
||||
node_name = node.get("name")
|
||||
try:
|
||||
node_type = self.map_node_type(node["type"])
|
||||
if node_type == NodeType.UNKNOWN:
|
||||
self.errors.append(UnsupportNodeType(
|
||||
node_id=node_id,
|
||||
node_type=node["type"]
|
||||
))
|
||||
return None
|
||||
|
||||
config = node.get("config") or {}
|
||||
converter = self.get_node_convert(node_type)
|
||||
converter(node_id, node_name, config) # validates and appends errors if invalid
|
||||
|
||||
return NodeDefinition(**node)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.NODE,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=f"convert node error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||
try:
|
||||
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||
))
|
||||
return None
|
||||
return EdgeDefinition(**edge)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.EDGE,
|
||||
detail=f"convert edge error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
|
||||
try:
|
||||
return VariableDefinition(**variable)
|
||||
except Exception as e:
|
||||
self.warnings.append(ExceptionDefineition(
|
||||
type=ExceptionType.VARIABLE,
|
||||
name=variable.get("name"),
|
||||
detail=f"convert variable error - {e}"
|
||||
))
|
||||
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def parse_workflow(self) -> WorkflowParserResult:
|
||||
self.nodes = self.origin_nodes
|
||||
self.edges = self.origin_edges
|
||||
self.conv_variables = self.origin_variables
|
||||
for node in self.origin_nodes:
|
||||
converted = self._convert_node(node)
|
||||
if converted:
|
||||
self.nodes.append(converted)
|
||||
|
||||
valid_node_ids = {n.id for n in self.nodes}
|
||||
|
||||
for edge in self.origin_edges:
|
||||
converted = self._convert_edge(edge, valid_node_ids)
|
||||
if converted:
|
||||
self.edges.append(converted)
|
||||
|
||||
for variable in self.origin_variables:
|
||||
converted = self._convert_variable(variable)
|
||||
if converted:
|
||||
self.conv_variables.append(converted)
|
||||
|
||||
return WorkflowParserResult(
|
||||
success=True,
|
||||
success=not self.errors and not self.warnings,
|
||||
platform=self.get_metadata(),
|
||||
execution_config=ExecutionConfig(),
|
||||
origin_config=self.config,
|
||||
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
|
||||
variables=self.conv_variables,
|
||||
warnings=self.warnings,
|
||||
errors=self.errors,
|
||||
|
||||
)
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.configs import (
|
||||
StartNodeConfig,
|
||||
EndNodeConfig,
|
||||
LLMNodeConfig,
|
||||
AgentNodeConfig,
|
||||
IfElseNodeConfig,
|
||||
KnowledgeRetrievalNodeConfig,
|
||||
AssignerNodeConfig,
|
||||
CodeNodeConfig,
|
||||
HttpRequestNodeConfig,
|
||||
JinjaRenderNodeConfig,
|
||||
VariableAggregatorNodeConfig,
|
||||
ParameterExtractorNodeConfig,
|
||||
LoopNodeConfig,
|
||||
IterationNodeConfig,
|
||||
QuestionClassifierNodeConfig,
|
||||
ToolNodeConfig,
|
||||
MemoryReadNodeConfig,
|
||||
MemoryWriteNodeConfig,
|
||||
NoteNodeConfig,
|
||||
)
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
|
||||
|
||||
class MemoryBearConverter(BaseConverter):
|
||||
errors: list
|
||||
warnings: list
|
||||
|
||||
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
|
||||
NodeType.START: StartNodeConfig,
|
||||
NodeType.END: EndNodeConfig,
|
||||
NodeType.ANSWER: EndNodeConfig,
|
||||
NodeType.LLM: LLMNodeConfig,
|
||||
NodeType.AGENT: AgentNodeConfig,
|
||||
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
|
||||
NodeType.ASSIGNER: AssignerNodeConfig,
|
||||
NodeType.CODE: CodeNodeConfig,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
|
||||
NodeType.JINJARENDER: JinjaRenderNodeConfig,
|
||||
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
|
||||
NodeType.LOOP: LoopNodeConfig,
|
||||
NodeType.ITERATION: IterationNodeConfig,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
|
||||
NodeType.TOOL: ToolNodeConfig,
|
||||
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||
NodeType.NOTES: NoteNodeConfig,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _convert_file(var):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _convert_array_file(var):
|
||||
return []
|
||||
|
||||
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
|
||||
try:
|
||||
return config_cls.model_validate(value)
|
||||
except Exception as e:
|
||||
self.errors.append(ExceptionDefineition(
|
||||
type=ExceptionType.CONFIG,
|
||||
node_id=node_id,
|
||||
node_name=node_name,
|
||||
detail=str(e)
|
||||
))
|
||||
return None
|
||||
|
||||
def get_node_convert(self, node_type: NodeType):
|
||||
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
|
||||
if not config_cls:
|
||||
return lambda node_id, node_name, config: config
|
||||
|
||||
def validate(node_id: str, node_name: str, config: dict):
|
||||
self.config_validate(node_id, node_name, config_cls, config)
|
||||
return config
|
||||
|
||||
return validate
|
||||
@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes import NodeFactory
|
||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||
from app.core.workflow.validator import WorkflowValidator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||
|
||||
|
||||
class GraphBuilder:
|
||||
def __init__(
|
||||
@@ -37,13 +49,13 @@ class GraphBuilder:
|
||||
self.stream = stream
|
||||
self.subgraph = subgraph
|
||||
|
||||
self.start_node_id = None
|
||||
self.end_node_ids = []
|
||||
self.start_node_id: str | None = None
|
||||
|
||||
self.node_map = {node["id"]: node for node in self.nodes}
|
||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||
self._find_upstream_branch_node = lru_cache(
|
||||
self._find_upstream_activation_dep = lru_cache(
|
||||
maxsize=len(self.nodes) * 2
|
||||
)(self._find_upstream_branch_node)
|
||||
)(self._find_upstream_activation_dep)
|
||||
if variable_pool:
|
||||
self.variable_pool = variable_pool
|
||||
else:
|
||||
@@ -51,10 +63,19 @@ class GraphBuilder:
|
||||
|
||||
self.graph = StateGraph(WorkflowState)
|
||||
self.add_nodes()
|
||||
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||
self.end_nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||
]
|
||||
self.add_edges()
|
||||
self._analyze_end_node_output()
|
||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||
|
||||
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||
self._build_reverse_adj()
|
||||
self._analyze_end_node_output()
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[dict[str, Any]]:
|
||||
return self.workflow_config.get("nodes", [])
|
||||
@@ -87,60 +108,50 @@ class GraphBuilder:
|
||||
result[node[0]].append(node[1])
|
||||
return result
|
||||
|
||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
||||
"""
|
||||
Recursively find all upstream branch (control) nodes that influence the execution
|
||||
of the given target node.
|
||||
def _build_reverse_adj(self):
|
||||
for edge in self.edges:
|
||||
if edge["source"] not in self.reachable_nodes:
|
||||
continue
|
||||
self._reverse_adj[edge.get("target")].append({
|
||||
"id": edge["source"], "branch": edge.get("label")
|
||||
})
|
||||
|
||||
This method walks upstream along the workflow graph starting from `target_node`.
|
||||
It distinguishes between:
|
||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
||||
- non-branch nodes (ordinary processing nodes)
|
||||
def _find_upstream_activation_dep(
|
||||
self,
|
||||
target_node: str
|
||||
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||
"""Find upstream dependencies that affect the activation of a target node.
|
||||
|
||||
Traversal rules:
|
||||
1. For each immediate upstream node:
|
||||
- If it is a branch node, it is recorded as an affecting control node.
|
||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
||||
a branch node, the traversal is considered invalid:
|
||||
- `has_branch` will be False
|
||||
- no branch nodes are returned.
|
||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
||||
branch node will `has_branch` be True.
|
||||
Walks upstream along the workflow graph from the target node, collecting
|
||||
two types of dependencies:
|
||||
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||
routing outcome determines whether the target node executes.
|
||||
- Output nodes: upstream END nodes that must complete their output
|
||||
before the target node can activate.
|
||||
|
||||
Special case:
|
||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
||||
it is considered directly reachable from the workflow entry, and therefore
|
||||
has no controlling branch nodes.
|
||||
The traversal terminates early and returns empty tuples if any upstream
|
||||
path reaches START/CYCLE_START without encountering a branch or output
|
||||
node, indicating the target node is directly reachable and should be
|
||||
activated immediately.
|
||||
|
||||
Args:
|
||||
target_node (str):
|
||||
The identifier of the node whose upstream control branches
|
||||
are to be resolved.
|
||||
target_node: The ID of the node whose upstream activation
|
||||
dependencies are to be resolved.
|
||||
|
||||
Returns:
|
||||
tuple[bool, tuple[tuple[str, str]]]:
|
||||
- has_branch (bool):
|
||||
True if every upstream path from `target_node` encounters
|
||||
at least one branch node.
|
||||
False if any path reaches a start node without a branch.
|
||||
- branch_nodes (tuple[tuple[str, str]]):
|
||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
||||
representing all branch nodes that can influence `target_node`.
|
||||
Returns an empty tuple if `has_branch` is False.
|
||||
A tuple of two elements:
|
||||
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||
representing upstream branch control dependencies. Empty if
|
||||
any clean path to START exists.
|
||||
- A deduplicated tuple of upstream output node IDs that must
|
||||
complete before this node activates.
|
||||
"""
|
||||
source_nodes = [
|
||||
{
|
||||
"id": edge.get("source"),
|
||||
"branch": edge.get("label")
|
||||
}
|
||||
for edge in self.edges
|
||||
if edge.get("target") == target_node
|
||||
]
|
||||
source_nodes = self._reverse_adj[target_node]
|
||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||
return False, tuple()
|
||||
return tuple(), tuple()
|
||||
|
||||
branch_nodes = []
|
||||
output_nodes = []
|
||||
non_branch_nodes = []
|
||||
|
||||
for node_info in source_nodes:
|
||||
@@ -149,19 +160,23 @@ class GraphBuilder:
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||
output_nodes.append(node_info["id"])
|
||||
non_branch_nodes.append(node_info["id"])
|
||||
|
||||
has_branch = True
|
||||
for node_id in non_branch_nodes:
|
||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
||||
has_branch = has_branch and node_has_branch
|
||||
if not has_branch:
|
||||
break
|
||||
branch_nodes.extend(nodes)
|
||||
if not has_branch:
|
||||
branch_nodes = []
|
||||
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||
if not upstream_control_nodes:
|
||||
if not upstream_output_nodes and node_id not in output_nodes:
|
||||
return tuple(), tuple()
|
||||
branch_nodes = []
|
||||
has_branch = False
|
||||
if has_branch:
|
||||
branch_nodes.extend(upstream_control_nodes)
|
||||
output_nodes.extend(upstream_output_nodes)
|
||||
|
||||
return has_branch, tuple(set(branch_nodes))
|
||||
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||
|
||||
def _analyze_end_node_output(self):
|
||||
"""
|
||||
@@ -182,11 +197,10 @@ class GraphBuilder:
|
||||
"""
|
||||
|
||||
# Collect all End nodes in the workflow
|
||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
||||
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||
|
||||
# Iterate through each End node to analyze its output
|
||||
for end_node in end_nodes:
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
config = end_node.get("config", {})
|
||||
output = config.get("output")
|
||||
@@ -195,42 +209,33 @@ class GraphBuilder:
|
||||
if not output:
|
||||
continue
|
||||
|
||||
# Regex to split output into:
|
||||
# - variable placeholders: {{ ... }}
|
||||
# - normal literal text
|
||||
#
|
||||
# Example:
|
||||
# "Hello {{user.name}}!" ->
|
||||
# ["Hello ", "{{user.name}}", "!"]
|
||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
||||
|
||||
# Strict variable format: {{ node_id.field_name }}
|
||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
||||
variable_pattern = re.compile(variable_pattern_string)
|
||||
|
||||
# Split output into ordered segments
|
||||
output_template = list(re.findall(pattern, output))
|
||||
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||
|
||||
# Determine whether each segment is literal text
|
||||
# True -> literal (can be directly output)
|
||||
# False -> variable placeholder (needs runtime value)
|
||||
output_flag = [
|
||||
not bool(variable_pattern.match(item))
|
||||
not bool(_VARIABLE_PATTERN.match(item))
|
||||
for item in output_template
|
||||
]
|
||||
|
||||
# Stream mode: output activation depends on upstream branch nodes
|
||||
if self.stream:
|
||||
# Find upstream branch nodes that can control this End node
|
||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
||||
|
||||
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||
# Build StreamOutputConfig for this End node
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
id=end_node_id,
|
||||
# If there is no upstream branch, output is active immediately
|
||||
activate=not has_branch,
|
||||
activate=activate,
|
||||
|
||||
# Branch nodes that control activation of this End node
|
||||
control_nodes=self._merge_control_nodes(control_nodes),
|
||||
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||
upstream_output_nodes=list(upstream_output_nodes),
|
||||
control_resolved=not bool(upstream_control_nodes),
|
||||
output_resolved=not bool(upstream_output_nodes),
|
||||
|
||||
# Convert output segments into OutputContent objects
|
||||
outputs=list(
|
||||
@@ -249,14 +254,16 @@ class GraphBuilder:
|
||||
cursor=0
|
||||
)
|
||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||
f"activate: {not has_branch}, "
|
||||
f"control_nodes: {control_nodes},"
|
||||
f"activate: {activate}, "
|
||||
f"control_nodes: {upstream_control_nodes},"
|
||||
f"ref_outputs: {upstream_output_nodes},"
|
||||
f"output: {output_template},"
|
||||
f"output_activate: {output_flag}")
|
||||
|
||||
# Non-stream mode: all outputs are activated by default
|
||||
else:
|
||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||
id=end_node_id,
|
||||
activate=True,
|
||||
control_nodes={},
|
||||
outputs=list(
|
||||
@@ -269,7 +276,10 @@ class GraphBuilder:
|
||||
for output_string, activate in zip(output_template, output_flag)
|
||||
]
|
||||
),
|
||||
cursor=0
|
||||
cursor=0,
|
||||
upstream_output_nodes=[],
|
||||
control_resolved=True,
|
||||
output_resolved=True,
|
||||
)
|
||||
|
||||
def add_nodes(self):
|
||||
@@ -304,8 +314,6 @@ class GraphBuilder:
|
||||
# Record start and end node IDs
|
||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||
self.start_node_id = node_id
|
||||
elif node_type == NodeType.END:
|
||||
self.end_node_ids.append(node_id)
|
||||
|
||||
# Create node instance (start and end nodes are also created)
|
||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||
@@ -448,7 +456,7 @@ class GraphBuilder:
|
||||
branch_activate = []
|
||||
new_state = state.copy()
|
||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
||||
node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
|
||||
for label, branch in unique_branch.items():
|
||||
if node_output and evaluate_condition(
|
||||
branch["condition"],
|
||||
@@ -494,9 +502,11 @@ class GraphBuilder:
|
||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||
|
||||
# Connect End nodes to the global END node
|
||||
for end_node_id in self.end_node_ids:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
for end_node in self.end_nodes:
|
||||
end_node_id = end_node.get("id")
|
||||
if end_node_id:
|
||||
self.graph.add_edge(end_node_id, END)
|
||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||
return
|
||||
|
||||
def build(self) -> CompiledStateGraph:
|
||||
|
||||
@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
|
||||
variable_pool: VariablePool,
|
||||
elapsed_time: float,
|
||||
final_output: str,
|
||||
success: bool
|
||||
):
|
||||
"""Construct the final standardized output of the workflow execution.
|
||||
|
||||
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
|
||||
elapsed_time (float): Total execution time in seconds.
|
||||
final_output (Any): The aggregated or final output content of the workflow
|
||||
(e.g., combined messages from all End nodes).
|
||||
success (bool): Whether the execution was successful.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the final workflow execution result with keys:
|
||||
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
|
||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||
|
||||
return {
|
||||
"status": "completed",
|
||||
"status": "completed" if success else "failed",
|
||||
"output": final_output,
|
||||
"variables": {
|
||||
"conv": variable_pool.get_all_conversation_vars(),
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/9 15:11
|
||||
import re
|
||||
from queue import Queue
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this output segment is currently active.\n"
|
||||
"- True: allowed to be emitted/output\n"
|
||||
"Whether this output segment is currently active."
|
||||
"- True: allowed to be emitted/output"
|
||||
"- False: blocked until activated by branch control"
|
||||
)
|
||||
)
|
||||
@@ -46,16 +47,17 @@ class OutputContent(BaseModel):
|
||||
is_variable: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether this segment represents a variable placeholder.\n"
|
||||
"True -> variable (e.g. {{ node.field }})\n"
|
||||
"Whether this segment represents a variable placeholder."
|
||||
"True -> variable (e.g. {{ node.field }})"
|
||||
"False -> literal text"
|
||||
)
|
||||
)
|
||||
|
||||
_SCOPE: str | None = None
|
||||
_SCOPE: str | None = PrivateAttr(default=None)
|
||||
|
||||
def get_scope(self) -> str:
|
||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
||||
def get_scope(self) -> str | None:
|
||||
matches = SCOPE_PATTERN.findall(self.literal)
|
||||
self._SCOPE = matches[0] if matches else None
|
||||
return self._SCOPE
|
||||
|
||||
def depends_on_scope(self, scope: str) -> bool:
|
||||
@@ -68,6 +70,8 @@ class OutputContent(BaseModel):
|
||||
Returns:
|
||||
bool: True if this segment references the given scope.
|
||||
"""
|
||||
if not self.is_variable:
|
||||
return False
|
||||
if self._SCOPE:
|
||||
return self._SCOPE == scope
|
||||
return self.get_scope() == scope
|
||||
@@ -83,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
||||
- which upstream branch/control nodes gate the activation
|
||||
- how each parsed output segment is streamed and activated
|
||||
"""
|
||||
id: str = Field(
|
||||
...,
|
||||
description="ID of the End node this configuration belongs to."
|
||||
)
|
||||
|
||||
activate: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Global activation flag for the End node output.\n"
|
||||
"When False, output segments should not be emitted even if available.\n"
|
||||
"Global activation flag for the End node output."
|
||||
"When False, output segments should not be emitted even if available."
|
||||
"This flag typically becomes True once required control branch conditions "
|
||||
"are satisfied."
|
||||
)
|
||||
@@ -97,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
||||
control_nodes: dict[str, list[str]] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Control branch conditions for this End node output.\n"
|
||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
||||
"Control branch conditions for this End node output."
|
||||
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||
"The End node output becomes globally active when a controlling branch node "
|
||||
"reports a matching completion status."
|
||||
)
|
||||
)
|
||||
|
||||
upstream_output_nodes: list[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Upstream output node dependencies (data flow)."
|
||||
"Represents END/output nodes that this output depends on."
|
||||
"These nodes provide data sources required before this output can be activated "
|
||||
"or streamed."
|
||||
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||
)
|
||||
)
|
||||
|
||||
control_resolved: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether all upstream branch control dependencies have been satisfied."
|
||||
"True if no upstream branch nodes exist or the required branch "
|
||||
"conditions have been met."
|
||||
)
|
||||
)
|
||||
|
||||
output_resolved: bool = Field(
|
||||
...,
|
||||
description=(
|
||||
"Whether all upstream output node dependencies have been completed."
|
||||
"True if no upstream output nodes exist or all upstream output "
|
||||
"nodes have finished their output."
|
||||
)
|
||||
)
|
||||
|
||||
outputs: list[OutputContent] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Ordered list of output segments parsed from the output template.\n"
|
||||
"Ordered list of output segments parsed from the output template."
|
||||
"Each segment represents either a literal text block or a variable placeholder "
|
||||
"that may be activated independently."
|
||||
)
|
||||
@@ -116,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
||||
cursor: int = Field(
|
||||
...,
|
||||
description=(
|
||||
"Streaming cursor index.\n"
|
||||
"Indicates the next output segment index to be emitted.\n"
|
||||
"Streaming cursor index."
|
||||
"Indicates the next output segment index to be emitted."
|
||||
"Segments with index < cursor are considered already streamed."
|
||||
)
|
||||
)
|
||||
|
||||
force: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Force flag for output emission."
|
||||
"When True, all output segments are emitted regardless of activation state."
|
||||
"Triggered when this output node has finished execution."
|
||||
)
|
||||
)
|
||||
|
||||
def update_activate(self, scope: str, status=None):
|
||||
"""
|
||||
Update streaming activation state based on an upstream node or special variable.
|
||||
Update streaming activation state based on upstream events.
|
||||
|
||||
Args:
|
||||
scope (str):
|
||||
Identifier of the completed upstream entity.
|
||||
- If a control branch node, it should match a key in `control_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
||||
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||
it may appear in output segments.
|
||||
|
||||
status (optional):
|
||||
Completion status of the control branch node.
|
||||
Required when `scope` refers to a control node.
|
||||
|
||||
Behavior:
|
||||
1. Control branch nodes:
|
||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
||||
branch label, the End node output becomes globally active (`activate = True`).
|
||||
1. Force activation:
|
||||
- If `self.force` is True, the method returns immediately.
|
||||
- If `scope == self.id`, the node marks itself as completed:
|
||||
- `activate = True`
|
||||
- `force = True`
|
||||
This is typically used for final flushing when the node finishes execution.
|
||||
|
||||
2. Variable output segments:
|
||||
- For each segment that is a variable (`is_variable=True`):
|
||||
- If the segment literal references `scope`, mark the segment as active.
|
||||
- This applies both to regular node variables (e.g., "node_id.field")
|
||||
and special system variables (e.g., "sys.xxx").
|
||||
2. Control dependency resolution:
|
||||
- If `scope` matches a key in `control_nodes`:
|
||||
- `status` must be provided.
|
||||
- If `status` matches expected branch labels, mark control as resolved
|
||||
(`control_resolved = True`).
|
||||
|
||||
3. Upstream output dependency resolution:
|
||||
- If `scope` is in `upstream_output_nodes`,
|
||||
mark data dependency as resolved (`output_resolved = True`).
|
||||
|
||||
4. Global activation condition:
|
||||
- The node becomes active when BOTH conditions are satisfied:
|
||||
- control_resolved == True
|
||||
- output_resolved == True
|
||||
- Once activated, `activate` remains True.
|
||||
|
||||
5. Variable segment activation:
|
||||
- For each output segment that is a variable (`is_variable=True`):
|
||||
- If the segment depends on the given `scope`,
|
||||
mark the segment as active.
|
||||
- This applies to both node variables (e.g., "node_id.field")
|
||||
and system variables (e.g., "sys.xxx").
|
||||
|
||||
Notes:
|
||||
- This method does not emit output or advance the streaming cursor.
|
||||
- It only updates activation flags based on upstream events or special variables.
|
||||
- This method does NOT emit output or advance the streaming cursor.
|
||||
- It only updates activation and dependency resolution states.
|
||||
- Activation is driven by both control flow (branch nodes) and
|
||||
data flow (upstream output nodes).
|
||||
"""
|
||||
if self.force:
|
||||
return
|
||||
|
||||
# Case 1: resolve control branch dependency
|
||||
if scope in self.control_nodes.keys():
|
||||
if scope == self.id:
|
||||
self.activate = True
|
||||
self.force = True
|
||||
return
|
||||
|
||||
# resolve control branch dependency
|
||||
if scope in self.control_nodes:
|
||||
if status is None:
|
||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||
if status in self.control_nodes[scope]:
|
||||
self.activate = True
|
||||
self.control_resolved = True
|
||||
|
||||
# Case 2: activate variable segments related to this node
|
||||
if scope in self.upstream_output_nodes:
|
||||
self.upstream_output_nodes.remove(scope)
|
||||
if not self.upstream_output_nodes:
|
||||
self.output_resolved = True
|
||||
|
||||
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||
|
||||
# activate variable segments related to this node
|
||||
for i in range(len(self.outputs)):
|
||||
if (
|
||||
self.outputs[i].is_variable
|
||||
@@ -171,12 +256,17 @@ class StreamOutputCoordinator:
|
||||
def __init__(self):
|
||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||
self.activate_end: str | None = None
|
||||
self.output_queue: Queue = Queue()
|
||||
self.processed_outputs = []
|
||||
|
||||
def initialize_end_outputs(
|
||||
self,
|
||||
end_node_map: dict[str, StreamOutputConfig]
|
||||
):
|
||||
self.end_outputs = end_node_map
|
||||
self.processed_outputs = []
|
||||
self.activate_end = None
|
||||
self.output_queue = Queue()
|
||||
|
||||
@property
|
||||
def current_activate_end_info(self):
|
||||
@@ -208,8 +298,11 @@ class StreamOutputCoordinator:
|
||||
"""
|
||||
for node in self.end_outputs.keys():
|
||||
self.end_outputs[node].update_activate(scope, status)
|
||||
if self.end_outputs[node].activate and self.activate_end is None:
|
||||
self.activate_end = node
|
||||
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||
self.output_queue.put(node)
|
||||
self.processed_outputs.append(node)
|
||||
if self.activate_end is None and not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
|
||||
async def emit_activate_chunk(
|
||||
self,
|
||||
@@ -253,7 +346,7 @@ class StreamOutputCoordinator:
|
||||
final_chunk = ''
|
||||
current_segment = end_info.outputs[end_info.cursor]
|
||||
|
||||
if not current_segment.activate and not force:
|
||||
if not current_segment.activate and not force and not end_info.force:
|
||||
# Stop processing until this segment becomes active
|
||||
break
|
||||
|
||||
@@ -270,7 +363,7 @@ class StreamOutputCoordinator:
|
||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||
|
||||
if final_chunk:
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
@@ -282,8 +375,7 @@ class StreamOutputCoordinator:
|
||||
end_info.cursor += 1
|
||||
|
||||
if end_info.cursor >= len(end_info.outputs):
|
||||
self.end_outputs.pop(self.activate_end)
|
||||
self.activate_end = None
|
||||
self.pop_current_activate_end()
|
||||
|
||||
async def flush_remaining_chunk(
|
||||
self,
|
||||
@@ -322,6 +414,8 @@ class StreamOutputCoordinator:
|
||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||
yield msg_event
|
||||
|
||||
if not self.output_queue.empty():
|
||||
self.activate_end = self.output_queue.get_nowait()
|
||||
# Move to next active End node if current one is done
|
||||
if not self.activate_end and self.end_outputs:
|
||||
self.activate_end = list(self.end_outputs.keys())[0]
|
||||
|
||||
@@ -351,12 +351,12 @@ class VariablePool:
|
||||
}
|
||||
return runtime_vars
|
||||
|
||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||
"""获取指定节点的输出(运行时变量)
|
||||
|
||||
Args:
|
||||
node_id: 节点 ID
|
||||
defalut: 默认值
|
||||
default: 默认值
|
||||
strict: 是否严格模式
|
||||
|
||||
Returns:
|
||||
@@ -368,7 +368,7 @@ class VariablePool:
|
||||
if strict:
|
||||
raise KeyError(f"node {node_id} output not exist")
|
||||
else:
|
||||
return defalut
|
||||
return default
|
||||
|
||||
def copy(self, pool: 'VariablePool'):
|
||||
self.variables = deepcopy(pool.variables)
|
||||
|
||||
@@ -128,89 +128,100 @@ class WorkflowExecutor:
|
||||
- token_usage: aggregated token usage if available
|
||||
- error: error message if any
|
||||
"""
|
||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
# Execute the workflow
|
||||
try:
|
||||
# Build the workflow graph
|
||||
graph = self.build_graph()
|
||||
|
||||
# Initialize the variable pool with input data
|
||||
await self.variable_initializer.initialize(
|
||||
variable_pool=self.variable_pool,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context
|
||||
)
|
||||
initial_state = self.state_manager.create_initial_state(
|
||||
workflow_config=self.workflow_config,
|
||||
input_data=input_data,
|
||||
execution_context=self.execution_context,
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
|
||||
# Aggregate output from all End nodes
|
||||
full_content = ''
|
||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
|
||||
# Append messages for user and assistant
|
||||
if input_data.get("files"):
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("files")
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
else:
|
||||
result["messages"].extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": input_data.get("message", '')
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": full_content
|
||||
}
|
||||
]
|
||||
)
|
||||
# Calculate elapsed time
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
|
||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
|
||||
except Exception as e:
|
||||
end_time = datetime.datetime.now()
|
||||
elapsed_time = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"output": None,
|
||||
"node_outputs": {},
|
||||
"elapsed_time": elapsed_time,
|
||||
"token_usage": None
|
||||
}
|
||||
start = datetime.datetime.now()
|
||||
async for event in self.execute_stream(input_data):
|
||||
if event.get("event") == "workflow_end":
|
||||
return event.get("data")
|
||||
return self.result_builder.build_final_output(
|
||||
{"error": "Workflow execution did not end as expected"},
|
||||
self.variable_pool,
|
||||
(datetime.datetime.now() - start).total_seconds(),
|
||||
"",
|
||||
success=False
|
||||
)
|
||||
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||
#
|
||||
# start_time = datetime.datetime.now()
|
||||
#
|
||||
# # Execute the workflow
|
||||
# try:
|
||||
# # Build the workflow graph
|
||||
# graph = self.build_graph()
|
||||
#
|
||||
# # Initialize the variable pool with input data
|
||||
# await self.variable_initializer.initialize(
|
||||
# variable_pool=self.variable_pool,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context
|
||||
# )
|
||||
# initial_state = self.state_manager.create_initial_state(
|
||||
# workflow_config=self.workflow_config,
|
||||
# input_data=input_data,
|
||||
# execution_context=self.execution_context,
|
||||
# start_node_id=self.start_node_id
|
||||
# )
|
||||
#
|
||||
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||
#
|
||||
# # Aggregate output from all End nodes
|
||||
# full_content = ''
|
||||
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||
#
|
||||
# # Append messages for user and assistant
|
||||
# if input_data.get("files"):
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("files")
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# else:
|
||||
# result["messages"].extend(
|
||||
# [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": input_data.get("message", '')
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": full_content
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# # Calculate elapsed time
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.info(
|
||||
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||
#
|
||||
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
#
|
||||
# except Exception as e:
|
||||
# end_time = datetime.datetime.now()
|
||||
# elapsed_time = (end_time - start_time).total_seconds()
|
||||
#
|
||||
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
# exc_info=True)
|
||||
# return {
|
||||
# "status": "failed",
|
||||
# "error": str(e),
|
||||
# "output": None,
|
||||
# "node_outputs": {},
|
||||
# "elapsed_time": elapsed_time,
|
||||
# "token_usage": None
|
||||
# }
|
||||
|
||||
async def execute_stream(
|
||||
self,
|
||||
@@ -248,7 +259,8 @@ class WorkflowExecutor:
|
||||
"timestamp": int(start_time.timestamp() * 1000)
|
||||
}
|
||||
}
|
||||
|
||||
result = None
|
||||
full_content = ''
|
||||
try:
|
||||
# Build the workflow graph in streaming mode
|
||||
graph = self.build_graph(stream=True)
|
||||
@@ -266,7 +278,6 @@ class WorkflowExecutor:
|
||||
start_node_id=self.start_node_id
|
||||
)
|
||||
|
||||
full_content = ''
|
||||
self.stream_coordinator.update_scope_activation("sys")
|
||||
|
||||
# Execute the workflow with streaming
|
||||
@@ -363,7 +374,12 @@ class WorkflowExecutor:
|
||||
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
success=True)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -372,16 +388,19 @@ class WorkflowExecutor:
|
||||
|
||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||
exc_info=True)
|
||||
|
||||
if result is None:
|
||||
result = {"error": str(e)}
|
||||
else:
|
||||
result["error"] = str(e)
|
||||
yield {
|
||||
"event": "workflow_end",
|
||||
"data": {
|
||||
"execution_id": self.execution_context.execution_id,
|
||||
"status": "failed",
|
||||
"error": str(e),
|
||||
"elapsed_time": elapsed_time,
|
||||
"timestamp": end_time.isoformat()
|
||||
}
|
||||
"data": self.result_builder.build_final_output(
|
||||
result,
|
||||
self.variable_pool,
|
||||
elapsed_time,
|
||||
full_content,
|
||||
success=False
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from functools import cached_property
|
||||
@@ -15,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.db import get_db_read
|
||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||
from app.schemas import FileInput
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -619,11 +621,12 @@ class BaseNode(ABC):
|
||||
|
||||
@staticmethod
|
||||
async def process_message(
|
||||
provider: str,
|
||||
is_omni: bool,
|
||||
api_config: ModelInfo,
|
||||
content: str | dict | FileObject,
|
||||
end_user_id: str,
|
||||
enable_file=False
|
||||
) -> list | str | None:
|
||||
provider = api_config.provider
|
||||
if isinstance(content, dict):
|
||||
content = FileObject(
|
||||
type=content.get("type"),
|
||||
@@ -642,16 +645,20 @@ class BaseNode(ABC):
|
||||
if content.content_cache.get(provider):
|
||||
return content.content_cache[provider]
|
||||
with get_db_read() as db:
|
||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
||||
message = await multimodel_service.process_files(
|
||||
[FileInput.model_construct(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
file_type=content.origin_file_type,
|
||||
upload_file_id=content.file_id
|
||||
)]
|
||||
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||
file_obj = FileInput(
|
||||
type=content.type,
|
||||
url=content.url,
|
||||
transfer_method=content.transfer_method,
|
||||
origin_file_type=content.origin_file_type,
|
||||
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||
)
|
||||
file_obj.set_content(content.get_content())
|
||||
message = await multimodel_service.process_files(
|
||||
end_user_id,
|
||||
[file_obj],
|
||||
)
|
||||
content.set_content(file_obj.get_content())
|
||||
if message:
|
||||
content.content_cache[provider] = message
|
||||
return message
|
||||
|
||||
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient(timeout=60) as client:
|
||||
response = await client.post(
|
||||
"http://sandbox:8194/v1/sandbox/run",
|
||||
headers={
|
||||
|
||||
@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||
|
||||
__all__ = [
|
||||
# 基础类
|
||||
@@ -47,5 +48,6 @@ __all__ = [
|
||||
"ToolNodeConfig",
|
||||
"MemoryReadNodeConfig",
|
||||
"MemoryWriteNodeConfig",
|
||||
"CodeNodeConfig"
|
||||
"CodeNodeConfig",
|
||||
"NoteNodeConfig"
|
||||
]
|
||||
|
||||
@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
right: Any = Field(
|
||||
...,
|
||||
default=None,
|
||||
description="Right-hand operand of the comparison expression"
|
||||
)
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ class LoopRuntime:
|
||||
self.variable_pool.variables["conv"].update(
|
||||
self.child_variable_pool.variables["conv"]
|
||||
)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
|
||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||
|
||||
def evaluate_conditional(self) -> bool:
|
||||
@@ -261,4 +261,4 @@ class LoopRuntime:
|
||||
idx += 1
|
||||
|
||||
logger.info(f"loop node {self.node_id}: execution completed")
|
||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
||||
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import FileObject
|
||||
|
||||
|
||||
class HttpAuthConfig(BaseModel):
|
||||
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
|
||||
description="Http response headers"
|
||||
)
|
||||
|
||||
files: list[FileObject] = Field(
|
||||
default_factory=list,
|
||||
description="List of files",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="SUCCESS",
|
||||
description="HTTP response body",
|
||||
|
||||
@@ -1,24 +1,146 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import uuid
|
||||
import imghdr
|
||||
from email.message import Message
|
||||
from typing import Any, Callable, Coroutine
|
||||
|
||||
import httpx
|
||||
# import filetypes # TODO: File support (Feature)
|
||||
from httpx import AsyncClient, Response, Timeout
|
||||
import magic
|
||||
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||
from app.schemas import FileType, TransferMethod
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
class HttpResponse:
|
||||
def __init__(self, response: httpx.Response):
|
||||
self.response = response
|
||||
self.headers = dict(response.headers)
|
||||
|
||||
self._is_file: bool | None = None
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
return self.headers.get("content-type", "")
|
||||
|
||||
@property
|
||||
def content_disposition(self) -> Message | None:
|
||||
content_disposition = self.headers.get("content-disposition", "")
|
||||
if content_disposition:
|
||||
msg = Message()
|
||||
msg["content-disposition"] = content_disposition
|
||||
return msg
|
||||
return None
|
||||
|
||||
@property
|
||||
def is_file(self) -> bool:
|
||||
if self._is_file is not None:
|
||||
return self._is_file
|
||||
content_type = self.content_type.split(";")[0].strip().lower()
|
||||
|
||||
parsed_content_disposition = self.content_disposition
|
||||
if parsed_content_disposition:
|
||||
disp_type = parsed_content_disposition.get_content_disposition()
|
||||
filename = parsed_content_disposition.get_filename()
|
||||
if disp_type == "attachment" or filename:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
if content_type.startswith("text/") and "csv" not in content_type:
|
||||
return False
|
||||
|
||||
if content_type.startswith("application/"):
|
||||
if any(
|
||||
text_type in content_type
|
||||
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
|
||||
):
|
||||
self._is_file = False
|
||||
return False
|
||||
try:
|
||||
content_sample = self.response.content[:1024]
|
||||
content_sample.decode("utf-8")
|
||||
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||
if any(marker in content_sample for marker in text_markers):
|
||||
return False
|
||||
except UnicodeDecodeError:
|
||||
self._is_file = True
|
||||
return True
|
||||
|
||||
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||
if main_type:
|
||||
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||
return self._is_file
|
||||
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||
return self._is_file
|
||||
|
||||
@property
|
||||
def is_image(self):
|
||||
if self.is_file:
|
||||
kind = imghdr.what(None, h=self.response.content)
|
||||
return kind is not None
|
||||
return False
|
||||
|
||||
@property
|
||||
def url(self) -> str:
|
||||
return str(self.response.url)
|
||||
|
||||
@property
|
||||
def body(self) -> str:
|
||||
if self.is_file:
|
||||
return f"{'!' if self.is_image else ''}[file]({self.url})"
|
||||
return self.response.text
|
||||
|
||||
@staticmethod
|
||||
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
|
||||
mime = magic.from_buffer(file_bytes, mime=True)
|
||||
|
||||
if mime.startswith("image"):
|
||||
return FileType.IMAGE, mime
|
||||
elif mime.startswith("video"):
|
||||
return FileType.VIDEO, mime
|
||||
elif mime.startswith("audio"):
|
||||
return FileType.AUDIO, mime
|
||||
elif mime in ["application/pdf",
|
||||
"application/msword",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"text/plain"]:
|
||||
return FileType.DOCUMENT, mime
|
||||
return None, None
|
||||
|
||||
@property
|
||||
def files(self) -> list[FileObject]:
|
||||
file_type, mime_type = self.get_file_type(self.response.content)
|
||||
origin_file_type = mime_to_file_type(mime_type)
|
||||
if self.is_file and file_type and origin_file_type:
|
||||
file_obj = FileObject(
|
||||
type=file_type,
|
||||
url=self.url,
|
||||
transfer_method=TransferMethod.REMOTE_URL.value,
|
||||
origin_file_type=origin_file_type,
|
||||
file_id=None,
|
||||
is_file=True
|
||||
)
|
||||
file_obj.set_content(self.response.content)
|
||||
return [
|
||||
file_obj
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
class HttpRequestNode(BaseNode):
|
||||
"""
|
||||
HTTP Request Workflow Node.
|
||||
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
|
||||
"body": VariableType.STRING,
|
||||
"status_code": VariableType.NUMBER,
|
||||
"headers": VariableType.OBJECT,
|
||||
"files": VariableType.ARRAY_FILE,
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
|
||||
)
|
||||
resp.raise_for_status()
|
||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||
response = HttpResponse(resp)
|
||||
return HttpRequestNodeOutput(
|
||||
body=resp.text,
|
||||
body=response.body,
|
||||
status_code=resp.status_code,
|
||||
headers=resp.headers,
|
||||
files=response.files
|
||||
).model_dump()
|
||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||
logger.error(f"HTTP request node exception: {e}")
|
||||
|
||||
@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
|
||||
)
|
||||
|
||||
right: Any = Field(
|
||||
...,
|
||||
default=None,
|
||||
description="Value to compare with"
|
||||
)
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Any
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
|
||||
"output": VariableType.STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
result = []
|
||||
for case in self.typed_config.cases:
|
||||
expressions = []
|
||||
for expression in case.expressions:
|
||||
expressions.append({
|
||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||
"right": expression.right
|
||||
if expression.input_type == ValueInputType.CONSTANT or expression.right is None
|
||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||
"operator": str(expression.operator),
|
||||
})
|
||||
result.append({
|
||||
"expressions": expressions,
|
||||
"logical_operator": str(case.logical_operator),
|
||||
})
|
||||
return {
|
||||
"cases": result
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||
match operator:
|
||||
|
||||
@@ -30,6 +30,12 @@ class KnowledgeRetrievalNode(BaseNode):
|
||||
"output": VariableType.ARRAY_STRING
|
||||
}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||
"""
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
from app.models import ModelType
|
||||
from app.schemas.model_schema import ModelInfo
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 在 Session 关闭前提取所有需要的数据
|
||||
api_config = self.model_balance(config)
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
is_omni = api_config.is_omni
|
||||
model_type = config.type
|
||||
model_info = ModelInfo(
|
||||
model_name=api_config.model_name,
|
||||
model_type=ModelType(config.type),
|
||||
api_key=api_config.api_key,
|
||||
api_base=api_config.api_base,
|
||||
provider=api_config.provider,
|
||||
is_omni=api_config.is_omni,
|
||||
capability=api_config.capability
|
||||
)
|
||||
|
||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
model_name=model_info.model_name,
|
||||
provider=model_info.provider,
|
||||
api_key=model_info.api_key,
|
||||
base_url=model_info.api_base,
|
||||
extra_params=extra_params,
|
||||
is_omni=is_omni
|
||||
is_omni=model_info.is_omni
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
type=model_info.model_type
|
||||
)
|
||||
|
||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
||||
logger.debug(
|
||||
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||
|
||||
messages_config = self.typed_config.messages
|
||||
|
||||
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
|
||||
content_template = msg_config.content
|
||||
content_template = self._render_context(content_template, variable_pool)
|
||||
content = self._render_template(content_template, variable_pool)
|
||||
|
||||
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||
# 根据角色创建对应的消息对象
|
||||
if role == "system":
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(
|
||||
model_info,
|
||||
content,
|
||||
user_id,
|
||||
self.typed_config.vision,
|
||||
)
|
||||
})
|
||||
elif role in ["user", "human"]:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
elif role in ["ai", "assistant"]:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
else:
|
||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
||||
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||
})
|
||||
|
||||
if self.typed_config.vision_input and self.typed_config.vision:
|
||||
file_content = []
|
||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||
for file in files.value:
|
||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
if messages and messages[-1]["role"] == 'user':
|
||||
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
|
||||
if isinstance(message["content"], list):
|
||||
file_content = []
|
||||
for file in message["content"]:
|
||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
||||
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||
if content:
|
||||
file_content.extend(content)
|
||||
history_message.append(
|
||||
{"role": message["role"], "content": file_content}
|
||||
)
|
||||
else:
|
||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
||||
message["content"] = await self.process_message(
|
||||
model_info,
|
||||
message["content"],
|
||||
user_id,
|
||||
self.typed_config.vision
|
||||
)
|
||||
history_message.append(message)
|
||||
messages = messages[:-1] + history_message + messages[-1:]
|
||||
self.messages = messages
|
||||
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
||||
async for chunk in llm.astream(self.messages):
|
||||
# 提取内容
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
|
||||
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from pydantic import Field
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||
|
||||
|
||||
class NoteNodeConfig(BaseNodeConfig):
|
||||
author: str = Field(default="", description="author")
|
||||
text: str = Field(default="", description="note content")
|
||||
width: int = Field(default=80)
|
||||
height: int = Field(default=80)
|
||||
theme: str = Field(default="blue")
|
||||
show_author: bool = Field(default=True)
|
||||
@@ -250,6 +250,8 @@ class ConditionBase(ABC):
|
||||
self.type_limit = getattr(self, "type_limit", None)
|
||||
|
||||
def resolve_right_literal_value(self):
|
||||
if self.right_selector is None:
|
||||
return None
|
||||
if self.input_type == ValueInputType.VARIABLE:
|
||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||
|
||||
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
return {
|
||||
"text": self._render_template(self.typed_config.text, variable_pool),
|
||||
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
|
||||
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
|
||||
"model_id": str(self.typed_config.model_id),
|
||||
}
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
outputs = {}
|
||||
for param in self.typed_config.params:
|
||||
|
||||
@@ -27,7 +27,6 @@ class ToolNode(BaseNode):
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {
|
||||
"data": VariableType.STRING,
|
||||
"error_code": VariableType.STRING,
|
||||
"execution_time": VariableType.NUMBER
|
||||
}
|
||||
|
||||
@@ -48,10 +47,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
if not tenant_id:
|
||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||
return {
|
||||
"success": False,
|
||||
"data": "缺少租户ID"
|
||||
}
|
||||
raise ValueError("缺少租户ID")
|
||||
|
||||
# 渲染工具参数
|
||||
rendered_parameters = {}
|
||||
@@ -83,13 +79,8 @@ class ToolNode(BaseNode):
|
||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||
return {
|
||||
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
|
||||
"error_code": "",
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
else:
|
||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||
return {
|
||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
||||
"error_code": result.error_code,
|
||||
"execution_time": result.execution_time
|
||||
}
|
||||
raise ValueError(f"工具执行失败: {result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False)}")
|
||||
|
||||
56
api/app/core/workflow/utils/file_processer.py
Normal file
56
api/app/core/workflow/utils/file_processer.py
Normal file
@@ -0,0 +1,56 @@
|
||||
# -*- coding: UTF-8 -*-
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/3/10 13:36
|
||||
TRANSFORM_FILE_TYPE = {
|
||||
'text/plain': 'document/text',
|
||||
'text/markdown': 'document/markdown',
|
||||
'text/x-markdown': 'document/x-markdown',
|
||||
|
||||
'application/pdf': 'document/pdf',
|
||||
|
||||
'application/msword': 'document/doc',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||
|
||||
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||
}
|
||||
ALLOWED_FILE_TYPES = [
|
||||
'text/plain',
|
||||
'text/markdown',
|
||||
'text/x-markdown',
|
||||
'application/pdf',
|
||||
'application/msword',
|
||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||
'application/vnd.ms-powerpoint',
|
||||
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||
'image/jpg',
|
||||
'image/jpeg',
|
||||
'image/png',
|
||||
'image/gif',
|
||||
'image/bmp',
|
||||
'image/webp',
|
||||
'image/svg+xml',
|
||||
'video/mp4',
|
||||
'video/quicktime',
|
||||
'video/x-msvideo',
|
||||
'video/x-matroska',
|
||||
'video/webm',
|
||||
'video/x-flv',
|
||||
'video/x-ms-wmv',
|
||||
'audio/mpeg',
|
||||
'audio/wav',
|
||||
'audio/ogg',
|
||||
'audio/aac',
|
||||
'audio/flac',
|
||||
'audio/mp4',
|
||||
'audio/x-ms-wma',
|
||||
'audio/x-m4a',
|
||||
]
|
||||
|
||||
|
||||
def mime_to_file_type(mime_type):
|
||||
if mime_type not in ALLOWED_FILE_TYPES:
|
||||
return None
|
||||
|
||||
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
||||
# 仅在发布时验证所有节点可达
|
||||
# 6. 验证所有节点可达(从 start 节点出发)
|
||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||
reachable = WorkflowValidator._get_reachable_nodes(
|
||||
reachable = WorkflowValidator.get_reachable_nodes(
|
||||
start_nodes[0]["id"],
|
||||
edges
|
||||
)
|
||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
||||
return len(errors) == 0, errors
|
||||
|
||||
@staticmethod
|
||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||
"""获取从 start 节点可达的所有节点
|
||||
|
||||
Args:
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from app.schemas import FileType
|
||||
|
||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
||||
"""
|
||||
if isinstance(var, str):
|
||||
return cls.STRING
|
||||
elif isinstance(var, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, bool):
|
||||
return cls.BOOLEAN
|
||||
elif isinstance(var, (int, float)):
|
||||
return cls.NUMBER
|
||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||
return cls.FILE
|
||||
elif isinstance(var, dict):
|
||||
@@ -114,9 +114,16 @@ class FileObject(BaseModel):
|
||||
file_id: str | None
|
||||
|
||||
content_cache: dict = Field(default_factory=dict)
|
||||
|
||||
is_file: bool
|
||||
|
||||
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||
|
||||
def get_content(self):
|
||||
return self._byte_content
|
||||
|
||||
def set_content(self, byte_content):
|
||||
self._byte_content = byte_content
|
||||
|
||||
|
||||
class BaseVariable(ABC):
|
||||
"""Abstract base class for all workflow variables.
|
||||
|
||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
||||
|
||||
|
||||
class StringVariable(BaseVariable):
|
||||
value: str
|
||||
type = 'str'
|
||||
|
||||
def valid_value(self, value) -> str:
|
||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
||||
|
||||
|
||||
class NumberVariable(BaseVariable):
|
||||
value: int | float
|
||||
type = 'number'
|
||||
|
||||
def valid_value(self, value) -> int | float:
|
||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
||||
|
||||
|
||||
class BooleanVariable(BaseVariable):
|
||||
value: bool
|
||||
type = 'boolean'
|
||||
|
||||
def valid_value(self, value) -> bool:
|
||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
||||
|
||||
|
||||
class DictVariable(BaseVariable):
|
||||
value: dict
|
||||
type = 'object'
|
||||
|
||||
def valid_value(self, value) -> dict:
|
||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
||||
|
||||
|
||||
class FileVariable(BaseVariable):
|
||||
value: FileObject
|
||||
type = 'file'
|
||||
|
||||
def valid_value(self, value) -> FileObject:
|
||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
||||
|
||||
|
||||
class ArrayVariable(BaseVariable, Generic[T]):
|
||||
value: list[T]
|
||||
type = 'array'
|
||||
|
||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
||||
|
||||
|
||||
class NestedArrayVariable(BaseVariable):
|
||||
value: list[ArrayVariable]
|
||||
type = 'array_nest'
|
||||
|
||||
def valid_value(self, value: list[T]) -> list[T]:
|
||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
||||
category=RuntimeWarning
|
||||
)
|
||||
class AnyVariable(BaseVariable):
|
||||
value: Any
|
||||
type = 'any'
|
||||
|
||||
def valid_value(self, value: Any) -> Any:
|
||||
|
||||
@@ -16,7 +16,7 @@ engine = create_engine(
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
connect_args={
|
||||
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000"
|
||||
"options": "-c timezone=UTC -c statement_timeout=60000"
|
||||
},
|
||||
)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
|
||||
yield db
|
||||
finally:
|
||||
db.rollback() # 只读任务无需 commit
|
||||
db.close()
|
||||
|
||||
|
||||
def get_pool_status():
|
||||
|
||||
61
api/app/i18n/README.md
Normal file
61
api/app/i18n/README.md
Normal file
@@ -0,0 +1,61 @@
|
||||
# Internationalization (i18n) Module
|
||||
|
||||
This module provides internationalization support for the MemoryBear API.
|
||||
|
||||
## Components
|
||||
|
||||
- `service.py` - Translation service and core translation logic
|
||||
- `middleware.py` - Language detection middleware
|
||||
- `dependencies.py` - FastAPI dependency injection functions
|
||||
- `exceptions.py` - Internationalized exception classes
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
message = t("common.validation.required", field="Name")
|
||||
```
|
||||
|
||||
### Enum Translation
|
||||
|
||||
```python
|
||||
from app.i18n import t_enum
|
||||
|
||||
# Translate enum value
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
```
|
||||
|
||||
### In FastAPI Endpoints
|
||||
|
||||
```python
|
||||
from fastapi import Depends
|
||||
from app.i18n.dependencies import get_translator
|
||||
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
See `app/core/config.py` for i18n configuration options:
|
||||
|
||||
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
|
||||
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
|
||||
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
|
||||
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)
|
||||
124
api/app/i18n/__init__.py
Normal file
124
api/app/i18n/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Internationalization (i18n) module for MemoryBear Enterprise.
|
||||
|
||||
This module provides complete i18n support for the backend API including:
|
||||
- Translation loading from multiple directories (community + enterprise)
|
||||
- Translation service with caching and fallback
|
||||
- Language detection middleware
|
||||
- Dependency injection for FastAPI
|
||||
- Convenience functions for easy usage
|
||||
|
||||
Usage:
|
||||
from app.i18n import t, t_enum
|
||||
|
||||
# Simple translation
|
||||
message = t("common.success.created")
|
||||
|
||||
# Parameterized translation
|
||||
error = t("common.validation.required", field="名称")
|
||||
|
||||
# Enum translation
|
||||
role_display = t_enum("workspace_role", "manager")
|
||||
"""
|
||||
|
||||
from app.i18n.dependencies import (
|
||||
get_current_language,
|
||||
get_enum_translator,
|
||||
get_translator,
|
||||
)
|
||||
from app.i18n.exceptions import (
|
||||
BadRequestError,
|
||||
ConflictError,
|
||||
FileNotFoundError,
|
||||
FileTooLargeError,
|
||||
ForbiddenError,
|
||||
I18nException,
|
||||
InternalServerError,
|
||||
InvalidCredentialsError,
|
||||
InvalidFileTypeError,
|
||||
NotFoundError,
|
||||
QuotaExceededError,
|
||||
RateLimitExceededError,
|
||||
ServiceUnavailableError,
|
||||
TenantNotFoundError,
|
||||
TenantSuspendedError,
|
||||
TokenExpiredError,
|
||||
TokenInvalidError,
|
||||
UnauthorizedError,
|
||||
UserAlreadyExistsError,
|
||||
UserNotFoundError,
|
||||
ValidationError,
|
||||
WorkspaceNotFoundError,
|
||||
WorkspacePermissionDeniedError,
|
||||
get_current_locale,
|
||||
set_current_locale,
|
||||
)
|
||||
from app.i18n.loader import TranslationLoader
|
||||
from app.i18n.logger import (
|
||||
TranslationLogger,
|
||||
get_translation_logger,
|
||||
log_missing_translation,
|
||||
log_translation_error,
|
||||
)
|
||||
from app.i18n.middleware import LanguageMiddleware
|
||||
from app.i18n.serializers import (
|
||||
I18nResponseMixin,
|
||||
WorkspaceSerializer,
|
||||
WorkspaceMemberSerializer,
|
||||
WorkspaceInviteSerializer,
|
||||
)
|
||||
from app.i18n.service import (
|
||||
TranslationService,
|
||||
get_translation_service,
|
||||
t,
|
||||
t_enum,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TranslationLoader",
|
||||
"LanguageMiddleware",
|
||||
"TranslationService",
|
||||
"get_translation_service",
|
||||
"t",
|
||||
"t_enum",
|
||||
"get_current_language",
|
||||
"get_translator",
|
||||
"get_enum_translator",
|
||||
# Context management
|
||||
"get_current_locale",
|
||||
"set_current_locale",
|
||||
# Logging
|
||||
"TranslationLogger",
|
||||
"get_translation_logger",
|
||||
"log_missing_translation",
|
||||
"log_translation_error",
|
||||
# Serializers
|
||||
"I18nResponseMixin",
|
||||
"WorkspaceSerializer",
|
||||
"WorkspaceMemberSerializer",
|
||||
"WorkspaceInviteSerializer",
|
||||
# Exception classes
|
||||
"I18nException",
|
||||
"BadRequestError",
|
||||
"UnauthorizedError",
|
||||
"ForbiddenError",
|
||||
"NotFoundError",
|
||||
"ConflictError",
|
||||
"ValidationError",
|
||||
"InternalServerError",
|
||||
"ServiceUnavailableError",
|
||||
"WorkspaceNotFoundError",
|
||||
"WorkspacePermissionDeniedError",
|
||||
"UserNotFoundError",
|
||||
"UserAlreadyExistsError",
|
||||
"TenantNotFoundError",
|
||||
"TenantSuspendedError",
|
||||
"InvalidCredentialsError",
|
||||
"TokenExpiredError",
|
||||
"TokenInvalidError",
|
||||
"FileNotFoundError",
|
||||
"FileTooLargeError",
|
||||
"InvalidFileTypeError",
|
||||
"RateLimitExceededError",
|
||||
"QuotaExceededError",
|
||||
]
|
||||
291
api/app/i18n/cache.py
Normal file
291
api/app/i18n/cache.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Advanced caching system for i18n translations.
|
||||
|
||||
This module provides:
|
||||
- LRU cache for hot translations
|
||||
- Lazy loading mechanism
|
||||
- Memory optimization
|
||||
- Cache statistics
|
||||
"""
|
||||
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, Optional
|
||||
from collections import OrderedDict
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TranslationCache:
|
||||
"""
|
||||
Advanced translation cache with LRU eviction and lazy loading.
|
||||
|
||||
Features:
|
||||
- LRU cache for frequently accessed translations
|
||||
- Lazy loading to reduce startup time
|
||||
- Memory-efficient storage
|
||||
- Cache hit/miss statistics
|
||||
"""
|
||||
|
||||
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
|
||||
"""
|
||||
Initialize the translation cache.
|
||||
|
||||
Args:
|
||||
max_lru_size: Maximum size of LRU cache for hot translations
|
||||
enable_lazy_load: Enable lazy loading of locales
|
||||
"""
|
||||
self.max_lru_size = max_lru_size
|
||||
self.enable_lazy_load = enable_lazy_load
|
||||
|
||||
# Main cache: {locale: {namespace: {key: value}}}
|
||||
self._main_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# LRU cache for hot translations
|
||||
self._lru_cache: OrderedDict = OrderedDict()
|
||||
|
||||
# Loaded locales tracker
|
||||
self._loaded_locales: set = set()
|
||||
|
||||
# Statistics
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"TranslationCache initialized with LRU size: {max_lru_size}, "
|
||||
f"lazy loading: {enable_lazy_load}"
|
||||
)
|
||||
|
||||
def set_locale_data(self, locale: str, data: Dict[str, Any]):
|
||||
"""
|
||||
Set translation data for a locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
data: Translation data dictionary
|
||||
"""
|
||||
self._main_cache[locale] = data
|
||||
self._loaded_locales.add(locale)
|
||||
logger.debug(f"Loaded locale '{locale}' into cache")
|
||||
|
||||
def get_translation(
|
||||
self,
|
||||
locale: str,
|
||||
namespace: str,
|
||||
key_path: list
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Get translation from cache with LRU optimization.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key_path: List of nested keys
|
||||
|
||||
Returns:
|
||||
Translation string or None if not found
|
||||
"""
|
||||
# Build cache key for LRU
|
||||
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
|
||||
|
||||
# Check LRU cache first (hot translations)
|
||||
if cache_key in self._lru_cache:
|
||||
self._stats["lru_hits"] += 1
|
||||
self._stats["hits"] += 1
|
||||
# Move to end (most recently used)
|
||||
self._lru_cache.move_to_end(cache_key)
|
||||
return self._lru_cache[cache_key]
|
||||
|
||||
self._stats["lru_misses"] += 1
|
||||
|
||||
# Check main cache
|
||||
if locale not in self._main_cache:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
if namespace not in self._main_cache[locale]:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Navigate through nested keys
|
||||
current = self._main_cache[locale][namespace]
|
||||
for key in key_path:
|
||||
if isinstance(current, dict) and key in current:
|
||||
current = current[key]
|
||||
else:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# Return only if it's a string value
|
||||
if not isinstance(current, str):
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
self._stats["hits"] += 1
|
||||
|
||||
# Add to LRU cache
|
||||
self._add_to_lru(cache_key, current)
|
||||
|
||||
return current
|
||||
|
||||
def _add_to_lru(self, key: str, value: str):
|
||||
"""
|
||||
Add translation to LRU cache.
|
||||
|
||||
Args:
|
||||
key: Cache key
|
||||
value: Translation value
|
||||
"""
|
||||
# Remove oldest if cache is full
|
||||
if len(self._lru_cache) >= self.max_lru_size:
|
||||
self._lru_cache.popitem(last=False)
|
||||
|
||||
self._lru_cache[key] = value
|
||||
|
||||
def is_locale_loaded(self, locale: str) -> bool:
|
||||
"""
|
||||
Check if a locale is loaded.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
|
||||
Returns:
|
||||
True if locale is loaded
|
||||
"""
|
||||
return locale in self._loaded_locales
|
||||
|
||||
def get_loaded_locales(self) -> list:
|
||||
"""
|
||||
Get list of loaded locales.
|
||||
|
||||
Returns:
|
||||
List of locale codes
|
||||
"""
|
||||
return list(self._loaded_locales)
|
||||
|
||||
def clear_lru(self):
|
||||
"""Clear the LRU cache."""
|
||||
self._lru_cache.clear()
|
||||
logger.info("LRU cache cleared")
|
||||
|
||||
def clear_locale(self, locale: str):
|
||||
"""
|
||||
Clear cache for a specific locale.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
"""
|
||||
if locale in self._main_cache:
|
||||
del self._main_cache[locale]
|
||||
self._loaded_locales.discard(locale)
|
||||
|
||||
# Clear related LRU entries
|
||||
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
|
||||
for key in keys_to_remove:
|
||||
del self._lru_cache[key]
|
||||
|
||||
logger.info(f"Cleared cache for locale '{locale}'")
|
||||
|
||||
def clear_all(self):
|
||||
"""Clear all caches."""
|
||||
self._main_cache.clear()
|
||||
self._lru_cache.clear()
|
||||
self._loaded_locales.clear()
|
||||
logger.info("All caches cleared")
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with cache statistics
|
||||
"""
|
||||
total_requests = self._stats["hits"] + self._stats["misses"]
|
||||
hit_rate = (
|
||||
self._stats["hits"] / total_requests * 100
|
||||
if total_requests > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
|
||||
lru_hit_rate = (
|
||||
self._stats["lru_hits"] / lru_total * 100
|
||||
if lru_total > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"hits": self._stats["hits"],
|
||||
"misses": self._stats["misses"],
|
||||
"hit_rate": round(hit_rate, 2),
|
||||
"lru_hits": self._stats["lru_hits"],
|
||||
"lru_misses": self._stats["lru_misses"],
|
||||
"lru_hit_rate": round(lru_hit_rate, 2),
|
||||
"lru_size": len(self._lru_cache),
|
||||
"lru_max_size": self.max_lru_size,
|
||||
"loaded_locales": len(self._loaded_locales),
|
||||
"lazy_loads": self._stats["lazy_loads"]
|
||||
}
|
||||
|
||||
def reset_stats(self):
|
||||
"""Reset cache statistics."""
|
||||
self._stats = {
|
||||
"hits": 0,
|
||||
"misses": 0,
|
||||
"lru_hits": 0,
|
||||
"lru_misses": 0,
|
||||
"lazy_loads": 0
|
||||
}
|
||||
logger.info("Cache statistics reset")
|
||||
|
||||
def get_memory_usage(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Estimate memory usage of the cache.
|
||||
|
||||
Returns:
|
||||
Dictionary with memory usage information
|
||||
"""
|
||||
import sys
|
||||
|
||||
main_cache_size = sys.getsizeof(self._main_cache)
|
||||
lru_cache_size = sys.getsizeof(self._lru_cache)
|
||||
|
||||
# Rough estimate of nested data
|
||||
for locale_data in self._main_cache.values():
|
||||
main_cache_size += sys.getsizeof(locale_data)
|
||||
for namespace_data in locale_data.values():
|
||||
main_cache_size += sys.getsizeof(namespace_data)
|
||||
|
||||
return {
|
||||
"main_cache_bytes": main_cache_size,
|
||||
"lru_cache_bytes": lru_cache_size,
|
||||
"total_bytes": main_cache_size + lru_cache_size,
|
||||
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
|
||||
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
|
||||
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
|
||||
"""
|
||||
LRU cached function for building translation cache keys.
|
||||
|
||||
This reduces string concatenation overhead for frequently accessed keys.
|
||||
|
||||
Args:
|
||||
locale: Locale code
|
||||
namespace: Translation namespace
|
||||
key: Translation key
|
||||
|
||||
Returns:
|
||||
Cache key string
|
||||
"""
|
||||
return f"{locale}:{namespace}:{key}"
|
||||
158
api/app/i18n/dependencies.py
Normal file
158
api/app/i18n/dependencies.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""
|
||||
FastAPI dependency injection functions for i18n.
|
||||
|
||||
This module provides dependency injection functions that can be used
|
||||
in FastAPI route handlers to access the current language and translator.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.i18n.service import get_translation_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def get_current_language(request: Request) -> str:
|
||||
"""
|
||||
Get the current language from the request context.
|
||||
|
||||
This dependency extracts the language that was determined by the
|
||||
LanguageMiddleware and stored in request.state.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Language code (e.g., "zh", "en")
|
||||
|
||||
Usage:
|
||||
@router.get("/example")
|
||||
async def example(language: str = Depends(get_current_language)):
|
||||
return {"language": language}
|
||||
"""
|
||||
# Get language from request state (set by LanguageMiddleware)
|
||||
language = getattr(request.state, "language", None)
|
||||
|
||||
if language is None:
|
||||
# Fallback to default language if not set
|
||||
from app.core.config import settings
|
||||
language = settings.I18N_DEFAULT_LANGUAGE
|
||||
logger.warning(
|
||||
"Language not found in request.state, using default: "
|
||||
f"{language}"
|
||||
)
|
||||
|
||||
return language
|
||||
|
||||
|
||||
async def get_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get a translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a translation function that automatically
|
||||
uses the current request's language, making it easy to translate
|
||||
strings in route handlers.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Translation function with signature: t(key: str, **params) -> str
|
||||
|
||||
Usage:
|
||||
@router.post("/workspaces")
|
||||
async def create_workspace(
|
||||
data: WorkspaceCreate,
|
||||
t: Callable = Depends(get_translator)
|
||||
):
|
||||
workspace = await workspace_service.create(data)
|
||||
return {
|
||||
"success": True,
|
||||
"message": t("workspace.created_successfully"),
|
||||
"data": workspace
|
||||
}
|
||||
|
||||
# With parameters
|
||||
@router.get("/items")
|
||||
async def get_items(t: Callable = Depends(get_translator)):
|
||||
count = 5
|
||||
return {
|
||||
"message": t("items.found", count=count)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound translation function
|
||||
def translate(key: str, **params) -> str:
|
||||
"""
|
||||
Translate a key using the current request's language.
|
||||
|
||||
Args:
|
||||
key: Translation key (e.g., "common.success.created")
|
||||
**params: Parameters for parameterized messages
|
||||
|
||||
Returns:
|
||||
Translated string
|
||||
"""
|
||||
return service.translate(key, language, **params)
|
||||
|
||||
return translate
|
||||
|
||||
|
||||
async def get_enum_translator(request: Request) -> Callable:
|
||||
"""
|
||||
Get an enum translator function bound to the current request's language.
|
||||
|
||||
This dependency returns a function for translating enum values
|
||||
that automatically uses the current request's language.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Enum translation function with signature:
|
||||
t_enum(enum_type: str, value: str) -> str
|
||||
|
||||
Usage:
|
||||
@router.get("/workspace/{id}")
|
||||
async def get_workspace(
|
||||
id: str,
|
||||
t_enum: Callable = Depends(get_enum_translator)
|
||||
):
|
||||
workspace = await workspace_service.get(id)
|
||||
return {
|
||||
"id": workspace.id,
|
||||
"role": workspace.role,
|
||||
"role_display": t_enum("workspace_role", workspace.role),
|
||||
"status": workspace.status,
|
||||
"status_display": t_enum("workspace_status", workspace.status)
|
||||
}
|
||||
"""
|
||||
# Get current language
|
||||
language = await get_current_language(request)
|
||||
|
||||
# Get translation service
|
||||
service = get_translation_service()
|
||||
|
||||
# Return a bound enum translation function
|
||||
def translate_enum(enum_type: str, value: str) -> str:
|
||||
"""
|
||||
Translate an enum value using the current request's language.
|
||||
|
||||
Args:
|
||||
enum_type: Enum type name (e.g., "workspace_role")
|
||||
value: Enum value (e.g., "manager")
|
||||
|
||||
Returns:
|
||||
Translated enum display name
|
||||
"""
|
||||
return service.translate_enum(enum_type, value, language)
|
||||
|
||||
return translate_enum
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user