Compare commits
172 Commits
release/v0
...
feature/to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e8ae46b286 | ||
|
|
78316de411 | ||
|
|
c205e7d20e | ||
|
|
81f3b50200 | ||
|
|
e3795fe1ed | ||
|
|
72a2f2a7e8 | ||
|
|
035cc17264 | ||
|
|
cf26c9f39c | ||
|
|
fabc8936ab | ||
|
|
06de54ebfd | ||
|
|
7c6e48b04e | ||
|
|
fcc81ac025 | ||
|
|
9d8c26b999 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
0bb8278a39 | ||
|
|
e43f812c14 | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
4bc030c1ef | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
2e50e30071 | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
c2fc4ab4ff | ||
|
|
83fcabadae | ||
|
|
d12ad213e0 | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
a07727c047 | ||
|
|
25bc506f74 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
d77220a603 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3f04153f22 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
5d6007aaff | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
b52e4d756c | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
83017d0c80 | ||
|
|
e24217a6ba | ||
|
|
a0f2f738df | ||
|
|
9d9250954b | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
e8c3744f5e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
a3ccd41288 | ||
|
|
8ddacb7bc9 | ||
|
|
e74a74c3fb | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
fc2360d40d | ||
|
|
ab67bda5a1 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ede8a11584 | ||
|
|
43130dcbc8 | ||
|
|
ff6459e439 | ||
|
|
1893de4c75 | ||
|
|
dfcc85a466 | ||
|
|
dacfb360f6 | ||
|
|
8a0d83b340 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
5df339b56d | ||
|
|
56adca9f22 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
8e6288bca8 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
19d149c129 | ||
|
|
b8e85bed61 | ||
|
|
396493ad2b | ||
|
|
f32d92b9d0 | ||
|
|
6d79db8ba3 | ||
|
|
f9fb480cc3 | ||
|
|
1efa8798bf | ||
|
|
c244e9834f | ||
|
|
01a1e8eab1 | ||
|
|
6a0ee22d81 | ||
|
|
f6d929ab7a | ||
|
|
7b8f101824 | ||
|
|
fc58ac0408 | ||
|
|
5b431400be |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -25,6 +25,8 @@ examples/
|
|||||||
time.log
|
time.log
|
||||||
celerybeat-schedule.db
|
celerybeat-schedule.db
|
||||||
search_results.json
|
search_results.json
|
||||||
|
redbear-mem-metrics/
|
||||||
|
pitch-deck/
|
||||||
|
|
||||||
api/migrations/versions
|
api/migrations/versions
|
||||||
tmp
|
tmp
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from . import (
|
|||||||
document_controller,
|
document_controller,
|
||||||
emotion_config_controller,
|
emotion_config_controller,
|
||||||
emotion_controller,
|
emotion_controller,
|
||||||
|
end_user_controller,
|
||||||
file_controller,
|
file_controller,
|
||||||
file_storage_controller,
|
file_storage_controller,
|
||||||
home_page_controller,
|
home_page_controller,
|
||||||
@@ -96,5 +97,6 @@ manager_router.include_router(file_storage_controller.router)
|
|||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
manager_router.include_router(skill_controller.router)
|
||||||
manager_router.include_router(i18n_controller.router)
|
manager_router.include_router(i18n_controller.router)
|
||||||
|
manager_router.include_router(end_user_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -194,6 +194,7 @@ def delete_app(
|
|||||||
def copy_app(
|
def copy_app(
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
new_name: Optional[str] = None,
|
new_name: Optional[str] = None,
|
||||||
|
payload: app_schema.CopyAppRequest = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -205,6 +206,8 @@ def copy_app(
|
|||||||
- 不影响原应用
|
- 不影响原应用
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# body takes precedence over query param for backward compatibility
|
||||||
|
new_name = (payload.new_name if payload else None) or new_name
|
||||||
logger.info(
|
logger.info(
|
||||||
"用户请求复制应用",
|
"用户请求复制应用",
|
||||||
extra={
|
extra={
|
||||||
@@ -254,6 +257,27 @@ def get_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_opening(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
if hasattr(features, "model_dump"):
|
||||||
|
features = features.model_dump()
|
||||||
|
opening = features.get("opening_statement", {})
|
||||||
|
return success(data=app_schema.OpeningResponse(
|
||||||
|
enabled=opening.get("enabled", False),
|
||||||
|
statement=opening.get("statement"),
|
||||||
|
suggested_questions=opening.get("suggested_questions", []),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def publish_app(
|
def publish_app(
|
||||||
@@ -496,7 +520,7 @@ async def draft_run(
|
|||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
@@ -513,11 +537,12 @@ async def draft_run(
|
|||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
payload.user_id = str(new_end_user.id)
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
@@ -534,6 +559,17 @@ async def draft_run(
|
|||||||
service._check_agent_config(app_id)
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
|
is_shared = app.workspace_id != workspace_id
|
||||||
|
if is_shared:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
agent_cfg = service._agent_config_from_release(release)
|
||||||
|
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||||
|
else:
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
agent_cfg = db.scalars(stmt).first()
|
agent_cfg = db.scalars(stmt).first()
|
||||||
if not agent_cfg:
|
if not agent_cfg:
|
||||||
@@ -701,6 +737,16 @@ async def draft_run(
|
|||||||
msg="多 Agent 任务执行成功"
|
msg="多 Agent 任务执行成功"
|
||||||
)
|
)
|
||||||
elif app.type == AppType.WORKFLOW: # 工作流
|
elif app.type == AppType.WORKFLOW: # 工作流
|
||||||
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
|
is_shared = app.workspace_id != workspace_id
|
||||||
|
if is_shared:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
config = service._workflow_config_from_release(release)
|
||||||
|
else:
|
||||||
config = workflow_service.check_config(app_id)
|
config = workflow_service.check_config(app_id)
|
||||||
# 3. 流式返回
|
# 3. 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
@@ -845,11 +891,12 @@ async def draft_run_compare(
|
|||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
payload.user_id = str(new_end_user.id)
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
@@ -898,7 +945,12 @@ async def draft_run_compare(
|
|||||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||||
|
features_config: dict = agent_cfg.features or {}
|
||||||
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
@@ -915,7 +967,7 @@ async def draft_run_compare(
|
|||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
@@ -946,7 +998,7 @@ async def draft_run_compare(
|
|||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
|
|||||||
48
api/app/controllers/end_user_controller.py
Normal file
48
api/app/controllers/end_user_controller.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""End User 管理接口 - 无需认证"""
|
||||||
|
|
||||||
|
from app.core.logging_config import get_business_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
from app.schemas.memory_api_schema import (
|
||||||
|
CreateEndUserRequest,
|
||||||
|
CreateEndUserResponse,
|
||||||
|
)
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/end_users", tags=["End Users"])
|
||||||
|
logger = get_business_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def create_end_user(
|
||||||
|
data: CreateEndUserRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create an end user.
|
||||||
|
|
||||||
|
Creates a new end user for the given workspace.
|
||||||
|
If an end user with the same other_id already exists in the workspace,
|
||||||
|
returns the existing one.
|
||||||
|
"""
|
||||||
|
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
|
||||||
|
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=None,
|
||||||
|
workspace_id=data.workspace_id,
|
||||||
|
other_id=data.other_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"End user ready: {end_user.id}")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"id": str(end_user.id),
|
||||||
|
"other_id": end_user.other_id or "",
|
||||||
|
"other_name": end_user.other_name or "",
|
||||||
|
"workspace_id": str(end_user.workspace_id),
|
||||||
|
}
|
||||||
|
|
||||||
|
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")
|
||||||
@@ -15,7 +15,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -47,6 +47,19 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _match_scheme(request: Request, url: str) -> str:
|
||||||
|
"""
|
||||||
|
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||||
|
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||||
|
"""
|
||||||
|
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||||
|
if url.startswith("http://") and incoming_scheme == "https":
|
||||||
|
return "https://" + url[7:]
|
||||||
|
if url.startswith("https://") and incoming_scheme == "http":
|
||||||
|
return "http://" + url[8:]
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/files", response_model=ApiResponse)
|
@router.post("/files", response_model=ApiResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -280,6 +293,7 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -327,6 +341,7 @@ async def download_file(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -400,6 +415,7 @@ async def delete_file(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||||
async def get_file_url(
|
async def get_file_url(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = None,
|
expires: int = None,
|
||||||
permanent: bool = False,
|
permanent: bool = False,
|
||||||
@@ -463,6 +479,7 @@ async def get_file_url(
|
|||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||||
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
return success(
|
return success(
|
||||||
@@ -484,6 +501,7 @@ async def get_file_url(
|
|||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = 0,
|
expires: int = 0,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
@@ -555,6 +573,7 @@ async def public_download_file(
|
|||||||
# For remote storage, redirect to presigned URL
|
# For remote storage, redirect to presigned URL
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -566,6 +585,7 @@ async def public_download_file(
|
|||||||
|
|
||||||
@router.get("/permanent/{file_id}", response_model=Any)
|
@router.get("/permanent/{file_id}", response_model=Any)
|
||||||
async def permanent_download_file(
|
async def permanent_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
@@ -625,6 +645,7 @@ async def permanent_download_file(
|
|||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
|
|||||||
@@ -603,9 +603,12 @@ async def dashboard_data(
|
|||||||
)
|
)
|
||||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||||
# total_app: 统计当前空间下的所有app数量
|
# total_app: 统计当前空间下的所有app数量
|
||||||
from app.repositories import app_repository
|
# 包含自有app + 被分享给本工作空间的app
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
from app.services import app_service as _app_svc
|
||||||
neo4j_data["total_app"] = len(apps_orm)
|
_, total_app = _app_svc.AppService(db).list_apps(
|
||||||
|
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||||
|
)
|
||||||
|
neo4j_data["total_app"] = total_app
|
||||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import User
|
from app.models import User
|
||||||
|
from app.schemas import conversation_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
@@ -90,11 +91,7 @@ def get_messages(
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
conversation_schema.Message.model_validate(message)
|
||||||
"role": message.role,
|
|
||||||
"content": message.content,
|
|
||||||
"created_at": int(message.created_at.timestamp() * 1000),
|
|
||||||
}
|
|
||||||
for message in messages_obj
|
for message in messages_obj
|
||||||
]
|
]
|
||||||
return success(data=messages, msg="get conversation history success")
|
return success(data=messages, msg="get conversation history success")
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.db import get_db, get_db_read
|
from app.db import get_db, get_db_read
|
||||||
from app.dependencies import get_share_user_id, ShareTokenData
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
from app.models.app_model import App
|
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
@@ -22,6 +21,7 @@ from app.schemas import release_share_schema, conversation_schema
|
|||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
|
from app.services.app_service import AppService
|
||||||
from app.services.auth_service import create_access_token
|
from app.services.auth_service import create_access_token
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
@@ -215,8 +215,11 @@ def list_conversations(
|
|||||||
service = SharedChatService(db)
|
service = SharedChatService(db)
|
||||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
logger.debug(new_end_user.id)
|
logger.debug(new_end_user.id)
|
||||||
@@ -308,25 +311,29 @@ async def chat(
|
|||||||
|
|
||||||
# Store end_user_id in database with original user_id
|
# Store end_user_id in database with original user_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
|
workspace_id = app.workspace_id
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id # Save original user_id to other_id
|
original_user_id=user_id
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
appid = share.app_id
|
# appid = share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
|
|
||||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||||
app = db.query(App).filter(
|
# app = db.query(App).filter(
|
||||||
App.id == appid,
|
# App.id == appid,
|
||||||
App.is_active.is_(True)
|
# App.is_active.is_(True)
|
||||||
).first()
|
# ).first()
|
||||||
if not app:
|
# if not app:
|
||||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||||
|
|
||||||
workspace_id = app.workspace_id
|
# workspace_id = app.workspace_id
|
||||||
|
|
||||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||||
@@ -610,11 +617,11 @@ async def chat(
|
|||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
|
files=payload.files,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
@@ -654,17 +661,21 @@ async def config_query(
|
|||||||
workflow_service = WorkflowService(db)
|
workflow_service = WorkflowService(db)
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": workflow_service.get_start_node_variables(release.config)
|
"variables": workflow_service.get_start_node_variables(release.config),
|
||||||
|
"memory": workflow_service.is_memory_enable(release.config),
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.AGENT:
|
elif release.app.type == AppType.AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": release.config.get("variables")
|
"variables": release.config.get("variables"),
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
elif release.app.type == AppType.MULTI_AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": []
|
"variables": [],
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ async def chat(
|
|||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=other_id # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
web_search = True
|
web_search = True
|
||||||
@@ -280,6 +280,7 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=app.current_release.id
|
release_id=app.current_release.id
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.schemas.api_key_schema import ApiKeyAuth
|
from app.schemas.api_key_schema import ApiKeyAuth
|
||||||
from app.schemas.memory_api_schema import (
|
from app.schemas.memory_api_schema import (
|
||||||
|
ListConfigsResponse,
|
||||||
MemoryReadRequest,
|
MemoryReadRequest,
|
||||||
MemoryReadResponse,
|
MemoryReadResponse,
|
||||||
MemoryWriteRequest,
|
MemoryWriteRequest,
|
||||||
@@ -31,14 +32,15 @@ async def write_memory_api_service(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryWriteRequest = Body(..., embed=False),
|
message: str = Body(..., description="Message content"),
|
||||||
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Write memory to storage.
|
Write memory to storage.
|
||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryWriteRequest(**body)
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
@@ -62,13 +64,15 @@ async def read_memory_api_service(
|
|||||||
request: Request,
|
request: Request,
|
||||||
api_key_auth: ApiKeyAuth = None,
|
api_key_auth: ApiKeyAuth = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
payload: MemoryReadRequest = Body(..., embed=False),
|
message: str = Body(..., description="Query message"),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Read memory from storage.
|
Read memory from storage.
|
||||||
|
|
||||||
Queries and retrieves memories for the specified end user with context-aware responses.
|
Queries and retrieves memories for the specified end user with context-aware responses.
|
||||||
"""
|
"""
|
||||||
|
body = await request.json()
|
||||||
|
payload = MemoryReadRequest(**body)
|
||||||
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
@@ -85,3 +89,27 @@ async def read_memory_api_service(
|
|||||||
|
|
||||||
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
|
||||||
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/configs")
|
||||||
|
@require_api_key(scopes=["memory"])
|
||||||
|
async def list_memory_configs(
|
||||||
|
request: Request,
|
||||||
|
api_key_auth: ApiKeyAuth = None,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
List all memory configs for the workspace.
|
||||||
|
|
||||||
|
Returns all available memory configurations associated with the authorized workspace.
|
||||||
|
"""
|
||||||
|
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
result = memory_api_service.list_memory_configs(
|
||||||
|
workspace_id=api_key_auth.workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
|
||||||
|
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")
|
||||||
|
|||||||
@@ -3,8 +3,11 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
from app.schemas.tool_schema import (
|
from app.schemas.tool_schema import (
|
||||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||||
|
CustomToolTestRequest, ToolActiveUpdate
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -156,7 +159,7 @@ async def delete_tool(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
service: ToolService = Depends(get_tool_service)
|
service: ToolService = Depends(get_tool_service)
|
||||||
):
|
):
|
||||||
"""删除工具"""
|
"""删除工具(逻辑删除,is_active=False)"""
|
||||||
try:
|
try:
|
||||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||||
if not success_flag:
|
if not success_flag:
|
||||||
@@ -168,6 +171,30 @@ async def delete_tool(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||||
|
async def set_tool_active(
|
||||||
|
tool_id: str,
|
||||||
|
request: ToolActiveUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
service: ToolService = Depends(get_tool_service)
|
||||||
|
):
|
||||||
|
"""设置工具可用状态(启用/禁用)
|
||||||
|
|
||||||
|
- is_active=true: 启用工具
|
||||||
|
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||||
|
if not success_flag:
|
||||||
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
|
action = "启用" if request.is_active else "禁用"
|
||||||
|
return success(msg=f"工具已{action}")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/execution/execute", response_model=ApiResponse)
|
@router.post("/execution/execute", response_model=ApiResponse)
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
request: ToolExecuteRequest,
|
request: ToolExecuteRequest,
|
||||||
@@ -225,8 +252,10 @@ async def sync_mcp_tools(
|
|||||||
try:
|
try:
|
||||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||||
if not result.get("success", False):
|
if not result.get("success", False):
|
||||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||||
return success(data=result, msg="MCP工具列表同步完成")
|
return success(data=result, msg="MCP工具列表同步完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -249,8 +278,10 @@ async def test_tool_connection(
|
|||||||
# 普通连接测试
|
# 普通连接测试
|
||||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||||
return success(data=result, msg="连接测试完成")
|
return success(data=result, msg="连接测试完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||||
|
|
||||||
|
|
||||||
def content_input_node(state: ReadState) -> ReadState:
|
def content_input_node(state: ReadState) -> ReadState:
|
||||||
@@ -17,6 +18,9 @@ def content_input_node(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# Return content and maintain all state information
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
|
|
||||||
|
|
||||||
@@ -35,4 +39,7 @@ def content_input_write(state: WriteState) -> WriteState:
|
|||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# Return content and maintain all state information
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ async def clean_databases(data) -> str:
|
|||||||
# Process reranked results
|
# Process reranked results
|
||||||
reranked = results.get('reranked_results', {})
|
reranked = results.get('reranked_results', {})
|
||||||
if reranked:
|
if reranked:
|
||||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
|
||||||
items = reranked.get(category, [])
|
items = reranked.get(category, [])
|
||||||
if isinstance(items, list):
|
if isinstance(items, list):
|
||||||
content_list.extend(items)
|
content_list.extend(items)
|
||||||
@@ -169,11 +169,18 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(time_search, list):
|
elif isinstance(time_search, list):
|
||||||
content_list.extend(time_search)
|
content_list.extend(time_search)
|
||||||
|
|
||||||
# Extract text content
|
# Extract text content,对 community 按 name 去重(多次 tool 调用会产生重复)
|
||||||
text_parts = []
|
text_parts = []
|
||||||
|
seen_community_names = set()
|
||||||
for item in content_list:
|
for item in content_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
text = item.get('statement') or item.get('content', '')
|
# community 节点用 name 去重
|
||||||
|
if 'member_count' in item or 'core_entities' in item:
|
||||||
|
community_name = item.get('name') or item.get('id', '')
|
||||||
|
if community_name in seen_community_names:
|
||||||
|
continue
|
||||||
|
seen_community_names.add(community_name)
|
||||||
|
text = item.get('statement') or item.get('content') or item.get('summary', '')
|
||||||
if text:
|
if text:
|
||||||
text_parts.append(text)
|
text_parts.append(text)
|
||||||
elif isinstance(item, str):
|
elif isinstance(item, str):
|
||||||
@@ -354,7 +361,11 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
)
|
)
|
||||||
|
|
||||||
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
|
||||||
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
|
search_params = {
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
"return_raw_results": True,
|
||||||
|
"include": ["summaries", "statements", "chunks", "entities", "communities"],
|
||||||
|
}
|
||||||
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||||
agent = create_agent(
|
agent = create_agent(
|
||||||
llm,
|
llm,
|
||||||
@@ -390,8 +401,32 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
raw_results = tool_results['content']
|
raw_results = tool_results['content']
|
||||||
clean_content = await clean_databases(raw_results)
|
clean_content = await clean_databases(raw_results)
|
||||||
|
|
||||||
|
# 社区展开:从 tool 返回结果中提取命中的 community,
|
||||||
|
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
|
||||||
|
_expanded_stmts_to_write = []
|
||||||
|
try:
|
||||||
|
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
|
||||||
|
reranked = results_dict.get('reranked_results', {})
|
||||||
|
community_hits = reranked.get('communities', [])
|
||||||
|
if not community_hits:
|
||||||
|
community_hits = results_dict.get('communities', [])
|
||||||
|
if community_hits:
|
||||||
|
from app.core.memory.agent.services.search_service import expand_communities_to_statements
|
||||||
|
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
|
||||||
|
community_results=community_hits,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
existing_content=clean_content,
|
||||||
|
)
|
||||||
|
if new_texts:
|
||||||
|
clean_content = clean_content + '\n' + '\n'.join(new_texts)
|
||||||
|
except Exception as parse_err:
|
||||||
|
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
raw_results = raw_results['results']
|
raw_results = raw_results['results']
|
||||||
|
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
|
||||||
|
if _expanded_stmts_to_write and isinstance(raw_results, dict):
|
||||||
|
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
|
||||||
except Exception:
|
except Exception:
|
||||||
raw_results = []
|
raw_results = []
|
||||||
|
|
||||||
|
|||||||
@@ -334,13 +334,22 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
"end_user_id": end_user_id,
|
"end_user_id": end_user_id,
|
||||||
"question": data,
|
"question": data,
|
||||||
"return_raw_results": True,
|
"return_raw_results": True,
|
||||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if storage_type != "rag":
|
if storage_type != "rag":
|
||||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
|
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
|
||||||
memory_config=memory_config)
|
**search_params,
|
||||||
|
memory_config=memory_config,
|
||||||
|
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
|
||||||
|
)
|
||||||
|
# 调试:打印 community 检索结果数量
|
||||||
|
if raw_results and isinstance(raw_results, dict):
|
||||||
|
reranked = raw_results.get('reranked_results', {})
|
||||||
|
community_hits = reranked.get('communities', [])
|
||||||
|
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
|
||||||
|
f"summary 命中数: {len(reranked.get('summaries', []))}")
|
||||||
else:
|
else:
|
||||||
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -252,9 +252,10 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||||
}
|
}
|
||||||
|
# 注意:'id' 字段保留,community 展开时需要用 community id 查询成员 statements
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# Clean dictionary
|
# Clean dictionary
|
||||||
@@ -310,7 +311,7 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
|
||||||
"output_path": None, # Don't save to file
|
"output_path": None, # Don't save to file
|
||||||
"memory_config": memory_config,
|
"memory_config": memory_config,
|
||||||
"rerank_alpha": rerank_alpha,
|
"rerank_alpha": rerank_alpha,
|
||||||
|
|||||||
@@ -13,6 +13,72 @@ from app.core.memory.utils.data.text_utils import escape_lucene_query
|
|||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
|
# 需要从展开结果中过滤的字段(含 Neo4j DateTime,不可 JSON 序列化)
|
||||||
|
_EXPAND_FIELDS_TO_REMOVE = {
|
||||||
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
|
'expired_at', 'created_at', 'chunk_id', 'apply_id',
|
||||||
|
'user_id', 'statement_ids', 'updated_at', 'chunk_ids', 'fact_summary'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _clean_expand_fields(obj):
|
||||||
|
"""递归过滤展开结果中不可序列化的字段(DateTime 等)。"""
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return {k: _clean_expand_fields(v) for k, v in obj.items() if k not in _EXPAND_FIELDS_TO_REMOVE}
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_clean_expand_fields(i) for i in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
|
async def expand_communities_to_statements(
|
||||||
|
community_results: List[dict],
|
||||||
|
end_user_id: str,
|
||||||
|
existing_content: str = "",
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Tuple[List[dict], List[str]]:
|
||||||
|
"""
|
||||||
|
社区展开 helper:给定命中的 community 列表,拉取关联 Statement。
|
||||||
|
|
||||||
|
- 对展开结果去重(过滤已在 existing_content 中出现的文本)
|
||||||
|
- 过滤不可序列化字段
|
||||||
|
- 返回 (cleaned_expanded_stmts, new_texts)
|
||||||
|
- cleaned_expanded_stmts: 可直接写回 raw_results 的列表
|
||||||
|
- new_texts: 去重后新增的 statement 文本列表,用于追加到 clean_content
|
||||||
|
"""
|
||||||
|
community_ids = [r.get("id") for r in community_results if r.get("id")]
|
||||||
|
if not community_ids or not end_user_id:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
from app.repositories.neo4j.graph_search import search_graph_community_expand
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
result = await search_graph_community_expand(
|
||||||
|
connector=connector,
|
||||||
|
community_ids=community_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[expand_communities] 社区展开检索失败,跳过: {e}")
|
||||||
|
return [], []
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|
||||||
|
expanded_stmts = result.get("expanded_statements", [])
|
||||||
|
if not expanded_stmts:
|
||||||
|
return [], []
|
||||||
|
|
||||||
|
existing_lines = set(existing_content.splitlines())
|
||||||
|
new_texts = [
|
||||||
|
s["statement"] for s in expanded_stmts
|
||||||
|
if s.get("statement") and s["statement"] not in existing_lines
|
||||||
|
]
|
||||||
|
cleaned = _clean_expand_fields(expanded_stmts)
|
||||||
|
logger.info(f"[expand_communities] 展开 {len(expanded_stmts)} 条 statements,新增 {len(new_texts)} 条,community_ids={community_ids}")
|
||||||
|
return cleaned, new_texts
|
||||||
|
|
||||||
|
|
||||||
class SearchService:
|
class SearchService:
|
||||||
"""Service for executing hybrid search and processing results."""
|
"""Service for executing hybrid search and processing results."""
|
||||||
@@ -21,7 +87,7 @@ class SearchService:
|
|||||||
"""Initialize the search service."""
|
"""Initialize the search service."""
|
||||||
logger.info("SearchService initialized")
|
logger.info("SearchService initialized")
|
||||||
|
|
||||||
def extract_content_from_result(self, result: dict) -> str:
|
def extract_content_from_result(self, result: dict, node_type: str = "") -> str:
|
||||||
"""
|
"""
|
||||||
Extract only meaningful content from search results, dropping all metadata.
|
Extract only meaningful content from search results, dropping all metadata.
|
||||||
|
|
||||||
@@ -30,9 +96,11 @@ class SearchService:
|
|||||||
- Entities: extract 'name' and 'fact_summary' fields
|
- Entities: extract 'name' and 'fact_summary' fields
|
||||||
- Summaries: extract 'content' field
|
- Summaries: extract 'content' field
|
||||||
- Chunks: extract 'content' field
|
- Chunks: extract 'content' field
|
||||||
|
- Communities: extract 'content' field (c.summary), prefixed with community name
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
result: Search result dictionary
|
result: Search result dictionary
|
||||||
|
node_type: Hint for node type ("community", "summary", etc.)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Clean content string without metadata
|
Clean content string without metadata
|
||||||
@@ -46,8 +114,21 @@ class SearchService:
|
|||||||
if 'statement' in result and result['statement']:
|
if 'statement' in result and result['statement']:
|
||||||
content_parts.append(result['statement'])
|
content_parts.append(result['statement'])
|
||||||
|
|
||||||
# Summaries/Chunks: extract content field
|
# Community 节点:有 member_count 或 core_entities 字段,或 node_type 明确指定
|
||||||
if 'content' in result and result['content']:
|
# 用 "[主题:{name}]" 前缀区分,让 LLM 知道这是主题级摘要
|
||||||
|
is_community = (
|
||||||
|
node_type == "community"
|
||||||
|
or 'member_count' in result
|
||||||
|
or 'core_entities' in result
|
||||||
|
)
|
||||||
|
if is_community:
|
||||||
|
name = result.get('name', '')
|
||||||
|
content = result.get('content', '')
|
||||||
|
if content:
|
||||||
|
prefix = f"[主题:{name}] " if name else ""
|
||||||
|
content_parts.append(f"{prefix}{content}")
|
||||||
|
elif 'content' in result and result['content']:
|
||||||
|
# Summaries / Chunks
|
||||||
content_parts.append(result['content'])
|
content_parts.append(result['content'])
|
||||||
|
|
||||||
# Entities: extract name and fact_summary (commented out in original)
|
# Entities: extract name and fact_summary (commented out in original)
|
||||||
@@ -99,7 +180,8 @@ class SearchService:
|
|||||||
rerank_alpha: float = 0.4,
|
rerank_alpha: float = 0.4,
|
||||||
output_path: str = "search_results.json",
|
output_path: str = "search_results.json",
|
||||||
return_raw_results: bool = False,
|
return_raw_results: bool = False,
|
||||||
memory_config = None
|
memory_config = None,
|
||||||
|
expand_communities: bool = True,
|
||||||
) -> Tuple[str, str, Optional[dict]]:
|
) -> Tuple[str, str, Optional[dict]]:
|
||||||
"""
|
"""
|
||||||
Execute hybrid search and return clean content.
|
Execute hybrid search and return clean content.
|
||||||
@@ -114,13 +196,15 @@ class SearchService:
|
|||||||
output_path: Path to save search results (default: "search_results.json")
|
output_path: Path to save search results (default: "search_results.json")
|
||||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||||
memory_config: Memory configuration object (required)
|
memory_config: Memory configuration object (required)
|
||||||
|
expand_communities: If True, expand community hits to member statements (default: True).
|
||||||
|
Set to False for quick-summary paths that only need community-level text.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (clean_content, cleaned_query, raw_results)
|
Tuple of (clean_content, cleaned_query, raw_results)
|
||||||
raw_results is None if return_raw_results=False
|
raw_results is None if return_raw_results=False
|
||||||
"""
|
"""
|
||||||
if include is None:
|
if include is None:
|
||||||
include = ["statements", "chunks", "entities", "summaries"]
|
include = ["statements", "chunks", "entities", "summaries", "communities"]
|
||||||
|
|
||||||
# Clean query
|
# Clean query
|
||||||
cleaned_query = self.clean_query(question)
|
cleaned_query = self.clean_query(question)
|
||||||
@@ -146,8 +230,8 @@ class SearchService:
|
|||||||
if search_type == "hybrid":
|
if search_type == "hybrid":
|
||||||
reranked_results = answer.get('reranked_results', {})
|
reranked_results = answer.get('reranked_results', {})
|
||||||
|
|
||||||
# Priority order: summaries first (most contextual), then statements, chunks, entities
|
# Priority order: summaries first (most contextual), then communities, statements, chunks, entities
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in reranked_results:
|
if category in include and category in reranked_results:
|
||||||
@@ -157,7 +241,7 @@ class SearchService:
|
|||||||
else:
|
else:
|
||||||
# For keyword or embedding search, results are directly in answer dict
|
# For keyword or embedding search, results are directly in answer dict
|
||||||
# Apply same priority order
|
# Apply same priority order
|
||||||
priority_order = ['summaries', 'statements', 'chunks', 'entities']
|
priority_order = ['summaries', 'communities', 'statements', 'chunks', 'entities']
|
||||||
|
|
||||||
for category in priority_order:
|
for category in priority_order:
|
||||||
if category in include and category in answer:
|
if category in include and category in answer:
|
||||||
@@ -165,11 +249,25 @@ class SearchService:
|
|||||||
if isinstance(category_results, list):
|
if isinstance(category_results, list):
|
||||||
answer_list.extend(category_results)
|
answer_list.extend(category_results)
|
||||||
|
|
||||||
# Extract clean content from all results
|
# 对命中的 community 节点展开其成员 statements(路径 "0"/"1" 需要,路径 "2" 不需要)
|
||||||
content_list = [
|
if expand_communities and "communities" in include:
|
||||||
self.extract_content_from_result(ans)
|
community_results = (
|
||||||
for ans in answer_list
|
answer.get('reranked_results', {}).get('communities', [])
|
||||||
]
|
if search_type == "hybrid"
|
||||||
|
else answer.get('communities', [])
|
||||||
|
)
|
||||||
|
cleaned_stmts, new_texts = await expand_communities_to_statements(
|
||||||
|
community_results=community_results,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
answer_list.extend(cleaned_stmts)
|
||||||
|
|
||||||
|
# Extract clean content from all results,按类型传入 node_type 区分 community
|
||||||
|
content_list = []
|
||||||
|
for ans in answer_list:
|
||||||
|
# community 节点有 member_count 或 core_entities 字段
|
||||||
|
ntype = "community" if ('member_count' in ans or 'core_entities' in ans) else ""
|
||||||
|
content_list.append(self.extract_content_from_result(ans, node_type=ntype))
|
||||||
|
|
||||||
|
|
||||||
# Filter out empty strings and join with newlines
|
# Filter out empty strings and join with newlines
|
||||||
|
|||||||
@@ -84,7 +84,7 @@ async def get_chunked_dialogs(
|
|||||||
pruning_scene=memory_config.pruning_scene or "education",
|
pruning_scene=memory_config.pruning_scene or "education",
|
||||||
pruning_threshold=memory_config.pruning_threshold,
|
pruning_threshold=memory_config.pruning_threshold,
|
||||||
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||||
ontology_classes=memory_config.ontology_classes,
|
ontology_class_infos=memory_config.ontology_class_infos,
|
||||||
)
|
)
|
||||||
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
|
|||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
from app.schemas.memory_config_schema import MemoryConfig
|
from app.schemas.memory_config_schema import MemoryConfig
|
||||||
|
|
||||||
@@ -166,11 +166,15 @@ async def write(
|
|||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
connector=neo4j_connector,
|
connector=neo4j_connector,
|
||||||
config_id=config_id,
|
|
||||||
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
||||||
|
schedule_clustering_after_write(
|
||||||
|
all_entity_nodes,
|
||||||
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
|
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
|
||||||
|
)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
logger.warning("Failed to save some data to Neo4j")
|
logger.warning("Failed to save some data to Neo4j")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ of the memory system including LLM, chunking, pruning, and search.
|
|||||||
Classes:
|
Classes:
|
||||||
LLMConfig: Configuration for LLM client
|
LLMConfig: Configuration for LLM client
|
||||||
ChunkerConfig: Configuration for dialogue chunking
|
ChunkerConfig: Configuration for dialogue chunking
|
||||||
|
OntologyClassInfo: Single ontology class with name and description
|
||||||
PruningConfig: Configuration for semantic pruning
|
PruningConfig: Configuration for semantic pruning
|
||||||
TemporalSearchParams: Parameters for temporal search queries
|
TemporalSearchParams: Parameters for temporal search queries
|
||||||
"""
|
"""
|
||||||
@@ -50,30 +51,41 @@ class ChunkerConfig(BaseModel):
|
|||||||
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
min_characters_per_chunk: Optional[int] = Field(24, ge=0, description="The minimum number of characters in each chunk.")
|
||||||
|
|
||||||
|
|
||||||
|
class OntologyClassInfo(BaseModel):
|
||||||
|
"""本体类型的名称与语义描述,用于剪枝提示词注入。
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
class_name: 本体类型名称(如"患者"、"课程")
|
||||||
|
class_description: 本体类型语义描述,告知 LLM 该类型在当前场景下的含义
|
||||||
|
"""
|
||||||
|
class_name: str = Field(..., description="本体类型名称")
|
||||||
|
class_description: str = Field(default="", description="本体类型语义描述")
|
||||||
|
|
||||||
|
|
||||||
class PruningConfig(BaseModel):
|
class PruningConfig(BaseModel):
|
||||||
"""Configuration for semantic pruning of dialogue content.
|
"""Configuration for semantic pruning of dialogue content.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
pruning_switch: Enable or disable semantic pruning
|
pruning_switch: Enable or disable semantic pruning
|
||||||
pruning_scene: Scene name for pruning, either a built-in key
|
pruning_scene: Scene name for pruning from ontology_scene table
|
||||||
('education', 'online_service', 'outbound') or a custom scene_name
|
|
||||||
from ontology_scene table
|
|
||||||
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal)
|
||||||
scene_id: Optional ontology scene UUID, used to load custom ontology classes
|
scene_id: Optional ontology scene UUID
|
||||||
ontology_classes: List of class_name strings from ontology_class table,
|
ontology_class_infos: Full ontology class info (name + description) from
|
||||||
injected into the prompt when pruning_scene is not a built-in scene
|
ontology_class table, injected into the pruning prompt to drive
|
||||||
|
scene-aware preservation decisions
|
||||||
"""
|
"""
|
||||||
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
pruning_switch: bool = Field(False, description="Enable semantic pruning when True.")
|
||||||
pruning_scene: str = Field(
|
pruning_scene: str = Field(
|
||||||
"education",
|
"education",
|
||||||
description="Scene for pruning: built-in key or custom scene_name from ontology_scene.",
|
description="Scene name from ontology_scene table.",
|
||||||
)
|
)
|
||||||
pruning_threshold: float = Field(
|
pruning_threshold: float = Field(
|
||||||
0.5, ge=0.0, le=0.9,
|
0.5, ge=0.0, le=0.9,
|
||||||
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).")
|
||||||
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).")
|
||||||
ontology_classes: Optional[List[str]] = Field(
|
ontology_class_infos: List[OntologyClassInfo] = Field(
|
||||||
None, description="Class names from ontology_class table for custom scenes."
|
default_factory=list,
|
||||||
|
description="Full ontology class info (name + description) injected into pruning prompt."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ def rerank_with_activation(
|
|||||||
|
|
||||||
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
for category in ["statements", "chunks", "entities", "summaries"]:
|
for category in ["statements", "chunks", "entities", "summaries", "communities"]:
|
||||||
keyword_items = keyword_results.get(category, [])
|
keyword_items = keyword_results.get(category, [])
|
||||||
embedding_items = embedding_results.get(category, [])
|
embedding_items = embedding_results.get(category, [])
|
||||||
|
|
||||||
@@ -281,21 +281,23 @@ def rerank_with_activation(
|
|||||||
for item in items_list:
|
for item in items_list:
|
||||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||||
if item_id and item_id in combined_items:
|
if item_id and item_id in combined_items:
|
||||||
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value", 0)
|
combined_items[item_id]["normalized_activation_value"] = item.get("normalized_activation_value")
|
||||||
|
|
||||||
# 步骤 4: 计算基础分数和最终分数
|
# 步骤 4: 计算基础分数和最终分数
|
||||||
for item_id, item in combined_items.items():
|
for item_id, item in combined_items.items():
|
||||||
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
bm25_norm = float(item.get("bm25_score", 0) or 0)
|
||||||
emb_norm = float(item.get("embedding_score", 0) or 0)
|
emb_norm = float(item.get("embedding_score", 0) or 0)
|
||||||
act_norm = float(item.get("normalized_activation_value", 0) or 0)
|
# normalized_activation_value 为 None 表示该节点无激活值,保留 None 语义
|
||||||
|
raw_act_norm = item.get("normalized_activation_value")
|
||||||
|
act_norm = float(raw_act_norm) if raw_act_norm is not None else None
|
||||||
|
|
||||||
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
# 第一阶段:只考虑内容相关性(BM25 + Embedding)
|
||||||
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
# alpha 控制 BM25 权重,(1-alpha) 控制 Embedding 权重
|
||||||
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
content_score = alpha * bm25_norm + (1 - alpha) * emb_norm
|
||||||
base_score = content_score # 第一阶段用内容分数
|
base_score = content_score # 第一阶段用内容分数
|
||||||
|
|
||||||
# 存储激活度分数供第二阶段使用
|
# 存储激活度分数供第二阶段使用(None 表示无激活值,不参与激活值排序)
|
||||||
item["activation_score"] = act_norm
|
item["activation_score"] = act_norm # 可能为 None
|
||||||
item["content_score"] = content_score
|
item["content_score"] = content_score
|
||||||
item["base_score"] = base_score
|
item["base_score"] = base_score
|
||||||
|
|
||||||
@@ -724,6 +726,8 @@ async def run_hybrid_search(
|
|||||||
try:
|
try:
|
||||||
keyword_task = None
|
keyword_task = None
|
||||||
embedding_task = None
|
embedding_task = None
|
||||||
|
keyword_results: Dict[str, List] = {}
|
||||||
|
embedding_results: Dict[str, List] = {}
|
||||||
|
|
||||||
if search_type in ["keyword", "hybrid"]:
|
if search_type in ["keyword", "hybrid"]:
|
||||||
# Keyword-based search
|
# Keyword-based search
|
||||||
@@ -746,6 +750,7 @@ async def run_hybrid_search(
|
|||||||
|
|
||||||
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
# 从数据库读取嵌入器配置(按 ID)并构建 RedBearModelConfig
|
||||||
config_load_start = time.time()
|
config_load_start = time.time()
|
||||||
|
try:
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
config_service = MemoryConfigService(db)
|
config_service = MemoryConfigService(db)
|
||||||
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
embedder_config_dict = config_service.get_embedder_config(str(memory_config.embedding_model_id))
|
||||||
@@ -775,6 +780,12 @@ async def run_hybrid_search(
|
|||||||
include=include,
|
include=include,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
except Exception as emb_init_err:
|
||||||
|
logger.warning(
|
||||||
|
f"[PERF] Embedding search skipped due to init error "
|
||||||
|
f"(embedding_model_id={memory_config.embedding_model_id}): {emb_init_err}"
|
||||||
|
)
|
||||||
|
embedding_task = None
|
||||||
|
|
||||||
if keyword_task:
|
if keyword_task:
|
||||||
keyword_results = await keyword_task
|
keyword_results = await keyword_task
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from math import sqrt
|
from math import sqrt
|
||||||
@@ -19,8 +20,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
# 全量迭代最大轮数,防止不收敛
|
# 全量迭代最大轮数,防止不收敛
|
||||||
MAX_ITERATIONS = 10
|
MAX_ITERATIONS = 10
|
||||||
# 社区摘要核心实体数量
|
|
||||||
CORE_ENTITY_LIMIT = 5
|
# 社区核心实体取 top-N 数量
|
||||||
|
CORE_ENTITY_LIMIT = 10
|
||||||
|
|
||||||
|
|
||||||
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||||
@@ -67,13 +69,13 @@ class LabelPropagationEngine:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.connector = connector
|
self.connector = connector
|
||||||
self.repo = CommunityRepository(connector)
|
self.repo = CommunityRepository(connector)
|
||||||
self.config_id = config_id
|
|
||||||
self.llm_model_id = llm_model_id
|
self.llm_model_id = llm_model_id
|
||||||
|
self.embedding_model_id = embedding_model_id
|
||||||
|
|
||||||
# ──────────────────────────────────────────────────────────────────────────
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
# 公开接口
|
# 公开接口
|
||||||
@@ -103,58 +105,81 @@ class LabelPropagationEngine:
|
|||||||
|
|
||||||
async def full_clustering(self, end_user_id: str) -> None:
|
async def full_clustering(self, end_user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
全量标签传播初始化。
|
全量标签传播初始化(分批处理,控制内存峰值)。
|
||||||
|
|
||||||
1. 拉取所有实体,初始化每个实体为独立社区
|
策略:
|
||||||
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
- 每次只加载 BATCH_SIZE 个实体及其邻居进内存
|
||||||
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
- labels 字典跨批次共享(只存 id→community_id,内存极小)
|
||||||
4. 将最终标签写入 Neo4j
|
- 每批独立跑 MAX_ITERATIONS 轮 LPA,批次间通过 labels 传递社区信息
|
||||||
|
- 所有批次完成后统一 flush 和 merge
|
||||||
"""
|
"""
|
||||||
entities = await self.repo.get_all_entities(end_user_id)
|
BATCH_SIZE = 888 # 每批实体数,可按需调整
|
||||||
if not entities:
|
|
||||||
|
# 轻量查询:只获取总数和 ID 列表,不加载 embedding 等大字段
|
||||||
|
total_count = await self.repo.get_entity_count(end_user_id)
|
||||||
|
if not total_count:
|
||||||
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 初始化:每个实体持有自己 id 作为社区标签
|
all_entity_ids = await self.repo.get_all_entity_ids(end_user_id)
|
||||||
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
logger.info(f"[Clustering] 用户 {end_user_id} 共 {total_count} 个实体,"
|
||||||
embeddings: Dict[str, Optional[List[float]]] = {
|
f"分批大小 {BATCH_SIZE},共 {(total_count + BATCH_SIZE - 1) // BATCH_SIZE} 批")
|
||||||
e["id"]: e.get("name_embedding") for e in entities
|
|
||||||
|
# labels 跨批次共享:只存 id→community_id,内存极小
|
||||||
|
labels: Dict[str, str] = {eid: eid for eid in all_entity_ids}
|
||||||
|
del all_entity_ids # 释放 ID 列表,后续按批次加载完整数据
|
||||||
|
|
||||||
|
for batch_start in range(0, total_count, BATCH_SIZE):
|
||||||
|
batch_entities = await self.repo.get_entities_page(
|
||||||
|
end_user_id, skip=batch_start, limit=BATCH_SIZE
|
||||||
|
)
|
||||||
|
if not batch_entities:
|
||||||
|
break
|
||||||
|
|
||||||
|
batch_ids = [e["id"] for e in batch_entities]
|
||||||
|
batch_embeddings: Dict[str, Optional[List[float]]] = {
|
||||||
|
e["id"]: e.get("name_embedding") for e in batch_entities
|
||||||
}
|
}
|
||||||
|
|
||||||
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
logger.info(
|
||||||
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1}:"
|
||||||
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
f"加载 {len(batch_entities)} 个实体的邻居图..."
|
||||||
|
)
|
||||||
|
neighbors_cache = await self.repo.get_entity_neighbors_for_ids(
|
||||||
|
batch_ids, end_user_id
|
||||||
|
)
|
||||||
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||||
|
|
||||||
for iteration in range(MAX_ITERATIONS):
|
for iteration in range(MAX_ITERATIONS):
|
||||||
changed = 0
|
changed = 0
|
||||||
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
for entity in batch_entities:
|
||||||
for entity in entities:
|
|
||||||
eid = entity["id"]
|
eid = entity["id"]
|
||||||
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
|
||||||
neighbors = neighbors_cache.get(eid, [])
|
neighbors = neighbors_cache.get(eid, [])
|
||||||
|
|
||||||
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
# 注入跨批次的最新标签(邻居可能在其他批次,labels 里有其最新值)
|
||||||
enriched = []
|
enriched = []
|
||||||
for nb in neighbors:
|
for nb in neighbors:
|
||||||
nb_copy = dict(nb)
|
nb_copy = dict(nb)
|
||||||
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||||
enriched.append(nb_copy)
|
enriched.append(nb_copy)
|
||||||
|
|
||||||
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
new_label = _weighted_vote(enriched, batch_embeddings.get(eid))
|
||||||
if new_label and new_label != labels[eid]:
|
if new_label and new_label != labels[eid]:
|
||||||
labels[eid] = new_label
|
labels[eid] = new_label
|
||||||
changed += 1
|
changed += 1
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
f"[Clustering] 批次 {batch_start // BATCH_SIZE + 1} "
|
||||||
f"标签变化数: {changed}"
|
f"迭代 {iteration + 1}/{MAX_ITERATIONS},标签变化数: {changed}"
|
||||||
)
|
)
|
||||||
if changed == 0:
|
if changed == 0:
|
||||||
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
logger.info("[Clustering] 标签已收敛,提前结束本批迭代")
|
||||||
break
|
break
|
||||||
|
|
||||||
# 将最终标签写入 Neo4j
|
# 释放本批次的大对象
|
||||||
|
del neighbors_cache, batch_embeddings, batch_entities
|
||||||
|
|
||||||
|
# 所有批次完成,统一写入 Neo4j
|
||||||
await self._flush_labels(labels, end_user_id)
|
await self._flush_labels(labels, end_user_id)
|
||||||
pre_merge_count = len(set(labels.values()))
|
pre_merge_count = len(set(labels.values()))
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -162,7 +187,6 @@ class LabelPropagationEngine:
|
|||||||
f"{len(labels)} 个实体,开始后处理合并"
|
f"{len(labels)} 个实体,开始后处理合并"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
|
||||||
all_community_ids = list(set(labels.values()))
|
all_community_ids = list(set(labels.values()))
|
||||||
await self._evaluate_merge(all_community_ids, end_user_id)
|
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||||
|
|
||||||
@@ -170,17 +194,15 @@ class LabelPropagationEngine:
|
|||||||
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||||
f"{len(labels)} 个实体"
|
f"{len(labels)} 个实体"
|
||||||
)
|
)
|
||||||
# 为所有社区生成元数据
|
|
||||||
# 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
|
# 查询存活社区并生成元数据
|
||||||
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
|
|
||||||
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||||
surviving_community_ids = list({
|
surviving_community_ids = list({
|
||||||
e.get("community_id") for e in surviving_communities
|
e.get("community_id") for e in surviving_communities
|
||||||
if e.get("community_id")
|
if e.get("community_id")
|
||||||
})
|
})
|
||||||
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||||
for cid in surviving_community_ids:
|
await self._generate_community_metadata(surviving_community_ids, end_user_id)
|
||||||
await self._generate_community_metadata(cid, end_user_id)
|
|
||||||
|
|
||||||
async def incremental_update(
|
async def incremental_update(
|
||||||
self, new_entity_ids: List[str], end_user_id: str
|
self, new_entity_ids: List[str], end_user_id: str
|
||||||
@@ -237,7 +259,7 @@ class LabelPropagationEngine:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata(new_cid, end_user_id)
|
await self._generate_community_metadata([new_cid], end_user_id)
|
||||||
else:
|
else:
|
||||||
# 加入得票最多的社区
|
# 加入得票最多的社区
|
||||||
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
@@ -249,7 +271,7 @@ class LabelPropagationEngine:
|
|||||||
await self._evaluate_merge(
|
await self._evaluate_merge(
|
||||||
list(community_ids_in_neighbors), end_user_id
|
list(community_ids_in_neighbors), end_user_id
|
||||||
)
|
)
|
||||||
await self._generate_community_metadata(target_cid, end_user_id)
|
await self._generate_community_metadata([target_cid], end_user_id)
|
||||||
|
|
||||||
async def _evaluate_merge(
|
async def _evaluate_merge(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
@@ -413,71 +435,137 @@ class LabelPropagationEngine:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_entity_lines(members: List[Dict]) -> List[str]:
|
||||||
|
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example。"""
|
||||||
|
lines = []
|
||||||
|
for m in members:
|
||||||
|
m_name = m.get("name", "")
|
||||||
|
aliases = m.get("aliases") or []
|
||||||
|
description = m.get("description") or ""
|
||||||
|
example = m.get("example") or ""
|
||||||
|
aliases_str = f"(别名:{'、'.join(aliases)})" if aliases else ""
|
||||||
|
desc_str = f":{description}" if description else ""
|
||||||
|
example_str = f"(示例:{example})" if example else ""
|
||||||
|
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
|
||||||
|
return lines
|
||||||
|
|
||||||
async def _generate_community_metadata(
|
async def _generate_community_metadata(
|
||||||
self, community_id: str, end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为社区生成并写入元数据:名称、摘要、核心实体。
|
为一个或多个社区生成并写入元数据。
|
||||||
|
|
||||||
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
流程:
|
||||||
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
1. 逐个社区调 LLM 生成 name / summary(串行)
|
||||||
|
2. 收集所有 summary,一次性批量 embed
|
||||||
|
3. 单个社区用 update_community_metadata,多个用 batch_update_community_metadata
|
||||||
"""
|
"""
|
||||||
try:
|
if not community_ids:
|
||||||
members = await self.repo.get_community_members(community_id, end_user_id)
|
|
||||||
if not members:
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# 核心实体:按 activation_value 降序取 top-N
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
# --- 阶段1:并发调 LLM 生成每个社区的 name / summary ---
|
||||||
|
async def _build_one(cid: str):
|
||||||
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
|
if not members:
|
||||||
|
return None
|
||||||
|
|
||||||
sorted_members = sorted(
|
sorted_members = sorted(
|
||||||
members,
|
members,
|
||||||
key=lambda m: m.get("activation_value") or 0,
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
all_names = [m["name"] for m in members if m.get("name")]
|
|
||||||
|
|
||||||
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
entity_list_str = "\n".join(self._build_entity_lines(members))
|
||||||
summary = f"包含实体:{', '.join(all_names)}"
|
|
||||||
|
|
||||||
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
# 方案四:注入社区内实体间关系三元组
|
||||||
if self.llm_model_id:
|
relationships = await self.repo.get_community_relationships(cid, end_user_id)
|
||||||
try:
|
rel_lines = [
|
||||||
from app.db import get_db_context
|
f"- {r['subject']} → {r['predicate']} → {r['object']}"
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
for r in relationships
|
||||||
|
if r.get("subject") and r.get("predicate") and r.get("object")
|
||||||
|
]
|
||||||
|
rel_section = (
|
||||||
|
f"\n实体间关系:\n" + "\n".join(rel_lines)
|
||||||
|
if rel_lines else ""
|
||||||
|
)
|
||||||
|
|
||||||
entity_list_str = "、".join(all_names)
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
|
||||||
f"请为这组实体所代表的主题:\n"
|
f"请为这组实体所代表的主题:\n"
|
||||||
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
f"2. 写一句话摘要(不超过50个字)\n\n"
|
f"2. 写一句话摘要(不超过80个字)\n\n"
|
||||||
f"严格按以下格式输出,不要有其他内容:\n"
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
f"名称:<名称>\n摘要:<摘要>"
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
)
|
)
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
factory = MemoryClientFactory(db)
|
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
|
||||||
llm_client = factory.get_llm_client(self.llm_model_id)
|
|
||||||
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
name, summary = "", ""
|
||||||
for line in text.strip().splitlines():
|
for line in text.strip().splitlines():
|
||||||
if line.startswith("名称:"):
|
if line.startswith("名称:"):
|
||||||
name = line[3:].strip()
|
name = line[3:].strip()
|
||||||
elif line.startswith("摘要:"):
|
elif line.startswith("摘要:"):
|
||||||
summary = line[3:].strip()
|
summary = line[3:].strip()
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
|
||||||
|
|
||||||
await self.repo.update_community_metadata(
|
return {
|
||||||
community_id=community_id,
|
"community_id": cid,
|
||||||
end_user_id=end_user_id,
|
"end_user_id": end_user_id,
|
||||||
name=name,
|
"name": name,
|
||||||
summary=summary,
|
"summary": summary,
|
||||||
core_entities=core_entities,
|
"core_entities": core_entities,
|
||||||
|
"summary_embedding": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
results = await asyncio.gather(
|
||||||
|
*[_build_one(cid) for cid in community_ids],
|
||||||
|
return_exceptions=True,
|
||||||
)
|
)
|
||||||
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
metadata_list = []
|
||||||
except Exception as e:
|
for cid, res in zip(community_ids, results):
|
||||||
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
if isinstance(res, Exception):
|
||||||
|
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {res}", exc_info=res)
|
||||||
|
elif res is not None:
|
||||||
|
metadata_list.append(res)
|
||||||
|
|
||||||
|
if not metadata_list:
|
||||||
|
return
|
||||||
|
|
||||||
|
# --- 阶段2:批量生成 summary_embedding ---
|
||||||
|
summaries = [m["summary"] for m in metadata_list]
|
||||||
|
with get_db_context() as db:
|
||||||
|
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
|
embeddings = await embedder.response(summaries)
|
||||||
|
for i, meta in enumerate(metadata_list):
|
||||||
|
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
|
||||||
|
|
||||||
|
# --- 阶段3:写入(单个 or 批量)---
|
||||||
|
if len(metadata_list) == 1:
|
||||||
|
m = metadata_list[0]
|
||||||
|
result = await self.repo.update_community_metadata(
|
||||||
|
community_id=m["community_id"],
|
||||||
|
end_user_id=m["end_user_id"],
|
||||||
|
name=m["name"],
|
||||||
|
summary=m["summary"],
|
||||||
|
core_entities=m["core_entities"],
|
||||||
|
summary_embedding=m["summary_embedding"],
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
|
||||||
|
else:
|
||||||
|
ok = await self.repo.batch_update_community_metadata(metadata_list)
|
||||||
|
if ok:
|
||||||
|
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Clustering] 批量写入社区元数据失败")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _new_community_id() -> str:
|
def _new_community_id() -> str:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext
|
||||||
from app.core.memory.models.config_models import PruningConfig
|
from app.core.memory.models.config_models import PruningConfig
|
||||||
from app.core.memory.utils.config.config_utils import get_pruning_config
|
|
||||||
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering
|
||||||
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import (
|
||||||
SceneConfigRegistry,
|
SceneConfigRegistry,
|
||||||
@@ -34,6 +33,8 @@ class DialogExtractionResponse(BaseModel):
|
|||||||
- is_related:对话与场景的相关性判定。
|
- is_related:对话与场景的相关性判定。
|
||||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||||
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||||
|
- scene_unrelated_snippets:与当前场景无关且无语义关联的消息片段(原文截取),
|
||||||
|
用于高阈值阶段精准删除跨场景内容。
|
||||||
"""
|
"""
|
||||||
is_related: bool = Field(...)
|
is_related: bool = Field(...)
|
||||||
times: List[str] = Field(default_factory=list)
|
times: List[str] = Field(default_factory=list)
|
||||||
@@ -43,6 +44,7 @@ class DialogExtractionResponse(BaseModel):
|
|||||||
addresses: List[str] = Field(default_factory=list)
|
addresses: List[str] = Field(default_factory=list)
|
||||||
keywords: List[str] = Field(default_factory=list)
|
keywords: List[str] = Field(default_factory=list)
|
||||||
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||||
|
scene_unrelated_snippets: List[str] = Field(default_factory=list,description="与当前场景无关且无语义关联的消息原文片段,高阈值阶段用于精准删除跨场景内容")
|
||||||
|
|
||||||
|
|
||||||
class MessageImportanceResponse(BaseModel):
|
class MessageImportanceResponse(BaseModel):
|
||||||
@@ -91,12 +93,14 @@ class SemanticPruner:
|
|||||||
# 加载统一填充词库
|
# 加载统一填充词库
|
||||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||||
|
|
||||||
# 本体类型列表(用于注入提示词,所有场景均支持)
|
# 本体类型列表:直接使用 ontology_class_infos(name + description)
|
||||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
self._ontology_class_infos = getattr(self.config, "ontology_class_infos", None) or []
|
||||||
|
# _ontology_classes 仅用于日志统计
|
||||||
|
self._ontology_classes = [info.class_name for info in self._ontology_class_infos]
|
||||||
|
|
||||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||||
if self._ontology_classes:
|
if self._ontology_class_infos:
|
||||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
self._log(f"[剪枝-初始化] 注入本体类型({len(self._ontology_class_infos)}个): {self._ontology_classes}")
|
||||||
else:
|
else:
|
||||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||||
|
|
||||||
@@ -121,7 +125,8 @@ class SemanticPruner:
|
|||||||
1. 空消息
|
1. 空消息
|
||||||
2. 场景特定填充词库精确匹配
|
2. 场景特定填充词库精确匹配
|
||||||
3. 常见寒暄精确匹配
|
3. 常见寒暄精确匹配
|
||||||
4. 纯表情/标点
|
4. 组合寒暄模式(前缀+后缀组合,如"好的谢谢"、"同学你好"、"明白了")
|
||||||
|
5. 纯表情/标点
|
||||||
"""
|
"""
|
||||||
t = message.msg.strip()
|
t = message.msg.strip()
|
||||||
if not t:
|
if not t:
|
||||||
@@ -143,6 +148,55 @@ class SemanticPruner:
|
|||||||
if t in common_greetings:
|
if t in common_greetings:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
# 组合寒暄模式:短消息(≤15字)且完全由寒暄成分构成
|
||||||
|
# 策略:将消息拆分后,每个片段都能在填充词库或常见寒暄中找到,则整体为填充
|
||||||
|
if len(t) <= 15:
|
||||||
|
# 确认+称呼/感谢组合,如"好的谢谢"、"明白了"、"知道了谢谢"
|
||||||
|
_confirm_prefixes = {"好的", "好", "嗯", "嗯嗯", "哦", "明白", "明白了", "知道了", "了解", "收到", "没问题"}
|
||||||
|
_thanks_suffixes = {"谢谢", "谢谢你", "谢谢您", "多谢", "感谢", "谢了"}
|
||||||
|
_greeting_suffixes = {"你好", "您好", "老师好", "同学好", "大家好"}
|
||||||
|
_greeting_prefixes = {"同学", "老师", "您好", "你好"}
|
||||||
|
_close_patterns = {
|
||||||
|
"没有了", "没事了", "没问题了", "好了", "行了", "可以了",
|
||||||
|
"不用了", "不需要了", "就这样", "就这样吧", "那就这样",
|
||||||
|
}
|
||||||
|
_polite_responses = {
|
||||||
|
"不客气", "不用谢", "没关系", "没事", "应该的", "这是我应该做的",
|
||||||
|
}
|
||||||
|
|
||||||
|
# 规则1:确认词 + 感谢词(如"好的谢谢"、"嗯谢谢")
|
||||||
|
for cp in _confirm_prefixes:
|
||||||
|
for ts in _thanks_suffixes:
|
||||||
|
if t == cp + ts or t == cp + "," + ts or t == cp + "," + ts:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 规则2:称呼前缀 + 问候(如"同学你好"、"老师好")
|
||||||
|
for gp in _greeting_prefixes:
|
||||||
|
for gs in _greeting_suffixes:
|
||||||
|
if t == gp + gs or t.startswith(gp) and t.endswith("好"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 规则3:结束语 + 感谢(如"没有了,谢谢老师"、"没有了谢谢")
|
||||||
|
for cp in _close_patterns:
|
||||||
|
if t.startswith(cp):
|
||||||
|
remainder = t[len(cp):].lstrip(",,、 ")
|
||||||
|
if not remainder or any(remainder.startswith(ts) for ts in _thanks_suffixes):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 规则4:礼貌回应(如"不客气,祝你考试顺利"——前缀是礼貌词,后半是祝福套话)
|
||||||
|
for pr in _polite_responses:
|
||||||
|
if t.startswith(pr):
|
||||||
|
remainder = t[len(pr):].lstrip(",,、 ")
|
||||||
|
# 后半是祝福/套话(不含实质信息)
|
||||||
|
if not remainder or re.match(r"^(祝|希望|期待|加油|顺利|好好|保重)", remainder):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# 规则5:纯确认词加"了"后缀(如"明白了"、"知道了"、"好了")
|
||||||
|
_confirm_base = {"明白", "知道", "了解", "收到", "好", "行", "可以", "没问题"}
|
||||||
|
for cb in _confirm_base:
|
||||||
|
if t == cb + "了" or t == cb + "了。" or t == cb + "了!":
|
||||||
|
return True
|
||||||
|
|
||||||
# 检查是否为纯表情符号(方括号包裹)
|
# 检查是否为纯表情符号(方括号包裹)
|
||||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||||
return True
|
return True
|
||||||
@@ -331,13 +385,13 @@ class SemanticPruner:
|
|||||||
|
|
||||||
rendered = self.template.render(
|
rendered = self.template.render(
|
||||||
pruning_scene=self.config.pruning_scene,
|
pruning_scene=self.config.pruning_scene,
|
||||||
ontology_classes=self._ontology_classes,
|
ontology_class_infos=self._ontology_class_infos,
|
||||||
dialog_text=dialog_text,
|
dialog_text=dialog_text,
|
||||||
language=self.language
|
language=self.language
|
||||||
)
|
)
|
||||||
log_template_rendering("extracat_Pruning.jinja2", {
|
log_template_rendering("extracat_Pruning.jinja2", {
|
||||||
"pruning_scene": self.config.pruning_scene,
|
"pruning_scene": self.config.pruning_scene,
|
||||||
"ontology_classes_count": len(self._ontology_classes),
|
"ontology_class_infos_count": len(self._ontology_class_infos),
|
||||||
"language": self.language
|
"language": self.language
|
||||||
})
|
})
|
||||||
log_prompt_rendering("pruning-extract", rendered)
|
log_prompt_rendering("pruning-extract", rendered)
|
||||||
@@ -377,6 +431,183 @@ class SemanticPruner:
|
|||||||
)
|
)
|
||||||
return fallback_response
|
return fallback_response
|
||||||
|
|
||||||
|
def _get_pruning_mode(self) -> str:
|
||||||
|
"""根据 pruning_threshold 返回当前剪枝阶段。
|
||||||
|
|
||||||
|
- 低阈值 [0.0, 0.3):conservative 只删填充,保留所有实质内容
|
||||||
|
- 中阈值 [0.3, 0.6):semantic 保留场景相关 + 有语义关联的内容,删除无关联内容
|
||||||
|
- 高阈值 [0.6, 0.9]:strict 只保留场景相关内容,跨场景内容可被删除
|
||||||
|
"""
|
||||||
|
t = float(self.config.pruning_threshold)
|
||||||
|
if t < 0.3:
|
||||||
|
return "conservative"
|
||||||
|
elif t < 0.6:
|
||||||
|
return "semantic"
|
||||||
|
else:
|
||||||
|
return "strict"
|
||||||
|
|
||||||
|
def _apply_related_dialog_pruning(
|
||||||
|
self,
|
||||||
|
msgs: List[ConversationMessage],
|
||||||
|
extraction: "DialogExtractionResponse",
|
||||||
|
dialog_label: str,
|
||||||
|
pruning_mode: str,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
"""相关对话统一剪枝入口,消除 prune_dialog / prune_dataset 中的重复逻辑。
|
||||||
|
|
||||||
|
- conservative:只删填充
|
||||||
|
- semantic / strict:场景感知剪枝
|
||||||
|
"""
|
||||||
|
if pruning_mode == "conservative":
|
||||||
|
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||||
|
return self._prune_fillers_only(msgs, preserve_tokens, dialog_label)
|
||||||
|
else:
|
||||||
|
return self._prune_with_scene_filter(msgs, extraction, dialog_label, pruning_mode)
|
||||||
|
|
||||||
|
def _prune_fillers_only(
|
||||||
|
self,
|
||||||
|
msgs: List[ConversationMessage],
|
||||||
|
preserve_tokens: List[str],
|
||||||
|
dialog_label: str,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
"""相关对话专用:只删填充消息,LLM 保护消息和实质内容一律保留。
|
||||||
|
|
||||||
|
不受 pruning_threshold 约束,删多少算多少(填充有多少删多少)。
|
||||||
|
至少保留 1 条消息。
|
||||||
|
注意:填充检测优先于 preserve_tokens 保护——填充消息本身无信息价值,
|
||||||
|
即使 LLM 误将其关键词放入 preserve_tokens 也应删除。
|
||||||
|
"""
|
||||||
|
to_delete_ids: set = set()
|
||||||
|
for m in msgs:
|
||||||
|
# 填充检测优先:先判断是否为填充,再看 LLM 保护
|
||||||
|
if self._is_filler_message(m):
|
||||||
|
to_delete_ids.add(id(m))
|
||||||
|
self._log(f" [填充] '{m.msg[:40]}' → 删除")
|
||||||
|
continue
|
||||||
|
if self._msg_matches_tokens(m, preserve_tokens):
|
||||||
|
self._log(f" [保护] '{m.msg[:40]}' → LLM保护,跳过")
|
||||||
|
|
||||||
|
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||||
|
if not kept and msgs:
|
||||||
|
kept = [msgs[0]]
|
||||||
|
|
||||||
|
deleted = len(msgs) - len(kept)
|
||||||
|
self._log(
|
||||||
|
f"[剪枝-相关] {dialog_label} 总消息={len(msgs)} "
|
||||||
|
f"填充删除={deleted} 保留={len(kept)}"
|
||||||
|
)
|
||||||
|
return kept
|
||||||
|
|
||||||
|
def _prune_with_scene_filter(
|
||||||
|
self,
|
||||||
|
msgs: List[ConversationMessage],
|
||||||
|
extraction: "DialogExtractionResponse",
|
||||||
|
dialog_label: str,
|
||||||
|
mode: str,
|
||||||
|
) -> List[ConversationMessage]:
|
||||||
|
"""场景感知剪枝,供 semantic / strict 两个阈值档位调用。
|
||||||
|
|
||||||
|
本函数体现剪枝系统的三层递进逻辑:
|
||||||
|
|
||||||
|
第一层(conservative,阈值 < 0.3):
|
||||||
|
不进入本函数,由 _prune_fillers_only 处理。
|
||||||
|
保留标准:只问"有没有信息量",填充消息(嗯/好的/哈哈等)删除,其余一律保留。
|
||||||
|
|
||||||
|
第二层(semantic,阈值 [0.3, 0.6)):
|
||||||
|
保留标准:内容价值优先,场景相关性是参考而非唯一标准。
|
||||||
|
- 填充消息 → 删除(最高优先级)
|
||||||
|
- 场景相关消息 → 保留
|
||||||
|
- 场景无关消息 → 有两次豁免机会:
|
||||||
|
1. 命中 scene_preserve_tokens(LLM 标记的关键词/时间/金额等)→ 保留
|
||||||
|
2. 含情感词(感觉/压力/开心等)→ 保留(情感内容有记忆价值)
|
||||||
|
3. 两次豁免均未命中 → 删除
|
||||||
|
|
||||||
|
第三层(strict,阈值 [0.6, 0.9]):
|
||||||
|
保留标准:场景相关性优先,无任何豁免。
|
||||||
|
- 填充消息 → 删除(最高优先级)
|
||||||
|
- 场景相关消息 → 保留
|
||||||
|
- 场景无关消息 → 直接删除,preserve_keywords 和情感词在此模式下均不生效
|
||||||
|
|
||||||
|
至少保留 1 条消息(兜底取第一条)。
|
||||||
|
"""
|
||||||
|
# strict 模式收窄保护范围:只保护结构化关键信息(时间/编号/金额/联系方式/地址),
|
||||||
|
# 不保护 keywords / preserve_keywords,让场景过滤能删掉更多内容。
|
||||||
|
# semantic 模式完整保护:包含 LLM 抽取的所有重要片段(含 keywords 和 preserve_keywords)。
|
||||||
|
if mode == "strict":
|
||||||
|
scene_preserve_tokens = (
|
||||||
|
extraction.times + extraction.ids + extraction.amounts +
|
||||||
|
extraction.contacts + extraction.addresses
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scene_preserve_tokens = self._build_preserve_tokens(extraction)
|
||||||
|
|
||||||
|
unrelated_snippets = extraction.scene_unrelated_snippets or []
|
||||||
|
|
||||||
|
to_delete_ids: set = set()
|
||||||
|
for m in msgs:
|
||||||
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
|
# 第一优先级:填充消息无论模式直接删除,不参与后续场景判断
|
||||||
|
if self._is_filler_message(m):
|
||||||
|
to_delete_ids.add(id(m))
|
||||||
|
self._log(f" [填充] '{msg_text[:40]}' → 删除")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 双向包含匹配:处理 LLM 返回片段与原始消息文本长度不完全一致的情况
|
||||||
|
is_scene_unrelated = any(
|
||||||
|
snip and (snip in msg_text or msg_text in snip)
|
||||||
|
for snip in unrelated_snippets
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_scene_unrelated:
|
||||||
|
if mode == "strict":
|
||||||
|
# strict:场景无关直接删除,不做任何豁免
|
||||||
|
# 场景相关性是唯一裁决标准,preserve_keywords 在此模式下不生效
|
||||||
|
to_delete_ids.add(id(m))
|
||||||
|
self._log(f" [场景无关-严格] '{msg_text[:40]}' → 删除")
|
||||||
|
elif mode == "semantic":
|
||||||
|
# semantic:场景无关但有内容价值 → 保留
|
||||||
|
# 豁免第一层:命中 scene_preserve_tokens(关键词/结构化信息保护)
|
||||||
|
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||||
|
self._log(f" [保护] '{msg_text[:40]}' → 场景关键词保护,保留")
|
||||||
|
else:
|
||||||
|
# 豁免第二层:含情感词,认为有情境记忆价值,即使场景无关也保留
|
||||||
|
has_contextual_emotion = any(
|
||||||
|
word in msg_text
|
||||||
|
for word in ["感觉", "觉得", "心情", "开心", "难过", "高兴", "沮丧",
|
||||||
|
"喜欢", "讨厌", "爱", "恨", "担心", "害怕", "兴奋",
|
||||||
|
"压力", "累", "疲惫", "烦", "焦虑", "委屈", "感动"]
|
||||||
|
)
|
||||||
|
if not has_contextual_emotion:
|
||||||
|
to_delete_ids.add(id(m))
|
||||||
|
self._log(f" [场景无关-语义] '{msg_text[:40]}' → 删除(无情感关联)")
|
||||||
|
else:
|
||||||
|
self._log(f" [场景关联-保留] '{msg_text[:40]}' → 有情感关联,保留")
|
||||||
|
else:
|
||||||
|
# 不在 scene_unrelated_snippets 中 → 场景相关,直接保留
|
||||||
|
if self._msg_matches_tokens(m, scene_preserve_tokens):
|
||||||
|
self._log(f" [保护] '{msg_text[:40]}' → LLM保护,跳过")
|
||||||
|
# else: 普通场景相关消息,保留,不输出日志
|
||||||
|
|
||||||
|
kept = [m for m in msgs if id(m) not in to_delete_ids]
|
||||||
|
if not kept and msgs:
|
||||||
|
kept = [msgs[0]]
|
||||||
|
|
||||||
|
deleted = len(msgs) - len(kept)
|
||||||
|
self._log(
|
||||||
|
f"[剪枝-{mode}] {dialog_label} 总消息={len(msgs)} "
|
||||||
|
f"删除={deleted} 保留={len(kept)}"
|
||||||
|
)
|
||||||
|
return kept
|
||||||
|
|
||||||
|
def _build_preserve_tokens(self, extraction: "DialogExtractionResponse") -> List[str]:
|
||||||
|
"""统一构建 preserve_tokens,合并 LLM 抽取的所有重要片段。"""
|
||||||
|
return (
|
||||||
|
extraction.times + extraction.ids + extraction.amounts +
|
||||||
|
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||||
|
extraction.preserve_keywords
|
||||||
|
)
|
||||||
|
|
||||||
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool:
|
||||||
"""判断消息是否包含任意抽取到的重要片段。"""
|
"""判断消息是否包含任意抽取到的重要片段。"""
|
||||||
if not tokens:
|
if not tokens:
|
||||||
@@ -397,16 +628,18 @@ class SemanticPruner:
|
|||||||
|
|
||||||
proportion = float(self.config.pruning_threshold)
|
proportion = float(self.config.pruning_threshold)
|
||||||
extraction = await self._extract_dialog_important(dialog.content)
|
extraction = await self._extract_dialog_important(dialog.content)
|
||||||
|
pruning_mode = self._get_pruning_mode()
|
||||||
|
self._log(f"[剪枝-模式] 阈值={proportion} → 模式={pruning_mode}")
|
||||||
|
|
||||||
if extraction.is_related:
|
if extraction.is_related:
|
||||||
# 相关对话不剪枝
|
kept = self._apply_related_dialog_pruning(
|
||||||
|
dialog.context.msgs, extraction, f"对话ID={dialog.id}", pruning_mode
|
||||||
|
)
|
||||||
|
dialog.context = ConversationContext(msgs=kept)
|
||||||
return dialog
|
return dialog
|
||||||
|
|
||||||
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||||
preserve_tokens = (
|
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||||
extraction.times + extraction.ids + extraction.amounts +
|
|
||||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
|
||||||
extraction.preserve_keywords
|
|
||||||
)
|
|
||||||
msgs = dialog.context.msgs
|
msgs = dialog.context.msgs
|
||||||
|
|
||||||
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||||
@@ -482,10 +715,29 @@ class SemanticPruner:
|
|||||||
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pruning_mode = self._get_pruning_mode()
|
||||||
|
self._log(f"[剪枝-数据集] 阈值={proportion} → 剪枝阶段={pruning_mode}")
|
||||||
|
|
||||||
result: List[DialogData] = []
|
result: List[DialogData] = []
|
||||||
total_original_msgs = 0
|
total_original_msgs = 0
|
||||||
total_deleted_msgs = 0
|
total_deleted_msgs = 0
|
||||||
|
|
||||||
|
# 统计对象:直接收集结构化数据,无需事后正则解析
|
||||||
|
stats = {
|
||||||
|
"scene": self.config.pruning_scene,
|
||||||
|
"dialog_total": len(dialogs),
|
||||||
|
"deletion_ratio": proportion,
|
||||||
|
"enabled": self.config.pruning_switch,
|
||||||
|
"pruning_mode": pruning_mode,
|
||||||
|
"related_count": 0,
|
||||||
|
"unrelated_count": 0,
|
||||||
|
"related_indices": [],
|
||||||
|
"unrelated_indices": [],
|
||||||
|
"total_deleted_messages": 0,
|
||||||
|
"remaining_dialogs": 0,
|
||||||
|
"dialogs": [],
|
||||||
|
}
|
||||||
|
|
||||||
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||||
semaphore = asyncio.Semaphore(self.max_concurrent)
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
|
||||||
@@ -505,12 +757,31 @@ class SemanticPruner:
|
|||||||
original_count = len(msgs)
|
original_count = len(msgs)
|
||||||
total_original_msgs += original_count
|
total_original_msgs += original_count
|
||||||
|
|
||||||
# 从 LLM 抽取结果中获取所有需要保留的 token
|
# 相关对话:根据阶段决定处理力度
|
||||||
preserve_tokens = (
|
if extraction.is_related:
|
||||||
extraction.times + extraction.ids + extraction.amounts +
|
stats["related_count"] += 1
|
||||||
extraction.contacts + extraction.addresses + extraction.keywords +
|
stats["related_indices"].append(d_idx + 1)
|
||||||
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
|
kept = self._apply_related_dialog_pruning(
|
||||||
|
msgs, extraction, f"对话 {d_idx+1}", pruning_mode
|
||||||
)
|
)
|
||||||
|
deleted_count = original_count - len(kept)
|
||||||
|
total_deleted_msgs += deleted_count
|
||||||
|
dd.context.msgs = kept
|
||||||
|
result.append(dd)
|
||||||
|
stats["dialogs"].append({
|
||||||
|
"index": d_idx + 1,
|
||||||
|
"is_related": True,
|
||||||
|
"total_messages": original_count,
|
||||||
|
"deleted": deleted_count,
|
||||||
|
"kept": len(kept),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
stats["unrelated_count"] += 1
|
||||||
|
stats["unrelated_indices"].append(d_idx + 1)
|
||||||
|
|
||||||
|
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||||
|
preserve_tokens = self._build_preserve_tokens(extraction)
|
||||||
|
|
||||||
# 判断是否需要详细日志
|
# 判断是否需要详细日志
|
||||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||||
@@ -601,19 +872,34 @@ class SemanticPruner:
|
|||||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stats["dialogs"].append({
|
||||||
|
"index": d_idx + 1,
|
||||||
|
"is_related": False,
|
||||||
|
"total_messages": original_count,
|
||||||
|
"protected": len(important_msgs),
|
||||||
|
"fillers": len(filler_msgs),
|
||||||
|
"deletable": len(deletable_msgs),
|
||||||
|
"deleted": deleted_count,
|
||||||
|
"kept": len(kept_msgs),
|
||||||
|
})
|
||||||
|
|
||||||
result.append(dd)
|
result.append(dd)
|
||||||
|
|
||||||
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
# 补全统计对象
|
||||||
|
stats["total_deleted_messages"] = total_deleted_msgs
|
||||||
|
stats["remaining_dialogs"] = len(result)
|
||||||
|
|
||||||
# 保存日志
|
self._log(f"[剪枝-数据集] 剩余对话数={len(result)}")
|
||||||
|
self._log(f"[剪枝-数据集] 相关对话数={stats['related_count']} 不相关对话数={stats['unrelated_count']}")
|
||||||
|
self._log(f"[剪枝-数据集] 总删除 {total_deleted_msgs} 条")
|
||||||
|
|
||||||
|
# 直接序列化统计对象,无需正则解析
|
||||||
try:
|
try:
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
settings.ensure_memory_output_dir()
|
settings.ensure_memory_output_dir()
|
||||||
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
log_output_path = settings.get_memory_output_path("pruned_terminal.json")
|
||||||
sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs]
|
|
||||||
payload = self._parse_logs_to_structured(sanitized_logs)
|
|
||||||
with open(log_output_path, "w", encoding="utf-8") as f:
|
with open(log_output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(payload, f, ensure_ascii=False, indent=2)
|
json.dump(stats, f, ensure_ascii=False, indent=2)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
|
self._log(f"[剪枝-数据集] 保存终端输出日志失败:{e}")
|
||||||
|
|
||||||
@@ -633,114 +919,4 @@ class SemanticPruner:
|
|||||||
pass
|
pass
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
def _sanitize_log_line(self, line: str) -> str:
|
|
||||||
"""移除行首的方括号标签前缀,例如 [剪枝-数据集] 或 [剪枝-对话]。"""
|
|
||||||
try:
|
|
||||||
return re.sub(r"^\[[^\]]+\]\s*", "", line)
|
|
||||||
except Exception:
|
|
||||||
return line
|
|
||||||
|
|
||||||
def _parse_logs_to_structured(self, logs: List[str]) -> dict:
|
|
||||||
"""将已去前缀的日志列表解析为结构化 JSON,便于数据对接。"""
|
|
||||||
summary = {
|
|
||||||
"scene": self.config.pruning_scene,
|
|
||||||
"dialog_total": None,
|
|
||||||
"deletion_ratio": None,
|
|
||||||
"enabled": None,
|
|
||||||
"related_count": None,
|
|
||||||
"unrelated_count": None,
|
|
||||||
"related_indices": [],
|
|
||||||
"unrelated_indices": [],
|
|
||||||
"total_deleted_messages": None,
|
|
||||||
"remaining_dialogs": None,
|
|
||||||
}
|
|
||||||
dialogs = []
|
|
||||||
|
|
||||||
# 解析函数
|
|
||||||
def parse_int(value: str) -> Optional[int]:
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def parse_float(value: str) -> Optional[float]:
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
def parse_indices(s: str) -> List[int]:
|
|
||||||
s = s.strip()
|
|
||||||
if not s:
|
|
||||||
return []
|
|
||||||
parts = [p.strip() for p in s.split(",") if p.strip()]
|
|
||||||
out: List[int] = []
|
|
||||||
for p in parts:
|
|
||||||
try:
|
|
||||||
out.append(int(p))
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return out
|
|
||||||
|
|
||||||
# 正则
|
|
||||||
re_header = re.compile(r"对话总数=(\d+)\s+场景=([^\s]+)\s+删除比例=([0-9.]+)\s+开关=(True|False)")
|
|
||||||
re_counts = re.compile(r"相关对话数=(\d+)\s+不相关对话数=(\d+)")
|
|
||||||
re_indices = re.compile(r"相关对话:第\[(.*?)\]段;不相关对话:第\[(.*?)\]段")
|
|
||||||
re_dialog = re.compile(r"对话\s+(\d+)\s+总消息=(\d+)\s+分配删除=(\d+)\s+实删=(\d+)\s+保留=(\d+)")
|
|
||||||
re_total_del = re.compile(r"总删除\s+(\d+)\s+条")
|
|
||||||
re_remaining = re.compile(r"剩余对话数=(\d+)")
|
|
||||||
|
|
||||||
for line in logs:
|
|
||||||
# 第一行:总览
|
|
||||||
m = re_header.search(line)
|
|
||||||
if m:
|
|
||||||
summary["dialog_total"] = parse_int(m.group(1))
|
|
||||||
# 顶层 scene 依配置,这里不覆盖,但也可校验 m.group(2)
|
|
||||||
summary["deletion_ratio"] = parse_float(m.group(3))
|
|
||||||
summary["enabled"] = True if m.group(4) == "True" else False
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第二行:相关/不相关数量
|
|
||||||
m = re_counts.search(line)
|
|
||||||
if m:
|
|
||||||
summary["related_count"] = parse_int(m.group(1))
|
|
||||||
summary["unrelated_count"] = parse_int(m.group(2))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 第三行:相关/不相关索引
|
|
||||||
m = re_indices.search(line)
|
|
||||||
if m:
|
|
||||||
summary["related_indices"] = parse_indices(m.group(1))
|
|
||||||
summary["unrelated_indices"] = parse_indices(m.group(2))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 对话级统计
|
|
||||||
m = re_dialog.search(line)
|
|
||||||
if m:
|
|
||||||
dialogs.append({
|
|
||||||
"index": parse_int(m.group(1)),
|
|
||||||
"total_messages": parse_int(m.group(2)),
|
|
||||||
"quota_delete": parse_int(m.group(3)),
|
|
||||||
"actual_deleted": parse_int(m.group(4)),
|
|
||||||
"kept": parse_int(m.group(5)),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 全局删除总数
|
|
||||||
m = re_total_del.search(line)
|
|
||||||
if m:
|
|
||||||
summary["total_deleted_messages"] = parse_int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 剩余对话数
|
|
||||||
m = re_remaining.search(line)
|
|
||||||
if m:
|
|
||||||
summary["remaining_dialogs"] = parse_int(m.group(1))
|
|
||||||
continue
|
|
||||||
|
|
||||||
return {
|
|
||||||
"scene": summary["scene"],
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"summary": {k: v for k, v in summary.items() if k != "scene"},
|
|
||||||
"dialogs": dialogs,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -384,6 +384,14 @@ class ExtractionOrchestrator:
|
|||||||
|
|
||||||
logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句")
|
logger.info(f"陈述句提取完成,共提取 {len(all_statements)} 条陈述句")
|
||||||
|
|
||||||
|
# 试运行模式下,所有分块提取完成后发送完成事件
|
||||||
|
if self.progress_callback and self.is_pilot_run:
|
||||||
|
await self.progress_callback(
|
||||||
|
"knowledge_extraction_complete",
|
||||||
|
f"陈述句提取完成,共提取 {len(all_statements)} 条",
|
||||||
|
{"total_statements": len(all_statements), "total_chunks": total_chunks}
|
||||||
|
)
|
||||||
|
|
||||||
return dialog_data_list
|
return dialog_data_list
|
||||||
|
|
||||||
async def _extract_triplets(
|
async def _extract_triplets(
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
{#
|
{#
|
||||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||||
输入:pruning_scene, ontology_classes, dialog_text, language
|
输入:pruning_scene, ontology_class_infos, dialog_text, language
|
||||||
|
- ontology_class_infos: List[{class_name: str, class_description: str}]
|
||||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||||
- is_related: bool,是否与所选场景相关
|
- is_related: bool,是否与所选场景相关
|
||||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||||
@@ -18,20 +19,16 @@
|
|||||||
#}
|
#}
|
||||||
|
|
||||||
{# ── 确定场景说明 ── #}
|
{# ── 确定场景说明 ── #}
|
||||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||||
{% if language == 'en' %}
|
{% if language == 'en' %}
|
||||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is relevant if it involves any of the following entity types.' %}
|
||||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
|
||||||
{% else %}
|
{% else %}
|
||||||
{% set custom_types_str = ontology_classes | join('、') %}
|
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关。' %}
|
||||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% else %}
|
{% else %}
|
||||||
{% if language == 'en' %}
|
{% if language == 'en' %}
|
||||||
{% set custom_types_str = '' %}
|
|
||||||
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||||
{% else %}
|
{% else %}
|
||||||
{% set custom_types_str = '' %}
|
|
||||||
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
@@ -42,8 +39,17 @@
|
|||||||
2. 从对话中抽取所有需要保留的重要信息片段。
|
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||||
|
|
||||||
场景说明:{{ instruction }}
|
场景说明:{{ instruction }}
|
||||||
{% if custom_types_str %}
|
|
||||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||||
|
【本场景实体类型定义】
|
||||||
|
以下实体类型定义了本场景中哪些内容是重要的。
|
||||||
|
凡是与以下任意类型相关的内容,都必须保留,并将关键词/短语提取到 keywords 字段:
|
||||||
|
|
||||||
|
{% for info in ontology_class_infos %}
|
||||||
|
- {{ info.class_name }}:{{ info.class_description }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
重要提示:只要对话中出现与上述任意实体类型相关的内容,即判定为相关(is_related=true)。
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -51,13 +57,40 @@
|
|||||||
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||||
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||||
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||||
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
|
- 金额信息:价格、费用、金额(含货币符号或单位,如"100元"、"¥200")→ amounts 字段(注意:考试分数、成绩分数不属于金额,不要放入此字段)
|
||||||
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||||
- 地址信息:地点、地址、位置 → addresses 字段
|
- 地址信息:地点、地址、位置 → addresses 字段
|
||||||
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
|
- 场景关键词:与**当前场景**强相关的专业术语、事件名称 → keywords 字段(注意:只放与当前场景直接相关的词,跨场景的内容不要放入此字段)
|
||||||
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||||
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||||
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
|
- **个人情感态度**:对人际关系、情感状态的明确表达(如"我跟室友闹矛盾了"、"我都快抑郁了")→ preserve_keywords 字段
|
||||||
|
- 注意:学业目标(如"我想考研")、成绩(如"87分")、学科偏好(如"喜欢数学")属于学业信息,不属于情绪/情感,不要放入 preserve_keywords 字段
|
||||||
|
|
||||||
|
【场景无关内容标记】
|
||||||
|
请从对话中识别出与当前场景({{ pruning_scene }})**既不相关、也无语义关联**的消息片段,将其原文(或关键片段)提取到 scene_unrelated_snippets 字段。
|
||||||
|
判断标准:
|
||||||
|
- 与场景实体类型完全无关
|
||||||
|
- 与场景话题没有因果/时间/情境上的关联(例如:不是"因为上课所以累"这种关联)
|
||||||
|
- 纯粹是另一个话题的内容(如在教育场景中讨论购物、娱乐等)
|
||||||
|
注意:有情绪/感受表达的消息即使话题不同,也可能有语义关联,请谨慎标记。
|
||||||
|
|
||||||
|
**重要:scene_unrelated_snippets 必须认真填写,不能为空数组。**
|
||||||
|
如果对话中存在与场景无关的内容,必须将其原文片段提取出来。
|
||||||
|
|
||||||
|
示例(场景=在线教育):
|
||||||
|
- "我最近心情很差,跟室友闹矛盾了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||||
|
- "她总是很晚回来吵到我睡觉" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||||
|
- "对,我都快抑郁了" → 与教育场景无关,加入 scene_unrelated_snippets
|
||||||
|
- "期末考试12月25日" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||||
|
- "我上次高数作业87分" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||||
|
- "我的目标是考研" → 与教育场景相关,不加入 scene_unrelated_snippets
|
||||||
|
|
||||||
|
示例(场景=情感陪伴):
|
||||||
|
- "我最近心情很差,跟室友闹矛盾了" → 与情感陪伴场景相关(情绪+关系),不加入 scene_unrelated_snippets
|
||||||
|
- "对,我都快抑郁了" → 与情感陪伴场景相关(情绪),不加入 scene_unrelated_snippets
|
||||||
|
- "期末考试12月25日,3号教学楼201室" → 与情感陪伴场景无关(教育信息),加入 scene_unrelated_snippets
|
||||||
|
- "我上次高数作业87分,这次能考好吗" → 与情感陪伴场景无关(学业信息),加入 scene_unrelated_snippets
|
||||||
|
- "我的目标是考研,想读应用数学" → 与情感陪伴场景无关(学业目标),加入 scene_unrelated_snippets
|
||||||
|
|
||||||
【可以删除的内容】
|
【可以删除的内容】
|
||||||
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||||
@@ -88,7 +121,8 @@
|
|||||||
"contacts": [<string>...],
|
"contacts": [<string>...],
|
||||||
"addresses": [<string>...],
|
"addresses": [<string>...],
|
||||||
"keywords": [<string>...],
|
"keywords": [<string>...],
|
||||||
"preserve_keywords": [<string>...]
|
"preserve_keywords": [<string>...],
|
||||||
|
"scene_unrelated_snippets": [<string>...]
|
||||||
}
|
}
|
||||||
{% else %}
|
{% else %}
|
||||||
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||||
@@ -96,8 +130,17 @@ You are a dialogue content analysis assistant. Please analyze the full dialogue
|
|||||||
2. Extract all important information fragments that must be preserved.
|
2. Extract all important information fragments that must be preserved.
|
||||||
|
|
||||||
Scenario Description: {{ instruction }}
|
Scenario Description: {{ instruction }}
|
||||||
{% if custom_types_str %}
|
|
||||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
{% if ontology_class_infos and ontology_class_infos | length > 0 %}
|
||||||
|
[Scene Entity Type Definitions]
|
||||||
|
The following entity types define what content is important in this scene.
|
||||||
|
Content related to ANY of these types must be preserved and extracted into the keywords field:
|
||||||
|
|
||||||
|
{% for info in ontology_class_infos %}
|
||||||
|
- {{ info.class_name }}: {{ info.class_description }}
|
||||||
|
{% endfor %}
|
||||||
|
|
||||||
|
Important: If the dialogue contains content related to any of the entity types above, mark it as relevant (is_related=true).
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -105,13 +148,22 @@ Important: If the dialogue contains content related to any of the entity types a
|
|||||||
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||||
- Time information: dates, time points, durations, expiry dates → times field
|
- Time information: dates, time points, durations, expiry dates → times field
|
||||||
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||||
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
|
- Amount information: prices, fees, amounts (with currency symbols or units, e.g., "$100", "¥200") → amounts field (Note: exam scores and grades are NOT amounts, do not put them here)
|
||||||
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||||
- Address information: locations, addresses, places → addresses field
|
- Address information: locations, addresses, places → addresses field
|
||||||
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
|
- Scene keywords: professional terms and event names strongly related to **the current scene** → keywords field (Note: only put terms directly related to the current scene; cross-scene content should not be placed here)
|
||||||
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||||
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||||
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
|
- **Personal emotional attitudes**: clear expressions about interpersonal relationships or emotional states (e.g., "I had a fight with my roommate", "I'm almost depressed") → preserve_keywords field
|
||||||
|
- Note: Academic goals (e.g., "I want to pursue a master's degree"), grades (e.g., "87 points"), and subject preferences (e.g., "I like math") are academic information, NOT emotions/feelings — do not put them in preserve_keywords
|
||||||
|
|
||||||
|
[Scene-Unrelated Content Marking]
|
||||||
|
Please identify message snippets in the dialogue that are **neither relevant to nor semantically associated with** the current scene ({{ pruning_scene }}), and extract their original text (or key fragments) into the scene_unrelated_snippets field.
|
||||||
|
Criteria:
|
||||||
|
- Completely unrelated to the scene's entity types
|
||||||
|
- No causal/temporal/contextual association with the scene topic (e.g., "feeling tired because of class" IS associated)
|
||||||
|
- Purely belongs to a different topic (e.g., discussing shopping or entertainment in an education scene)
|
||||||
|
Note: Messages with emotional/feeling expressions may still have semantic association even if the topic differs — mark carefully.
|
||||||
|
|
||||||
[CAN BE DELETED]
|
[CAN BE DELETED]
|
||||||
The following types of content are low-value and can be removed during pruning:
|
The following types of content are low-value and can be removed during pruning:
|
||||||
@@ -141,6 +193,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
|||||||
"contacts": [<string>...],
|
"contacts": [<string>...],
|
||||||
"addresses": [<string>...],
|
"addresses": [<string>...],
|
||||||
"keywords": [<string>...],
|
"keywords": [<string>...],
|
||||||
"preserve_keywords": [<string>...]
|
"preserve_keywords": [<string>...],
|
||||||
|
"scene_unrelated_snippets": [<string>...]
|
||||||
}
|
}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
@@ -94,72 +94,16 @@ def knowledge_retrieval(
|
|||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||||
# Process shared knowledge base
|
# Process shared knowledge base
|
||||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
rs, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
db=db,
|
||||||
knowledgeshare_id=db_knowledge.id)
|
db_knowledge=db_knowledge,
|
||||||
if knowledgeshare:
|
kb_config={**kb_config, "query": query}, # 或改为单独参数
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
file_names_filter=file_names_filter,
|
||||||
knowledge_id=knowledgeshare.source_kb_id)
|
chat_model=chat_model,
|
||||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
embedding_model=embedding_model,
|
||||||
continue
|
kb_ids=kb_ids,
|
||||||
else:
|
workspace_ids=workspace_ids,
|
||||||
continue
|
|
||||||
|
|
||||||
if str(db_knowledge.id) not in kb_ids:
|
|
||||||
kb_ids.append(str(db_knowledge.id))
|
|
||||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
|
||||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
|
||||||
if not chat_model:
|
|
||||||
chat_model = Base(
|
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
|
||||||
)
|
)
|
||||||
if not embedding_model:
|
|
||||||
embedding_model = OpenAIEmbed(
|
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
|
||||||
)
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
# Retrieve according to the configured retrieval type
|
|
||||||
match kb_config["retrieve_type"]:
|
|
||||||
case "participle":
|
|
||||||
rs = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case "semantic":
|
|
||||||
rs = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case _: # hybrid
|
|
||||||
rs1 = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
rs2 = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplication of merge results
|
|
||||||
seen_ids = set()
|
|
||||||
unique_rs = []
|
|
||||||
for doc in rs1 + rs2:
|
|
||||||
if doc.metadata["doc_id"] not in seen_ids:
|
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
|
||||||
unique_rs.append(doc)
|
|
||||||
rs = unique_rs
|
|
||||||
|
|
||||||
all_results.extend(rs)
|
all_results.extend(rs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -199,6 +143,115 @@ def knowledge_retrieval(
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
def _retrieve_for_knowledge(
|
||||||
|
db: Session,
|
||||||
|
db_knowledge,
|
||||||
|
kb_config: Dict[str, Any],
|
||||||
|
file_names_filter: list[str],
|
||||||
|
chat_model: Base | None,
|
||||||
|
embedding_model: OpenAIEmbed | None,
|
||||||
|
kb_ids: list[str],
|
||||||
|
workspace_ids: list[str],
|
||||||
|
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
|
||||||
|
"""
|
||||||
|
对单个知识库进行检索。
|
||||||
|
- 处理共享知识库
|
||||||
|
- 如果是 Folder,则递归检索其子知识库
|
||||||
|
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
|
||||||
|
"""
|
||||||
|
results: list[DocumentChunk] = []
|
||||||
|
|
||||||
|
# 处理共享知识库
|
||||||
|
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||||
|
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
|
||||||
|
if not knowledgeshare:
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
|
||||||
|
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# Folder 类型:递归处理子知识库
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
# 递归处理子知识库(子库如果还是 Folder,会继续往下)
|
||||||
|
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
|
db=db,
|
||||||
|
db_knowledge=child,
|
||||||
|
kb_config=kb_config,
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
chat_model=chat_model,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
workspace_ids=workspace_ids,
|
||||||
|
)
|
||||||
|
results.extend(child_results)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# 普通知识库,执行一次检索
|
||||||
|
if str(db_knowledge.id) not in kb_ids:
|
||||||
|
kb_ids.append(str(db_knowledge.id))
|
||||||
|
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||||
|
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||||
|
|
||||||
|
if not chat_model:
|
||||||
|
chat_model = Base(
|
||||||
|
key=db_knowledge.llm.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
if not embedding_model:
|
||||||
|
embedding_model = OpenAIEmbed(
|
||||||
|
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
match kb_config["retrieve_type"]:
|
||||||
|
case "participle":
|
||||||
|
rs = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case "semantic":
|
||||||
|
rs = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
rs2 = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
# 合并去重
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = unique_rs
|
||||||
|
|
||||||
|
results.extend(rs)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
|
||||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ file operations across different storage backends.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
|
||||||
class StorageBackend(ABC):
|
class StorageBackend(ABC):
|
||||||
@@ -42,6 +42,26 @@ class StorageBackend(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Upload a file from an async byte stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: Unique identifier for the file.
|
||||||
|
stream: Async iterator yielding bytes chunks.
|
||||||
|
content_type: Optional MIME type of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total bytes written.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
from app.core.storage.base import StorageBackend
|
from app.core.storage.base import StorageBackend
|
||||||
from app.core.storage_exceptions import (
|
from app.core.storage_exceptions import (
|
||||||
@@ -179,6 +180,36 @@ class LocalStorage(StorageBackend):
|
|||||||
full_path = self._get_full_path(file_key)
|
full_path = self._get_full_path(file_key)
|
||||||
return full_path.exists()
|
return full_path.exists()
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Upload a file from an async byte stream to the local file system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total bytes written.
|
||||||
|
"""
|
||||||
|
full_path = self._get_full_path(file_key)
|
||||||
|
try:
|
||||||
|
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
total = 0
|
||||||
|
async with aiofiles.open(full_path, "wb") as f:
|
||||||
|
async for chunk in stream:
|
||||||
|
await f.write(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
logger.info(f"File stream uploaded successfully: {file_key}")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stream upload file {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||||
"""
|
"""
|
||||||
Get an access URL for the file.
|
Get an access URL for the file.
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on Aliyun Object
|
|||||||
Storage Service (OSS) using the oss2 SDK.
|
Storage Service (OSS) using the oss2 SDK.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import oss2
|
import oss2
|
||||||
from oss2.exceptions import NoSuchKey, OssError
|
from oss2.exceptions import NoSuchKey, OssError
|
||||||
@@ -125,10 +126,39 @@ class OSSStorage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
buf.write(chunk)
|
||||||
|
content = buf.getvalue()
|
||||||
|
headers = {"Content-Type": content_type} if content_type else None
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers)
|
||||||
|
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||||
|
return len(content)
|
||||||
|
except OssError as e:
|
||||||
|
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to OSS: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Download a file from OSS.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on AWS S3
|
|||||||
using the boto3 SDK.
|
using the boto3 SDK.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
||||||
@@ -174,6 +175,62 @@ class S3Storage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Upload from async stream to S3 via multipart upload. Returns total bytes written."""
|
||||||
|
extra_args = {"ContentType": content_type} if content_type else {}
|
||||||
|
mpu = self.client.create_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key, **extra_args
|
||||||
|
)
|
||||||
|
upload_id = mpu["UploadId"]
|
||||||
|
parts = []
|
||||||
|
part_number = 1
|
||||||
|
buf = io.BytesIO()
|
||||||
|
total = 0
|
||||||
|
min_part_size = 5 * 1024 * 1024 # S3 最小分片 5MB
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
buf.write(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
if buf.tell() >= min_part_size:
|
||||||
|
buf.seek(0)
|
||||||
|
resp = self.client.upload_part(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id, PartNumber=part_number, Body=buf.read()
|
||||||
|
)
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
part_number += 1
|
||||||
|
buf = io.BytesIO()
|
||||||
|
# 上传剩余数据(最后一片可小于 5MB)
|
||||||
|
remaining = buf.getvalue()
|
||||||
|
if remaining:
|
||||||
|
resp = self.client.upload_part(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id, PartNumber=part_number, Body=remaining
|
||||||
|
)
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
self.client.complete_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id,
|
||||||
|
MultipartUpload={"Parts": parts}
|
||||||
|
)
|
||||||
|
logger.info(f"File stream uploaded to S3 successfully: {file_key}")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
self.client.abort_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key, UploadId=upload_id
|
||||||
|
)
|
||||||
|
logger.error(f"Failed to stream upload file to S3 {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to S3: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Download a file from S3.
|
Download a file from S3.
|
||||||
|
|||||||
@@ -195,6 +195,6 @@ class MCPToolManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": "连接失败",
|
||||||
"message": "连接失败"
|
"message": str(e)
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ class SimpleMCPClient:
|
|||||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.connection_config = connection_config or {}
|
self.connection_config = connection_config or {}
|
||||||
self.timeout = self.connection_config.get("timeout", 30)
|
self.timeout = self.connection_config.get("timeout", 10)
|
||||||
|
|
||||||
# 确定连接类型
|
# 确定连接类型
|
||||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||||
|
|||||||
@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
|
|||||||
from app.core.workflow.nodes import NodeFactory
|
from app.core.workflow.nodes import NodeFactory
|
||||||
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
|
||||||
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
from app.core.workflow.utils.expression_evaluator import evaluate_condition
|
||||||
|
from app.core.workflow.validator import WorkflowValidator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Regex to split output into:
|
||||||
|
# - variable placeholders: {{ ... }}
|
||||||
|
# - normal literal text
|
||||||
|
#
|
||||||
|
# Example:
|
||||||
|
# "Hello {{user.name}}!" ->
|
||||||
|
# ["Hello ", "{{user.name}}", "!"]
|
||||||
|
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
|
||||||
|
# Strict variable format: {{ node_id.field_name }}
|
||||||
|
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
|
||||||
|
|
||||||
|
|
||||||
class GraphBuilder:
|
class GraphBuilder:
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -37,13 +49,13 @@ class GraphBuilder:
|
|||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.subgraph = subgraph
|
self.subgraph = subgraph
|
||||||
|
|
||||||
self.start_node_id = None
|
self.start_node_id: str | None = None
|
||||||
self.end_node_ids = []
|
|
||||||
self.node_map = {node["id"]: node for node in self.nodes}
|
self.node_map = {node["id"]: node for node in self.nodes}
|
||||||
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
self.end_node_map: dict[str, StreamOutputConfig] = {}
|
||||||
self._find_upstream_branch_node = lru_cache(
|
self._find_upstream_activation_dep = lru_cache(
|
||||||
maxsize=len(self.nodes) * 2
|
maxsize=len(self.nodes) * 2
|
||||||
)(self._find_upstream_branch_node)
|
)(self._find_upstream_activation_dep)
|
||||||
if variable_pool:
|
if variable_pool:
|
||||||
self.variable_pool = variable_pool
|
self.variable_pool = variable_pool
|
||||||
else:
|
else:
|
||||||
@@ -51,10 +63,19 @@ class GraphBuilder:
|
|||||||
|
|
||||||
self.graph = StateGraph(WorkflowState)
|
self.graph = StateGraph(WorkflowState)
|
||||||
self.add_nodes()
|
self.add_nodes()
|
||||||
|
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
|
||||||
|
self.end_nodes = [
|
||||||
|
node
|
||||||
|
for node in self.nodes
|
||||||
|
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
|
||||||
|
]
|
||||||
self.add_edges()
|
self.add_edges()
|
||||||
self._analyze_end_node_output()
|
|
||||||
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
|
||||||
|
|
||||||
|
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
|
||||||
|
self._build_reverse_adj()
|
||||||
|
self._analyze_end_node_output()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def nodes(self) -> list[dict[str, Any]]:
|
def nodes(self) -> list[dict[str, Any]]:
|
||||||
return self.workflow_config.get("nodes", [])
|
return self.workflow_config.get("nodes", [])
|
||||||
@@ -87,60 +108,50 @@ class GraphBuilder:
|
|||||||
result[node[0]].append(node[1])
|
result[node[0]].append(node[1])
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
|
def _build_reverse_adj(self):
|
||||||
"""
|
for edge in self.edges:
|
||||||
Recursively find all upstream branch (control) nodes that influence the execution
|
if edge["source"] not in self.reachable_nodes:
|
||||||
of the given target node.
|
continue
|
||||||
|
self._reverse_adj[edge.get("target")].append({
|
||||||
|
"id": edge["source"], "branch": edge.get("label")
|
||||||
|
})
|
||||||
|
|
||||||
This method walks upstream along the workflow graph starting from `target_node`.
|
def _find_upstream_activation_dep(
|
||||||
It distinguishes between:
|
self,
|
||||||
- branch nodes (node types listed in `BRANCH_NODES`)
|
target_node: str
|
||||||
- non-branch nodes (ordinary processing nodes)
|
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
|
||||||
|
"""Find upstream dependencies that affect the activation of a target node.
|
||||||
|
|
||||||
Traversal rules:
|
Walks upstream along the workflow graph from the target node, collecting
|
||||||
1. For each immediate upstream node:
|
two types of dependencies:
|
||||||
- If it is a branch node, it is recorded as an affecting control node.
|
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
|
||||||
- If it is a non-branch node, the traversal continues recursively upstream.
|
routing outcome determines whether the target node executes.
|
||||||
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
|
- Output nodes: upstream END nodes that must complete their output
|
||||||
a branch node, the traversal is considered invalid:
|
before the target node can activate.
|
||||||
- `has_branch` will be False
|
|
||||||
- no branch nodes are returned.
|
|
||||||
3. Only when ALL upstream non-branch paths eventually lead to at least one
|
|
||||||
branch node will `has_branch` be True.
|
|
||||||
|
|
||||||
Special case:
|
The traversal terminates early and returns empty tuples if any upstream
|
||||||
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
|
path reaches START/CYCLE_START without encountering a branch or output
|
||||||
it is considered directly reachable from the workflow entry, and therefore
|
node, indicating the target node is directly reachable and should be
|
||||||
has no controlling branch nodes.
|
activated immediately.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_node (str):
|
target_node: The ID of the node whose upstream activation
|
||||||
The identifier of the node whose upstream control branches
|
dependencies are to be resolved.
|
||||||
are to be resolved.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[bool, tuple[tuple[str, str]]]:
|
A tuple of two elements:
|
||||||
- has_branch (bool):
|
- A deduplicated tuple of (branch_node_id, branch_label) pairs
|
||||||
True if every upstream path from `target_node` encounters
|
representing upstream branch control dependencies. Empty if
|
||||||
at least one branch node.
|
any clean path to START exists.
|
||||||
False if any path reaches a start node without a branch.
|
- A deduplicated tuple of upstream output node IDs that must
|
||||||
- branch_nodes (tuple[tuple[str, str]]):
|
complete before this node activates.
|
||||||
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
|
|
||||||
representing all branch nodes that can influence `target_node`.
|
|
||||||
Returns an empty tuple if `has_branch` is False.
|
|
||||||
"""
|
"""
|
||||||
source_nodes = [
|
source_nodes = self._reverse_adj[target_node]
|
||||||
{
|
|
||||||
"id": edge.get("source"),
|
|
||||||
"branch": edge.get("label")
|
|
||||||
}
|
|
||||||
for edge in self.edges
|
|
||||||
if edge.get("target") == target_node
|
|
||||||
]
|
|
||||||
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
return False, tuple()
|
return tuple(), tuple()
|
||||||
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
|
output_nodes = []
|
||||||
non_branch_nodes = []
|
non_branch_nodes = []
|
||||||
|
|
||||||
for node_info in source_nodes:
|
for node_info in source_nodes:
|
||||||
@@ -149,19 +160,23 @@ class GraphBuilder:
|
|||||||
(node_info["id"], node_info["branch"])
|
(node_info["id"], node_info["branch"])
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.get_node_type(node_info["id"]) == NodeType.END:
|
||||||
|
output_nodes.append(node_info["id"])
|
||||||
non_branch_nodes.append(node_info["id"])
|
non_branch_nodes.append(node_info["id"])
|
||||||
|
|
||||||
has_branch = True
|
has_branch = True
|
||||||
for node_id in non_branch_nodes:
|
for node_id in non_branch_nodes:
|
||||||
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
|
||||||
has_branch = has_branch and node_has_branch
|
if not upstream_control_nodes:
|
||||||
if not has_branch:
|
if not upstream_output_nodes and node_id not in output_nodes:
|
||||||
break
|
return tuple(), tuple()
|
||||||
branch_nodes.extend(nodes)
|
|
||||||
if not has_branch:
|
|
||||||
branch_nodes = []
|
branch_nodes = []
|
||||||
|
has_branch = False
|
||||||
|
if has_branch:
|
||||||
|
branch_nodes.extend(upstream_control_nodes)
|
||||||
|
output_nodes.extend(upstream_output_nodes)
|
||||||
|
|
||||||
return has_branch, tuple(set(branch_nodes))
|
return tuple(set(branch_nodes)), tuple(set(output_nodes))
|
||||||
|
|
||||||
def _analyze_end_node_output(self):
|
def _analyze_end_node_output(self):
|
||||||
"""
|
"""
|
||||||
@@ -182,11 +197,10 @@ class GraphBuilder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Collect all End nodes in the workflow
|
# Collect all End nodes in the workflow
|
||||||
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
|
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
|
||||||
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
|
|
||||||
|
|
||||||
# Iterate through each End node to analyze its output
|
# Iterate through each End node to analyze its output
|
||||||
for end_node in end_nodes:
|
for end_node in self.end_nodes:
|
||||||
end_node_id = end_node.get("id")
|
end_node_id = end_node.get("id")
|
||||||
config = end_node.get("config", {})
|
config = end_node.get("config", {})
|
||||||
output = config.get("output")
|
output = config.get("output")
|
||||||
@@ -195,42 +209,33 @@ class GraphBuilder:
|
|||||||
if not output:
|
if not output:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Regex to split output into:
|
|
||||||
# - variable placeholders: {{ ... }}
|
|
||||||
# - normal literal text
|
|
||||||
#
|
|
||||||
# Example:
|
|
||||||
# "Hello {{user.name}}!" ->
|
|
||||||
# ["Hello ", "{{user.name}}", "!"]
|
|
||||||
pattern = r'\{\{.*?\}\}|[^{}]+'
|
|
||||||
|
|
||||||
# Strict variable format: {{ node_id.field_name }}
|
|
||||||
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
|
|
||||||
variable_pattern = re.compile(variable_pattern_string)
|
|
||||||
|
|
||||||
# Split output into ordered segments
|
# Split output into ordered segments
|
||||||
output_template = list(re.findall(pattern, output))
|
output_template = list(_OUTPUT_PATTERN.findall(output))
|
||||||
|
|
||||||
# Determine whether each segment is literal text
|
# Determine whether each segment is literal text
|
||||||
# True -> literal (can be directly output)
|
# True -> literal (can be directly output)
|
||||||
# False -> variable placeholder (needs runtime value)
|
# False -> variable placeholder (needs runtime value)
|
||||||
output_flag = [
|
output_flag = [
|
||||||
not bool(variable_pattern.match(item))
|
not bool(_VARIABLE_PATTERN.match(item))
|
||||||
for item in output_template
|
for item in output_template
|
||||||
]
|
]
|
||||||
|
|
||||||
# Stream mode: output activation depends on upstream branch nodes
|
# Stream mode: output activation depends on upstream branch nodes
|
||||||
if self.stream:
|
if self.stream:
|
||||||
# Find upstream branch nodes that can control this End node
|
# Find upstream branch nodes that can control this End node
|
||||||
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
|
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
|
||||||
|
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
|
||||||
# Build StreamOutputConfig for this End node
|
# Build StreamOutputConfig for this End node
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
# If there is no upstream branch, output is active immediately
|
# If there is no upstream branch, output is active immediately
|
||||||
activate=not has_branch,
|
activate=activate,
|
||||||
|
|
||||||
# Branch nodes that control activation of this End node
|
# Branch nodes that control activation of this End node
|
||||||
control_nodes=self._merge_control_nodes(control_nodes),
|
control_nodes=self._merge_control_nodes(upstream_control_nodes),
|
||||||
|
upstream_output_nodes=list(upstream_output_nodes),
|
||||||
|
control_resolved=not bool(upstream_control_nodes),
|
||||||
|
output_resolved=not bool(upstream_output_nodes),
|
||||||
|
|
||||||
# Convert output segments into OutputContent objects
|
# Convert output segments into OutputContent objects
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -249,14 +254,16 @@ class GraphBuilder:
|
|||||||
cursor=0
|
cursor=0
|
||||||
)
|
)
|
||||||
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
|
||||||
f"activate: {not has_branch}, "
|
f"activate: {activate}, "
|
||||||
f"control_nodes: {control_nodes},"
|
f"control_nodes: {upstream_control_nodes},"
|
||||||
|
f"ref_outputs: {upstream_output_nodes},"
|
||||||
f"output: {output_template},"
|
f"output: {output_template},"
|
||||||
f"output_activate: {output_flag}")
|
f"output_activate: {output_flag}")
|
||||||
|
|
||||||
# Non-stream mode: all outputs are activated by default
|
# Non-stream mode: all outputs are activated by default
|
||||||
else:
|
else:
|
||||||
self.end_node_map[end_node_id] = StreamOutputConfig(
|
self.end_node_map[end_node_id] = StreamOutputConfig(
|
||||||
|
id=end_node_id,
|
||||||
activate=True,
|
activate=True,
|
||||||
control_nodes={},
|
control_nodes={},
|
||||||
outputs=list(
|
outputs=list(
|
||||||
@@ -269,7 +276,10 @@ class GraphBuilder:
|
|||||||
for output_string, activate in zip(output_template, output_flag)
|
for output_string, activate in zip(output_template, output_flag)
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
cursor=0
|
cursor=0,
|
||||||
|
upstream_output_nodes=[],
|
||||||
|
control_resolved=True,
|
||||||
|
output_resolved=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_nodes(self):
|
def add_nodes(self):
|
||||||
@@ -304,8 +314,6 @@ class GraphBuilder:
|
|||||||
# Record start and end node IDs
|
# Record start and end node IDs
|
||||||
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
if node_type in [NodeType.START, NodeType.CYCLE_START]:
|
||||||
self.start_node_id = node_id
|
self.start_node_id = node_id
|
||||||
elif node_type == NodeType.END:
|
|
||||||
self.end_node_ids.append(node_id)
|
|
||||||
|
|
||||||
# Create node instance (start and end nodes are also created)
|
# Create node instance (start and end nodes are also created)
|
||||||
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
|
||||||
@@ -448,7 +456,7 @@ class GraphBuilder:
|
|||||||
branch_activate = []
|
branch_activate = []
|
||||||
new_state = state.copy()
|
new_state = state.copy()
|
||||||
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
|
||||||
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
|
node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
|
||||||
for label, branch in unique_branch.items():
|
for label, branch in unique_branch.items():
|
||||||
if node_output and evaluate_condition(
|
if node_output and evaluate_condition(
|
||||||
branch["condition"],
|
branch["condition"],
|
||||||
@@ -494,7 +502,9 @@ class GraphBuilder:
|
|||||||
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
logger.debug(f"Added waiting edge: {sources} -> {target}")
|
||||||
|
|
||||||
# Connect End nodes to the global END node
|
# Connect End nodes to the global END node
|
||||||
for end_node_id in self.end_node_ids:
|
for end_node in self.end_nodes:
|
||||||
|
end_node_id = end_node.get("id")
|
||||||
|
if end_node_id:
|
||||||
self.graph.add_edge(end_node_id, END)
|
self.graph.add_edge(end_node_id, END)
|
||||||
logger.debug(f"Added edge: {end_node_id} -> END")
|
logger.debug(f"Added edge: {end_node_id} -> END")
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
|
|||||||
variable_pool: VariablePool,
|
variable_pool: VariablePool,
|
||||||
elapsed_time: float,
|
elapsed_time: float,
|
||||||
final_output: str,
|
final_output: str,
|
||||||
|
success: bool
|
||||||
):
|
):
|
||||||
"""Construct the final standardized output of the workflow execution.
|
"""Construct the final standardized output of the workflow execution.
|
||||||
|
|
||||||
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
|
|||||||
elapsed_time (float): Total execution time in seconds.
|
elapsed_time (float): Total execution time in seconds.
|
||||||
final_output (Any): The aggregated or final output content of the workflow
|
final_output (Any): The aggregated or final output content of the workflow
|
||||||
(e.g., combined messages from all End nodes).
|
(e.g., combined messages from all End nodes).
|
||||||
|
success (bool): Whether the execution was successful.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary containing the final workflow execution result with keys:
|
dict: A dictionary containing the final workflow execution result with keys:
|
||||||
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
|
|||||||
conversation_id = variable_pool.get_value("sys.conversation_id")
|
conversation_id = variable_pool.get_value("sys.conversation_id")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "completed",
|
"status": "completed" if success else "failed",
|
||||||
"output": final_output,
|
"output": final_output,
|
||||||
"variables": {
|
"variables": {
|
||||||
"conv": variable_pool.get_all_conversation_vars(),
|
"conv": variable_pool.get_all_conversation_vars(),
|
||||||
|
|||||||
@@ -3,9 +3,10 @@
|
|||||||
# @Email: 1533512157@qq.com
|
# @Email: 1533512157@qq.com
|
||||||
# @Time : 2026/2/9 15:11
|
# @Time : 2026/2/9 15:11
|
||||||
import re
|
import re
|
||||||
|
from queue import Queue
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
|
|||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this output segment is currently active.\n"
|
"Whether this output segment is currently active."
|
||||||
"- True: allowed to be emitted/output\n"
|
"- True: allowed to be emitted/output"
|
||||||
"- False: blocked until activated by branch control"
|
"- False: blocked until activated by branch control"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -46,16 +47,17 @@ class OutputContent(BaseModel):
|
|||||||
is_variable: bool = Field(
|
is_variable: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Whether this segment represents a variable placeholder.\n"
|
"Whether this segment represents a variable placeholder."
|
||||||
"True -> variable (e.g. {{ node.field }})\n"
|
"True -> variable (e.g. {{ node.field }})"
|
||||||
"False -> literal text"
|
"False -> literal text"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
_SCOPE: str | None = None
|
_SCOPE: str | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_scope(self) -> str:
|
def get_scope(self) -> str | None:
|
||||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
matches = SCOPE_PATTERN.findall(self.literal)
|
||||||
|
self._SCOPE = matches[0] if matches else None
|
||||||
return self._SCOPE
|
return self._SCOPE
|
||||||
|
|
||||||
def depends_on_scope(self, scope: str) -> bool:
|
def depends_on_scope(self, scope: str) -> bool:
|
||||||
@@ -68,6 +70,8 @@ class OutputContent(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if this segment references the given scope.
|
bool: True if this segment references the given scope.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_variable:
|
||||||
|
return False
|
||||||
if self._SCOPE:
|
if self._SCOPE:
|
||||||
return self._SCOPE == scope
|
return self._SCOPE == scope
|
||||||
return self.get_scope() == scope
|
return self.get_scope() == scope
|
||||||
@@ -83,12 +87,16 @@ class StreamOutputConfig(BaseModel):
|
|||||||
- which upstream branch/control nodes gate the activation
|
- which upstream branch/control nodes gate the activation
|
||||||
- how each parsed output segment is streamed and activated
|
- how each parsed output segment is streamed and activated
|
||||||
"""
|
"""
|
||||||
|
id: str = Field(
|
||||||
|
...,
|
||||||
|
description="ID of the End node this configuration belongs to."
|
||||||
|
)
|
||||||
|
|
||||||
activate: bool = Field(
|
activate: bool = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Global activation flag for the End node output.\n"
|
"Global activation flag for the End node output."
|
||||||
"When False, output segments should not be emitted even if available.\n"
|
"When False, output segments should not be emitted even if available."
|
||||||
"This flag typically becomes True once required control branch conditions "
|
"This flag typically becomes True once required control branch conditions "
|
||||||
"are satisfied."
|
"are satisfied."
|
||||||
)
|
)
|
||||||
@@ -97,17 +105,46 @@ class StreamOutputConfig(BaseModel):
|
|||||||
control_nodes: dict[str, list[str]] = Field(
|
control_nodes: dict[str, list[str]] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Control branch conditions for this End node output.\n"
|
"Control branch conditions for this End node output."
|
||||||
"Mapping of `branch_node_id -> expected_branch_label`.\n"
|
"Mapping of `branch_node_id -> expected_branch_label`."
|
||||||
"The End node output becomes globally active when a controlling branch node "
|
"The End node output becomes globally active when a controlling branch node "
|
||||||
"reports a matching completion status."
|
"reports a matching completion status."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upstream_output_nodes: list[str] = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Upstream output node dependencies (data flow)."
|
||||||
|
"Represents END/output nodes that this output depends on."
|
||||||
|
"These nodes provide data sources required before this output can be activated "
|
||||||
|
"or streamed."
|
||||||
|
"Used to ensure correct ordering and dependency resolution in streaming mode."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
control_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream branch control dependencies have been satisfied."
|
||||||
|
"True if no upstream branch nodes exist or the required branch "
|
||||||
|
"conditions have been met."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
output_resolved: bool = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Whether all upstream output node dependencies have been completed."
|
||||||
|
"True if no upstream output nodes exist or all upstream output "
|
||||||
|
"nodes have finished their output."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
outputs: list[OutputContent] = Field(
|
outputs: list[OutputContent] = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Ordered list of output segments parsed from the output template.\n"
|
"Ordered list of output segments parsed from the output template."
|
||||||
"Each segment represents either a literal text block or a variable placeholder "
|
"Each segment represents either a literal text block or a variable placeholder "
|
||||||
"that may be activated independently."
|
"that may be activated independently."
|
||||||
)
|
)
|
||||||
@@ -116,49 +153,97 @@ class StreamOutputConfig(BaseModel):
|
|||||||
cursor: int = Field(
|
cursor: int = Field(
|
||||||
...,
|
...,
|
||||||
description=(
|
description=(
|
||||||
"Streaming cursor index.\n"
|
"Streaming cursor index."
|
||||||
"Indicates the next output segment index to be emitted.\n"
|
"Indicates the next output segment index to be emitted."
|
||||||
"Segments with index < cursor are considered already streamed."
|
"Segments with index < cursor are considered already streamed."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
force: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description=(
|
||||||
|
"Force flag for output emission."
|
||||||
|
"When True, all output segments are emitted regardless of activation state."
|
||||||
|
"Triggered when this output node has finished execution."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def update_activate(self, scope: str, status=None):
|
def update_activate(self, scope: str, status=None):
|
||||||
"""
|
"""
|
||||||
Update streaming activation state based on an upstream node or special variable.
|
Update streaming activation state based on upstream events.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
scope (str):
|
scope (str):
|
||||||
Identifier of the completed upstream entity.
|
Identifier of the completed upstream entity.
|
||||||
- If a control branch node, it should match a key in `control_nodes`.
|
- If a control branch node, it should match a key in `control_nodes`.
|
||||||
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
|
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
|
||||||
|
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
|
||||||
|
it may appear in output segments.
|
||||||
|
|
||||||
status (optional):
|
status (optional):
|
||||||
Completion status of the control branch node.
|
Completion status of the control branch node.
|
||||||
Required when `scope` refers to a control node.
|
Required when `scope` refers to a control node.
|
||||||
|
|
||||||
Behavior:
|
Behavior:
|
||||||
1. Control branch nodes:
|
1. Force activation:
|
||||||
- If `scope` matches a key in `control_nodes` and `status` matches the expected
|
- If `self.force` is True, the method returns immediately.
|
||||||
branch label, the End node output becomes globally active (`activate = True`).
|
- If `scope == self.id`, the node marks itself as completed:
|
||||||
|
- `activate = True`
|
||||||
|
- `force = True`
|
||||||
|
This is typically used for final flushing when the node finishes execution.
|
||||||
|
|
||||||
2. Variable output segments:
|
2. Control dependency resolution:
|
||||||
- For each segment that is a variable (`is_variable=True`):
|
- If `scope` matches a key in `control_nodes`:
|
||||||
- If the segment literal references `scope`, mark the segment as active.
|
- `status` must be provided.
|
||||||
- This applies both to regular node variables (e.g., "node_id.field")
|
- If `status` matches expected branch labels, mark control as resolved
|
||||||
and special system variables (e.g., "sys.xxx").
|
(`control_resolved = True`).
|
||||||
|
|
||||||
|
3. Upstream output dependency resolution:
|
||||||
|
- If `scope` is in `upstream_output_nodes`,
|
||||||
|
mark data dependency as resolved (`output_resolved = True`).
|
||||||
|
|
||||||
|
4. Global activation condition:
|
||||||
|
- The node becomes active when BOTH conditions are satisfied:
|
||||||
|
- control_resolved == True
|
||||||
|
- output_resolved == True
|
||||||
|
- Once activated, `activate` remains True.
|
||||||
|
|
||||||
|
5. Variable segment activation:
|
||||||
|
- For each output segment that is a variable (`is_variable=True`):
|
||||||
|
- If the segment depends on the given `scope`,
|
||||||
|
mark the segment as active.
|
||||||
|
- This applies to both node variables (e.g., "node_id.field")
|
||||||
|
and system variables (e.g., "sys.xxx").
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- This method does not emit output or advance the streaming cursor.
|
- This method does NOT emit output or advance the streaming cursor.
|
||||||
- It only updates activation flags based on upstream events or special variables.
|
- It only updates activation and dependency resolution states.
|
||||||
|
- Activation is driven by both control flow (branch nodes) and
|
||||||
|
data flow (upstream output nodes).
|
||||||
"""
|
"""
|
||||||
|
if self.force:
|
||||||
|
return
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
if scope == self.id:
|
||||||
if scope in self.control_nodes.keys():
|
self.activate = True
|
||||||
|
self.force = True
|
||||||
|
return
|
||||||
|
|
||||||
|
# resolve control branch dependency
|
||||||
|
if scope in self.control_nodes:
|
||||||
if status is None:
|
if status is None:
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
if status in self.control_nodes[scope]:
|
if status in self.control_nodes[scope]:
|
||||||
self.activate = True
|
self.control_resolved = True
|
||||||
|
|
||||||
# Case 2: activate variable segments related to this node
|
if scope in self.upstream_output_nodes:
|
||||||
|
self.upstream_output_nodes.remove(scope)
|
||||||
|
if not self.upstream_output_nodes:
|
||||||
|
self.output_resolved = True
|
||||||
|
|
||||||
|
self.activate = self.activate or (self.control_resolved and self.output_resolved)
|
||||||
|
|
||||||
|
# activate variable segments related to this node
|
||||||
for i in range(len(self.outputs)):
|
for i in range(len(self.outputs)):
|
||||||
if (
|
if (
|
||||||
self.outputs[i].is_variable
|
self.outputs[i].is_variable
|
||||||
@@ -171,12 +256,17 @@ class StreamOutputCoordinator:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
self.end_outputs: dict[str, StreamOutputConfig] = {}
|
||||||
self.activate_end: str | None = None
|
self.activate_end: str | None = None
|
||||||
|
self.output_queue: Queue = Queue()
|
||||||
|
self.processed_outputs = []
|
||||||
|
|
||||||
def initialize_end_outputs(
|
def initialize_end_outputs(
|
||||||
self,
|
self,
|
||||||
end_node_map: dict[str, StreamOutputConfig]
|
end_node_map: dict[str, StreamOutputConfig]
|
||||||
):
|
):
|
||||||
self.end_outputs = end_node_map
|
self.end_outputs = end_node_map
|
||||||
|
self.processed_outputs = []
|
||||||
|
self.activate_end = None
|
||||||
|
self.output_queue = Queue()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def current_activate_end_info(self):
|
def current_activate_end_info(self):
|
||||||
@@ -208,8 +298,11 @@ class StreamOutputCoordinator:
|
|||||||
"""
|
"""
|
||||||
for node in self.end_outputs.keys():
|
for node in self.end_outputs.keys():
|
||||||
self.end_outputs[node].update_activate(scope, status)
|
self.end_outputs[node].update_activate(scope, status)
|
||||||
if self.end_outputs[node].activate and self.activate_end is None:
|
if self.end_outputs[node].activate and node not in self.processed_outputs:
|
||||||
self.activate_end = node
|
self.output_queue.put(node)
|
||||||
|
self.processed_outputs.append(node)
|
||||||
|
if self.activate_end is None and not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
|
|
||||||
async def emit_activate_chunk(
|
async def emit_activate_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -253,7 +346,7 @@ class StreamOutputCoordinator:
|
|||||||
final_chunk = ''
|
final_chunk = ''
|
||||||
current_segment = end_info.outputs[end_info.cursor]
|
current_segment = end_info.outputs[end_info.cursor]
|
||||||
|
|
||||||
if not current_segment.activate and not force:
|
if not current_segment.activate and not force and not end_info.force:
|
||||||
# Stop processing until this segment becomes active
|
# Stop processing until this segment becomes active
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -270,7 +363,7 @@ class StreamOutputCoordinator:
|
|||||||
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
|
||||||
|
|
||||||
if final_chunk:
|
if final_chunk:
|
||||||
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
|
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
|
||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
@@ -282,8 +375,7 @@ class StreamOutputCoordinator:
|
|||||||
end_info.cursor += 1
|
end_info.cursor += 1
|
||||||
|
|
||||||
if end_info.cursor >= len(end_info.outputs):
|
if end_info.cursor >= len(end_info.outputs):
|
||||||
self.end_outputs.pop(self.activate_end)
|
self.pop_current_activate_end()
|
||||||
self.activate_end = None
|
|
||||||
|
|
||||||
async def flush_remaining_chunk(
|
async def flush_remaining_chunk(
|
||||||
self,
|
self,
|
||||||
@@ -322,6 +414,8 @@ class StreamOutputCoordinator:
|
|||||||
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
|
if not self.output_queue.empty():
|
||||||
|
self.activate_end = self.output_queue.get_nowait()
|
||||||
# Move to next active End node if current one is done
|
# Move to next active End node if current one is done
|
||||||
if not self.activate_end and self.end_outputs:
|
if not self.activate_end and self.end_outputs:
|
||||||
self.activate_end = list(self.end_outputs.keys())[0]
|
self.activate_end = list(self.end_outputs.keys())[0]
|
||||||
|
|||||||
@@ -351,12 +351,12 @@ class VariablePool:
|
|||||||
}
|
}
|
||||||
return runtime_vars
|
return runtime_vars
|
||||||
|
|
||||||
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
|
||||||
"""获取指定节点的输出(运行时变量)
|
"""获取指定节点的输出(运行时变量)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_id: 节点 ID
|
node_id: 节点 ID
|
||||||
defalut: 默认值
|
default: 默认值
|
||||||
strict: 是否严格模式
|
strict: 是否严格模式
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -368,7 +368,7 @@ class VariablePool:
|
|||||||
if strict:
|
if strict:
|
||||||
raise KeyError(f"node {node_id} output not exist")
|
raise KeyError(f"node {node_id} output not exist")
|
||||||
else:
|
else:
|
||||||
return defalut
|
return default
|
||||||
|
|
||||||
def copy(self, pool: 'VariablePool'):
|
def copy(self, pool: 'VariablePool'):
|
||||||
self.variables = deepcopy(pool.variables)
|
self.variables = deepcopy(pool.variables)
|
||||||
|
|||||||
@@ -128,89 +128,100 @@ class WorkflowExecutor:
|
|||||||
- token_usage: aggregated token usage if available
|
- token_usage: aggregated token usage if available
|
||||||
- error: error message if any
|
- error: error message if any
|
||||||
"""
|
"""
|
||||||
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
start = datetime.datetime.now()
|
||||||
|
async for event in self.execute_stream(input_data):
|
||||||
start_time = datetime.datetime.now()
|
if event.get("event") == "workflow_end":
|
||||||
|
return event.get("data")
|
||||||
# Execute the workflow
|
return self.result_builder.build_final_output(
|
||||||
try:
|
{"error": "Workflow execution did not end as expected"},
|
||||||
# Build the workflow graph
|
self.variable_pool,
|
||||||
graph = self.build_graph()
|
(datetime.datetime.now() - start).total_seconds(),
|
||||||
|
"",
|
||||||
# Initialize the variable pool with input data
|
success=False
|
||||||
await self.variable_initializer.initialize(
|
|
||||||
variable_pool=self.variable_pool,
|
|
||||||
input_data=input_data,
|
|
||||||
execution_context=self.execution_context
|
|
||||||
)
|
)
|
||||||
initial_state = self.state_manager.create_initial_state(
|
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
|
||||||
workflow_config=self.workflow_config,
|
#
|
||||||
input_data=input_data,
|
# start_time = datetime.datetime.now()
|
||||||
execution_context=self.execution_context,
|
#
|
||||||
start_node_id=self.start_node_id
|
# # Execute the workflow
|
||||||
)
|
# try:
|
||||||
|
# # Build the workflow graph
|
||||||
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
# graph = self.build_graph()
|
||||||
|
#
|
||||||
# Aggregate output from all End nodes
|
# # Initialize the variable pool with input data
|
||||||
full_content = ''
|
# await self.variable_initializer.initialize(
|
||||||
for end_id in self.stream_coordinator.end_outputs.keys():
|
# variable_pool=self.variable_pool,
|
||||||
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
# input_data=input_data,
|
||||||
|
# execution_context=self.execution_context
|
||||||
# Append messages for user and assistant
|
# )
|
||||||
if input_data.get("files"):
|
# initial_state = self.state_manager.create_initial_state(
|
||||||
result["messages"].extend(
|
# workflow_config=self.workflow_config,
|
||||||
[
|
# input_data=input_data,
|
||||||
{
|
# execution_context=self.execution_context,
|
||||||
"role": "user",
|
# start_node_id=self.start_node_id
|
||||||
"content": input_data.get("message", '')
|
# )
|
||||||
},
|
#
|
||||||
{
|
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
|
||||||
"role": "user",
|
#
|
||||||
"content": input_data.get("files")
|
# # Aggregate output from all End nodes
|
||||||
},
|
# full_content = ''
|
||||||
{
|
# for end_id in self.stream_coordinator.end_outputs.keys():
|
||||||
"role": "assistant",
|
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
|
||||||
"content": full_content
|
#
|
||||||
}
|
# # Append messages for user and assistant
|
||||||
]
|
# if input_data.get("files"):
|
||||||
)
|
# result["messages"].extend(
|
||||||
else:
|
# [
|
||||||
result["messages"].extend(
|
# {
|
||||||
[
|
# "role": "user",
|
||||||
{
|
# "content": input_data.get("message", '')
|
||||||
"role": "user",
|
# },
|
||||||
"content": input_data.get("message", '')
|
# {
|
||||||
},
|
# "role": "user",
|
||||||
{
|
# "content": input_data.get("files")
|
||||||
"role": "assistant",
|
# },
|
||||||
"content": full_content
|
# {
|
||||||
}
|
# "role": "assistant",
|
||||||
]
|
# "content": full_content
|
||||||
)
|
# }
|
||||||
# Calculate elapsed time
|
# ]
|
||||||
end_time = datetime.datetime.now()
|
# )
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# else:
|
||||||
|
# result["messages"].extend(
|
||||||
logger.info(
|
# [
|
||||||
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
# {
|
||||||
|
# "role": "user",
|
||||||
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
# "content": input_data.get("message", '')
|
||||||
|
# },
|
||||||
except Exception as e:
|
# {
|
||||||
end_time = datetime.datetime.now()
|
# "role": "assistant",
|
||||||
elapsed_time = (end_time - start_time).total_seconds()
|
# "content": full_content
|
||||||
|
# }
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
# ]
|
||||||
exc_info=True)
|
# )
|
||||||
return {
|
# # Calculate elapsed time
|
||||||
"status": "failed",
|
# end_time = datetime.datetime.now()
|
||||||
"error": str(e),
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
"output": None,
|
#
|
||||||
"node_outputs": {},
|
# logger.info(
|
||||||
"elapsed_time": elapsed_time,
|
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
|
||||||
"token_usage": None
|
#
|
||||||
}
|
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# end_time = datetime.datetime.now()
|
||||||
|
# elapsed_time = (end_time - start_time).total_seconds()
|
||||||
|
#
|
||||||
|
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
|
# exc_info=True)
|
||||||
|
# return {
|
||||||
|
# "status": "failed",
|
||||||
|
# "error": str(e),
|
||||||
|
# "output": None,
|
||||||
|
# "node_outputs": {},
|
||||||
|
# "elapsed_time": elapsed_time,
|
||||||
|
# "token_usage": None
|
||||||
|
# }
|
||||||
|
|
||||||
async def execute_stream(
|
async def execute_stream(
|
||||||
self,
|
self,
|
||||||
@@ -248,7 +259,8 @@ class WorkflowExecutor:
|
|||||||
"timestamp": int(start_time.timestamp() * 1000)
|
"timestamp": int(start_time.timestamp() * 1000)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
result = None
|
||||||
|
full_content = ''
|
||||||
try:
|
try:
|
||||||
# Build the workflow graph in streaming mode
|
# Build the workflow graph in streaming mode
|
||||||
graph = self.build_graph(stream=True)
|
graph = self.build_graph(stream=True)
|
||||||
@@ -266,7 +278,6 @@ class WorkflowExecutor:
|
|||||||
start_node_id=self.start_node_id
|
start_node_id=self.start_node_id
|
||||||
)
|
)
|
||||||
|
|
||||||
full_content = ''
|
|
||||||
self.stream_coordinator.update_scope_activation("sys")
|
self.stream_coordinator.update_scope_activation("sys")
|
||||||
|
|
||||||
# Execute the workflow with streaming
|
# Execute the workflow with streaming
|
||||||
@@ -363,7 +374,12 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
|
"data": self.result_builder.build_final_output(
|
||||||
|
result,
|
||||||
|
self.variable_pool,
|
||||||
|
elapsed_time,
|
||||||
|
full_content,
|
||||||
|
success=True)
|
||||||
}
|
}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -372,16 +388,19 @@ class WorkflowExecutor:
|
|||||||
|
|
||||||
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
|
||||||
exc_info=True)
|
exc_info=True)
|
||||||
|
if result is None:
|
||||||
|
result = {"error": str(e)}
|
||||||
|
else:
|
||||||
|
result["error"] = str(e)
|
||||||
yield {
|
yield {
|
||||||
"event": "workflow_end",
|
"event": "workflow_end",
|
||||||
"data": {
|
"data": self.result_builder.build_final_output(
|
||||||
"execution_id": self.execution_context.execution_id,
|
result,
|
||||||
"status": "failed",
|
self.variable_pool,
|
||||||
"error": str(e),
|
elapsed_time,
|
||||||
"elapsed_time": elapsed_time,
|
full_content,
|
||||||
"timestamp": end_time.isoformat()
|
success=False
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
raise ValueError(f"Unsupported language: {self.typed_config.language}")
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient(timeout=60) as client:
|
||||||
response = await client.post(
|
response = await client.post(
|
||||||
"http://sandbox:8194/v1/sandbox/run",
|
"http://sandbox:8194/v1/sandbox/run",
|
||||||
headers={
|
headers={
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Right-hand operand of the comparison expression"
|
description="Right-hand operand of the comparison expression"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ class LoopRuntime:
|
|||||||
self.variable_pool.variables["conv"].update(
|
self.variable_pool.variables["conv"].update(
|
||||||
self.child_variable_pool.variables["conv"]
|
self.child_variable_pool.variables["conv"]
|
||||||
)
|
)
|
||||||
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
|
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
|
||||||
loopstate["node_outputs"][self.node_id] = loop_vars
|
loopstate["node_outputs"][self.node_id] = loop_vars
|
||||||
|
|
||||||
def evaluate_conditional(self) -> bool:
|
def evaluate_conditional(self) -> bool:
|
||||||
@@ -261,4 +261,4 @@ class LoopRuntime:
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
logger.info(f"loop node {self.node_id}: execution completed")
|
logger.info(f"loop node {self.node_id}: execution completed")
|
||||||
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
|
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
right: Any = Field(
|
right: Any = Field(
|
||||||
...,
|
default=None,
|
||||||
description="Value to compare with"
|
description="Value to compare with"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -31,13 +31,13 @@ class IfElseNode(BaseNode):
|
|||||||
expressions.append({
|
expressions.append({
|
||||||
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||||
"right": expression.right
|
"right": expression.right
|
||||||
if expression.input_type == ValueInputType.CONSTANT
|
if expression.input_type == ValueInputType.CONSTANT or expression.right is None
|
||||||
else self.get_variable(expression.right, variable_pool, strict=False),
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
"operator": expression.operator,
|
"operator": str(expression.operator),
|
||||||
})
|
})
|
||||||
result.append({
|
result.append({
|
||||||
"expressions": expressions,
|
"expressions": expressions,
|
||||||
"logical_operator": case.logical_operator,
|
"logical_operator": str(case.logical_operator),
|
||||||
})
|
})
|
||||||
return {
|
return {
|
||||||
"cases": result
|
"cases": result
|
||||||
|
|||||||
@@ -250,6 +250,8 @@ class ConditionBase(ABC):
|
|||||||
self.type_limit = getattr(self, "type_limit", None)
|
self.type_limit = getattr(self, "type_limit", None)
|
||||||
|
|
||||||
def resolve_right_literal_value(self):
|
def resolve_right_literal_value(self):
|
||||||
|
if self.right_selector is None:
|
||||||
|
return None
|
||||||
if self.input_type == ValueInputType.VARIABLE:
|
if self.input_type == ValueInputType.VARIABLE:
|
||||||
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
pattern = r"\{\{\s*(.*?)\s*\}\}"
|
||||||
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class ToolNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
"data": VariableType.STRING,
|
"data": VariableType.STRING,
|
||||||
"error_code": VariableType.STRING,
|
|
||||||
"execution_time": VariableType.NUMBER
|
"execution_time": VariableType.NUMBER
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,10 +47,7 @@ class ToolNode(BaseNode):
|
|||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
return {
|
raise ValueError("缺少租户ID")
|
||||||
"success": False,
|
|
||||||
"data": "缺少租户ID"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
@@ -83,13 +79,8 @@ class ToolNode(BaseNode):
|
|||||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||||
return {
|
return {
|
||||||
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
|
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
|
||||||
"error_code": "",
|
|
||||||
"execution_time": result.execution_time
|
"execution_time": result.execution_time
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
return {
|
raise ValueError(f"工具执行失败: {result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False)}")
|
||||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
|
||||||
"error_code": result.error_code,
|
|
||||||
"execution_time": result.execution_time
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -170,7 +170,7 @@ class WorkflowValidator:
|
|||||||
# 仅在发布时验证所有节点可达
|
# 仅在发布时验证所有节点可达
|
||||||
# 6. 验证所有节点可达(从 start 节点出发)
|
# 6. 验证所有节点可达(从 start 节点出发)
|
||||||
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
|
||||||
reachable = WorkflowValidator._get_reachable_nodes(
|
reachable = WorkflowValidator.get_reachable_nodes(
|
||||||
start_nodes[0]["id"],
|
start_nodes[0]["id"],
|
||||||
edges
|
edges
|
||||||
)
|
)
|
||||||
@@ -194,7 +194,7 @@ class WorkflowValidator:
|
|||||||
return len(errors) == 0, errors
|
return len(errors) == 0, errors
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
|
||||||
"""获取从 start 节点可达的所有节点
|
"""获取从 start 节点可达的所有节点
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ from enum import StrEnum
|
|||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
|
|||||||
"""
|
"""
|
||||||
if isinstance(var, str):
|
if isinstance(var, str):
|
||||||
return cls.STRING
|
return cls.STRING
|
||||||
elif isinstance(var, (int, float)):
|
|
||||||
return cls.NUMBER
|
|
||||||
elif isinstance(var, bool):
|
elif isinstance(var, bool):
|
||||||
return cls.BOOLEAN
|
return cls.BOOLEAN
|
||||||
|
elif isinstance(var, (int, float)):
|
||||||
|
return cls.NUMBER
|
||||||
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
|
||||||
return cls.FILE
|
return cls.FILE
|
||||||
elif isinstance(var, dict):
|
elif isinstance(var, dict):
|
||||||
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
|
|||||||
content_cache: dict = Field(default_factory=dict)
|
content_cache: dict = Field(default_factory=dict)
|
||||||
is_file: bool
|
is_file: bool
|
||||||
|
|
||||||
_byte_content: bytes | None = None
|
_byte_content: bytes | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_content(self):
|
def get_content(self):
|
||||||
return self._byte_content
|
return self._byte_content
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
|
|||||||
|
|
||||||
|
|
||||||
class StringVariable(BaseVariable):
|
class StringVariable(BaseVariable):
|
||||||
|
value: str
|
||||||
type = 'str'
|
type = 'str'
|
||||||
|
|
||||||
def valid_value(self, value) -> str:
|
def valid_value(self, value) -> str:
|
||||||
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class NumberVariable(BaseVariable):
|
class NumberVariable(BaseVariable):
|
||||||
|
value: int | float
|
||||||
type = 'number'
|
type = 'number'
|
||||||
|
|
||||||
def valid_value(self, value) -> int | float:
|
def valid_value(self, value) -> int | float:
|
||||||
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class BooleanVariable(BaseVariable):
|
class BooleanVariable(BaseVariable):
|
||||||
|
value: bool
|
||||||
type = 'boolean'
|
type = 'boolean'
|
||||||
|
|
||||||
def valid_value(self, value) -> bool:
|
def valid_value(self, value) -> bool:
|
||||||
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class DictVariable(BaseVariable):
|
class DictVariable(BaseVariable):
|
||||||
|
value: dict
|
||||||
type = 'object'
|
type = 'object'
|
||||||
|
|
||||||
def valid_value(self, value) -> dict:
|
def valid_value(self, value) -> dict:
|
||||||
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class FileVariable(BaseVariable):
|
class FileVariable(BaseVariable):
|
||||||
|
value: FileObject
|
||||||
type = 'file'
|
type = 'file'
|
||||||
|
|
||||||
def valid_value(self, value) -> FileObject:
|
def valid_value(self, value) -> FileObject:
|
||||||
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
|
|||||||
|
|
||||||
|
|
||||||
class ArrayVariable(BaseVariable, Generic[T]):
|
class ArrayVariable(BaseVariable, Generic[T]):
|
||||||
|
value: list[T]
|
||||||
type = 'array'
|
type = 'array'
|
||||||
|
|
||||||
def __init__(self, child_type: Type[T], value: list[Any]):
|
def __init__(self, child_type: Type[T], value: list[Any]):
|
||||||
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
|
|||||||
|
|
||||||
|
|
||||||
class NestedArrayVariable(BaseVariable):
|
class NestedArrayVariable(BaseVariable):
|
||||||
|
value: list[ArrayVariable]
|
||||||
type = 'array_nest'
|
type = 'array_nest'
|
||||||
|
|
||||||
def valid_value(self, value: list[T]) -> list[T]:
|
def valid_value(self, value: list[T]) -> list[T]:
|
||||||
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
|
|||||||
category=RuntimeWarning
|
category=RuntimeWarning
|
||||||
)
|
)
|
||||||
class AnyVariable(BaseVariable):
|
class AnyVariable(BaseVariable):
|
||||||
|
value: Any
|
||||||
type = 'any'
|
type = 'any'
|
||||||
|
|
||||||
def valid_value(self, value: Any) -> Any:
|
def valid_value(self, value: Any) -> Any:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ engine = create_engine(
|
|||||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||||
connect_args={
|
connect_args={
|
||||||
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000"
|
"options": "-c timezone=UTC -c statement_timeout=60000"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
|
|||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
db.rollback() # 只读任务无需 commit
|
db.rollback() # 只读任务无需 commit
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
def get_pool_status():
|
def get_pool_status():
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class AgentConfig(Base):
|
|||||||
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
variables = Column(JSON, default=list, nullable=True, comment="变量配置")
|
||||||
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
|
tools = Column(JSON, default=list, nullable=True, comment="工具配置")
|
||||||
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
|
skills = Column(JSON, default=dict, nullable=True, comment="技能配置")
|
||||||
|
features = Column(JSON, default=dict, nullable=True, comment="功能特性配置")
|
||||||
|
|
||||||
# 多 Agent 相关字段
|
# 多 Agent 相关字段
|
||||||
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
agent_role = Column(String(20), comment="Agent 角色: master|sub|standalone")
|
||||||
|
|||||||
@@ -12,7 +12,8 @@ class EndUser(Base):
|
|||||||
__tablename__ = "end_users"
|
__tablename__ = "end_users"
|
||||||
|
|
||||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
|
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, nullable=False, index=True)
|
||||||
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=False)
|
app_id = Column(UUID(as_uuid=True), ForeignKey("apps.id"), nullable=True)
|
||||||
|
workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id"), nullable=False)
|
||||||
# end_user_id = Column(String, nullable=False, index=True)
|
# end_user_id = Column(String, nullable=False, index=True)
|
||||||
other_id = Column(String, nullable=True) # Store original user_id
|
other_id = Column(String, nullable=True) # Store original user_id
|
||||||
other_name = Column(String, default="", nullable=False)
|
other_name = Column(String, default="", nullable=False)
|
||||||
@@ -62,3 +63,6 @@ class EndUser(Base):
|
|||||||
"App",
|
"App",
|
||||||
back_populates="end_users"
|
back_populates="end_users"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 与 WorkSpace 的反向关系
|
||||||
|
workspace = relationship("Workspace", back_populates="end_users")
|
||||||
@@ -9,7 +9,6 @@ from sqlalchemy.dialects.postgresql import JSONB
|
|||||||
from app.db import Base
|
from app.db import Base
|
||||||
from app.schemas import FileType
|
from app.schemas import FileType
|
||||||
|
|
||||||
|
|
||||||
class PerceptualType(IntEnum):
|
class PerceptualType(IntEnum):
|
||||||
VISION = 1
|
VISION = 1
|
||||||
AUDIO = 2
|
AUDIO = 2
|
||||||
|
|||||||
@@ -111,6 +111,9 @@ class ToolConfig(Base):
|
|||||||
version = Column(String(50), default="1.0.0")
|
version = Column(String(50), default="1.0.0")
|
||||||
tags = Column(JSON, default=list) # 标签列表
|
tags = Column(JSON, default=list) # 标签列表
|
||||||
|
|
||||||
|
# 逻辑删除标志
|
||||||
|
is_active = Column(Boolean, default=True, server_default='true', nullable=False, index=True, comment="是否可用,False表示已删除")
|
||||||
|
|
||||||
# 时间戳
|
# 时间戳
|
||||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
|
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now, nullable=False)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ class WorkflowConfig(Base):
|
|||||||
|
|
||||||
# 执行配置
|
# 执行配置
|
||||||
execution_config = Column(JSONB, nullable=False, default=dict)
|
execution_config = Column(JSONB, nullable=False, default=dict)
|
||||||
|
features = Column(JSONB, nullable=True, default=dict)
|
||||||
|
|
||||||
# 触发器配置(可选)
|
# 触发器配置(可选)
|
||||||
triggers = Column(JSONB, default=list)
|
triggers = Column(JSONB, default=list)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class Workspace(Base):
|
|||||||
members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership
|
members = relationship("WorkspaceMember", back_populates="workspace") # users collaborate through membership
|
||||||
api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys
|
api_keys = relationship("ApiKey", back_populates="workspace", cascade="all, delete-orphan") # API Keys
|
||||||
memory_increments = relationship("MemoryIncrement", back_populates="workspace")
|
memory_increments = relationship("MemoryIncrement", back_populates="workspace")
|
||||||
|
end_users = relationship("EndUser", back_populates="workspace", cascade="all, delete-orphan")
|
||||||
|
|
||||||
class WorkspaceMember(Base):
|
class WorkspaceMember(Base):
|
||||||
__tablename__ = "workspace_members"
|
__tablename__ = "workspace_members"
|
||||||
|
|||||||
@@ -32,6 +32,21 @@ class EndUserRepository:
|
|||||||
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
|
db_logger.error(f"查询应用 {app_id} 下宿主时出错: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_end_users_by_workspace(self, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||||
|
"""获取指定 workspace 下的所有 end_user"""
|
||||||
|
try:
|
||||||
|
end_users = (
|
||||||
|
self.db.query(EndUser)
|
||||||
|
.filter(EndUser.workspace_id == workspace_id)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||||
|
return end_users
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
db_logger.error(f"查询工作空间 {workspace_id} 下终端用户时出错: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
def get_end_user_by_id(self, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||||
"""根据 end_user_id 查询宿主"""
|
"""根据 end_user_id 查询宿主"""
|
||||||
try:
|
try:
|
||||||
@@ -53,6 +68,7 @@ class EndUserRepository:
|
|||||||
def get_or_create_end_user(
|
def get_or_create_end_user(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
other_id: str,
|
other_id: str,
|
||||||
original_user_id: Optional[str] = None
|
original_user_id: Optional[str] = None
|
||||||
) -> EndUser:
|
) -> EndUser:
|
||||||
@@ -60,6 +76,7 @@ class EndUserRepository:
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
app_id: 应用ID
|
app_id: 应用ID
|
||||||
|
workspace_id: 工作空间ID
|
||||||
other_id: 第三方ID
|
other_id: 第三方ID
|
||||||
original_user_id: 原始用户ID (存储到 other_id)
|
original_user_id: 原始用户ID (存储到 other_id)
|
||||||
"""
|
"""
|
||||||
@@ -68,26 +85,31 @@ class EndUserRepository:
|
|||||||
end_user = (
|
end_user = (
|
||||||
self.db.query(EndUser)
|
self.db.query(EndUser)
|
||||||
.filter(
|
.filter(
|
||||||
EndUser.app_id == app_id,
|
EndUser.workspace_id == workspace_id,
|
||||||
EndUser.other_id == other_id
|
EndUser.other_id == other_id
|
||||||
)
|
)
|
||||||
|
.order_by(EndUser.created_at.asc())
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
if end_user:
|
if end_user:
|
||||||
db_logger.debug(f"找到现有终端用户: 应用ID {app_id}、第三方ID {other_id}")
|
db_logger.debug(f"找到现有终端用户: 应用ID {workspace_id}、第三方ID {other_id}")
|
||||||
|
end_user.app_id=app_id
|
||||||
|
self.db.commit()
|
||||||
|
self.db.refresh(end_user)
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
# 创建新用户
|
# 创建新用户
|
||||||
end_user = EndUser(
|
end_user = EndUser(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
self.db.add(end_user)
|
self.db.add(end_user)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
self.db.refresh(end_user)
|
self.db.refresh(end_user)
|
||||||
|
|
||||||
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for app {app_id}")
|
db_logger.info(f"创建新终端用户: (other_id: {other_id}) for workspace {workspace_id}")
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -314,8 +336,7 @@ class EndUserRepository:
|
|||||||
try:
|
try:
|
||||||
end_users = (
|
end_users = (
|
||||||
self.db.query(EndUser)
|
self.db.query(EndUser)
|
||||||
.join(App, EndUser.app_id == App.id)
|
.filter(EndUser.workspace_id == workspace_id)
|
||||||
.filter(App.workspace_id == workspace_id)
|
|
||||||
.all()
|
.all()
|
||||||
)
|
)
|
||||||
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(end_users)} 个终端用户")
|
||||||
@@ -402,26 +423,61 @@ class EndUserRepository:
|
|||||||
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
|
db_logger.error(f"获取终端用户 {end_user_id} 的 memory_config_id 时出错: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def batch_update_memory_config_id(
|
# def batch_update_memory_config_id(
|
||||||
|
# self,
|
||||||
|
# app_id: uuid.UUID,
|
||||||
|
# memory_config_id: uuid.UUID
|
||||||
|
# ) -> int:
|
||||||
|
# """批量更新应用下所有终端用户的 memory_config_id
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# app_id: 应用ID
|
||||||
|
# memory_config_id: 新的记忆配置ID
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# int: 更新的行数
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# from sqlalchemy import update
|
||||||
|
#
|
||||||
|
# stmt = (
|
||||||
|
# update(EndUser)
|
||||||
|
# .where(EndUser.app_id == app_id)
|
||||||
|
# .values(memory_config_id=memory_config_id)
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# result = self.db.execute(stmt)
|
||||||
|
# self.db.commit()
|
||||||
|
#
|
||||||
|
# updated_count = result.rowcount
|
||||||
|
#
|
||||||
|
# db_logger.info(
|
||||||
|
# f"批量更新终端用户记忆配置: app_id={app_id}, "
|
||||||
|
# f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# return updated_count
|
||||||
|
#
|
||||||
|
# except Exception as e:
|
||||||
|
# self.db.rollback()
|
||||||
|
# db_logger.error(
|
||||||
|
# f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
||||||
|
# f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||||
|
# )
|
||||||
|
# raise
|
||||||
|
|
||||||
|
def batch_update_memory_config_id_by_workspace(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
memory_config_id: uuid.UUID
|
memory_config_id: uuid.UUID
|
||||||
) -> int:
|
) -> int:
|
||||||
"""批量更新应用下所有终端用户的 memory_config_id
|
"""批量更新工作空间下所有终端用户的 memory_config_id"""
|
||||||
|
|
||||||
Args:
|
|
||||||
app_id: 应用ID
|
|
||||||
memory_config_id: 新的记忆配置ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 更新的行数
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
update(EndUser)
|
update(EndUser)
|
||||||
.where(EndUser.app_id == app_id)
|
.where(EndUser.workspace_id == workspace_id)
|
||||||
.values(memory_config_id=memory_config_id)
|
.values(memory_config_id=memory_config_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -431,16 +487,15 @@ class EndUserRepository:
|
|||||||
updated_count = result.rowcount
|
updated_count = result.rowcount
|
||||||
|
|
||||||
db_logger.info(
|
db_logger.info(
|
||||||
f"批量更新终端用户记忆配置: app_id={app_id}, "
|
f"批量更新终端用户记忆配置: workspace_id={workspace_id}, "
|
||||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return updated_count
|
return updated_count
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.db.rollback()
|
self.db.rollback()
|
||||||
db_logger.error(
|
db_logger.error(
|
||||||
f"批量更新终端用户记忆配置时出错: app_id={app_id}, "
|
f"批量更新终端用户记忆配置时出错: workspace_id={workspace_id}, "
|
||||||
f"memory_config_id={memory_config_id}, error={str(e)}"
|
f"memory_config_id={memory_config_id}, error={str(e)}"
|
||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
@@ -519,10 +574,16 @@ class EndUserRepository:
|
|||||||
)
|
)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
# def get_end_users_by_app_id(db: Session, app_id: uuid.UUID) -> List[EndUser]:
|
||||||
"""根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
# """根据应用ID查询宿主(返回 EndUser ORM 列表)"""
|
||||||
|
# repo = EndUserRepository(db)
|
||||||
|
# end_users = repo.get_end_users_by_app_id(app_id)
|
||||||
|
# return end_users
|
||||||
|
|
||||||
|
def get_end_users_by_workspace(db: Session, workspace_id: uuid.UUID) -> List[EndUser]:
|
||||||
|
"""根据工作空间ID查询终端用户(返回 EndUser ORM 列表)"""
|
||||||
repo = EndUserRepository(db)
|
repo = EndUserRepository(db)
|
||||||
end_users = repo.get_end_users_by_app_id(app_id)
|
end_users = repo.get_end_users_by_workspace(workspace_id)
|
||||||
return end_users
|
return end_users
|
||||||
|
|
||||||
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
def get_end_user_by_id(db: Session, end_user_id: uuid.UUID) -> Optional[EndUser]:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ Implicit Emotions Storage Repository
|
|||||||
事务由调用方控制,仓储层只使用 flush/refresh
|
事务由调用方控制,仓储层只使用 flush/refresh
|
||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
from datetime import date, datetime, timedelta, timezone
|
from datetime import date, datetime, timezone
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional
|
||||||
|
|
||||||
|
|
||||||
@@ -177,22 +177,21 @@ class ImplicitEmotionsStorageRepository:
|
|||||||
if raw is None:
|
if raw is None:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
CST = timezone(timedelta(hours=8))
|
|
||||||
last_done = datetime.fromisoformat(raw)
|
last_done = datetime.fromisoformat(raw)
|
||||||
# last_done 写入时已是 CST naive,直接使用,无需转换
|
# last_done 写入时已是 UTC aware(+00:00),确保有 tzinfo
|
||||||
if last_done.tzinfo is not None:
|
if last_done.tzinfo is None:
|
||||||
last_done = last_done.astimezone(CST).replace(tzinfo=None)
|
last_done = last_done.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
if updated_at is None:
|
if updated_at is None:
|
||||||
yield end_user_id
|
yield end_user_id
|
||||||
continue
|
continue
|
||||||
# updated_at 数据库存的是 UTC naive,转为 CST naive 再比较
|
# updated_at 数据库存的是 UTC naive,补上 UTC tzinfo 再比较
|
||||||
if updated_at.tzinfo is None:
|
if updated_at.tzinfo is None:
|
||||||
updated_at_cst = updated_at.replace(tzinfo=timezone.utc).astimezone(CST).replace(tzinfo=None)
|
updated_at_utc = updated_at.replace(tzinfo=timezone.utc)
|
||||||
else:
|
else:
|
||||||
updated_at_cst = updated_at.astimezone(CST).replace(tzinfo=None)
|
updated_at_utc = updated_at.astimezone(timezone.utc)
|
||||||
|
|
||||||
if last_done > updated_at_cst:
|
if last_done > updated_at_utc:
|
||||||
yield end_user_id
|
yield end_user_id
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
|
logger.warning(f"解析 last_done 时间戳失败: end_user_id={end_user_id}, raw={raw}, error={e}")
|
||||||
|
|||||||
@@ -111,6 +111,20 @@ def get_knowledge_by_id(db: Session, knowledge_id: uuid.UUID) -> Knowledge | Non
|
|||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_knowledges_by_parent_id(db: Session, parent_id: uuid.UUID) -> list[Knowledge]:
|
||||||
|
db_logger.debug(f"Query knowledge bases based on parent ID: parent_id={parent_id}")
|
||||||
|
try:
|
||||||
|
knowledges = db.query(Knowledge).filter(Knowledge.parent_id == parent_id).all()
|
||||||
|
if knowledges:
|
||||||
|
db_logger.debug(f"Knowledge bases query successful: count={len(knowledges)} (parent_id: {parent_id})")
|
||||||
|
else:
|
||||||
|
db_logger.debug(f"No knowledge bases found for given parent: parent_id={parent_id}")
|
||||||
|
return knowledges
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query the knowledge bases based on parent ID: parent_id={parent_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
def get_knowledge_by_name(db: Session, name: str, workspace_id: uuid.UUID) -> Knowledge | None:
|
||||||
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
db_logger.debug(f"Query knowledge base based on name and workspace_id: name={name}, workspace_id={workspace_id}")
|
||||||
|
|
||||||
|
|||||||
@@ -13,12 +13,18 @@ from app.repositories.neo4j.cypher_queries import (
|
|||||||
ENTITY_LEAVE_ALL_COMMUNITIES,
|
ENTITY_LEAVE_ALL_COMMUNITIES,
|
||||||
GET_ENTITY_NEIGHBORS,
|
GET_ENTITY_NEIGHBORS,
|
||||||
GET_ALL_ENTITIES_FOR_USER,
|
GET_ALL_ENTITIES_FOR_USER,
|
||||||
|
GET_ENTITY_COUNT_FOR_USER,
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER,
|
||||||
|
GET_ENTITIES_PAGE,
|
||||||
GET_COMMUNITY_MEMBERS,
|
GET_COMMUNITY_MEMBERS,
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS,
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
GET_ALL_COMMUNITY_MEMBERS_BATCH,
|
||||||
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
GET_ALL_ENTITY_NEIGHBORS_BATCH,
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||||
CHECK_USER_HAS_COMMUNITIES,
|
CHECK_USER_HAS_COMMUNITIES,
|
||||||
UPDATE_COMMUNITY_MEMBER_COUNT,
|
UPDATE_COMMUNITY_MEMBER_COUNT,
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -110,10 +116,69 @@ class CommunityRepository:
|
|||||||
logger.error(f"get_all_entities failed: {e}")
|
logger.error(f"get_all_entities failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_entity_count(self, end_user_id: str) -> int:
|
||||||
|
"""仅返回用户实体总数,不加载实体数据。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
GET_ENTITY_COUNT_FOR_USER,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["entity_count"] if result else 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entity_count failed: {e}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
async def get_all_entity_ids(self, end_user_id: str) -> List[str]:
|
||||||
|
"""仅返回用户所有实体 ID 列表,不加载 embedding 等大字段。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
return [r["id"] for r in result]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_all_entity_ids failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_entities_page(
|
||||||
|
self, end_user_id: str, skip: int, limit: int
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""分页拉取实体,用于全量聚类分批处理。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_ENTITIES_PAGE,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
skip=skip,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entities_page failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_entity_neighbors_for_ids(
|
||||||
|
self, entity_ids: List[str], end_user_id: str
|
||||||
|
) -> Dict[str, List[Dict]]:
|
||||||
|
"""批量拉取指定实体列表的邻居,返回 {entity_id: [neighbors]}。"""
|
||||||
|
try:
|
||||||
|
rows = await self.connector.execute_query(
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
|
||||||
|
entity_ids=entity_ids,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
result: Dict[str, List[Dict]] = {}
|
||||||
|
for row in rows:
|
||||||
|
eid = row["entity_id"]
|
||||||
|
neighbor = {k: v for k, v in row.items() if k != "entity_id"}
|
||||||
|
result.setdefault(eid, []).append(neighbor)
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_entity_neighbors_for_ids failed: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
async def get_community_members(
|
async def get_community_members(
|
||||||
self, community_id: str, end_user_id: str
|
self, community_id: str, end_user_id: str
|
||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""查询社区成员列表。"""
|
"""查询社区成员列表(含 example 字段)。"""
|
||||||
try:
|
try:
|
||||||
return await self.connector.execute_query(
|
return await self.connector.execute_query(
|
||||||
GET_COMMUNITY_MEMBERS,
|
GET_COMMUNITY_MEMBERS,
|
||||||
@@ -124,6 +189,20 @@ class CommunityRepository:
|
|||||||
logger.error(f"get_community_members failed: {e}")
|
logger.error(f"get_community_members failed: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
async def get_community_relationships(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""查询社区内实体间的关系三元组(subject, predicate, object)。"""
|
||||||
|
try:
|
||||||
|
return await self.connector.execute_query(
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS,
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"get_community_relationships failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
async def get_all_community_members_batch(
|
async def get_all_community_members_batch(
|
||||||
self, community_ids: List[str], end_user_id: str
|
self, community_ids: List[str], end_user_id: str
|
||||||
) -> Dict[str, List[Dict]]:
|
) -> Dict[str, List[Dict]]:
|
||||||
@@ -177,8 +256,9 @@ class CommunityRepository:
|
|||||||
name: str,
|
name: str,
|
||||||
summary: str,
|
summary: str,
|
||||||
core_entities: List[str],
|
core_entities: List[str],
|
||||||
|
summary_embedding: Optional[List[float]] = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""更新社区的名称、摘要和核心实体列表。"""
|
"""更新社区的名称、摘要、核心实体列表和摘要向量。"""
|
||||||
try:
|
try:
|
||||||
result = await self.connector.execute_query(
|
result = await self.connector.execute_query(
|
||||||
UPDATE_COMMUNITY_METADATA,
|
UPDATE_COMMUNITY_METADATA,
|
||||||
@@ -187,8 +267,31 @@ class CommunityRepository:
|
|||||||
name=name,
|
name=name,
|
||||||
summary=summary,
|
summary=summary,
|
||||||
core_entities=core_entities,
|
core_entities=core_entities,
|
||||||
|
summary_embedding=summary_embedding,
|
||||||
)
|
)
|
||||||
return bool(result)
|
return bool(result)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"update_community_metadata failed: {e}")
|
logger.error(f"update_community_metadata failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def batch_update_community_metadata(
|
||||||
|
self,
|
||||||
|
communities: List[Dict],
|
||||||
|
) -> bool:
|
||||||
|
"""批量更新多个社区的元数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
communities: 每项包含 community_id, end_user_id, name, summary,
|
||||||
|
core_entities, summary_embedding
|
||||||
|
"""
|
||||||
|
if not communities:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
await self.connector.execute_query(
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA,
|
||||||
|
communities=communities,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"batch_update_community_metadata failed: {e}")
|
||||||
|
return False
|
||||||
|
|||||||
@@ -43,6 +43,13 @@ async def create_fulltext_indexes():
|
|||||||
""")
|
""")
|
||||||
print("✓ Created: summariesFulltext")
|
print("✓ Created: summariesFulltext")
|
||||||
|
|
||||||
|
# 创建 Community 索引
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE FULLTEXT INDEX communitiesFulltext IF NOT EXISTS FOR (c:Community) ON EACH [c.name, c.summary]
|
||||||
|
OPTIONS { indexConfig: { `fulltext.analyzer`: 'cjk' } }
|
||||||
|
""")
|
||||||
|
print("✓ Created: communitiesFulltext")
|
||||||
|
|
||||||
print("\nFull-text indexes created successfully with BM25 support.")
|
print("\nFull-text indexes created successfully with BM25 support.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"✗ Error creating full-text indexes: {e}")
|
print(f"✗ Error creating full-text indexes: {e}")
|
||||||
@@ -113,6 +120,18 @@ async def create_vector_indexes():
|
|||||||
""")
|
""")
|
||||||
print("✓ Created: summary_embedding_index")
|
print("✓ Created: summary_embedding_index")
|
||||||
|
|
||||||
|
# Community summary embedding index
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||||
|
FOR (c:Community)
|
||||||
|
ON c.summary_embedding
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
print("✓ Created: community_summary_embedding_index")
|
||||||
|
|
||||||
# Dialogue embedding index (optional)
|
# Dialogue embedding index (optional)
|
||||||
await connector.execute_query("""
|
await connector.execute_query("""
|
||||||
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
CREATE VECTOR INDEX dialogue_embedding_index IF NOT EXISTS
|
||||||
@@ -125,6 +144,18 @@ async def create_vector_indexes():
|
|||||||
""")
|
""")
|
||||||
print("✓ Created: dialogue_embedding_index")
|
print("✓ Created: dialogue_embedding_index")
|
||||||
|
|
||||||
|
# Community summary embedding index
|
||||||
|
await connector.execute_query("""
|
||||||
|
CREATE VECTOR INDEX community_summary_embedding_index IF NOT EXISTS
|
||||||
|
FOR (c:Community)
|
||||||
|
ON c.summary_embedding
|
||||||
|
OPTIONS {indexConfig: {
|
||||||
|
`vector.dimensions`: 1024,
|
||||||
|
`vector.similarity_function`: 'cosine'
|
||||||
|
}}
|
||||||
|
""")
|
||||||
|
print("✓ Created: community_summary_embedding_index")
|
||||||
|
|
||||||
print("\nVector indexes created successfully!")
|
print("\nVector indexes created successfully!")
|
||||||
print("\nExpected performance improvement:")
|
print("\nExpected performance improvement:")
|
||||||
print(" Before: ~1.4s for embedding search")
|
print(" Before: ~1.4s for embedding search")
|
||||||
|
|||||||
@@ -1122,21 +1122,43 @@ RETURN e.id AS id,
|
|||||||
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
GET_ENTITY_COUNT_FOR_USER = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
RETURN count(e) AS entity_count
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ALL_ENTITY_IDS_FOR_USER = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
RETURN e.id AS id
|
||||||
|
"""
|
||||||
|
|
||||||
GET_COMMUNITY_MEMBERS = """
|
GET_COMMUNITY_MEMBERS = """
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
e.name_embedding AS name_embedding
|
e.name_embedding AS name_embedding,
|
||||||
|
e.aliases AS aliases, e.description AS description,
|
||||||
|
e.example AS example
|
||||||
ORDER BY coalesce(e.activation_value, 0) DESC
|
ORDER BY coalesce(e.activation_value, 0) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
GET_COMMUNITY_RELATIONSHIPS = """
|
||||||
|
MATCH (e1:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
|
||||||
|
MATCH (e2:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
MATCH (e1)-[r:EXTRACTED_RELATIONSHIP]->(e2)
|
||||||
|
RETURN e1.name AS subject, r.predicate AS predicate, e2.name AS object
|
||||||
|
ORDER BY e1.name, r.predicate, e2.name
|
||||||
|
LIMIT 20
|
||||||
|
"""
|
||||||
|
|
||||||
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
WHERE c.community_id IN $community_ids
|
|
||||||
RETURN c.community_id AS community_id,
|
RETURN c.community_id AS community_id,
|
||||||
e.id AS id,
|
e.id AS id, e.name AS name, e.entity_type AS entity_type,
|
||||||
|
e.importance_score AS importance_score, e.activation_value AS activation_value,
|
||||||
e.name_embedding AS name_embedding,
|
e.name_embedding AS name_embedding,
|
||||||
e.activation_value AS activation_value
|
e.aliases AS aliases, e.description AS description
|
||||||
|
ORDER BY c.community_id, coalesce(e.activation_value, 0) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
CHECK_USER_HAS_COMMUNITIES = """
|
CHECK_USER_HAS_COMMUNITIES = """
|
||||||
@@ -1156,10 +1178,55 @@ MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
|
|||||||
SET c.name = $name,
|
SET c.name = $name,
|
||||||
c.summary = $summary,
|
c.summary = $summary,
|
||||||
c.core_entities = $core_entities,
|
c.core_entities = $core_entities,
|
||||||
|
c.summary_embedding = $summary_embedding,
|
||||||
c.updated_at = datetime()
|
c.updated_at = datetime()
|
||||||
RETURN c.community_id AS community_id
|
RETURN c.community_id AS community_id
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
BATCH_UPDATE_COMMUNITY_METADATA = """
|
||||||
|
UNWIND $communities AS row
|
||||||
|
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
|
||||||
|
SET c.name = row.name,
|
||||||
|
c.summary = row.summary,
|
||||||
|
c.core_entities = row.core_entities,
|
||||||
|
c.summary_embedding = row.summary_embedding,
|
||||||
|
c.updated_at = datetime()
|
||||||
|
RETURN c.community_id AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ENTITIES_PAGE = """
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
OPTIONAL MATCH (e)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN e.id AS id,
|
||||||
|
e.name AS name,
|
||||||
|
e.name_embedding AS name_embedding,
|
||||||
|
e.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
ORDER BY e.id
|
||||||
|
SKIP $skip LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS = """
|
||||||
|
// 批量拉取指定实体列表的邻居(用于分批全量聚类)
|
||||||
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE e.id IN $entity_ids
|
||||||
|
OPTIONAL MATCH (e)-[:EXTRACTED_RELATIONSHIP]-(nb1:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
OPTIONAL MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
OPTIONAL MATCH (s)-[:REFERENCES_ENTITY]->(nb2:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
|
WHERE nb2.id <> e.id
|
||||||
|
WITH e, collect(DISTINCT nb1) + collect(DISTINCT nb2) AS all_neighbors
|
||||||
|
UNWIND all_neighbors AS nb
|
||||||
|
WITH e, nb WHERE nb IS NOT NULL
|
||||||
|
OPTIONAL MATCH (nb)-[:BELONGS_TO_COMMUNITY]->(c:Community)
|
||||||
|
RETURN DISTINCT
|
||||||
|
e.id AS entity_id,
|
||||||
|
nb.id AS id,
|
||||||
|
nb.name AS name,
|
||||||
|
nb.name_embedding AS name_embedding,
|
||||||
|
nb.activation_value AS activation_value,
|
||||||
|
CASE WHEN c IS NOT NULL THEN c.community_id ELSE null END AS community_id
|
||||||
|
"""
|
||||||
|
|
||||||
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
GET_ALL_ENTITY_NEIGHBORS_BATCH = """
|
||||||
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
// 批量拉取某用户下所有实体的邻居(用于全量聚类预加载)
|
||||||
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})
|
||||||
@@ -1202,3 +1269,60 @@ RETURN
|
|||||||
properties(r) AS r_props,
|
properties(r) AS r_props,
|
||||||
startNode(r) = e AS r_from_e
|
startNode(r) = e AS r_from_e
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# Community keyword search: matches name or summary via fulltext index
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD = """
|
||||||
|
CALL db.index.fulltext.queryNodes("communitiesFulltext", $q) YIELD node AS c, score
|
||||||
|
WHERE ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
RETURN c.community_id AS id,
|
||||||
|
c.name AS name,
|
||||||
|
c.summary AS content,
|
||||||
|
c.core_entities AS core_entities,
|
||||||
|
c.member_count AS member_count,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.updated_at AS updated_at,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Community 向量检索 ──────────────────────────────────────────────────
|
||||||
|
# Community embedding-based search: cosine similarity on Community.summary_embedding
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH = """
|
||||||
|
CALL db.index.vector.queryNodes('community_summary_embedding_index', $limit * 100, $embedding)
|
||||||
|
YIELD node AS c, score
|
||||||
|
WHERE c.summary_embedding IS NOT NULL
|
||||||
|
AND ($end_user_id IS NULL OR c.end_user_id = $end_user_id)
|
||||||
|
RETURN c.community_id AS id,
|
||||||
|
c.name AS name,
|
||||||
|
c.summary AS content,
|
||||||
|
c.core_entities AS core_entities,
|
||||||
|
c.member_count AS member_count,
|
||||||
|
c.end_user_id AS end_user_id,
|
||||||
|
c.updated_at AS updated_at,
|
||||||
|
score
|
||||||
|
ORDER BY score DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Community 展开检索 ──────────────────────────────────────────────────
|
||||||
|
# 命中社区后,拉取该社区所有成员实体关联的 Statement 节点(主题→细节两级检索)
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS = """
|
||||||
|
MATCH (c:Community {community_id: $community_id})
|
||||||
|
MATCH (e:ExtractedEntity)-[:BELONGS_TO_COMMUNITY]->(c)
|
||||||
|
MATCH (s:Statement)-[:REFERENCES_ENTITY]->(e)
|
||||||
|
WHERE s.end_user_id = $end_user_id
|
||||||
|
RETURN s.statement AS statement,
|
||||||
|
s.id AS id,
|
||||||
|
s.end_user_id AS end_user_id,
|
||||||
|
s.created_at AS created_at,
|
||||||
|
s.valid_at AS valid_at,
|
||||||
|
s.invalid_at AS invalid_at,
|
||||||
|
COALESCE(s.activation_value, s.importance_score, 0.5) AS activation_value,
|
||||||
|
COALESCE(s.importance_score, 0.5) AS importance_score,
|
||||||
|
e.name AS source_entity,
|
||||||
|
c.name AS community_name
|
||||||
|
ORDER BY COALESCE(s.activation_value, 0) DESC
|
||||||
|
LIMIT $limit
|
||||||
|
"""
|
||||||
|
|||||||
@@ -158,11 +158,12 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
statement_chunk_edges: List[StatementChunkEdge],
|
statement_chunk_edges: List[StatementChunkEdge],
|
||||||
statement_entity_edges: List[StatementEntityEdge],
|
statement_entity_edges: List[StatementEntityEdge],
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
|
||||||
|
|
||||||
|
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
|
||||||
|
schedule_clustering_after_write() 显式触发。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dialogue_nodes: List of DialogueNode objects to save
|
dialogue_nodes: List of DialogueNode objects to save
|
||||||
chunk_nodes: List of ChunkNode objects to save
|
chunk_nodes: List of ChunkNode objects to save
|
||||||
@@ -293,9 +294,6 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
logger.info("Transaction completed. Summary: %s", summary)
|
logger.info("Transaction completed. Summary: %s", summary)
|
||||||
logger.debug("Full transaction results: %r", results)
|
logger.debug("Full transaction results: %r", results)
|
||||||
|
|
||||||
# 写入成功后,异步触发聚类(不阻塞写入响应)
|
|
||||||
schedule_clustering_after_write(entity_nodes, config_id=config_id, llm_model_id=llm_model_id)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -307,8 +305,8 @@ async def save_dialog_and_statements_to_neo4j(
|
|||||||
|
|
||||||
def schedule_clustering_after_write(
|
def schedule_clustering_after_write(
|
||||||
entity_nodes: List,
|
entity_nodes: List,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
写入 Neo4j 成功后,调度后台聚类任务。
|
写入 Neo4j 成功后,调度后台聚类任务。
|
||||||
@@ -327,14 +325,14 @@ def schedule_clustering_after_write(
|
|||||||
end_user_id = entity_nodes[0].end_user_id
|
end_user_id = entity_nodes[0].end_user_id
|
||||||
new_entity_ids = [e.id for e in entity_nodes]
|
new_entity_ids = [e.id for e in entity_nodes]
|
||||||
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
|
||||||
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id))
|
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
|
||||||
|
|
||||||
|
|
||||||
async def _trigger_clustering(
|
async def _trigger_clustering(
|
||||||
new_entity_ids: List[str],
|
new_entity_ids: List[str],
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
config_id: Optional[str] = None,
|
|
||||||
llm_model_id: Optional[str] = None,
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
聚类触发函数,自动判断全量初始化还是增量更新。
|
聚类触发函数,自动判断全量初始化还是增量更新。
|
||||||
@@ -344,7 +342,7 @@ async def _trigger_clustering(
|
|||||||
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
|
||||||
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
logger.info(f"[Clustering] 开始聚类,end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id)
|
engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
|
||||||
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
|
||||||
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
logger.info(f"[Clustering] 聚类完成,end_user_id={end_user_id}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from app.repositories.neo4j.cypher_queries import (
|
from app.repositories.neo4j.cypher_queries import (
|
||||||
CHUNK_EMBEDDING_SEARCH,
|
CHUNK_EMBEDDING_SEARCH,
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH,
|
||||||
ENTITY_EMBEDDING_SEARCH,
|
ENTITY_EMBEDDING_SEARCH,
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
MEMORY_SUMMARY_EMBEDDING_SEARCH,
|
||||||
SEARCH_CHUNK_BY_CHUNK_ID,
|
SEARCH_CHUNK_BY_CHUNK_ID,
|
||||||
SEARCH_CHUNKS_BY_CONTENT,
|
SEARCH_CHUNKS_BY_CONTENT,
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
SEARCH_DIALOGUE_BY_DIALOG_ID,
|
||||||
SEARCH_ENTITIES_BY_NAME,
|
SEARCH_ENTITIES_BY_NAME,
|
||||||
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
SEARCH_MEMORY_SUMMARIES_BY_KEYWORD,
|
||||||
@@ -286,6 +289,15 @@ async def search_graph(
|
|||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
|
|
||||||
|
if "communities" in include:
|
||||||
|
tasks.append(connector.execute_query(
|
||||||
|
SEARCH_COMMUNITIES_BY_KEYWORD,
|
||||||
|
q=q,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
))
|
||||||
|
task_keys.append("communities")
|
||||||
|
|
||||||
# Execute all queries in parallel
|
# Execute all queries in parallel
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
@@ -293,6 +305,7 @@ async def search_graph(
|
|||||||
results = {}
|
results = {}
|
||||||
for key, result in zip(task_keys, task_results):
|
for key, result in zip(task_keys, task_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"search_graph: {key} 关键词查询异常: {result}")
|
||||||
results[key] = []
|
results[key] = []
|
||||||
else:
|
else:
|
||||||
results[key] = result
|
results[key] = result
|
||||||
@@ -349,7 +362,11 @@ async def search_graph_by_embedding(
|
|||||||
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
print(f"[PERF] Embedding generation took: {embed_time:.4f}s")
|
||||||
|
|
||||||
if not embeddings or not embeddings[0]:
|
if not embeddings or not embeddings[0]:
|
||||||
return {"statements": [], "chunks": [], "entities": [], "summaries": []}
|
logger.warning(
|
||||||
|
f"search_graph_by_embedding: embedding 生成失败或为空,"
|
||||||
|
f"query='{query_text[:50]}', end_user_id={end_user_id},向量检索跳过"
|
||||||
|
)
|
||||||
|
return {"statements": [], "chunks": [], "entities": [], "summaries": [], "communities": []}
|
||||||
embedding = embeddings[0]
|
embedding = embeddings[0]
|
||||||
|
|
||||||
# Prepare tasks for parallel execution
|
# Prepare tasks for parallel execution
|
||||||
@@ -396,6 +413,16 @@ async def search_graph_by_embedding(
|
|||||||
))
|
))
|
||||||
task_keys.append("summaries")
|
task_keys.append("summaries")
|
||||||
|
|
||||||
|
# Communities (向量语义匹配)
|
||||||
|
if "communities" in include:
|
||||||
|
tasks.append(connector.execute_query(
|
||||||
|
COMMUNITY_EMBEDDING_SEARCH,
|
||||||
|
embedding=embedding,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
))
|
||||||
|
task_keys.append("communities")
|
||||||
|
|
||||||
# Execute all queries in parallel
|
# Execute all queries in parallel
|
||||||
query_start = time.time()
|
query_start = time.time()
|
||||||
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
@@ -408,10 +435,12 @@ async def search_graph_by_embedding(
|
|||||||
"chunks": [],
|
"chunks": [],
|
||||||
"entities": [],
|
"entities": [],
|
||||||
"summaries": [],
|
"summaries": [],
|
||||||
|
"communities": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, result in zip(task_keys, task_results):
|
for key, result in zip(task_keys, task_results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"search_graph_by_embedding: {key} 向量查询异常: {result}")
|
||||||
results[key] = []
|
results[key] = []
|
||||||
else:
|
else:
|
||||||
results[key] = result
|
results[key] = result
|
||||||
@@ -661,6 +690,62 @@ async def search_graph_by_chunk_id(
|
|||||||
return {"chunks": chunks}
|
return {"chunks": chunks}
|
||||||
|
|
||||||
|
|
||||||
|
async def search_graph_community_expand(
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
community_ids: List[str],
|
||||||
|
end_user_id: str,
|
||||||
|
limit: int = 10,
|
||||||
|
) -> Dict[str, List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
三期:社区展开检索 —— 主题 → 细节两级检索。
|
||||||
|
|
||||||
|
命中 Community 节点后,沿 BELONGS_TO_COMMUNITY 关系拉取成员实体,
|
||||||
|
再沿 REFERENCES_ENTITY 关系拉取关联的 Statement 节点,
|
||||||
|
按 activation_value 降序返回,实现"主题摘要 → 具体记忆"的深度召回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
connector: Neo4j 连接器
|
||||||
|
community_ids: 已命中的社区 ID 列表
|
||||||
|
end_user_id: 用户 ID,用于数据隔离
|
||||||
|
limit: 每个社区最多返回的 Statement 数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{"expanded_statements": [Statement 列表,含 community_name / source_entity 字段]}
|
||||||
|
"""
|
||||||
|
if not community_ids or not end_user_id:
|
||||||
|
return {"expanded_statements": []}
|
||||||
|
|
||||||
|
tasks = [
|
||||||
|
connector.execute_query(
|
||||||
|
EXPAND_COMMUNITY_STATEMENTS,
|
||||||
|
community_id=cid,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
for cid in community_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
task_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
expanded: List[Dict[str, Any]] = []
|
||||||
|
for cid, result in zip(community_ids, task_results):
|
||||||
|
if isinstance(result, Exception):
|
||||||
|
logger.warning(f"社区展开检索失败 community_id={cid}: {result}")
|
||||||
|
else:
|
||||||
|
expanded.extend(result)
|
||||||
|
|
||||||
|
# 按 activation_value 全局排序后去重
|
||||||
|
from app.core.memory.src.search import _deduplicate_results
|
||||||
|
expanded.sort(
|
||||||
|
key=lambda x: float(x.get("activation_value") or 0),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
expanded = _deduplicate_results(expanded)
|
||||||
|
|
||||||
|
logger.info(f"社区展开检索完成: community_ids={community_ids}, 展开 statements={len(expanded)}")
|
||||||
|
return {"expanded_statements": expanded}
|
||||||
|
|
||||||
|
|
||||||
async def search_graph_by_created_at(
|
async def search_graph_by_created_at(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
|
|||||||
@@ -74,9 +74,10 @@ class ToolRepository:
|
|||||||
status: Optional[ToolStatus] = None,
|
status: Optional[ToolStatus] = None,
|
||||||
is_enabled: Optional[bool] = None
|
is_enabled: Optional[bool] = None
|
||||||
) -> List[ToolConfig]:
|
) -> List[ToolConfig]:
|
||||||
"""根据租户查找工具"""
|
"""根据租户查找工具(只返回未删除的)"""
|
||||||
query = db.query(ToolConfig).filter(
|
query = db.query(ToolConfig).filter(
|
||||||
ToolConfig.tenant_id == tenant_id
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
ToolConfig.is_active.is_(True)
|
||||||
)
|
)
|
||||||
|
|
||||||
if name:
|
if name:
|
||||||
@@ -91,8 +92,17 @@ class ToolRepository:
|
|||||||
return query.all()
|
return query.all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def find_by_id_and_tenant(db:Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
def find_by_id_and_tenant(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||||
"""根据ID和租户查找工具"""
|
"""根据ID和租户查找工具(只返回未删除的)"""
|
||||||
|
return db.query(ToolConfig).filter(
|
||||||
|
ToolConfig.id == tool_id,
|
||||||
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
ToolConfig.is_active.is_(True)
|
||||||
|
).first()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def find_by_id_and_tenant_all(db: Session, tool_id: uuid.UUID, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||||
|
"""根据ID和租户查找工具(返回所有工具包括删除的)"""
|
||||||
return db.query(ToolConfig).filter(
|
return db.query(ToolConfig).filter(
|
||||||
ToolConfig.id == tool_id,
|
ToolConfig.id == tool_id,
|
||||||
ToolConfig.tenant_id == tenant_id
|
ToolConfig.tenant_id == tenant_id
|
||||||
@@ -100,29 +110,26 @@ class ToolRepository:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
def count_by_tenant(db: Session, tenant_id: uuid.UUID) -> int:
|
||||||
"""统计租户工具数量"""
|
"""统计租户工具数量(只统计未删除的)"""
|
||||||
return db.query(ToolConfig).filter(
|
return db.query(ToolConfig).filter(
|
||||||
ToolConfig.tenant_id == tenant_id
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
ToolConfig.is_active.is_(True)
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
def get_status_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||||
"""获取状态统计"""
|
"""获取状态统计"""
|
||||||
return db.query(
|
return db.query(ToolConfig.status, func.count(ToolConfig.id).label('count')).filter(
|
||||||
ToolConfig.status,
|
ToolConfig.tenant_id == tenant_id,
|
||||||
func.count(ToolConfig.id).label('count')
|
ToolConfig.is_active.is_(True)
|
||||||
).filter(
|
|
||||||
ToolConfig.tenant_id == tenant_id
|
|
||||||
).group_by(ToolConfig.status).all()
|
).group_by(ToolConfig.status).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
def get_type_statistics(db: Session, tenant_id: uuid.UUID) -> List[tuple]:
|
||||||
"""获取类型统计"""
|
"""获取类型统计"""
|
||||||
return db.query(
|
return db.query(ToolConfig.tool_type, func.count(ToolConfig.id).label('count')).filter(
|
||||||
ToolConfig.tool_type,
|
ToolConfig.tenant_id == tenant_id,
|
||||||
func.count(ToolConfig.id).label('count')
|
ToolConfig.is_active.is_(True)
|
||||||
).filter(
|
|
||||||
ToolConfig.tenant_id == tenant_id
|
|
||||||
).group_by(ToolConfig.tool_type).all()
|
).group_by(ToolConfig.tool_type).all()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -130,6 +137,7 @@ class ToolRepository:
|
|||||||
"""统计租户启用的工具数量"""
|
"""统计租户启用的工具数量"""
|
||||||
return db.query(ToolConfig).filter(
|
return db.query(ToolConfig).filter(
|
||||||
ToolConfig.tenant_id == tenant_id,
|
ToolConfig.tenant_id == tenant_id,
|
||||||
|
ToolConfig.is_active.is_(True),
|
||||||
ToolConfig.is_enabled == True
|
ToolConfig.is_enabled == True
|
||||||
).count()
|
).count()
|
||||||
|
|
||||||
@@ -138,7 +146,8 @@ class ToolRepository:
|
|||||||
"""检查租户是否已有内置工具"""
|
"""检查租户是否已有内置工具"""
|
||||||
return db.query(ToolConfig).filter(
|
return db.query(ToolConfig).filter(
|
||||||
ToolConfig.tenant_id == tenant_id,
|
ToolConfig.tenant_id == tenant_id,
|
||||||
ToolConfig.tool_type == ToolType.BUILTIN.value
|
ToolConfig.tool_type == ToolType.BUILTIN.value,
|
||||||
|
ToolConfig.is_active.is_(True)
|
||||||
).count() > 0
|
).count() > 0
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -125,6 +125,85 @@ class SkillConfig(BaseModel):
|
|||||||
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
|
all_skills: Optional[bool] = Field(default=False, description="是否允许访问所有技能")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- App Features ----------
|
||||||
|
|
||||||
|
class FileUploadConfig(BaseModel):
|
||||||
|
"""文件上传配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
# 允许的传输方式:local_file / remote_url,默认两种都允许
|
||||||
|
allowed_transfer_methods: List[str] = Field(
|
||||||
|
default=["local_file", "remote_url"],
|
||||||
|
description="允许的传输方式"
|
||||||
|
)
|
||||||
|
# 图片文件:PNG/JPG/JPEG/GIF/WEBP,最大 20MB
|
||||||
|
image_enabled: bool = Field(default=False)
|
||||||
|
image_max_size_mb: int = Field(default=20)
|
||||||
|
image_allowed_extensions: List[str] = Field(
|
||||||
|
default=["png", "jpg", "jpeg"]
|
||||||
|
)
|
||||||
|
# 语音文件:MP3/WAV/M4A/OGG/FLAC,最大 50MB
|
||||||
|
audio_enabled: bool = Field(default=False)
|
||||||
|
audio_max_size_mb: int = Field(default=50)
|
||||||
|
audio_allowed_extensions: List[str] = Field(
|
||||||
|
default=["mp3", "wav", "m4a"]
|
||||||
|
)
|
||||||
|
# 通用文件:PDF/DOCX/XLSX/TXT/CSV/JSON,最大 100MB
|
||||||
|
document_enabled: bool = Field(default=False)
|
||||||
|
document_max_size_mb: int = Field(default=100)
|
||||||
|
document_allowed_extensions: List[str] = Field(
|
||||||
|
default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"]
|
||||||
|
)
|
||||||
|
# 视频文件:MP4/MOV/AVI/WebM,最大 500MB
|
||||||
|
video_enabled: bool = Field(default=False)
|
||||||
|
video_max_size_mb: int = Field(default=500)
|
||||||
|
video_allowed_extensions: List[str] = Field(
|
||||||
|
default=["mp4", "mov"]
|
||||||
|
)
|
||||||
|
# 最大文件数量
|
||||||
|
max_file_count: int = Field(default=5, ge=1, le=20)
|
||||||
|
|
||||||
|
|
||||||
|
class OpeningStatementConfig(BaseModel):
|
||||||
|
"""对话开场白配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
statement: Optional[str] = Field(default=None, description="开场白内容")
|
||||||
|
suggested_questions: List[str] = Field(default_factory=list, description="预设问题列表")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestedQuestionsConfig(BaseModel):
|
||||||
|
"""下一步问题建议配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class TextToSpeechConfig(BaseModel):
|
||||||
|
"""文字转语音配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
voice: Optional[str] = Field(default=None, description="语音音色")
|
||||||
|
language: Optional[str] = Field(default=None, description="语言")
|
||||||
|
autoplay: bool = Field(default=False, description="是否自动播放")
|
||||||
|
|
||||||
|
|
||||||
|
class CitationConfig(BaseModel):
|
||||||
|
"""引用和归属配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
|
||||||
|
|
||||||
|
class WebSearchConfig(BaseModel):
|
||||||
|
"""联网搜索配置"""
|
||||||
|
enabled: bool = Field(default=False)
|
||||||
|
search_engine: Optional[str] = Field(default=None, description="搜索引擎")
|
||||||
|
|
||||||
|
|
||||||
|
class AppFeatures(BaseModel):
|
||||||
|
"""应用功能特性配置"""
|
||||||
|
file_upload: FileUploadConfig = Field(default_factory=FileUploadConfig)
|
||||||
|
opening_statement: OpeningStatementConfig = Field(default_factory=OpeningStatementConfig)
|
||||||
|
suggested_questions_after_answer: SuggestedQuestionsConfig = Field(default_factory=SuggestedQuestionsConfig)
|
||||||
|
text_to_speech: TextToSpeechConfig = Field(default_factory=TextToSpeechConfig)
|
||||||
|
citation: CitationConfig = Field(default_factory=CitationConfig)
|
||||||
|
web_search: WebSearchConfig = Field(default_factory=WebSearchConfig)
|
||||||
|
|
||||||
|
|
||||||
class ToolOldConfig(BaseModel):
|
class ToolOldConfig(BaseModel):
|
||||||
"""工具配置"""
|
"""工具配置"""
|
||||||
enabled: bool = Field(default=False, description="是否启用该工具")
|
enabled: bool = Field(default=False, description="是否启用该工具")
|
||||||
@@ -201,6 +280,9 @@ class AgentConfigCreate(BaseModel):
|
|||||||
# 技能配置
|
# 技能配置
|
||||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||||
|
|
||||||
|
# 功能特性
|
||||||
|
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
|
||||||
|
|
||||||
|
|
||||||
class AppCreate(BaseModel):
|
class AppCreate(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
@@ -258,6 +340,9 @@ class AgentConfigUpdate(BaseModel):
|
|||||||
# 技能配置
|
# 技能配置
|
||||||
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
skills: Optional[SkillConfig] = Field(default=dict, description="关联的技能列表")
|
||||||
|
|
||||||
|
# 功能特性
|
||||||
|
features: Optional[AppFeatures] = Field(default=None, description="功能特性配置")
|
||||||
|
|
||||||
|
|
||||||
# ---------- Output Schemas ----------
|
# ---------- Output Schemas ----------
|
||||||
|
|
||||||
@@ -283,6 +368,10 @@ class App(BaseModel):
|
|||||||
source_workspace_icon: Optional[str] = None # 共享来源工作空间图标
|
source_workspace_icon: Optional[str] = None # 共享来源工作空间图标
|
||||||
source_app_version: Optional[str] = None # 应用版本号
|
source_app_version: Optional[str] = None # 应用版本号
|
||||||
source_app_is_active: Optional[bool] = None # 应用是否生效
|
source_app_is_active: Optional[bool] = None # 应用是否生效
|
||||||
|
share_id: Optional[uuid.UUID] = None # 分享记录ID(取消共享时使用)
|
||||||
|
shared_by: Optional[uuid.UUID] = None # 分享者用户ID
|
||||||
|
shared_by_name: Optional[str] = None # 分享者名称
|
||||||
|
shared_at: Optional[datetime.datetime] = None # 分享时间
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
|
|
||||||
@@ -294,6 +383,10 @@ class App(BaseModel):
|
|||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("shared_at", when_used="json")
|
||||||
|
def _serialize_shared_at(self, dt: Optional[datetime.datetime]):
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
class AgentConfig(BaseModel):
|
||||||
"""Agent 配置输出 Schema"""
|
"""Agent 配置输出 Schema"""
|
||||||
@@ -323,6 +416,8 @@ class AgentConfig(BaseModel):
|
|||||||
|
|
||||||
skills: Optional[SkillConfig] = {}
|
skills: Optional[SkillConfig] = {}
|
||||||
|
|
||||||
|
features: Optional[AppFeatures] = None
|
||||||
|
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
@@ -359,6 +454,14 @@ class AgentConfig(BaseModel):
|
|||||||
return {}
|
return {}
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@field_validator("features", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_features(cls, v):
|
||||||
|
"""处理 None 值,返回默认 AppFeatures"""
|
||||||
|
if v is None:
|
||||||
|
return AppFeatures()
|
||||||
|
return v
|
||||||
|
|
||||||
@field_serializer("created_at", when_used="json")
|
@field_serializer("created_at", when_used="json")
|
||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
@@ -422,6 +525,13 @@ class AppRelease(BaseModel):
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- App Copy Schema ----------
|
||||||
|
|
||||||
|
class CopyAppRequest(BaseModel):
|
||||||
|
"""复制应用请求"""
|
||||||
|
new_name: Optional[str] = Field(None, description="新应用名称,不填则使用原名称-副本")
|
||||||
|
|
||||||
|
|
||||||
# ---------- App Share Schemas ----------
|
# ---------- App Share Schemas ----------
|
||||||
|
|
||||||
class AppShareCreate(BaseModel):
|
class AppShareCreate(BaseModel):
|
||||||
@@ -500,12 +610,35 @@ class DraftRunRequest(BaseModel):
|
|||||||
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
files: Optional[List[FileInput]] = Field(default_factory=list, description="附件列表(支持多文件)")
|
||||||
|
|
||||||
|
|
||||||
|
class SuggestedQuestion(BaseModel):
|
||||||
|
"""建议问题"""
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class CitationSource(BaseModel):
|
||||||
|
"""引用来源"""
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
score: Optional[float] = None
|
||||||
|
kb_id: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DraftRunResponse(BaseModel):
|
class DraftRunResponse(BaseModel):
|
||||||
"""试运行响应(非流式)"""
|
"""试运行响应(非流式)"""
|
||||||
message: str = Field(..., description="AI 回复消息")
|
message: str = Field(..., description="AI 回复消息")
|
||||||
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
conversation_id: Optional[str] = Field(default=None, description="会话ID(用于多轮对话)")
|
||||||
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
usage: Optional[Dict[str, Any]] = Field(default=None, description="Token 使用情况")
|
||||||
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
elapsed_time: Optional[float] = Field(default=None, description="耗时(秒)")
|
||||||
|
suggested_questions: List[str] = Field(default_factory=list, description="下一步建议问题")
|
||||||
|
citations: List[CitationSource] = Field(default_factory=list, description="引用来源")
|
||||||
|
audio_url: Optional[str] = Field(default=None, description="TTS 语音URL")
|
||||||
|
|
||||||
|
|
||||||
|
class OpeningResponse(BaseModel):
|
||||||
|
"""应用开场白响应"""
|
||||||
|
enabled: bool
|
||||||
|
statement: Optional[str] = None
|
||||||
|
suggested_questions: List[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DraftRunStreamChunk(BaseModel):
|
class DraftRunStreamChunk(BaseModel):
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ class Message(BaseModel):
|
|||||||
def _serialize_created_at(self, dt: datetime.datetime):
|
def _serialize_created_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("meta_data", when_used="json")
|
||||||
|
def _serialize_meta_data(self, data: Optional[Dict[str, Any]]):
|
||||||
|
return data or {}
|
||||||
|
|
||||||
|
|
||||||
class Conversation(BaseModel):
|
class Conversation(BaseModel):
|
||||||
"""会话输出"""
|
"""会话输出"""
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ class EndUser(BaseModel):
|
|||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
id: uuid.UUID = Field(description="终端用户ID")
|
id: uuid.UUID = Field(description="终端用户ID")
|
||||||
app_id: uuid.UUID = Field(description="应用ID")
|
app_id: Optional[uuid.UUID] = Field(description="应用ID", default=None)
|
||||||
# end_user_id: str = Field(description="终端用户ID")
|
# end_user_id: str = Field(description="终端用户ID")
|
||||||
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
other_id: Optional[str] = Field(description="第三方ID", default=None)
|
||||||
other_name: Optional[str] = Field(description="其他名称", default="")
|
other_name: Optional[str] = Field(description="其他名称", default="")
|
||||||
|
|||||||
@@ -26,5 +26,7 @@ class AgentMemory_Long_Term(ABC):
|
|||||||
STRATEGY_TIME = "time"
|
STRATEGY_TIME = "time"
|
||||||
DEFAULT_SCOPE = 6
|
DEFAULT_SCOPE = 6
|
||||||
TIME_SCOPE=5
|
TIME_SCOPE=5
|
||||||
|
class AgentMemoryDataset(ABC):
|
||||||
|
PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余']
|
||||||
|
NAME='用户'
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class MemoryWriteRequest(BaseModel):
|
|||||||
"""
|
"""
|
||||||
end_user_id: str = Field(..., description="End user ID (required)")
|
end_user_id: str = Field(..., description="End user ID (required)")
|
||||||
message: str = Field(..., description="Message content to store")
|
message: str = Field(..., description="Message content to store")
|
||||||
config_id: Optional[str] = Field(None, description="Memory configuration ID")
|
config_id: str = Field(..., description="Memory configuration ID (required)")
|
||||||
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
||||||
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
||||||
|
|
||||||
@@ -68,7 +68,7 @@ class MemoryReadRequest(BaseModel):
|
|||||||
"0",
|
"0",
|
||||||
description="Search mode: 0=verify, 1=direct, 2=context"
|
description="Search mode: 0=verify, 1=direct, 2=context"
|
||||||
)
|
)
|
||||||
config_id: Optional[str] = Field(None, description="Memory configuration ID")
|
config_id: str = Field(..., description="Memory configuration ID (required)")
|
||||||
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
|
||||||
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
|
||||||
|
|
||||||
@@ -132,3 +132,79 @@ class MemoryReadResponse(BaseModel):
|
|||||||
description="Intermediate retrieval outputs"
|
description="Intermediate retrieval outputs"
|
||||||
)
|
)
|
||||||
end_user_id: str = Field(..., description="End user ID")
|
end_user_id: str = Field(..., description="End user ID")
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEndUserRequest(BaseModel):
|
||||||
|
"""Request schema for creating an end user.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
workspace_id: Workspace ID (required)
|
||||||
|
other_id: External user identifier (required)
|
||||||
|
other_name: Display name for the end user
|
||||||
|
"""
|
||||||
|
workspace_id: str = Field(..., description="Workspace ID (required)")
|
||||||
|
other_id: str = Field(..., description="External user identifier (required)")
|
||||||
|
other_name: Optional[str] = Field("", description="Display name")
|
||||||
|
|
||||||
|
@field_validator("workspace_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_workspace_id(cls, v: str) -> str:
|
||||||
|
"""Validate that workspace_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("workspace_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
@field_validator("other_id")
|
||||||
|
@classmethod
|
||||||
|
def validate_other_id(cls, v: str) -> str:
|
||||||
|
"""Validate that other_id is not empty."""
|
||||||
|
if not v or not v.strip():
|
||||||
|
raise ValueError("other_id is required and cannot be empty")
|
||||||
|
return v.strip()
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEndUserResponse(BaseModel):
|
||||||
|
"""Response schema for end user creation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Created end user UUID
|
||||||
|
other_id: External user identifier
|
||||||
|
other_name: Display name
|
||||||
|
workspace_id: Workspace the user belongs to
|
||||||
|
"""
|
||||||
|
id: str = Field(..., description="End user UUID")
|
||||||
|
other_id: str = Field(..., description="External user identifier")
|
||||||
|
other_name: str = Field("", description="Display name")
|
||||||
|
workspace_id: str = Field(..., description="Workspace ID")
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryConfigItem(BaseModel):
|
||||||
|
"""Schema for a single memory config in the list response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config_id: Configuration UUID
|
||||||
|
config_name: Configuration name
|
||||||
|
config_desc: Configuration description
|
||||||
|
is_default: Whether this is the workspace default config
|
||||||
|
scene_name: Associated ontology scene name
|
||||||
|
created_at: Creation timestamp
|
||||||
|
updated_at: Last update timestamp
|
||||||
|
"""
|
||||||
|
config_id: str = Field(..., description="Configuration ID")
|
||||||
|
config_name: str = Field(..., description="Configuration name")
|
||||||
|
config_desc: Optional[str] = Field(None, description="Configuration description")
|
||||||
|
is_default: bool = Field(False, description="Whether this is the workspace default")
|
||||||
|
scene_name: Optional[str] = Field(None, description="Associated ontology scene name")
|
||||||
|
created_at: Optional[str] = Field(None, description="Creation timestamp")
|
||||||
|
updated_at: Optional[str] = Field(None, description="Last update timestamp")
|
||||||
|
|
||||||
|
|
||||||
|
class ListConfigsResponse(BaseModel):
|
||||||
|
"""Response schema for listing memory configs.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
configs: List of memory config items
|
||||||
|
total: Total number of configs
|
||||||
|
"""
|
||||||
|
configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs")
|
||||||
|
total: int = Field(0, description="Total number of configs")
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ class MemoryConfig:
|
|||||||
|
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id: Optional[UUID] = None
|
scene_id: Optional[UUID] = None
|
||||||
ontology_classes: Optional[list] = field(default=None)
|
ontology_class_infos: list[dict] = field(default_factory=list)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
"""Validate configuration after initialization."""
|
"""Validate configuration after initialization."""
|
||||||
|
|||||||
@@ -90,6 +90,7 @@ class ToolInfo(BaseModel):
|
|||||||
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
parameters: List[ToolParameter] = Field(default_factory=list, description="工具参数")
|
||||||
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
|
config_data: Dict[str, Any] = Field(default_factory=dict, description="工具配置")
|
||||||
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
|
status: ToolStatus = Field(ToolStatus.AVAILABLE, description="工具状态")
|
||||||
|
is_active: bool = Field(True, description="是否可用(False 表示已删除)")
|
||||||
tags: List[str] = Field(default_factory=list, description="工具标签")
|
tags: List[str] = Field(default_factory=list, description="工具标签")
|
||||||
tenant_id: Optional[str] = Field(None, description="租户ID")
|
tenant_id: Optional[str] = Field(None, description="租户ID")
|
||||||
created_at: datetime = Field(..., description="创建时间")
|
created_at: datetime = Field(..., description="创建时间")
|
||||||
@@ -212,6 +213,11 @@ class ToolUpdateRequest(BaseModel):
|
|||||||
tags: Optional[List[str]] = None
|
tags: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolActiveUpdate(BaseModel):
|
||||||
|
"""工具可用状态更新"""
|
||||||
|
is_active: bool = Field(..., description="True=启用, False=禁用(逻辑删除)")
|
||||||
|
|
||||||
|
|
||||||
class ToolExecuteRequest(BaseModel):
|
class ToolExecuteRequest(BaseModel):
|
||||||
"""执行工具请求"""
|
"""执行工具请求"""
|
||||||
tool_id: str
|
tool_id: str
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class WorkflowConfigCreate(BaseModel):
|
|||||||
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
|
variables: list[VariableDefinition] = Field(default_factory=list, description="变量列表")
|
||||||
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置")
|
||||||
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
|
triggers: list[TriggerConfig] = Field(default_factory=list, description="触发器列表")
|
||||||
|
features: dict = Field(default_factory=dict, description="功能特性配置")
|
||||||
|
|
||||||
|
|
||||||
class WorkflowConfigUpdate(BaseModel):
|
class WorkflowConfigUpdate(BaseModel):
|
||||||
@@ -87,6 +88,7 @@ class WorkflowConfigUpdate(BaseModel):
|
|||||||
nodes: list[NodeDefinition] | None = None
|
nodes: list[NodeDefinition] | None = None
|
||||||
edges: list[EdgeDefinition] | None = None
|
edges: list[EdgeDefinition] | None = None
|
||||||
variables: list[VariableDefinition] | None = None
|
variables: list[VariableDefinition] | None = None
|
||||||
|
features: dict | None = None
|
||||||
execution_config: ExecutionConfig | None = None
|
execution_config: ExecutionConfig | None = None
|
||||||
triggers: list[TriggerConfig] | None = None
|
triggers: list[TriggerConfig] | None = None
|
||||||
|
|
||||||
@@ -102,6 +104,7 @@ class WorkflowConfig(BaseModel):
|
|||||||
variables: list[dict[str, Any]]
|
variables: list[dict[str, Any]]
|
||||||
execution_config: dict[str, Any]
|
execution_config: dict[str, Any]
|
||||||
triggers: list[dict[str, Any]]
|
triggers: list[dict[str, Any]]
|
||||||
|
features: dict | None
|
||||||
is_active: bool
|
is_active: bool
|
||||||
created_at: datetime.datetime
|
created_at: datetime.datetime
|
||||||
updated_at: datetime.datetime
|
updated_at: datetime.datetime
|
||||||
@@ -114,6 +117,10 @@ class WorkflowConfig(BaseModel):
|
|||||||
def _serialize_updated_at(self, dt: datetime.datetime):
|
def _serialize_updated_at(self, dt: datetime.datetime):
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
@field_serializer("features", when_used="json")
|
||||||
|
def _serialize_features(self, features: dict | None):
|
||||||
|
return features or {}
|
||||||
|
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ class AgentConfigConverter:
|
|||||||
if hasattr(config, "skills") and config.skills:
|
if hasattr(config, "skills") and config.skills:
|
||||||
result["skills"] = config.skills.model_dump()
|
result["skills"] = config.skills.model_dump()
|
||||||
|
|
||||||
|
if hasattr(config, "features") and config.features:
|
||||||
|
result["features"] = config.features.model_dump()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from app.services.model_service import ModelApiKeyService
|
|||||||
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
from app.services.multi_agent_orchestrator import MultiAgentOrchestrator
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.workflow_service import WorkflowService
|
from app.services.workflow_service import WorkflowService
|
||||||
|
from app.schemas import FileType
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -49,12 +50,23 @@ class AppChatService:
|
|||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
files: Optional[List[FileInput]] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""聊天(非流式)"""
|
"""聊天(非流式)"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
config_id = None
|
||||||
|
|
||||||
|
# 应用 features 配置
|
||||||
|
features_config: dict = config.features or {}
|
||||||
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||||
|
web_search = False
|
||||||
|
|
||||||
|
# 校验文件上传
|
||||||
|
self.agent_service._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||||
|
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
@@ -107,12 +119,9 @@ class AppChatService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 加载历史消息
|
# 加载历史消息
|
||||||
history = []
|
|
||||||
memory_config = {"enabled": True, 'max_history': 10}
|
|
||||||
if memory_config.get("enabled"):
|
|
||||||
messages = self.conversation_service.get_messages(
|
messages = self.conversation_service.get_messages(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
limit=memory_config.get("max_history", 10)
|
limit=10
|
||||||
)
|
)
|
||||||
history = [
|
history = [
|
||||||
{"role": msg.role, "content": msg.content}
|
{"role": msg.role, "content": msg.content}
|
||||||
@@ -148,24 +157,61 @@ class AppChatService:
|
|||||||
files=processed_files # 传递处理后的文件
|
files=processed_files # 传递处理后的文件
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存消息
|
|
||||||
message_id = self.conversation_service.save_conversation_messages(
|
|
||||||
conversation_id=conversation_id,
|
|
||||||
user_message=message,
|
|
||||||
assistant_message=result["content"],
|
|
||||||
meta_data={
|
|
||||||
"usage": result.get("usage", {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
})
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
|
# suggested_questions
|
||||||
|
suggested_questions = []
|
||||||
|
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||||
|
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||||
|
suggested_questions = await self.agent_service._generate_suggested_questions(
|
||||||
|
features_config, result["content"],
|
||||||
|
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||||
|
"api_base": api_key_obj.api_base}, {}
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_url = await self.agent_service._generate_tts(
|
||||||
|
features_config, result["content"],
|
||||||
|
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||||
|
"api_base": api_key_obj.api_base, "provider": api_key_obj.provider},
|
||||||
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建用户消息内容(含多模态文件)
|
||||||
|
human_meta = {
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
|
assistant_meta = {
|
||||||
|
"model": api_key_obj.model_name,
|
||||||
|
"usage": result.get("usage", {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}),
|
||||||
|
"audio_url": None
|
||||||
|
}
|
||||||
|
if files:
|
||||||
|
for f in files:
|
||||||
|
# url = await MultimodalService(self.db).get_file_url(f)
|
||||||
|
human_meta["files"].append({
|
||||||
|
"type": f.type,
|
||||||
|
"url": f.url
|
||||||
|
})
|
||||||
|
|
||||||
|
# 保存消息
|
||||||
|
if audio_url:
|
||||||
|
assistant_meta["audio_url"] = audio_url
|
||||||
|
self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=message,
|
||||||
|
meta_data=human_meta
|
||||||
|
)
|
||||||
|
ai_message = self.conversation_service.add_message(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=result["content"],
|
||||||
|
meta_data=assistant_meta
|
||||||
|
)
|
||||||
|
message_id = ai_message.id
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"message_id": str(message_id),
|
"message_id": str(message_id),
|
||||||
@@ -175,7 +221,10 @@ class AppChatService:
|
|||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
}),
|
}),
|
||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time,
|
||||||
|
"suggested_questions": suggested_questions,
|
||||||
|
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
|
||||||
|
"audio_url": audio_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def agnet_chat_stream(
|
async def agnet_chat_stream(
|
||||||
@@ -190,7 +239,7 @@ class AppChatService:
|
|||||||
storage_type: Optional[str] = None,
|
storage_type: Optional[str] = None,
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
workspace_id: Optional[str] = None,
|
workspace_id: Optional[str] = None,
|
||||||
files: Optional[List[FileInput]] = None # 新增:多模态文件
|
files: Optional[List[FileInput]] = None
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""聊天(流式)"""
|
"""聊天(流式)"""
|
||||||
|
|
||||||
@@ -198,10 +247,19 @@ class AppChatService:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
config_id = None
|
config_id = None
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
yield f"event: start\ndata: {json.dumps({
|
|
||||||
'conversation_id': str(conversation_id),
|
# 应用 features 配置
|
||||||
"message_id": str(message_id)
|
features_config: dict = config.features or {}
|
||||||
}, ensure_ascii=False)}\n\n"
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||||
|
web_search = False
|
||||||
|
|
||||||
|
# 校验文件上传
|
||||||
|
self.agent_service._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
variables = self.agent_service.prepare_variables(variables, config.variables)
|
variables = self.agent_service.prepare_variables(variables, config.variables)
|
||||||
# 获取模型配置ID
|
# 获取模型配置ID
|
||||||
@@ -284,9 +342,17 @@ class AppChatService:
|
|||||||
processed_files = await multimodal_service.process_files(user_id, files)
|
processed_files = await multimodal_service.process_files(user_id, files)
|
||||||
logger.info(f"处理了 {len(processed_files)} 个文件")
|
logger.info(f"处理了 {len(processed_files)} 个文件")
|
||||||
|
|
||||||
# 流式调用 Agent(支持多模态)
|
# 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
text_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
|
||||||
|
features_config, api_key_obj,
|
||||||
|
text_queue=text_queue,
|
||||||
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -296,39 +362,67 @@ class AppChatService:
|
|||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
memory_flag=memory_flag,
|
memory_flag=memory_flag,
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
else:
|
else:
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
# 发送消息块事件
|
|
||||||
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
||||||
|
if tts_task is not None:
|
||||||
|
await text_queue.put(chunk)
|
||||||
|
|
||||||
|
if tts_task is not None:
|
||||||
|
await text_queue.put(None)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
||||||
|
|
||||||
|
# 发送结束事件(包含 suggested_questions、tts、citations)
|
||||||
|
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
||||||
|
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||||
|
if isinstance(sq_config, dict) and sq_config.get("enabled"):
|
||||||
|
end_data["suggested_questions"] = await self.agent_service._generate_suggested_questions(
|
||||||
|
features_config, full_content,
|
||||||
|
{"model_name": api_key_obj.model_name, "api_key": api_key_obj.api_key,
|
||||||
|
"api_base": api_key_obj.api_base}, {}
|
||||||
|
)
|
||||||
|
end_data["audio_url"] = stream_audio_url
|
||||||
|
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
|
||||||
|
|
||||||
# 保存消息
|
# 保存消息
|
||||||
|
human_meta = {
|
||||||
|
"files":[]
|
||||||
|
}
|
||||||
|
assistant_meta = {
|
||||||
|
"model": api_key_obj.model_name,
|
||||||
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens},
|
||||||
|
"audio_url": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if files:
|
||||||
|
for f in files:
|
||||||
|
# url = await MultimodalService(self.db).get_file_url(f)
|
||||||
|
human_meta["files"].append({
|
||||||
|
"type": f.type,
|
||||||
|
"url": f.url
|
||||||
|
})
|
||||||
|
|
||||||
|
if stream_audio_url:
|
||||||
|
assistant_meta["audio_url"] = stream_audio_url
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="user",
|
role="user",
|
||||||
content=message
|
content=message,
|
||||||
|
meta_data=human_meta
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=full_content,
|
content=full_content,
|
||||||
meta_data={
|
meta_data=assistant_meta
|
||||||
"model": api_key_obj.model_name,
|
|
||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
|
|
||||||
|
|
||||||
# 发送结束事件
|
|
||||||
end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
|
|
||||||
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -442,7 +536,7 @@ class AppChatService:
|
|||||||
try:
|
try:
|
||||||
message_id = uuid.uuid4()
|
message_id = uuid.uuid4()
|
||||||
# 发送开始事件
|
# 发送开始事件
|
||||||
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n"
|
yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), 'message_id': str(message_id)}, ensure_ascii=False)}\n\n"
|
||||||
|
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
@@ -534,6 +628,7 @@ class AppChatService:
|
|||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
release_id: uuid.UUID,
|
release_id: uuid.UUID,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
|
files: Optional[List[FileInput]] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
variables: Optional[Dict[str, Any]] = None,
|
variables: Optional[Dict[str, Any]] = None,
|
||||||
web_search: bool = False,
|
web_search: bool = False,
|
||||||
@@ -547,7 +642,8 @@ class AppChatService:
|
|||||||
variables=variables,
|
variables=variables,
|
||||||
conversation_id=str(conversation_id),
|
conversation_id=str(conversation_id),
|
||||||
stream=True,
|
stream=True,
|
||||||
user_id=user_id
|
user_id=user_id,
|
||||||
|
files=files
|
||||||
)
|
)
|
||||||
return await self.workflow_service.run(
|
return await self.workflow_service.run(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from app.core.error_codes import BizCode
|
|||||||
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
from app.core.exceptions import BusinessException, ResourceNotFoundException
|
||||||
from app.models import AgentConfig, MultiAgentConfig
|
from app.models import AgentConfig, MultiAgentConfig
|
||||||
from app.models.app_model import App, AppType
|
from app.models.app_model import App, AppType
|
||||||
|
from app.models.appshare_model import AppShare
|
||||||
from app.models.app_release_model import AppRelease
|
from app.models.app_release_model import AppRelease
|
||||||
from app.models.knowledge_model import Knowledge
|
from app.models.knowledge_model import Knowledge
|
||||||
from app.models.models_model import ModelConfig
|
from app.models.models_model import ModelConfig
|
||||||
@@ -298,11 +299,22 @@ class AppDslService:
|
|||||||
return new_app, warnings
|
return new_app, warnings
|
||||||
|
|
||||||
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
||||||
|
"""生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用"""
|
||||||
|
# 本空间自有应用名
|
||||||
existing = {r[0] for r in self.db.query(App.name).filter(
|
existing = {r[0] for r in self.db.query(App.name).filter(
|
||||||
App.workspace_id == workspace_id,
|
App.workspace_id == workspace_id,
|
||||||
App.type == app_type,
|
App.type == app_type,
|
||||||
App.is_active.is_(True)
|
App.is_active.is_(True)
|
||||||
).all()}
|
).all()}
|
||||||
|
# 共享到本空间的应用名
|
||||||
|
shared_names = {r[0] for r in self.db.query(App.name).join(
|
||||||
|
AppShare, AppShare.source_app_id == App.id
|
||||||
|
).filter(
|
||||||
|
AppShare.target_workspace_id == workspace_id,
|
||||||
|
App.type == app_type,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).all()}
|
||||||
|
existing |= shared_names
|
||||||
if name not in existing:
|
if name not in existing:
|
||||||
return name
|
return name
|
||||||
counter = 1
|
counter = 1
|
||||||
|
|||||||
@@ -7,6 +7,7 @@
|
|||||||
- 应用发布和版本管理
|
- 应用发布和版本管理
|
||||||
- 应用回滚
|
- 应用回滚
|
||||||
"""
|
"""
|
||||||
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
from typing import Annotated, Any, Dict, List, Optional, Tuple
|
||||||
@@ -80,6 +81,8 @@ class AppService:
|
|||||||
)
|
)
|
||||||
raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS)
|
raise BusinessException("应用不在指定工作空间中", BizCode.WORKSPACE_NO_ACCESS)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _check_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> bool:
|
def _check_app_accessible(self, app: App, workspace_id: Optional[uuid.UUID]) -> bool:
|
||||||
"""检查应用是否可访问(包括共享应用)
|
"""检查应用是否可访问(包括共享应用)
|
||||||
|
|
||||||
@@ -126,6 +129,28 @@ class AppService:
|
|||||||
)
|
)
|
||||||
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
|
raise BusinessException("应用不可访问", BizCode.WORKSPACE_NO_ACCESS)
|
||||||
|
|
||||||
|
def _unique_app_name(self, name: str, workspace_id: uuid.UUID, app_type: AppType) -> str:
|
||||||
|
"""生成唯一应用名称,同时检查本空间自有应用和共享到本空间的应用"""
|
||||||
|
existing = {r[0] for r in self.db.query(App.name).filter(
|
||||||
|
App.workspace_id == workspace_id,
|
||||||
|
App.type == app_type,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).all()}
|
||||||
|
shared_names = {r[0] for r in self.db.query(App.name).join(
|
||||||
|
AppShare, AppShare.source_app_id == App.id
|
||||||
|
).filter(
|
||||||
|
AppShare.target_workspace_id == workspace_id,
|
||||||
|
App.type == app_type,
|
||||||
|
App.is_active.is_(True)
|
||||||
|
).all()}
|
||||||
|
existing |= shared_names
|
||||||
|
if name not in existing:
|
||||||
|
return name
|
||||||
|
counter = 1
|
||||||
|
while f"{name}({counter})" in existing:
|
||||||
|
counter += 1
|
||||||
|
return f"{name}({counter})"
|
||||||
|
|
||||||
def _get_share_permission(self, app: App, workspace_id: Optional[uuid.UUID]) -> Optional[str]:
|
def _get_share_permission(self, app: App, workspace_id: Optional[uuid.UUID]) -> Optional[str]:
|
||||||
"""获取共享应用的权限
|
"""获取共享应用的权限
|
||||||
|
|
||||||
@@ -148,11 +173,11 @@ class AppService:
|
|||||||
return share.permission if share else None
|
return share.permission if share else None
|
||||||
|
|
||||||
def _validate_app_writable(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
def _validate_app_writable(self, app: App, workspace_id: Optional[uuid.UUID]) -> None:
|
||||||
"""Validate that the app config is writable (owner only).
|
"""Validate that the app config is writable.
|
||||||
|
|
||||||
Shared apps (both readonly and editable) cannot modify config.
|
|
||||||
- Own workspace app: allowed
|
- Own workspace app: allowed
|
||||||
- Any shared app: denied
|
- Shared app with editable permission: allowed
|
||||||
|
- Shared app with readonly permission: denied
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
BusinessException: when app is not writable
|
BusinessException: when app is not writable
|
||||||
@@ -164,6 +189,11 @@ class AppService:
|
|||||||
if app.workspace_id == workspace_id:
|
if app.workspace_id == workspace_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Check share permission
|
||||||
|
permission = self._get_share_permission(app, workspace_id)
|
||||||
|
if permission == "editable":
|
||||||
|
return
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"应用写操作被拒",
|
"应用写操作被拒",
|
||||||
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
extra={"app_id": str(app.id), "workspace_id": str(workspace_id)}
|
||||||
@@ -360,6 +390,7 @@ class AppService:
|
|||||||
variables=storage_data.get("variables", []),
|
variables=storage_data.get("variables", []),
|
||||||
tools=storage_data.get("tools", []),
|
tools=storage_data.get("tools", []),
|
||||||
skills=storage_data.get("skills", {}),
|
skills=storage_data.get("skills", {}),
|
||||||
|
features=storage_data.get("features", {}),
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -505,6 +536,10 @@ class AppService:
|
|||||||
source_workspace_icon = None
|
source_workspace_icon = None
|
||||||
source_app_version = None
|
source_app_version = None
|
||||||
source_app_is_active = None
|
source_app_is_active = None
|
||||||
|
share_id = None
|
||||||
|
shared_by = None
|
||||||
|
shared_by_name = None
|
||||||
|
shared_at = None
|
||||||
|
|
||||||
if is_shared:
|
if is_shared:
|
||||||
# 查询共享权限和来源工作空间名称
|
# 查询共享权限和来源工作空间名称
|
||||||
@@ -516,7 +551,12 @@ class AppService:
|
|||||||
)
|
)
|
||||||
share = self.db.scalars(stmt).first()
|
share = self.db.scalars(stmt).first()
|
||||||
if share:
|
if share:
|
||||||
|
share_id = share.id
|
||||||
share_permission = share.permission
|
share_permission = share.permission
|
||||||
|
shared_by = share.shared_by
|
||||||
|
shared_at = share.created_at
|
||||||
|
if share.shared_user:
|
||||||
|
shared_by_name = share.shared_user.username
|
||||||
if share.source_workspace:
|
if share.source_workspace:
|
||||||
source_workspace_name = share.source_workspace.name
|
source_workspace_name = share.source_workspace.name
|
||||||
source_workspace_icon = share.source_workspace.icon
|
source_workspace_icon = share.source_workspace.icon
|
||||||
@@ -546,6 +586,10 @@ class AppService:
|
|||||||
"source_workspace_icon": source_workspace_icon,
|
"source_workspace_icon": source_workspace_icon,
|
||||||
"source_app_version": source_app_version,
|
"source_app_version": source_app_version,
|
||||||
"source_app_is_active": source_app_is_active,
|
"source_app_is_active": source_app_is_active,
|
||||||
|
"share_id": share_id,
|
||||||
|
"shared_by": shared_by,
|
||||||
|
"shared_by_name": shared_by_name,
|
||||||
|
"shared_at": shared_at,
|
||||||
"created_at": app.created_at,
|
"created_at": app.created_at,
|
||||||
"updated_at": app.updated_at
|
"updated_at": app.updated_at
|
||||||
}
|
}
|
||||||
@@ -760,6 +804,7 @@ class AppService:
|
|||||||
# 确定新应用名称
|
# 确定新应用名称
|
||||||
if not new_name:
|
if not new_name:
|
||||||
new_name = f"{source_app.name} - 副本"
|
new_name = f"{source_app.name} - 副本"
|
||||||
|
new_name = self._unique_app_name(new_name, target_workspace_id, source_app.type)
|
||||||
|
|
||||||
now = datetime.datetime.now()
|
now = datetime.datetime.now()
|
||||||
|
|
||||||
@@ -783,6 +828,19 @@ class AppService:
|
|||||||
self.db.add(new_app)
|
self.db.add(new_app)
|
||||||
self.db.flush()
|
self.db.flush()
|
||||||
|
|
||||||
|
# 判断是否跨工作空间复制(共享应用复制到自己的工作空间)
|
||||||
|
is_cross_workspace = target_workspace_id != source_app.workspace_id
|
||||||
|
|
||||||
|
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
|
||||||
|
target_tenant_id = None
|
||||||
|
available_model_ids: set = set()
|
||||||
|
available_kb_ids: set = set()
|
||||||
|
if is_cross_workspace:
|
||||||
|
target_ws = self.db.get(Workspace, target_workspace_id)
|
||||||
|
if not target_ws:
|
||||||
|
raise ResourceNotFoundException("工作空间", str(target_workspace_id))
|
||||||
|
target_tenant_id = target_ws.tenant_id
|
||||||
|
|
||||||
# 如果是 agent 类型,复制 AgentConfig
|
# 如果是 agent 类型,复制 AgentConfig
|
||||||
if source_app.type == AppType.AGENT:
|
if source_app.type == AppType.AGENT:
|
||||||
source_config = self.db.query(AgentConfig).filter(
|
source_config = self.db.query(AgentConfig).filter(
|
||||||
@@ -790,16 +848,40 @@ class AppService:
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
|
if is_cross_workspace:
|
||||||
|
# Batch-collect and preload all referenced resources
|
||||||
|
model_ids, kb_ids = self._collect_resource_ids_from_config(
|
||||||
|
source_config.default_model_config_id,
|
||||||
|
source_config.knowledge_retrieval,
|
||||||
|
source_config.tools
|
||||||
|
)
|
||||||
|
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||||
|
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||||
|
)
|
||||||
|
new_model_config_id = self._is_model_available(
|
||||||
|
source_config.default_model_config_id, available_model_ids
|
||||||
|
)
|
||||||
|
new_knowledge_retrieval = self._clean_knowledge_retrieval(
|
||||||
|
source_config.knowledge_retrieval, available_kb_ids
|
||||||
|
)
|
||||||
|
new_tools = self._clean_tools(
|
||||||
|
source_config.tools, available_kb_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_model_config_id = source_config.default_model_config_id
|
||||||
|
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
|
||||||
|
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
|
||||||
|
|
||||||
new_config = AgentConfig(
|
new_config = AgentConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
app_id=new_app.id,
|
app_id=new_app.id,
|
||||||
system_prompt=source_config.system_prompt,
|
system_prompt=source_config.system_prompt,
|
||||||
default_model_config_id=source_config.default_model_config_id,
|
default_model_config_id=new_model_config_id,
|
||||||
model_parameters=source_config.model_parameters.copy() if source_config.model_parameters else None,
|
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
|
||||||
knowledge_retrieval=source_config.knowledge_retrieval.copy() if source_config.knowledge_retrieval else None,
|
knowledge_retrieval=new_knowledge_retrieval,
|
||||||
memory=source_config.memory.copy() if source_config.memory else None,
|
memory=copy.deepcopy(source_config.memory) if source_config.memory else None,
|
||||||
variables=source_config.variables.copy() if source_config.variables else [],
|
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||||
tools=source_config.tools.copy() if source_config.tools else [],
|
tools=new_tools,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -812,14 +894,29 @@ class AppService:
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
|
if is_cross_workspace:
|
||||||
|
model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes(
|
||||||
|
source_config.nodes
|
||||||
|
)
|
||||||
|
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
|
||||||
|
target_tenant_id, target_workspace_id, model_ids, kb_ids
|
||||||
|
)
|
||||||
|
new_nodes = self._clean_workflow_nodes_for_cross_workspace(
|
||||||
|
source_config.nodes or [],
|
||||||
|
available_model_ids,
|
||||||
|
available_kb_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else []
|
||||||
|
|
||||||
new_config = WorkflowConfig(
|
new_config = WorkflowConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
app_id=new_app.id,
|
app_id=new_app.id,
|
||||||
nodes=source_config.nodes.copy() if source_config.nodes else [],
|
nodes=new_nodes,
|
||||||
edges=source_config.edges.copy() if source_config.edges else [],
|
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
|
||||||
variables=source_config.variables.copy() if source_config.variables else [],
|
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
|
||||||
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
||||||
triggers=source_config.triggers.copy() if source_config.triggers else [],
|
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -832,17 +929,28 @@ class AppService:
|
|||||||
).first()
|
).first()
|
||||||
|
|
||||||
if source_config:
|
if source_config:
|
||||||
|
if is_cross_workspace:
|
||||||
|
model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set()
|
||||||
|
available_model_ids, _ = self._preload_cross_workspace_resources(
|
||||||
|
target_tenant_id, target_workspace_id, model_ids, set()
|
||||||
|
)
|
||||||
|
new_model_config_id = self._is_model_available(
|
||||||
|
source_config.default_model_config_id, available_model_ids
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
new_model_config_id = source_config.default_model_config_id
|
||||||
|
|
||||||
new_config = MultiAgentConfig(
|
new_config = MultiAgentConfig(
|
||||||
id=uuid.uuid4(),
|
id=uuid.uuid4(),
|
||||||
app_id=new_app.id,
|
app_id=new_app.id,
|
||||||
master_agent_id=source_config.master_agent_id,
|
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
|
||||||
master_agent_name=source_config.master_agent_name,
|
master_agent_name=source_config.master_agent_name,
|
||||||
default_model_config_id=source_config.default_model_config_id,
|
default_model_config_id=new_model_config_id,
|
||||||
model_parameters=source_config.model_parameters,
|
model_parameters=source_config.model_parameters,
|
||||||
orchestration_mode=source_config.orchestration_mode,
|
orchestration_mode=source_config.orchestration_mode,
|
||||||
sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [],
|
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
|
||||||
routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None,
|
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
|
||||||
execution_config=source_config.execution_config.copy() if source_config.execution_config else {},
|
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
|
||||||
aggregation_strategy=source_config.aggregation_strategy,
|
aggregation_strategy=source_config.aggregation_strategy,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -872,6 +980,241 @@ class AppService:
|
|||||||
)
|
)
|
||||||
raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
raise BusinessException(f"应用复制失败: {str(e)}", BizCode.INTERNAL_ERROR, cause=e)
|
||||||
|
|
||||||
|
def _preload_cross_workspace_resources(
|
||||||
|
self,
|
||||||
|
target_tenant_id: Optional[uuid.UUID],
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
model_config_ids: set,
|
||||||
|
kb_ids: set
|
||||||
|
) -> tuple:
|
||||||
|
"""Batch-load model configs and knowledge bases to avoid N+1 queries.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(available_model_ids, available_kb_ids): sets of IDs available in target workspace
|
||||||
|
"""
|
||||||
|
from app.models.models_model import ModelConfig as MC
|
||||||
|
from app.models.knowledge_model import Knowledge
|
||||||
|
from app.models.knowledgeshare_model import KnowledgeShare
|
||||||
|
|
||||||
|
# Batch check model configs by tenant
|
||||||
|
available_model_ids: set = set()
|
||||||
|
if model_config_ids and target_tenant_id:
|
||||||
|
stmt = select(MC.id).where(
|
||||||
|
MC.id.in_(model_config_ids),
|
||||||
|
MC.tenant_id == target_tenant_id
|
||||||
|
)
|
||||||
|
available_model_ids = set(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
# Batch check knowledge bases
|
||||||
|
available_kb_ids: set = set()
|
||||||
|
if kb_ids:
|
||||||
|
kb_uuids = set()
|
||||||
|
for kid in kb_ids:
|
||||||
|
try:
|
||||||
|
kb_uuids.add(uuid.UUID(str(kid)))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
if kb_uuids:
|
||||||
|
# KBs in target workspace
|
||||||
|
stmt = select(Knowledge.id).where(
|
||||||
|
Knowledge.id.in_(kb_uuids),
|
||||||
|
Knowledge.workspace_id == target_workspace_id
|
||||||
|
)
|
||||||
|
available_kb_ids.update(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
# KBs shared to target workspace
|
||||||
|
remaining = kb_uuids - available_kb_ids
|
||||||
|
if remaining:
|
||||||
|
stmt = select(KnowledgeShare.source_kb_id).where(
|
||||||
|
KnowledgeShare.source_kb_id.in_(remaining),
|
||||||
|
KnowledgeShare.target_workspace_id == target_workspace_id
|
||||||
|
)
|
||||||
|
available_kb_ids.update(self.db.scalars(stmt).all())
|
||||||
|
|
||||||
|
return available_model_ids, available_kb_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_resource_ids_from_config(
|
||||||
|
model_config_id: Optional[uuid.UUID],
|
||||||
|
knowledge_retrieval: Optional[dict],
|
||||||
|
tools: Optional[list]
|
||||||
|
) -> tuple:
|
||||||
|
"""Extract all model config IDs and knowledge base IDs from an app config."""
|
||||||
|
model_ids: set = set()
|
||||||
|
kb_ids: set = set()
|
||||||
|
|
||||||
|
if model_config_id:
|
||||||
|
model_ids.add(model_config_id)
|
||||||
|
|
||||||
|
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
|
||||||
|
if "kb_ids" in knowledge_retrieval:
|
||||||
|
for kid in knowledge_retrieval.get("kb_ids", []):
|
||||||
|
if kid:
|
||||||
|
kb_ids.add(str(kid))
|
||||||
|
if knowledge_retrieval.get("knowledge_id"):
|
||||||
|
kb_ids.add(str(knowledge_retrieval["knowledge_id"]))
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, dict):
|
||||||
|
kid = tool.get("knowledge_id") or tool.get("kb_id")
|
||||||
|
if kid:
|
||||||
|
kb_ids.add(str(kid))
|
||||||
|
|
||||||
|
return model_ids, kb_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple:
|
||||||
|
"""Extract all model config IDs and knowledge base IDs from workflow nodes."""
|
||||||
|
model_ids: set = set()
|
||||||
|
kb_ids: set = set()
|
||||||
|
|
||||||
|
for node in (nodes or []):
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
continue
|
||||||
|
data = node.get("data", {})
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
continue
|
||||||
|
for key in ("model_config_id", "default_model_config_id"):
|
||||||
|
val = data.get(key)
|
||||||
|
if val:
|
||||||
|
try:
|
||||||
|
model_ids.add(uuid.UUID(str(val)))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
kr = data.get("knowledge_retrieval")
|
||||||
|
if isinstance(kr, dict):
|
||||||
|
for kid in kr.get("kb_ids", []):
|
||||||
|
if kid:
|
||||||
|
kb_ids.add(str(kid))
|
||||||
|
if kr.get("knowledge_id"):
|
||||||
|
kb_ids.add(str(kr["knowledge_id"]))
|
||||||
|
if data.get("knowledge_id"):
|
||||||
|
kb_ids.add(str(data["knowledge_id"]))
|
||||||
|
for kid in data.get("kb_ids", []):
|
||||||
|
if kid:
|
||||||
|
kb_ids.add(str(kid))
|
||||||
|
|
||||||
|
return model_ids, kb_ids
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]:
|
||||||
|
if not model_config_id:
|
||||||
|
return None
|
||||||
|
return model_config_id if model_config_id in available_model_ids else None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
|
||||||
|
if not kb_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return kb_id if uuid.UUID(str(kb_id)) in available_kb_ids else None
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _clean_knowledge_retrieval(
|
||||||
|
self,
|
||||||
|
knowledge_retrieval: Optional[dict],
|
||||||
|
available_kb_ids: set
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""Clean knowledge retrieval config, keeping only available KBs."""
|
||||||
|
if not knowledge_retrieval:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cleaned = copy.deepcopy(knowledge_retrieval)
|
||||||
|
|
||||||
|
if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list):
|
||||||
|
cleaned["kb_ids"] = [
|
||||||
|
kid for kid in cleaned["kb_ids"]
|
||||||
|
if self._is_kb_available(kid, available_kb_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
if "knowledge_id" in cleaned:
|
||||||
|
cleaned["knowledge_id"] = self._is_kb_available(
|
||||||
|
cleaned.get("knowledge_id"), available_kb_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _clean_tools(
|
||||||
|
self,
|
||||||
|
tools: Optional[list],
|
||||||
|
available_kb_ids: set
|
||||||
|
) -> list:
|
||||||
|
"""Clean tools config, keeping built-in tools and tools with available KBs."""
|
||||||
|
if not tools:
|
||||||
|
return []
|
||||||
|
|
||||||
|
cleaned = []
|
||||||
|
for tool in tools:
|
||||||
|
if not isinstance(tool, dict):
|
||||||
|
cleaned.append(tool)
|
||||||
|
continue
|
||||||
|
|
||||||
|
tool_type = tool.get("type", "")
|
||||||
|
if tool_type in ("builtin", "built_in", "system"):
|
||||||
|
cleaned.append(copy.deepcopy(tool))
|
||||||
|
continue
|
||||||
|
|
||||||
|
kb_id = tool.get("knowledge_id") or tool.get("kb_id")
|
||||||
|
if kb_id:
|
||||||
|
if self._is_kb_available(kb_id, available_kb_ids):
|
||||||
|
cleaned.append(copy.deepcopy(tool))
|
||||||
|
continue
|
||||||
|
|
||||||
|
cleaned.append(copy.deepcopy(tool))
|
||||||
|
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
def _clean_workflow_nodes_for_cross_workspace(
|
||||||
|
self,
|
||||||
|
nodes: list,
|
||||||
|
available_model_ids: set,
|
||||||
|
available_kb_ids: set
|
||||||
|
) -> list:
|
||||||
|
"""Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source."""
|
||||||
|
if not nodes:
|
||||||
|
return []
|
||||||
|
|
||||||
|
cleaned = []
|
||||||
|
for node in nodes:
|
||||||
|
if not isinstance(node, dict):
|
||||||
|
cleaned.append(node)
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_copy = copy.deepcopy(node)
|
||||||
|
data = node_copy.get("data")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
cleaned.append(node_copy)
|
||||||
|
continue
|
||||||
|
|
||||||
|
for key in ("model_config_id", "default_model_config_id"):
|
||||||
|
if key in data and data[key]:
|
||||||
|
try:
|
||||||
|
mid = uuid.UUID(str(data[key]))
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
data[key] = None
|
||||||
|
continue
|
||||||
|
data[key] = str(mid) if mid in available_model_ids else None
|
||||||
|
|
||||||
|
if "knowledge_retrieval" in data and data["knowledge_retrieval"]:
|
||||||
|
data["knowledge_retrieval"] = self._clean_knowledge_retrieval(
|
||||||
|
data["knowledge_retrieval"], available_kb_ids
|
||||||
|
)
|
||||||
|
if "knowledge_id" in data:
|
||||||
|
data["knowledge_id"] = self._is_kb_available(
|
||||||
|
data.get("knowledge_id"), available_kb_ids
|
||||||
|
)
|
||||||
|
if "kb_ids" in data and isinstance(data["kb_ids"], list):
|
||||||
|
data["kb_ids"] = [
|
||||||
|
kid for kid in data["kb_ids"]
|
||||||
|
if self._is_kb_available(kid, available_kb_ids)
|
||||||
|
]
|
||||||
|
|
||||||
|
cleaned.append(node_copy)
|
||||||
|
return cleaned
|
||||||
|
|
||||||
def list_apps(
|
def list_apps(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -1073,6 +1416,7 @@ class AppService:
|
|||||||
# if data.tools is not None:
|
# if data.tools is not None:
|
||||||
agent_cfg.tools = storage_data.get("tools", [])
|
agent_cfg.tools = storage_data.get("tools", [])
|
||||||
agent_cfg.skills = storage_data.get("skills", {})
|
agent_cfg.skills = storage_data.get("skills", {})
|
||||||
|
agent_cfg.features = storage_data.get("features", {})
|
||||||
|
|
||||||
agent_cfg.updated_at = now
|
agent_cfg.updated_at = now
|
||||||
|
|
||||||
@@ -1082,6 +1426,50 @@ class AppService:
|
|||||||
logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)})
|
logger.info("Agent 配置更新成功", extra={"app_id": str(app_id)})
|
||||||
return agent_cfg
|
return agent_cfg
|
||||||
|
|
||||||
|
def _agent_config_from_release(self, release: "AppRelease") -> "AgentConfig":
|
||||||
|
"""从发布版本快照重建 AgentConfig 对象(不入库,仅用于运行)"""
|
||||||
|
cfg = release.config or {}
|
||||||
|
now = release.created_at or datetime.datetime.now()
|
||||||
|
agent_cfg = AgentConfig(
|
||||||
|
id=uuid.uuid4(),
|
||||||
|
app_id=release.app_id,
|
||||||
|
system_prompt=cfg.get("system_prompt", ""),
|
||||||
|
default_model_config_id=release.default_model_config_id,
|
||||||
|
model_parameters=cfg.get("model_parameters"),
|
||||||
|
knowledge_retrieval=cfg.get("knowledge_retrieval"),
|
||||||
|
memory=cfg.get("memory", {}),
|
||||||
|
variables=cfg.get("variables", []),
|
||||||
|
tools=cfg.get("tools", []),
|
||||||
|
skills=cfg.get("skills", {}),
|
||||||
|
features=cfg.get("features", {}),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
return agent_cfg
|
||||||
|
|
||||||
|
def _workflow_config_from_release(self, release: "AppRelease") -> "WorkflowConfig":
|
||||||
|
"""从发布版本快照重建 WorkflowConfig 对象(不入库,仅用于运行)"""
|
||||||
|
cfg = release.config or {}
|
||||||
|
now = release.created_at or datetime.datetime.now()
|
||||||
|
from app.models.workflow_model import WorkflowConfig as WorkflowConfigModel
|
||||||
|
# 查出源应用真实的 WorkflowConfig id,供 workflow_executions 外键使用
|
||||||
|
real_config = WorkflowConfigRepository(self.db).get_by_app_id(release.app_id)
|
||||||
|
real_id = real_config.id if real_config else uuid.uuid4()
|
||||||
|
wf_cfg = WorkflowConfigModel(
|
||||||
|
id=real_id,
|
||||||
|
app_id=release.app_id,
|
||||||
|
nodes=cfg.get("nodes", []),
|
||||||
|
edges=cfg.get("edges", []),
|
||||||
|
variables=cfg.get("variables", []),
|
||||||
|
execution_config=cfg.get("execution_config", {}),
|
||||||
|
triggers=cfg.get("triggers", []),
|
||||||
|
is_active=True,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
return wf_cfg
|
||||||
|
|
||||||
def get_agent_config(
|
def get_agent_config(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -1113,6 +1501,15 @@ class AppService:
|
|||||||
# 只读操作,允许访问共享应用
|
# 只读操作,允许访问共享应用
|
||||||
self._validate_app_accessible(app, workspace_id)
|
self._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
# 共享应用:返回最新发布版本的配置快照,而非草稿
|
||||||
|
if workspace_id and app.workspace_id != workspace_id:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = self.db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
return self._agent_config_from_release(release)
|
||||||
|
|
||||||
stmt = select(AgentConfig).where(
|
stmt = select(AgentConfig).where(
|
||||||
AgentConfig.app_id == app_id,
|
AgentConfig.app_id == app_id,
|
||||||
AgentConfig.is_active.is_(True)
|
AgentConfig.is_active.is_(True)
|
||||||
@@ -1173,6 +1570,7 @@ class AppService:
|
|||||||
variables=[],
|
variables=[],
|
||||||
tools=[],
|
tools=[],
|
||||||
skills=[],
|
skills=[],
|
||||||
|
features={},
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
@@ -1210,6 +1608,16 @@ class AppService:
|
|||||||
|
|
||||||
# 只读操作,允许访问共享应用
|
# 只读操作,允许访问共享应用
|
||||||
self._validate_app_accessible(app, workspace_id)
|
self._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
# 共享应用:返回最新发布版本的配置快照,而非草稿
|
||||||
|
if workspace_id and app.workspace_id != workspace_id:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.CONFIG_MISSING)
|
||||||
|
release = self.db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.CONFIG_MISSING)
|
||||||
|
return self._workflow_config_from_release(release)
|
||||||
|
|
||||||
repo = WorkflowConfigRepository(self.db)
|
repo = WorkflowConfigRepository(self.db)
|
||||||
config = repo.get_by_app_id(app_id)
|
config = repo.get_by_app_id(app_id)
|
||||||
if config:
|
if config:
|
||||||
@@ -1264,6 +1672,7 @@ class AppService:
|
|||||||
variables=[var.model_dump() for var in data.variables] if data.variables else [],
|
variables=[var.model_dump() for var in data.variables] if data.variables else [],
|
||||||
execution_config=data.execution_config.model_dump() if data.execution_config else {},
|
execution_config=data.execution_config.model_dump() if data.execution_config else {},
|
||||||
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
|
triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [],
|
||||||
|
features=data.features or {},
|
||||||
is_active=True,
|
is_active=True,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now
|
updated_at=now
|
||||||
@@ -1277,6 +1686,7 @@ class AppService:
|
|||||||
workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else []
|
workflow_cfg.variables = [var.model_dump() for var in data.variables] if data.variables else []
|
||||||
workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {}
|
workflow_cfg.execution_config = data.execution_config.model_dump() if data.execution_config else {}
|
||||||
workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else []
|
workflow_cfg.triggers = [trigger.model_dump() for trigger in data.triggers] if data.triggers else []
|
||||||
|
workflow_cfg.features = data.features or {}
|
||||||
workflow_cfg.updated_at = now
|
workflow_cfg.updated_at = now
|
||||||
|
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
@@ -1389,15 +1799,15 @@ class AppService:
|
|||||||
|
|
||||||
return config.config_id
|
return config.config_id
|
||||||
|
|
||||||
def _update_endusers_memory_config(
|
def _update_endusers_memory_config_by_workspace(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
memory_config_id: uuid.UUID
|
memory_config_id: uuid.UUID
|
||||||
) -> int:
|
) -> int:
|
||||||
"""批量更新应用下所有终端用户的 memory_config_id
|
"""批量更新应用下所有终端用户的 memory_config_id
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app_id: 应用ID
|
workspace_id: 工作空间ID
|
||||||
memory_config_id: 新的记忆配置ID
|
memory_config_id: 新的记忆配置ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1406,8 +1816,8 @@ class AppService:
|
|||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
|
||||||
repo = EndUserRepository(self.db)
|
repo = EndUserRepository(self.db)
|
||||||
updated_count = repo.batch_update_memory_config_id(
|
updated_count = repo.batch_update_memory_config_id_by_workspace(
|
||||||
app_id=app_id,
|
workspace_id=workspace_id,
|
||||||
memory_config_id=memory_config_id
|
memory_config_id=memory_config_id
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1473,6 +1883,7 @@ class AppService:
|
|||||||
"variables": agent_cfg.variables or [],
|
"variables": agent_cfg.variables or [],
|
||||||
"tools": agent_cfg.tools or [],
|
"tools": agent_cfg.tools or [],
|
||||||
"skills": agent_cfg.skills or {},
|
"skills": agent_cfg.skills or {},
|
||||||
|
"features": agent_cfg.features or {}
|
||||||
}
|
}
|
||||||
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
# config = AgentConfigConverter.from_storage_format(agent_cfg)
|
||||||
default_model_config_id = agent_cfg.default_model_config_id
|
default_model_config_id = agent_cfg.default_model_config_id
|
||||||
@@ -1529,7 +1940,8 @@ class AppService:
|
|||||||
"edges": workflow_cfg.edges,
|
"edges": workflow_cfg.edges,
|
||||||
"variables": workflow_cfg.variables,
|
"variables": workflow_cfg.variables,
|
||||||
"execution_config": workflow_cfg.execution_config,
|
"execution_config": workflow_cfg.execution_config,
|
||||||
"triggers": workflow_cfg.triggers
|
"triggers": workflow_cfg.triggers,
|
||||||
|
"features": workflow_cfg.features or {}
|
||||||
}
|
}
|
||||||
|
|
||||||
is_valid, errors = WorkflowValidator.validate_for_publish(config)
|
is_valid, errors = WorkflowValidator.validate_for_publish(config)
|
||||||
@@ -1578,9 +1990,13 @@ class AppService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if memory_config_id:
|
if memory_config_id:
|
||||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
app = self.db.query(App).filter(App.id == app_id).first()
|
||||||
|
if app:
|
||||||
|
updated_count = self._update_endusers_memory_config_by_workspace(
|
||||||
|
app.workspace_id, memory_config_id
|
||||||
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"发布时更新终端用户记忆配置: app_id={app_id}, "
|
f"发布时更新终端用户记忆配置: app_id={app_id}, workspace_id={app.workspace_id}, "
|
||||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1712,7 +2128,8 @@ class AppService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if memory_config_id:
|
if memory_config_id:
|
||||||
updated_count = self._update_endusers_memory_config(app_id, memory_config_id)
|
|
||||||
|
updated_count = self._update_endusers_memory_config_by_workspace(app.workspace_id, memory_config_id)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
f"回滚时更新终端用户记忆配置: app_id={app_id}, version={version}, "
|
||||||
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
f"memory_config_id={memory_config_id}, updated_count={updated_count}"
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
|||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.agent.agent_middleware import AgentMiddleware
|
from app.core.agent.agent_middleware import AgentMiddleware
|
||||||
from app.core.agent.langchain_agent import LangChainAgent
|
from app.core.agent.langchain_agent import LangChainAgent
|
||||||
|
from app.core.config import settings
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -36,6 +37,7 @@ from app.services.model_parameter_merger import ModelParameterMerger
|
|||||||
from app.services.model_service import ModelApiKeyService
|
from app.services.model_service import ModelApiKeyService
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
|
from app.schemas import FileType
|
||||||
|
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
|
|
||||||
@@ -98,7 +100,7 @@ def create_long_term_memory_tool(
|
|||||||
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
**重要:如果用户的问题可以直接回答,不要调用此工具。只在确实需要历史信息时才使用。**
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词)
|
question: 需要检索的问题(保持原问题的核心语义,使用清晰的关键词,第三人称描述的偏好、行为通常指用户本人,比如(我,本人,在下,自己,咱,鄙人,吴,余)通指用户)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
检索到的历史记忆内容
|
检索到的历史记忆内容
|
||||||
@@ -262,9 +264,12 @@ class AgentRunService:
|
|||||||
|
|
||||||
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
def load_tools_config(self, tools_config, web_search, tenant_id) -> list:
|
||||||
"""加载工具配置"""
|
"""加载工具配置"""
|
||||||
if not tools_config:
|
|
||||||
return []
|
|
||||||
tools = []
|
tools = []
|
||||||
|
if web_search:
|
||||||
|
search_tool = create_web_search_tool({})
|
||||||
|
tools.append(search_tool)
|
||||||
|
if not tools_config:
|
||||||
|
return tools
|
||||||
tool_service = ToolService(self.db)
|
tool_service = ToolService(self.db)
|
||||||
|
|
||||||
if tools_config and isinstance(tools_config, list):
|
if tools_config and isinstance(tools_config, list):
|
||||||
@@ -273,18 +278,9 @@ class AgentRunService:
|
|||||||
# 根据工具名称查找工具实例
|
# 根据工具名称查找工具实例
|
||||||
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id)
|
||||||
if tool_instance:
|
if tool_instance:
|
||||||
if tool_instance.name == "baidu_search_tool" and not web_search:
|
|
||||||
continue
|
|
||||||
# 转换为LangChain工具
|
# 转换为LangChain工具
|
||||||
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None))
|
||||||
tools.append(langchain_tool)
|
tools.append(langchain_tool)
|
||||||
elif tools_config and isinstance(tools_config, dict):
|
|
||||||
web_search_choice = tools_config.get("web_search", {})
|
|
||||||
web_search_enable = web_search_choice.get("enabled", False)
|
|
||||||
if web_search and web_search_enable:
|
|
||||||
search_tool = create_web_search_tool({})
|
|
||||||
tools.append(search_tool)
|
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"已添加网络搜索工具",
|
"已添加网络搜索工具",
|
||||||
extra={
|
extra={
|
||||||
@@ -373,6 +369,86 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
return tools, bool(memory_config.get("enabled"))
|
return tools, bool(memory_config.get("enabled"))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_file_upload(
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
files: Optional[List[FileInput]]
|
||||||
|
) -> None:
|
||||||
|
"""校验上传文件是否符合 file_upload 配置"""
|
||||||
|
if not files or not features_config:
|
||||||
|
return
|
||||||
|
fu = features_config.get("file_upload", {})
|
||||||
|
if not (isinstance(fu, dict) and fu.get("enabled")):
|
||||||
|
raise BusinessException("该应用未开启文件上传功能", BizCode.BAD_REQUEST)
|
||||||
|
max_count = fu.get("max_file_count", 5)
|
||||||
|
if len(files) > max_count:
|
||||||
|
raise BusinessException(f"文件数量超过限制(最多 {max_count} 个)", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
# 校验传输方式
|
||||||
|
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
|
||||||
|
for f in files:
|
||||||
|
if f.transfer_method.value not in allowed_methods:
|
||||||
|
raise BusinessException(
|
||||||
|
f"不支持的文件传输方式:{f.transfer_method.value},允许的方式:{', '.join(allowed_methods)}",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
# 各类型对应的开关和大小限制配置键
|
||||||
|
type_cfg = {
|
||||||
|
"image": ("image_enabled", "image_max_size_mb", 20, "图片"),
|
||||||
|
"audio": ("audio_enabled", "audio_max_size_mb", 50, "音频"),
|
||||||
|
"document": ("document_enabled", "document_max_size_mb", 100, "文档"),
|
||||||
|
"video": ("video_enabled", "video_max_size_mb", 500, "视频"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
ftype = str(f.type) # 如 "image", "audio", "document", "video"
|
||||||
|
cfg = type_cfg.get(ftype)
|
||||||
|
if cfg is None:
|
||||||
|
continue
|
||||||
|
enabled_key, size_key, default_max_mb, label = cfg
|
||||||
|
|
||||||
|
# 校验类型开关
|
||||||
|
if not fu.get(enabled_key):
|
||||||
|
raise BusinessException(f"该应用未开启{label}文件上传", BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
# 校验文件大小(仅当内容已加载时)
|
||||||
|
content = f.get_content()
|
||||||
|
if content is not None:
|
||||||
|
max_mb = fu.get(size_key, default_max_mb)
|
||||||
|
size_mb = len(content) / (1024 * 1024)
|
||||||
|
if size_mb > max_mb:
|
||||||
|
raise BusinessException(
|
||||||
|
f"{label}文件大小超过限制(最大 {max_mb}MB,当前 {size_mb:.1f}MB)",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inject_opening_statement(
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
system_prompt: str,
|
||||||
|
is_new_conversation: bool
|
||||||
|
) -> str:
|
||||||
|
"""首轮对话时将开场白注入 system_prompt"""
|
||||||
|
if not is_new_conversation:
|
||||||
|
return system_prompt
|
||||||
|
opening = features_config.get("opening_statement", {})
|
||||||
|
if not (isinstance(opening, dict) and opening.get("enabled") and opening.get("statement")):
|
||||||
|
return system_prompt
|
||||||
|
statement = opening["statement"]
|
||||||
|
return f"{system_prompt}\n\n[对话开场白]\n{statement}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _filter_citations(
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
citations: List[Any]
|
||||||
|
) -> List[Any]:
|
||||||
|
"""根据 citation 开关决定是否返回引用来源"""
|
||||||
|
citation_cfg = features_config.get("citation", {})
|
||||||
|
if isinstance(citation_cfg, dict) and citation_cfg.get("enabled"):
|
||||||
|
return citations
|
||||||
|
return []
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -415,6 +491,15 @@ class AgentRunService:
|
|||||||
skills_config: dict | None = agent_config.skills
|
skills_config: dict | None = agent_config.skills
|
||||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||||
memory_config: dict | None = agent_config.memory
|
memory_config: dict | None = agent_config.memory
|
||||||
|
features_config: dict = agent_config.features or {}
|
||||||
|
|
||||||
|
# 从 features 中读取功能开关(优先级高于参数默认值)
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
if not isinstance(web_search_feature, dict) or not web_search_feature.get("enabled"):
|
||||||
|
web_search = False
|
||||||
|
|
||||||
|
# file_upload 校验
|
||||||
|
self._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 获取 API Key 配置
|
# 1. 获取 API Key 配置
|
||||||
@@ -449,6 +534,10 @@ class AgentRunService:
|
|||||||
# 3. 处理系统提示词(支持变量替换)
|
# 3. 处理系统提示词(支持变量替换)
|
||||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
# opening_statement:首轮对话注入开场白
|
||||||
|
is_new_conversation = not conversation_id
|
||||||
|
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
|
||||||
|
|
||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
@@ -491,11 +580,9 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = []
|
|
||||||
if memory_config and memory_config.get("enabled"):
|
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=agent_config.memory.get("max_history", 10)
|
max_history=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 处理多模态文件
|
# 6. 处理多模态文件
|
||||||
@@ -550,8 +637,14 @@ class AgentRunService:
|
|||||||
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||||
|
|
||||||
# 9. 保存会话消息
|
# 9. 生成 TTS audio_url(在保存消息前生成,以便一并存入 meta_data)
|
||||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
audio_url = await self._generate_tts(
|
||||||
|
features_config, result["content"], api_key_config,
|
||||||
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
|
) if not sub_agent else None
|
||||||
|
|
||||||
|
# 10. 保存会话消息
|
||||||
|
if not sub_agent:
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
@@ -564,7 +657,9 @@ class AgentRunService:
|
|||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
})
|
})
|
||||||
}
|
},
|
||||||
|
files=files,
|
||||||
|
audio_url=audio_url
|
||||||
)
|
)
|
||||||
|
|
||||||
response = {
|
response = {
|
||||||
@@ -575,7 +670,12 @@ class AgentRunService:
|
|||||||
"completion_tokens": 0,
|
"completion_tokens": 0,
|
||||||
"total_tokens": 0
|
"total_tokens": 0
|
||||||
}),
|
}),
|
||||||
"elapsed_time": elapsed_time
|
"elapsed_time": elapsed_time,
|
||||||
|
"suggested_questions": await self._generate_suggested_questions(
|
||||||
|
features_config, result["content"], api_key_config, effective_params
|
||||||
|
) if not sub_agent else [],
|
||||||
|
"citations": self._filter_citations(features_config, result.get("citations", [])),
|
||||||
|
"audio_url": audio_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -630,6 +730,15 @@ class AgentRunService:
|
|||||||
skills_config: dict | None = agent_config.skills
|
skills_config: dict | None = agent_config.skills
|
||||||
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval
|
||||||
memory_config: dict | None = agent_config.memory
|
memory_config: dict | None = agent_config.memory
|
||||||
|
features_config: dict = agent_config.features or {}
|
||||||
|
|
||||||
|
# 从 features 中读取功能开关
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
if not (isinstance(web_search_feature, dict) and web_search_feature.get("enabled")):
|
||||||
|
web_search = False
|
||||||
|
|
||||||
|
# file_upload 校验
|
||||||
|
self._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
@@ -659,6 +768,10 @@ class AgentRunService:
|
|||||||
# 3. 处理系统提示词(支持变量替换)
|
# 3. 处理系统提示词(支持变量替换)
|
||||||
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手"
|
||||||
|
|
||||||
|
# opening_statement:首轮对话注入开场白
|
||||||
|
is_new_conversation = not conversation_id
|
||||||
|
system_prompt = self._inject_opening_statement(features_config, system_prompt, is_new_conversation)
|
||||||
|
|
||||||
# 4. 准备工具列表
|
# 4. 准备工具列表
|
||||||
tools = []
|
tools = []
|
||||||
|
|
||||||
@@ -703,8 +816,6 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 6. 加载历史消息
|
# 6. 加载历史消息
|
||||||
history = []
|
|
||||||
if memory_config and memory_config.get("enabled"):
|
|
||||||
history = await self._load_conversation_history(
|
history = await self._load_conversation_history(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
max_history=memory_config.get("max_history", 10)
|
max_history=memory_config.get("max_history", 10)
|
||||||
@@ -741,9 +852,18 @@ class AgentRunService:
|
|||||||
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
# 兼容新旧字段名:优先使用 memory_config_id,回退到 memory_content
|
||||||
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
config_id = memory_config_.get("memory_config_id") or memory_config_.get("memory_content", None)
|
||||||
|
|
||||||
# 9. 流式调用 Agent(支持多模态)
|
# 9. 流式调用 Agent(支持多模态),同时并行启动 TTS
|
||||||
full_content = ""
|
full_content = ""
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
|
||||||
|
# 启动流式 TTS(文本边输出边合成)
|
||||||
|
text_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
stream_audio_url, tts_task = await self._generate_tts_streaming(
|
||||||
|
features_config, api_key_config,
|
||||||
|
text_queue=text_queue,
|
||||||
|
tenant_id=tenant_id, workspace_id=workspace_id
|
||||||
|
) if not sub_agent else (None, None)
|
||||||
|
|
||||||
async for chunk in agent.chat_stream(
|
async for chunk in agent.chat_stream(
|
||||||
message=message,
|
message=message,
|
||||||
history=history,
|
history=history,
|
||||||
@@ -753,28 +873,28 @@ class AgentRunService:
|
|||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
memory_flag=memory_flag,
|
memory_flag=memory_flag,
|
||||||
files=processed_files # 传递处理后的文件
|
files=processed_files
|
||||||
):
|
):
|
||||||
if isinstance(chunk, int):
|
if isinstance(chunk, int):
|
||||||
total_tokens = chunk
|
total_tokens = chunk
|
||||||
else:
|
else:
|
||||||
full_content += chunk
|
full_content += chunk
|
||||||
# 发送消息块事件
|
yield self._format_sse_event("message", {"content": chunk})
|
||||||
yield self._format_sse_event("message", {
|
if tts_task is not None:
|
||||||
"content": chunk
|
await text_queue.put(chunk)
|
||||||
})
|
|
||||||
|
# 文本结束,通知 TTS
|
||||||
|
if tts_task is not None:
|
||||||
|
await text_queue.put(None)
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
|
|
||||||
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id"))
|
||||||
|
|
||||||
if sub_agent:
|
if sub_agent:
|
||||||
yield self._format_sse_event("sub_usage", {
|
yield self._format_sse_event("sub_usage", {"total_tokens": total_tokens})
|
||||||
"total_tokens": total_tokens
|
|
||||||
})
|
|
||||||
|
|
||||||
# 10. 保存会话消息
|
# 11. 保存会话消息
|
||||||
if not sub_agent and memory_config and memory_config.get("enabled"):
|
if not sub_agent:
|
||||||
await self._save_conversation_message(
|
await self._save_conversation_message(
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
user_message=message,
|
user_message=message,
|
||||||
@@ -783,15 +903,24 @@ class AgentRunService:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
meta_data={
|
meta_data={
|
||||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
|
||||||
}
|
},
|
||||||
|
files=files,
|
||||||
|
audio_url=stream_audio_url
|
||||||
)
|
)
|
||||||
|
|
||||||
# 11. 发送结束事件
|
# 12. 发送结束事件(包含 suggested_questions 和 tts)
|
||||||
yield self._format_sse_event("end", {
|
end_data: Dict[str, Any] = {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
"elapsed_time": elapsed_time,
|
"elapsed_time": elapsed_time,
|
||||||
"message_length": len(full_content)
|
"message_length": len(full_content)
|
||||||
})
|
}
|
||||||
|
if not sub_agent:
|
||||||
|
end_data["suggested_questions"] = await self._generate_suggested_questions(
|
||||||
|
features_config, full_content, api_key_config, effective_params
|
||||||
|
)
|
||||||
|
end_data["audio_url"] = stream_audio_url
|
||||||
|
end_data["citations"] = self._filter_citations(features_config, [])
|
||||||
|
yield self._format_sse_event("end", end_data)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"流式试运行完成",
|
"流式试运行完成",
|
||||||
@@ -1028,7 +1157,9 @@ class AgentRunService:
|
|||||||
assistant_message: str,
|
assistant_message: str,
|
||||||
meta_data: dict,
|
meta_data: dict,
|
||||||
app_id: Optional[uuid.UUID] = None,
|
app_id: Optional[uuid.UUID] = None,
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None,
|
||||||
|
files: Optional[List[FileInput]] = None,
|
||||||
|
audio_url: Optional[str] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
|
||||||
|
|
||||||
@@ -1047,13 +1178,26 @@ class AgentRunService:
|
|||||||
conv_uuid = uuid.UUID(conversation_id)
|
conv_uuid = uuid.UUID(conversation_id)
|
||||||
|
|
||||||
# 保存消息(会话已经存在)
|
# 保存消息(会话已经存在)
|
||||||
|
human_meta = {
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
|
if files:
|
||||||
|
for f in files:
|
||||||
|
# url = await MultimodalService(self.db).get_file_url(f)
|
||||||
|
human_meta["files"].append({
|
||||||
|
"type": f.type,
|
||||||
|
"url": f.url
|
||||||
|
})
|
||||||
# 保存用户消息
|
# 保存用户消息
|
||||||
conversation_service.add_message(
|
conversation_service.add_message(
|
||||||
conversation_id=conv_uuid,
|
conversation_id=conv_uuid,
|
||||||
role="user",
|
role="user",
|
||||||
content=user_message
|
content=user_message,
|
||||||
|
meta_data=human_meta
|
||||||
)
|
)
|
||||||
# 保存助手消息
|
# 保存助手消息(含 audio_url)
|
||||||
|
if audio_url:
|
||||||
|
meta_data["audio_url"] = audio_url
|
||||||
conversation_service.add_message(
|
conversation_service.add_message(
|
||||||
conversation_id=conv_uuid,
|
conversation_id=conv_uuid,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
@@ -1137,6 +1281,385 @@ class AgentRunService:
|
|||||||
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
logger.debug("获取配置快照失败(可能是多 Agent 应用)", exc_info=True, extra={"error": str(e)})
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
async def _generate_suggested_questions(
|
||||||
|
self,
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
assistant_message: str,
|
||||||
|
api_key_config: Dict[str, Any],
|
||||||
|
effective_params: Dict[str, Any]
|
||||||
|
) -> List[str]:
|
||||||
|
"""根据 suggested_questions_after_answer 配置生成下一步建议问题"""
|
||||||
|
sq_config = features_config.get("suggested_questions_after_answer", {})
|
||||||
|
if not isinstance(sq_config, dict) or not sq_config.get("enabled"):
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
llm = ChatOpenAI(
|
||||||
|
model=api_key_config["model_name"],
|
||||||
|
api_key=api_key_config["api_key"],
|
||||||
|
base_url=api_key_config.get("api_base"),
|
||||||
|
temperature=0.5,
|
||||||
|
max_tokens=200,
|
||||||
|
)
|
||||||
|
prompt = (
|
||||||
|
f"根据以下AI回复,生成3个用户可能继续追问的简短问题,每行一个,不加序号:\n\n{assistant_message}"
|
||||||
|
)
|
||||||
|
resp = await llm.ainvoke([HumanMessage(content=prompt)])
|
||||||
|
lines = [l.strip() for l in resp.content.strip().split("\n") if l.strip()]
|
||||||
|
return lines[:3]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"生成建议问题失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _generate_tts(
|
||||||
|
self,
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
text: str,
|
||||||
|
api_key_config: Dict[str, Any],
|
||||||
|
tenant_id: Optional[uuid.UUID] = None,
|
||||||
|
workspace_id: Optional[uuid.UUID] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""先注册文件元数据并返回 audio_url,再后台流式写入音频内容"""
|
||||||
|
tts_config = features_config.get("text_to_speech", {})
|
||||||
|
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||||
|
return None
|
||||||
|
if not text or not text.strip():
|
||||||
|
return None
|
||||||
|
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
|
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||||||
|
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
api_key = api_key_config.get("api_key")
|
||||||
|
api_base = api_key_config.get("api_base")
|
||||||
|
voice = tts_config.get("voice")
|
||||||
|
file_ext, content_type = ".mp3", "audio/mpeg"
|
||||||
|
|
||||||
|
file_id = uuid.uuid4()
|
||||||
|
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||||
|
|
||||||
|
# 先写入 pending 状态的元数据,立即返回 URL
|
||||||
|
db_file = FileMetadata(
|
||||||
|
id=file_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
file_key=file_key,
|
||||||
|
file_name=f"tts_{file_id}{file_ext}",
|
||||||
|
file_ext=file_ext,
|
||||||
|
file_size=0,
|
||||||
|
content_type=content_type,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
self.db.add(db_file)
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
|
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||||
|
|
||||||
|
# 后台任务:流式生成并写入存储,完成后更新状态
|
||||||
|
async def _stream_to_storage():
|
||||||
|
try:
|
||||||
|
storage_service = FileStorageService()
|
||||||
|
if provider == "dashscope":
|
||||||
|
stream = self._tts_dashscope_stream(
|
||||||
|
api_key=api_key,
|
||||||
|
text=text,
|
||||||
|
voice=voice or "longxiaochun",
|
||||||
|
tts_config=tts_config,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stream = self._tts_openai_stream(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
text=text,
|
||||||
|
voice=voice or "alloy",
|
||||||
|
)
|
||||||
|
|
||||||
|
total_size = await storage_service.upload_stream(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
file_ext=file_ext,
|
||||||
|
stream=stream,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新元数据状态
|
||||||
|
with get_db_context() as bg_db:
|
||||||
|
record = bg_db.get(FileMetadata, file_id)
|
||||||
|
if record:
|
||||||
|
record.status = "completed"
|
||||||
|
record.file_size = total_size
|
||||||
|
bg_db.commit()
|
||||||
|
logger.debug(f"TTS 流式写入完成,provider={provider}, file_key={file_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS 流式写入失败: {e}")
|
||||||
|
with get_db_context() as bg_db:
|
||||||
|
record = bg_db.get(FileMetadata, file_id)
|
||||||
|
if record:
|
||||||
|
record.status = "failed"
|
||||||
|
bg_db.commit()
|
||||||
|
|
||||||
|
asyncio.create_task(_stream_to_storage())
|
||||||
|
return audio_url
|
||||||
|
|
||||||
|
async def _generate_tts_streaming(
|
||||||
|
self,
|
||||||
|
features_config: Dict[str, Any],
|
||||||
|
api_key_config: Dict[str, Any],
|
||||||
|
text_queue: asyncio.Queue,
|
||||||
|
tenant_id: Optional[uuid.UUID] = None,
|
||||||
|
workspace_id: Optional[uuid.UUID] = None,
|
||||||
|
) -> tuple[Optional[str], Optional[asyncio.Task]]:
|
||||||
|
"""文本流式输入并行合成音频。
|
||||||
|
返回 (audio_url, task),audio_url 立即可用,task 完成后文件内容就绪。
|
||||||
|
调用方向 text_queue put 文本 chunk,结束时 put None。
|
||||||
|
"""
|
||||||
|
tts_config = features_config.get("text_to_speech", {})
|
||||||
|
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
from app.models.file_metadata_model import FileMetadata
|
||||||
|
from app.services.file_storage_service import FileStorageService, generate_file_key
|
||||||
|
|
||||||
|
provider = api_key_config.get("provider", "openai")
|
||||||
|
api_key = api_key_config.get("api_key")
|
||||||
|
api_base = api_key_config.get("api_base")
|
||||||
|
voice = tts_config.get("voice")
|
||||||
|
file_ext, content_type = ".mp3", "audio/mpeg"
|
||||||
|
|
||||||
|
file_id = uuid.uuid4()
|
||||||
|
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||||
|
|
||||||
|
db_file = FileMetadata(
|
||||||
|
id=file_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
file_key=file_key,
|
||||||
|
file_name=f"tts_{file_id}{file_ext}",
|
||||||
|
file_ext=file_ext,
|
||||||
|
file_size=0,
|
||||||
|
content_type=content_type,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
self.db.add(db_file)
|
||||||
|
self.db.commit()
|
||||||
|
|
||||||
|
server_url = settings.FILE_LOCAL_SERVER_URL
|
||||||
|
audio_url = f"{server_url}/storage/permanent/{file_id}"
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
try:
|
||||||
|
storage_service = FileStorageService()
|
||||||
|
if provider == "dashscope":
|
||||||
|
audio_stream = self._tts_dashscope_stream_from_queue(
|
||||||
|
api_key=api_key,
|
||||||
|
voice=voice or "longxiaochun",
|
||||||
|
tts_config=tts_config,
|
||||||
|
text_queue=text_queue,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
audio_stream = self._tts_openai_stream_from_queue(
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
voice=voice or "alloy",
|
||||||
|
text_queue=text_queue,
|
||||||
|
)
|
||||||
|
total_size = await storage_service.upload_stream(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
file_id=file_id,
|
||||||
|
file_ext=file_ext,
|
||||||
|
stream=audio_stream,
|
||||||
|
content_type=content_type,
|
||||||
|
)
|
||||||
|
with get_db_context() as bg_db:
|
||||||
|
record = bg_db.get(FileMetadata, file_id)
|
||||||
|
if record:
|
||||||
|
record.status = "completed"
|
||||||
|
record.file_size = total_size
|
||||||
|
bg_db.commit()
|
||||||
|
logger.debug(f"TTS 流式合成完成,provider={provider}, file_key={file_key}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"TTS 流式合成失败: {e}")
|
||||||
|
with get_db_context() as bg_db:
|
||||||
|
record = bg_db.get(FileMetadata, file_id)
|
||||||
|
if record:
|
||||||
|
record.status = "failed"
|
||||||
|
bg_db.commit()
|
||||||
|
|
||||||
|
task = asyncio.create_task(_run())
|
||||||
|
return audio_url, task
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _tts_openai_stream_from_queue(
|
||||||
|
api_key: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
voice: str,
|
||||||
|
text_queue: asyncio.Queue,
|
||||||
|
):
|
||||||
|
"""OpenAI TTS:收集全部文本后流式合成(OpenAI 不支持增量输入)"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
# 收集全部文本(此时文本流已并行输出,等待时间短)
|
||||||
|
parts = []
|
||||||
|
while True:
|
||||||
|
chunk = await text_queue.get()
|
||||||
|
if chunk is None:
|
||||||
|
break
|
||||||
|
parts.append(chunk)
|
||||||
|
full_text = "".join(parts)
|
||||||
|
if not full_text.strip():
|
||||||
|
return
|
||||||
|
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||||
|
async with client.audio.speech.with_streaming_response.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice=voice,
|
||||||
|
input=full_text[:4096],
|
||||||
|
) as response:
|
||||||
|
async for chunk in response.iter_bytes(chunk_size=4096):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _tts_dashscope_stream_from_queue(
|
||||||
|
api_key: str,
|
||||||
|
voice: str,
|
||||||
|
tts_config: Dict[str, Any],
|
||||||
|
text_queue: asyncio.Queue,
|
||||||
|
):
|
||||||
|
"""DashScope TTS:文本流式输入,实现真正并行合成"""
|
||||||
|
import dashscope
|
||||||
|
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||||||
|
|
||||||
|
model = tts_config.get("model") or "cosyvoice-v2"
|
||||||
|
is_v2 = model.endswith("-v2")
|
||||||
|
if is_v2 and not voice.endswith("_v2"):
|
||||||
|
voice = voice + "_v2"
|
||||||
|
elif not is_v2 and voice.endswith("_v2"):
|
||||||
|
voice = voice[:-3]
|
||||||
|
|
||||||
|
audio_queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
class _Callback(ResultCallback):
|
||||||
|
def on_data(self, data: bytes):
|
||||||
|
if data:
|
||||||
|
loop.call_soon_threadsafe(audio_queue.put_nowait, data)
|
||||||
|
def on_complete(self):
|
||||||
|
loop.call_soon_threadsafe(audio_queue.put_nowait, None)
|
||||||
|
def on_error(self, message):
|
||||||
|
loop.call_soon_threadsafe(audio_queue.put_nowait, RuntimeError(str(message)))
|
||||||
|
def on_open(self): pass
|
||||||
|
def on_close(self): pass
|
||||||
|
|
||||||
|
dashscope.api_key = api_key
|
||||||
|
synthesizer = SpeechSynthesizer(
|
||||||
|
model=model,
|
||||||
|
voice=voice,
|
||||||
|
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||||||
|
callback=_Callback(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _feed_text():
|
||||||
|
"""从 text_queue 取文本按句子切分后喂给 synthesizer"""
|
||||||
|
import re
|
||||||
|
buf = ""
|
||||||
|
sentence_end = re.compile(r'[\u3002\uff01\uff1f\.!?\n]')
|
||||||
|
while True:
|
||||||
|
chunk = await text_queue.get()
|
||||||
|
if chunk is None:
|
||||||
|
if buf.strip():
|
||||||
|
await asyncio.to_thread(synthesizer.streaming_call, buf)
|
||||||
|
await asyncio.to_thread(synthesizer.streaming_complete)
|
||||||
|
break
|
||||||
|
buf += chunk
|
||||||
|
# 按句子切分喂入
|
||||||
|
while sentence_end.search(buf):
|
||||||
|
m = sentence_end.search(buf)
|
||||||
|
sentence = buf[:m.end()]
|
||||||
|
buf = buf[m.end():]
|
||||||
|
await asyncio.to_thread(synthesizer.streaming_call, sentence)
|
||||||
|
|
||||||
|
asyncio.create_task(_feed_text())
|
||||||
|
|
||||||
|
while True:
|
||||||
|
item = await audio_queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _tts_openai_stream(
|
||||||
|
api_key: str,
|
||||||
|
api_base: Optional[str],
|
||||||
|
text: str,
|
||||||
|
voice: str,
|
||||||
|
):
|
||||||
|
"""OpenAI 兼容 TTS 流式生成,yield bytes chunks"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
client = AsyncOpenAI(api_key=api_key, base_url=api_base)
|
||||||
|
async with client.audio.speech.with_streaming_response.create(
|
||||||
|
model="tts-1",
|
||||||
|
voice=voice,
|
||||||
|
input=text[:4096],
|
||||||
|
) as response:
|
||||||
|
async for chunk in response.iter_bytes(chunk_size=4096):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _tts_dashscope_stream(
|
||||||
|
api_key: str,
|
||||||
|
text: str,
|
||||||
|
voice: str,
|
||||||
|
tts_config: Dict[str, Any],
|
||||||
|
):
|
||||||
|
"""DashScope TTS 流式生成,yield bytes chunks"""
|
||||||
|
import dashscope
|
||||||
|
from dashscope.audio.tts_v2 import SpeechSynthesizer, AudioFormat, ResultCallback
|
||||||
|
|
||||||
|
model = tts_config.get("model") or "cosyvoice-v2"
|
||||||
|
is_v2 = model.endswith("-v2")
|
||||||
|
if is_v2 and not voice.endswith("_v2"):
|
||||||
|
voice = voice + "_v2"
|
||||||
|
elif not is_v2 and voice.endswith("_v2"):
|
||||||
|
voice = voice[:-3]
|
||||||
|
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
class _Callback(ResultCallback):
|
||||||
|
def on_data(self, data: bytes):
|
||||||
|
if data:
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, data)
|
||||||
|
def on_complete(self):
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||||
|
def on_error(self, message):
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, RuntimeError(str(message)))
|
||||||
|
def on_open(self): pass
|
||||||
|
def on_close(self): pass
|
||||||
|
|
||||||
|
def _sync_stream():
|
||||||
|
dashscope.api_key = api_key
|
||||||
|
synthesizer = SpeechSynthesizer(
|
||||||
|
model=model,
|
||||||
|
voice=voice,
|
||||||
|
format=AudioFormat.MP3_22050HZ_MONO_256KBPS,
|
||||||
|
callback=_Callback(),
|
||||||
|
)
|
||||||
|
synthesizer.streaming_call(text[:4096])
|
||||||
|
synthesizer.streaming_complete()
|
||||||
|
|
||||||
|
asyncio.create_task(asyncio.to_thread(_sync_stream))
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is None:
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
def _replace_variables(
|
def _replace_variables(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@@ -1221,6 +1744,12 @@ class AgentRunService:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 提前校验文件上传(与 run() 内部保持一致)
|
||||||
|
features_config: dict = agent_config.features or {}
|
||||||
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
# self._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
async def run_single_model(model_info):
|
async def run_single_model(model_info):
|
||||||
"""运行单个模型"""
|
"""运行单个模型"""
|
||||||
try:
|
try:
|
||||||
@@ -1271,6 +1800,9 @@ class AgentRunService:
|
|||||||
if elapsed > 0 and usage.get("completion_tokens") else None
|
if elapsed > 0 and usage.get("completion_tokens") else None
|
||||||
),
|
),
|
||||||
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
|
||||||
|
"audio_url": result.get("audio_url"),
|
||||||
|
"citations": result.get("citations", []),
|
||||||
|
"suggested_questions": result.get("suggested_questions", []),
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1343,7 +1875,12 @@ class AgentRunService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"results": results,
|
"results": [{
|
||||||
|
**r,
|
||||||
|
"audio_url": r.get("audio_url"),
|
||||||
|
"citations": r.get("citations", []),
|
||||||
|
"suggested_questions": r.get("suggested_questions", []),
|
||||||
|
} for r in results],
|
||||||
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
"total_elapsed_time": sum(r.get("elapsed_time", 0) for r in results),
|
||||||
"successful_count": len(successful),
|
"successful_count": len(successful),
|
||||||
"failed_count": len(failed),
|
"failed_count": len(failed),
|
||||||
@@ -1434,6 +1971,12 @@ class AgentRunService:
|
|||||||
extra={"model_count": len(models), "parallel": parallel}
|
extra={"model_count": len(models), "parallel": parallel}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 提前校验文件上传
|
||||||
|
# features_config: dict = agent_config.features or {}
|
||||||
|
# if hasattr(features_config, 'model_dump'):
|
||||||
|
# features_config = features_config.model_dump()
|
||||||
|
# self._validate_file_upload(features_config, files)
|
||||||
|
|
||||||
# 发送开始事件
|
# 发送开始事件
|
||||||
yield self._format_sse_event("compare_start", {
|
yield self._format_sse_event("compare_start", {
|
||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
@@ -1465,6 +2008,9 @@ class AgentRunService:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
full_content = ""
|
full_content = ""
|
||||||
returned_conversation_id = model_conversation_id
|
returned_conversation_id = model_conversation_id
|
||||||
|
audio_url = None
|
||||||
|
citations = []
|
||||||
|
suggested_questions = []
|
||||||
|
|
||||||
# 临时修改参数
|
# 临时修改参数
|
||||||
original_params = agent_config.model_parameters
|
original_params = agent_config.model_parameters
|
||||||
@@ -1518,6 +2064,12 @@ class AgentRunService:
|
|||||||
"content": chunk
|
"content": chunk
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
# 从 end 事件中提取 features 输出字段
|
||||||
|
if event_type == "end" and event_data:
|
||||||
|
audio_url = event_data.get("audio_url")
|
||||||
|
citations = event_data.get("citations", [])
|
||||||
|
suggested_questions = event_data.get("suggested_questions", [])
|
||||||
|
|
||||||
if event_type == "error" and event_data:
|
if event_type == "error" and event_data:
|
||||||
await event_queue.put(self._format_sse_event("model_error", {
|
await event_queue.put(self._format_sse_event("model_error", {
|
||||||
"model_index": idx,
|
"model_index": idx,
|
||||||
@@ -1543,6 +2095,9 @@ class AgentRunService:
|
|||||||
"parameters_used": model_info["parameters"],
|
"parameters_used": model_info["parameters"],
|
||||||
"message": full_content,
|
"message": full_content,
|
||||||
"elapsed_time": elapsed,
|
"elapsed_time": elapsed,
|
||||||
|
"audio_url": audio_url,
|
||||||
|
"citations": citations,
|
||||||
|
"suggested_questions": suggested_questions,
|
||||||
"error": None
|
"error": None
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1554,6 +2109,9 @@ class AgentRunService:
|
|||||||
"conversation_id": returned_conversation_id,
|
"conversation_id": returned_conversation_id,
|
||||||
"elapsed_time": elapsed,
|
"elapsed_time": elapsed,
|
||||||
"message_length": len(full_content),
|
"message_length": len(full_content),
|
||||||
|
"audio_url": audio_url,
|
||||||
|
"citations": citations,
|
||||||
|
"suggested_questions": suggested_questions,
|
||||||
"timestamp": time.time()
|
"timestamp": time.time()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -1685,8 +2243,11 @@ class AgentRunService:
|
|||||||
"model_name": r["model_name"],
|
"model_name": r["model_name"],
|
||||||
"label": r["label"],
|
"label": r["label"],
|
||||||
"conversation_id": r.get("conversation_id"),
|
"conversation_id": r.get("conversation_id"),
|
||||||
"message": r.get("message"), # 包含完整消息
|
"message": r.get("message"),
|
||||||
"elapsed_time": r.get("elapsed_time", 0),
|
"elapsed_time": r.get("elapsed_time", 0),
|
||||||
|
"audio_url": r.get("audio_url"),
|
||||||
|
"citations": r.get("citations", []),
|
||||||
|
"suggested_questions": r.get("suggested_questions", []),
|
||||||
"error": r.get("error")
|
"error": r.get("error")
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ and error handling.
|
|||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
from app.core.storage import StorageFactory, StorageBackend
|
from app.core.storage import StorageFactory, StorageBackend
|
||||||
from app.core.storage_exceptions import (
|
from app.core.storage_exceptions import (
|
||||||
@@ -162,6 +162,31 @@ class FileStorageService:
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
tenant_id: uuid.UUID,
|
||||||
|
workspace_id: uuid.UUID | None,
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
file_ext: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Upload a file from an async byte stream.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total bytes written.
|
||||||
|
"""
|
||||||
|
file_key = generate_file_key(tenant_id, workspace_id, file_id, file_ext)
|
||||||
|
logger.info(f"Starting stream upload: file_key={file_key}, content_type={content_type}")
|
||||||
|
try:
|
||||||
|
total = await self.storage.upload_stream(file_key, stream, content_type)
|
||||||
|
logger.info(f"Stream upload successful: file_key={file_key}, size={total} bytes")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Stream upload failed: file_key={file_key}, error={str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
async def download_file(self, file_key: str) -> bytes:
|
async def download_file(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Download a file from storage.
|
Download a file from storage.
|
||||||
|
|||||||
@@ -1179,7 +1179,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
app = db.query(App).filter(App.id == app_id).first()
|
app = db.query(App).filter(App.id == app_id).first()
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found: {app_id}")
|
logger.warning(f"App not found: {app_id}")
|
||||||
raise ValueError(f"应用不存在: {app_id}")
|
# raise ValueError(f"应用不存在: {app_id}")
|
||||||
# TODO: temp fix for draft run
|
# TODO: temp fix for draft run
|
||||||
# if not app.current_release_id:
|
# if not app.current_release_id:
|
||||||
# logger.warning(f"No current release for app: {app_id}")
|
# logger.warning(f"No current release for app: {app_id}")
|
||||||
@@ -1252,17 +1252,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
|
|||||||
memory_config_service = MemoryConfigService(db)
|
memory_config_service = MemoryConfigService(db)
|
||||||
memory_config = memory_config_service.get_config_with_fallback(
|
memory_config = memory_config_service.get_config_with_fallback(
|
||||||
memory_config_id=memory_config_id_to_use,
|
memory_config_id=memory_config_id_to_use,
|
||||||
workspace_id=app.workspace_id
|
workspace_id=end_user.workspace_id
|
||||||
)
|
)
|
||||||
|
|
||||||
memory_config_id = str(memory_config.config_id) if memory_config else None
|
memory_config_id = str(memory_config.config_id) if memory_config else None
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"end_user_id": str(end_user_id),
|
"end_user_id": str(end_user_id),
|
||||||
"app_id": str(app_id),
|
|
||||||
"release_id": str(app.current_release_id) if app.current_release_id else None,
|
|
||||||
"memory_config_id": memory_config_id,
|
"memory_config_id": memory_config_id,
|
||||||
"workspace_id": str(app.workspace_id)
|
"workspace_id": str(end_user.workspace_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -84,43 +84,65 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
if not app:
|
if not app:
|
||||||
logger.warning(f"App not found for end_user: {end_user_id}")
|
logger.warning(f"App not found for end_user: {end_user_id}")
|
||||||
raise ResourceNotFoundException(
|
# raise ResourceNotFoundException(
|
||||||
resource_type="App",
|
# resource_type="App",
|
||||||
resource_id=str(end_user.app_id)
|
# resource_id=str(end_user.app_id)
|
||||||
)
|
# )
|
||||||
|
# temporally allow any workspace to access
|
||||||
if app.workspace_id != workspace_id:
|
# if end_user.workspace_id != workspace_id:
|
||||||
logger.warning(
|
# print(f"[DEBUG] end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}")
|
||||||
f"End user {end_user_id} belongs to workspace {app.workspace_id}, "
|
# logger.warning(
|
||||||
f"not authorized workspace {workspace_id}"
|
# f"End user {end_user_id} belongs to workspace {end_user.workspace_id}, "
|
||||||
)
|
# f"not authorized workspace {workspace_id}"
|
||||||
raise BusinessException(
|
# )
|
||||||
message="End user does not belong to authorized workspace",
|
# raise BusinessException(
|
||||||
code=BizCode.FORBIDDEN
|
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
|
||||||
)
|
# code=BizCode.FORBIDDEN
|
||||||
|
# )
|
||||||
|
|
||||||
logger.info(f"End user {end_user_id} validated successfully")
|
logger.info(f"End user {end_user_id} validated successfully")
|
||||||
return end_user
|
return end_user
|
||||||
|
|
||||||
|
def _update_end_user_config(self, end_user_id: str, config_id: str) -> None:
|
||||||
|
"""Update the end user's memory_config_id.
|
||||||
|
|
||||||
|
Silently updates the config association. Logs warnings on failure
|
||||||
|
but does not raise, so it won't block the main read/write operation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: End user identifier
|
||||||
|
config_id: Memory configuration ID to assign
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
config_uuid = uuid.UUID(config_id)
|
||||||
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
|
end_user_repo = EndUserRepository(self.db)
|
||||||
|
end_user_repo.update_memory_config_id(
|
||||||
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
|
memory_config_id=config_uuid,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
|
||||||
|
|
||||||
async def write_memory(
|
async def write_memory(
|
||||||
self,
|
self,
|
||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
config_id: Optional[str] = None,
|
config_id: str,
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Write memory with validation.
|
"""Write memory with validation.
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, then delegates
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
to MemoryAgentService.write_memory.
|
memory_config_id, then delegates to MemoryAgentService.write_memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Message content to store
|
message: Message content to store
|
||||||
config_id: Optional memory configuration ID
|
config_id: Memory configuration ID (required)
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
user_rag_memory_id: Optional RAG memory ID
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
@@ -136,7 +158,8 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as end_user_id for memory operations
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
@@ -188,21 +211,21 @@ class MemoryAPIService:
|
|||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
search_switch: str = "0",
|
search_switch: str = "0",
|
||||||
config_id: Optional[str] = None,
|
config_id: str = "",
|
||||||
storage_type: str = "neo4j",
|
storage_type: str = "neo4j",
|
||||||
user_rag_memory_id: Optional[str] = None,
|
user_rag_memory_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Read memory with validation.
|
"""Read memory with validation.
|
||||||
|
|
||||||
Validates end_user exists and belongs to workspace, then delegates
|
Validates end_user exists and belongs to workspace, updates the end user's
|
||||||
to MemoryAgentService.read_memory.
|
memory_config_id, then delegates to MemoryAgentService.read_memory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
workspace_id: Workspace ID for resource validation
|
workspace_id: Workspace ID for resource validation
|
||||||
end_user_id: End user identifier (used as end_user_id)
|
end_user_id: End user identifier (used as end_user_id)
|
||||||
message: Query message
|
message: Query message
|
||||||
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
|
||||||
config_id: Optional memory configuration ID
|
config_id: Memory configuration ID (required)
|
||||||
storage_type: Storage backend (neo4j or rag)
|
storage_type: Storage backend (neo4j or rag)
|
||||||
user_rag_memory_id: Optional RAG memory ID
|
user_rag_memory_id: Optional RAG memory ID
|
||||||
|
|
||||||
@@ -218,7 +241,8 @@ class MemoryAPIService:
|
|||||||
# Validate end_user exists and belongs to workspace
|
# Validate end_user exists and belongs to workspace
|
||||||
self.validate_end_user(end_user_id, workspace_id)
|
self.validate_end_user(end_user_id, workspace_id)
|
||||||
|
|
||||||
# Use end_user_id as end_user_id for memory operations
|
# Update end user's memory_config_id
|
||||||
|
self._update_end_user_config(end_user_id, config_id)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -256,3 +280,50 @@ class MemoryAPIService:
|
|||||||
message=f"Memory read failed: {str(e)}",
|
message=f"Memory read failed: {str(e)}",
|
||||||
code=BizCode.MEMORY_READ_FAILED
|
code=BizCode.MEMORY_READ_FAILED
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def list_memory_configs(
|
||||||
|
self,
|
||||||
|
workspace_id: uuid.UUID,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""List all memory configs for a workspace.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: Workspace ID from API key authorization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with configs list and total count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
BusinessException: If listing fails
|
||||||
|
"""
|
||||||
|
logger.info(f"Listing memory configs for workspace: {workspace_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||||
|
|
||||||
|
results = MemoryConfigRepository.get_all(self.db, workspace_id=workspace_id)
|
||||||
|
|
||||||
|
configs = []
|
||||||
|
for config, scene_name in results:
|
||||||
|
configs.append({
|
||||||
|
"config_id": str(config.config_id),
|
||||||
|
"config_name": config.config_name,
|
||||||
|
"config_desc": config.config_desc,
|
||||||
|
"is_default": config.is_default or False,
|
||||||
|
"scene_name": scene_name,
|
||||||
|
"created_at": config.created_at.isoformat() if config.created_at else None,
|
||||||
|
"updated_at": config.updated_at.isoformat() if config.updated_at else None,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Found {len(configs)} memory configs for workspace {workspace_id}")
|
||||||
|
return {
|
||||||
|
"configs": configs,
|
||||||
|
"total": len(configs),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to list memory configs for workspace {workspace_id}: {e}")
|
||||||
|
raise BusinessException(
|
||||||
|
message=f"Failed to list memory configs: {str(e)}",
|
||||||
|
code=BizCode.MEMORY_READ_FAILED
|
||||||
|
)
|
||||||
|
|||||||
@@ -107,28 +107,29 @@ def _validate_config_id(config_id, db: Session = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]:
|
def _load_ontology_class_infos(db: Session, scene_id) -> list:
|
||||||
"""从 ontology_class 表加载场景类型名称列表,用于注入提示词。
|
"""从 ontology_class 表加载完整本体类型信息(name + description),用于注入剪枝提示词。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
db: 数据库会话
|
db: 数据库会话
|
||||||
scene_id: 本体场景 UUID
|
scene_id: 本体场景 UUID
|
||||||
pruning_scene: 语义剪枝场景名称(保留参数,暂未使用)
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
class_name 字符串列表,或 None(无数据时)
|
[{"class_name": ..., "class_description": ...}, ...] 或空列表
|
||||||
"""
|
"""
|
||||||
if not scene_id:
|
if not scene_id:
|
||||||
return None
|
return []
|
||||||
try:
|
try:
|
||||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||||
repo = OntologyClassRepository(db)
|
repo = OntologyClassRepository(db)
|
||||||
classes = repo.get_classes_by_scene(scene_id)
|
classes = repo.get_classes_by_scene(scene_id)
|
||||||
names = [c.class_name for c in classes if c.class_name]
|
return [
|
||||||
return names if names else None
|
{"class_name": c.class_name, "class_description": c.class_description or ""}
|
||||||
|
for c in classes if c.class_name
|
||||||
|
]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}")
|
logger.warning(f"Failed to load ontology class infos for scene_id={scene_id}: {e}")
|
||||||
return None
|
return []
|
||||||
|
|
||||||
|
|
||||||
class MemoryConfigService:
|
class MemoryConfigService:
|
||||||
@@ -383,7 +384,7 @@ class MemoryConfigService:
|
|||||||
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5,
|
||||||
# Ontology scene association
|
# Ontology scene association
|
||||||
scene_id=memory_config.scene_id,
|
scene_id=memory_config.scene_id,
|
||||||
ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene),
|
ontology_class_infos=_load_ontology_class_infos(self.db, memory_config.scene_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
elapsed_ms = (time.time() - start_time) * 1000
|
elapsed_ms = (time.time() - start_time) * 1000
|
||||||
@@ -550,11 +551,13 @@ class MemoryConfigService:
|
|||||||
- pruning_switch: bool
|
- pruning_switch: bool
|
||||||
- pruning_scene: str
|
- pruning_scene: str
|
||||||
- pruning_threshold: float
|
- pruning_threshold: float
|
||||||
|
- ontology_class_infos: list of {class_name, class_description} dicts
|
||||||
"""
|
"""
|
||||||
return {
|
return {
|
||||||
"pruning_switch": memory_config.pruning_enabled,
|
"pruning_switch": memory_config.pruning_enabled,
|
||||||
"pruning_scene": memory_config.pruning_scene,
|
"pruning_scene": memory_config.pruning_scene,
|
||||||
"pruning_threshold": memory_config.pruning_threshold,
|
"pruning_threshold": memory_config.pruning_threshold,
|
||||||
|
"ontology_class_infos": memory_config.ontology_class_infos or [],
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_ontology_types(self, memory_config: MemoryConfig):
|
def get_ontology_types(self, memory_config: MemoryConfig):
|
||||||
|
|||||||
@@ -68,14 +68,14 @@ def get_workspace_end_users(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
# 提取所有 app_id
|
# 提取所有 app_id
|
||||||
app_ids = [app.id for app in apps_orm]
|
# app_ids = [app.id for app in apps_orm]
|
||||||
|
|
||||||
# 批量查询所有 end_users(一次查询而非循环查询)
|
# 批量查询所有 end_users(一次查询而非循环查询)
|
||||||
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
# 按 created_at 降序排序,NULL 值排在最后;id 作为次级排序键保证确定性
|
||||||
from app.models.end_user_model import EndUser as EndUserModel
|
from app.models.end_user_model import EndUser as EndUserModel
|
||||||
from sqlalchemy import desc, nullslast
|
from sqlalchemy import desc, nullslast
|
||||||
end_users_orm = db.query(EndUserModel).filter(
|
end_users_orm = db.query(EndUserModel).filter(
|
||||||
EndUserModel.app_id.in_(app_ids)
|
EndUserModel.workspace_id == workspace_id
|
||||||
).order_by(
|
).order_by(
|
||||||
nullslast(desc(EndUserModel.created_at)),
|
nullslast(desc(EndUserModel.created_at)),
|
||||||
desc(EndUserModel.id)
|
desc(EndUserModel.id)
|
||||||
|
|||||||
@@ -518,7 +518,7 @@ class MemoryForgetService:
|
|||||||
'total_nodes': result['total_nodes'] or 0,
|
'total_nodes': result['total_nodes'] or 0,
|
||||||
'nodes_with_activation': result['nodes_with_activation'] or 0,
|
'nodes_with_activation': result['nodes_with_activation'] or 0,
|
||||||
'nodes_without_activation': result['nodes_without_activation'] or 0,
|
'nodes_without_activation': result['nodes_without_activation'] or 0,
|
||||||
'average_activation_value': result['average_activation'],
|
'average_activation_value': round(result['average_activation'], 2) if result['average_activation'] is not None else None,
|
||||||
'low_activation_nodes': result['low_activation_nodes'] or 0,
|
'low_activation_nodes': result['low_activation_nodes'] or 0,
|
||||||
'forgetting_threshold': forgetting_threshold,
|
'forgetting_threshold': forgetting_threshold,
|
||||||
'timestamp': int(datetime.now().timestamp() * 1000)
|
'timestamp': int(datetime.now().timestamp() * 1000)
|
||||||
|
|||||||
@@ -5,12 +5,14 @@ from urllib.parse import urlparse, unquote
|
|||||||
|
|
||||||
import json_repair
|
import json_repair
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||||
|
from app.models import FileMetadata
|
||||||
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
from app.models.memory_perceptual_model import PerceptualType, FileStorageService
|
||||||
from app.models.prompt_optimizer_model import RoleType
|
from app.models.prompt_optimizer_model import RoleType
|
||||||
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
from app.repositories.memory_perceptual_repository import MemoryPerceptualRepository
|
||||||
@@ -245,6 +247,18 @@ class MemoryPerceptualService:
|
|||||||
filename = os.path.basename(path)
|
filename = os.path.basename(path)
|
||||||
filename = unquote(filename)
|
filename = unquote(filename)
|
||||||
file_ext = os.path.splitext(filename)[1]
|
file_ext = os.path.splitext(filename)[1]
|
||||||
|
try:
|
||||||
|
file_id = uuid.UUID(filename)
|
||||||
|
stmt = select(FileMetadata).where(
|
||||||
|
FileMetadata.id == file_id
|
||||||
|
)
|
||||||
|
file = self.db.execute(stmt).scalar_one_or_none()
|
||||||
|
|
||||||
|
if file:
|
||||||
|
filename = file.file_name
|
||||||
|
file_ext = file.file_ext
|
||||||
|
except ValueError:
|
||||||
|
business_logger.debug(f"Remote file, file_id={filename}")
|
||||||
if not file_ext:
|
if not file_ext:
|
||||||
if file_type == FileType.AUDIO:
|
if file_type == FileType.AUDIO:
|
||||||
file_ext = ".mp3"
|
file_ext = ".mp3"
|
||||||
@@ -262,17 +276,17 @@ class MemoryPerceptualService:
|
|||||||
}
|
}
|
||||||
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
if file_type in [FileType.IMAGE, FileType.VIDEO]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"scene": content.get("scene")
|
"scene": content.get("scene", [])
|
||||||
}
|
}
|
||||||
elif file_type in [FileType.DOCUMENT]:
|
elif file_type in [FileType.DOCUMENT]:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"section_count": content.get("section_count"),
|
"section_count": content.get("section_count", 0),
|
||||||
"title": content.get("title"),
|
"title": content.get("title", ""),
|
||||||
"first_line": content.get("first_line")
|
"first_line": content.get("first_line", "")
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
file_modalities = {
|
file_modalities = {
|
||||||
"speaker_count": content.get("speaker_count")
|
"speaker_count": content.get("speaker_count", 0)
|
||||||
}
|
}
|
||||||
self.repository.create_perceptual_memory(
|
self.repository.create_perceptual_memory(
|
||||||
end_user_id=uuid.UUID(end_user_id),
|
end_user_id=uuid.UUID(end_user_id),
|
||||||
@@ -280,7 +294,7 @@ class MemoryPerceptualService:
|
|||||||
file_path=file_url,
|
file_path=file_url,
|
||||||
file_name=filename,
|
file_name=filename,
|
||||||
file_ext=file_ext,
|
file_ext=file_ext,
|
||||||
summary=content.get('summary'),
|
summary=content.get('summary', ""),
|
||||||
meta_data={
|
meta_data={
|
||||||
"content": file_content,
|
"content": file_content,
|
||||||
"modalities": file_modalities
|
"modalities": file_modalities
|
||||||
|
|||||||
@@ -1638,6 +1638,7 @@ class MultiAgentOrchestrator:
|
|||||||
self.variables = config_data.get("variables", [])
|
self.variables = config_data.get("variables", [])
|
||||||
self.tools = config_data.get("tools", {})
|
self.tools = config_data.get("tools", {})
|
||||||
self.skills = config_data.get("skills", {})
|
self.skills = config_data.get("skills", {})
|
||||||
|
self.features = config_data.get("features", {})
|
||||||
self.default_model_config_id = release.default_model_config_id
|
self.default_model_config_id = release.default_model_config_id
|
||||||
|
|
||||||
return AgentConfigProxy(release, app, config_data)
|
return AgentConfigProxy(release, app, config_data)
|
||||||
|
|||||||
@@ -14,9 +14,13 @@ import uuid
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
|
||||||
import PyPDF2
|
import PyPDF2
|
||||||
import httpx
|
import httpx
|
||||||
import magic
|
import magic
|
||||||
|
import openpyxl
|
||||||
from docx import Document
|
from docx import Document
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -37,8 +41,16 @@ TEXT_MIME = ['text/plain', 'text/x-markdown']
|
|||||||
PDF_MIME = ['application/pdf']
|
PDF_MIME = ['application/pdf']
|
||||||
DOC_MIME = [
|
DOC_MIME = [
|
||||||
'application/msword',
|
'application/msword',
|
||||||
'application/vnd.openxmlformats-officedocument.wordprocessingml.document'
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
|
'application/zip'
|
||||||
]
|
]
|
||||||
|
XLSX_MIME = [
|
||||||
|
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
|
||||||
|
'application/vnd.ms-excel',
|
||||||
|
'application/zip'
|
||||||
|
]
|
||||||
|
CSV_MIME = ['text/csv', 'application/csv']
|
||||||
|
JSON_MIME = ['application/json']
|
||||||
|
|
||||||
|
|
||||||
class MultimodalFormatStrategy(ABC):
|
class MultimodalFormatStrategy(ABC):
|
||||||
@@ -48,22 +60,22 @@ class MultimodalFormatStrategy(ABC):
|
|||||||
self.file = file
|
self.file = file
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""格式化图片"""
|
"""格式化图片"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""格式化文档"""
|
"""格式化文档"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
async def format_audio(self, file_type: str, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""格式化音频"""
|
"""格式化音频"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""格式化视频"""
|
"""格式化视频"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -71,16 +83,16 @@ class MultimodalFormatStrategy(ABC):
|
|||||||
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
||||||
"""通义千问策略"""
|
"""通义千问策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
"""通义千问图片格式:{"type": "image", "image": "url"}"""
|
||||||
return {
|
return True, {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"image": url
|
"image": url
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""通义千问文档格式"""
|
"""通义千问文档格式"""
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
@@ -91,26 +103,26 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
|||||||
url: str,
|
url: str,
|
||||||
content: bytes | None = None,
|
content: bytes | None = None,
|
||||||
transcription: Optional[str] = None
|
transcription: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
通义千问音频格式
|
通义千问音频格式
|
||||||
- 原生支持: qwen-audio 系列
|
- 原生支持: qwen-audio 系列
|
||||||
- 其他模型: 需要转录为文本
|
- 其他模型: 需要转录为文本
|
||||||
"""
|
"""
|
||||||
if transcription:
|
if transcription:
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
"text": f"<audio url=\"{url}\">\ntext_transcription:{transcription}\n</audio>"
|
||||||
}
|
}
|
||||||
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
# 通义千问音频格式:{"type": "audio", "audio": "url"}
|
||||||
return {
|
return True, {
|
||||||
"type": "audio",
|
"type": "audio",
|
||||||
"audio": url
|
"audio": url
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""通义千问视频格式(qwen-vl 系列原生支持)"""
|
"""通义千问视频格式(qwen-vl 系列原生支持)"""
|
||||||
return {
|
return True, {
|
||||||
"type": "video",
|
"type": "video",
|
||||||
"video": url
|
"video": url
|
||||||
}
|
}
|
||||||
@@ -119,7 +131,7 @@ class DashScopeFormatStrategy(MultimodalFormatStrategy):
|
|||||||
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
class BedrockFormatStrategy(MultimodalFormatStrategy):
|
||||||
"""Bedrock/Anthropic 策略"""
|
"""Bedrock/Anthropic 策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Bedrock/Anthropic 格式: base64 编码
|
Bedrock/Anthropic 格式: base64 编码
|
||||||
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
{"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}}
|
||||||
@@ -142,7 +154,7 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
|
|
||||||
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
logger.info(f"图片编码完成: media_type={media_type}, size={len(base64_data)}")
|
||||||
|
|
||||||
return {
|
return True, {
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
@@ -151,13 +163,13 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
"""Bedrock/Anthropic 文档格式(需要 base64 编码)"""
|
||||||
# Bedrock 文档需要 base64 编码
|
# Bedrock 文档需要 base64 编码
|
||||||
text_bytes = text.encode('utf-8')
|
text_bytes = text.encode('utf-8')
|
||||||
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
base64_text = base64.b64encode(text_bytes).decode('utf-8')
|
||||||
|
|
||||||
return {
|
return True, {
|
||||||
"type": "document",
|
"type": "document",
|
||||||
"source": {
|
"source": {
|
||||||
"type": "base64",
|
"type": "base64",
|
||||||
@@ -171,24 +183,24 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
url: str,
|
url: str,
|
||||||
content: bytes | None = None,
|
content: bytes | None = None,
|
||||||
transcription: Optional[str] = None
|
transcription: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Bedrock/Anthropic 音频格式
|
Bedrock/Anthropic 音频格式
|
||||||
不支持原生音频,必须转录为文本
|
不支持原生音频,必须转录为文本
|
||||||
"""
|
"""
|
||||||
if transcription:
|
if transcription:
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[音频转录]\n{transcription}"
|
"text": f"[音频转录]\n{transcription}"
|
||||||
}
|
}
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]"
|
"text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""Bedrock/Anthropic 视频格式"""
|
"""Bedrock/Anthropic 视频格式"""
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
|
"text": f"<video url=\"{url}\">\n[视频文件,当前 provider 暂不支持]\n</video>"
|
||||||
}
|
}
|
||||||
@@ -197,18 +209,18 @@ class BedrockFormatStrategy(MultimodalFormatStrategy):
|
|||||||
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
||||||
"""OpenAI 策略"""
|
"""OpenAI 策略"""
|
||||||
|
|
||||||
async def format_image(self, url: str, content: bytes | None = None) -> Dict[str, Any]:
|
async def format_image(self, url: str, content: bytes | None = None) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
"""OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}"""
|
||||||
return {
|
return True, {
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {
|
||||||
"url": url
|
"url": url
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_document(self, file_name: str, text: str) -> Dict[str, Any]:
|
async def format_document(self, file_name: str, text: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""OpenAI 文档格式"""
|
"""OpenAI 文档格式"""
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
"text": f"<document name=\"{file_name}\">\n{text}\n</document>"
|
||||||
}
|
}
|
||||||
@@ -219,14 +231,14 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
|||||||
url: str,
|
url: str,
|
||||||
content: bytes | None = None,
|
content: bytes | None = None,
|
||||||
transcription: Optional[str] = None
|
transcription: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
OpenAI 音频格式
|
OpenAI 音频格式
|
||||||
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
- gpt-4o-audio 系列支持原生音频(需要 base64 编码)
|
||||||
- 其他模型使用转录文本
|
- 其他模型使用转录文本
|
||||||
"""
|
"""
|
||||||
if transcription:
|
if transcription:
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
"text": f"<audio url=\"{url}\">\n{transcription}\n</audio>"
|
||||||
}
|
}
|
||||||
@@ -255,7 +267,7 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
|||||||
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
# supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"}
|
||||||
file_ext = "wav" if not file_ext else file_ext
|
file_ext = "wav" if not file_ext else file_ext
|
||||||
|
|
||||||
return {
|
return True, {
|
||||||
"type": "input_audio",
|
"type": "input_audio",
|
||||||
"input_audio": {
|
"input_audio": {
|
||||||
"data": f"data:;base64,{base64_audio}",
|
"data": f"data:;base64,{base64_audio}",
|
||||||
@@ -264,14 +276,14 @@ class OpenAIFormatStrategy(MultimodalFormatStrategy):
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"下载音频失败: {e}")
|
logger.error(f"下载音频失败: {e}")
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[音频处理失败: {str(e)}]"
|
"text": f"[音频处理失败: {str(e)}]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def format_video(self, url: str) -> Dict[str, Any]:
|
async def format_video(self, url: str) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""OpenAI 视频格式"""
|
"""OpenAI 视频格式"""
|
||||||
return {
|
return True, {
|
||||||
"type": "video_url",
|
"type": "video_url",
|
||||||
"video_url": {
|
"video_url": {
|
||||||
"url": url
|
"url": url
|
||||||
@@ -366,20 +378,24 @@ class MultimodalService:
|
|||||||
file.url = await self.get_file_url(file)
|
file.url = await self.get_file_url(file)
|
||||||
try:
|
try:
|
||||||
if file.type == FileType.IMAGE and "vision" in self.capability:
|
if file.type == FileType.IMAGE and "vision" in self.capability:
|
||||||
content = await self._process_image(file, strategy)
|
is_support, content = await self._process_image(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
if is_support:
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||||
elif file.type == FileType.DOCUMENT:
|
elif file.type == FileType.DOCUMENT:
|
||||||
content = await self._process_document(file, strategy)
|
is_support, content = await self._process_document(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
if is_support:
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||||
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
elif file.type == FileType.AUDIO and "audio" in self.capability:
|
||||||
content = await self._process_audio(file, strategy)
|
is_support, content = await self._process_audio(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
if is_support:
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||||
elif file.type == FileType.VIDEO and "video" in self.capability:
|
elif file.type == FileType.VIDEO and "video" in self.capability:
|
||||||
content = await self._process_video(file, strategy)
|
is_support, content = await self._process_video(file, strategy)
|
||||||
result.append(content)
|
result.append(content)
|
||||||
|
if is_support:
|
||||||
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
self.write_perceptual_memory(end_user_id, file.type, file.url, content)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"不支持的文件类型: {file.type}")
|
logger.warning(f"不支持的文件类型: {file.type}")
|
||||||
@@ -413,7 +429,7 @@ class MultimodalService:
|
|||||||
if end_user_id and self.api_config:
|
if end_user_id and self.api_config:
|
||||||
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
write_perceptual_memory.delay(end_user_id, self.api_config.model_dump(), file_type, file_url, file_message)
|
||||||
|
|
||||||
async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]:
|
async def _process_image(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理图片文件
|
处理图片文件
|
||||||
|
|
||||||
@@ -425,16 +441,16 @@ class MultimodalService:
|
|||||||
Dict: 根据 provider 返回不同格式的图片内容
|
Dict: 根据 provider 返回不同格式的图片内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
url = await self.get_file_url(file)
|
# url = await self.get_file_url(file)
|
||||||
return await strategy.format_image(url, content=file.get_content())
|
return await strategy.format_image(file.url, content=file.get_content())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理图片失败: {e}", exc_info=True)
|
logger.error(f"处理图片失败: {e}", exc_info=True)
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[图片处理失败: {str(e)}]"
|
"text": f"[图片处理失败: {str(e)}]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]:
|
async def _process_document(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理文档文件(PDF、Word 等)
|
处理文档文件(PDF、Word 等)
|
||||||
|
|
||||||
@@ -446,7 +462,7 @@ class MultimodalService:
|
|||||||
Dict: 根据 provider 返回不同格式的文档内容
|
Dict: 根据 provider 返回不同格式的文档内容
|
||||||
"""
|
"""
|
||||||
if file.transfer_method == TransferMethod.REMOTE_URL:
|
if file.transfer_method == TransferMethod.REMOTE_URL:
|
||||||
return {
|
return True, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
"text": f"<document url=\"{file.url}\">\n{await self._extract_document_text(file)}\n</document>"
|
||||||
}
|
}
|
||||||
@@ -464,7 +480,7 @@ class MultimodalService:
|
|||||||
# 使用策略格式化文档
|
# 使用策略格式化文档
|
||||||
return await strategy.format_document(file_name, text)
|
return await strategy.format_document(file_name, text)
|
||||||
|
|
||||||
async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]:
|
async def _process_audio(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理音频文件
|
处理音频文件
|
||||||
|
|
||||||
@@ -476,28 +492,28 @@ class MultimodalService:
|
|||||||
Dict: 根据 provider 返回不同格式的音频内容
|
Dict: 根据 provider 返回不同格式的音频内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
url = await self.get_file_url(file)
|
# url = await self.get_file_url(file)
|
||||||
|
|
||||||
# 如果启用音频转文本且有 API Key
|
# 如果启用音频转文本且有 API Key
|
||||||
transcription = None
|
transcription = None
|
||||||
if self.enable_audio_transcription and self.audio_api_key:
|
if self.enable_audio_transcription and self.audio_api_key:
|
||||||
logger.info(f"开始音频转文本: {url}")
|
logger.info(f"开始音频转文本: {file.url}")
|
||||||
if self.provider == "dashscope":
|
if self.provider == "dashscope":
|
||||||
transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.audio_api_key)
|
transcription = await AudioTranscriptionService.transcribe_dashscope(file.url, self.audio_api_key)
|
||||||
elif self.provider == "openai":
|
elif self.provider == "openai":
|
||||||
transcription = await AudioTranscriptionService.transcribe_openai(url, self.audio_api_key)
|
transcription = await AudioTranscriptionService.transcribe_openai(file.url, self.audio_api_key)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
logger.warning(f"Provider {self.provider} 不支持音频转文本")
|
||||||
|
|
||||||
return await strategy.format_audio(file.file_type, url, file.get_content(), transcription)
|
return await strategy.format_audio(file.file_type, file.url, file.get_content(), transcription)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理音频失败: {e}", exc_info=True)
|
logger.error(f"处理音频失败: {e}", exc_info=True)
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[音频处理失败: {str(e)}]"
|
"text": f"[音频处理失败: {str(e)}]"
|
||||||
}
|
}
|
||||||
|
|
||||||
async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]:
|
async def _process_video(self, file: FileInput, strategy) -> tuple[bool, Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
处理视频文件
|
处理视频文件
|
||||||
|
|
||||||
@@ -509,11 +525,11 @@ class MultimodalService:
|
|||||||
Dict: 根据 provider 返回不同格式的视频内容
|
Dict: 根据 provider 返回不同格式的视频内容
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
url = await self.get_file_url(file)
|
# url = await self.get_file_url(file)
|
||||||
return await strategy.format_video(url)
|
return await strategy.format_video(file.url)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理视频失败: {e}", exc_info=True)
|
logger.error(f"处理视频失败: {e}", exc_info=True)
|
||||||
return {
|
return False, {
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"[视频处理失败: {str(e)}]"
|
"text": f"[视频处理失败: {str(e)}]"
|
||||||
}
|
}
|
||||||
@@ -575,8 +591,14 @@ class MultimodalService:
|
|||||||
return file_content.decode("utf-8")
|
return file_content.decode("utf-8")
|
||||||
elif file_mime_type in PDF_MIME:
|
elif file_mime_type in PDF_MIME:
|
||||||
return await self._extract_pdf_text(file_content)
|
return await self._extract_pdf_text(file_content)
|
||||||
elif file_mime_type in DOC_MIME:
|
elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')):
|
||||||
return await self._extract_word_text(file_content)
|
return await self._extract_word_text(file_content)
|
||||||
|
elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")):
|
||||||
|
return await self._extract_xlsx_text(file_content)
|
||||||
|
elif file_mime_type in CSV_MIME:
|
||||||
|
return await self._extract_csv_text(file_content)
|
||||||
|
elif file_mime_type in JSON_MIME:
|
||||||
|
return await self._extract_json_text(file_content)
|
||||||
else:
|
else:
|
||||||
return f"[Unsupported file type: {file_mime_type}]"
|
return f"[Unsupported file type: {file_mime_type}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -602,7 +624,6 @@ class MultimodalService:
|
|||||||
async def _extract_word_text(file_content: bytes) -> str:
|
async def _extract_word_text(file_content: bytes) -> str:
|
||||||
"""提取 Word 文档文本"""
|
"""提取 Word 文档文本"""
|
||||||
try:
|
try:
|
||||||
# 使用 BytesIO 读取 Word 文档
|
|
||||||
word_file = io.BytesIO(file_content)
|
word_file = io.BytesIO(file_content)
|
||||||
doc = Document(word_file)
|
doc = Document(word_file)
|
||||||
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
text_parts = [paragraph.text for paragraph in doc.paragraphs]
|
||||||
@@ -611,6 +632,42 @@ class MultimodalService:
|
|||||||
logger.error(f"提取 Word 文本失败: {e}")
|
logger.error(f"提取 Word 文本失败: {e}")
|
||||||
return f"[Word 提取失败: {str(e)}]"
|
return f"[Word 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_xlsx_text(file_content: bytes) -> str:
|
||||||
|
"""提取 Excel 文本"""
|
||||||
|
try:
|
||||||
|
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
|
||||||
|
parts = []
|
||||||
|
for sheet in wb.worksheets:
|
||||||
|
parts.append(f"[Sheet: {sheet.title}]")
|
||||||
|
for row in sheet.iter_rows(values_only=True):
|
||||||
|
parts.append('\t'.join('' if v is None else str(v) for v in row))
|
||||||
|
return '\n'.join(parts)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 Excel 文本失败: {e}")
|
||||||
|
return f"[Excel 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_csv_text(file_content: bytes) -> str:
|
||||||
|
"""提取 CSV 文本"""
|
||||||
|
try:
|
||||||
|
text = file_content.decode('utf-8-sig')
|
||||||
|
reader = csv.reader(io.StringIO(text))
|
||||||
|
return '\n'.join('\t'.join(row) for row in reader)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 CSV 文本失败: {e}")
|
||||||
|
return f"[CSV 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _extract_json_text(file_content: bytes) -> str:
|
||||||
|
"""提取 JSON 文本"""
|
||||||
|
try:
|
||||||
|
data = json.loads(file_content.decode('utf-8'))
|
||||||
|
return json.dumps(data, ensure_ascii=False, indent=2)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取 JSON 文本失败: {e}")
|
||||||
|
return f"[JSON 提取失败: {str(e)}]"
|
||||||
|
|
||||||
|
|
||||||
def get_multimodal_service(db: Session) -> MultimodalService:
|
def get_multimodal_service(db: Session) -> MultimodalService:
|
||||||
"""获取多模态服务实例(依赖注入)"""
|
"""获取多模态服务实例(依赖注入)"""
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ async def run_pilot_extraction(
|
|||||||
"pruning_scene": memory_config.pruning_scene,
|
"pruning_scene": memory_config.pruning_scene,
|
||||||
"pruning_threshold": memory_config.pruning_threshold,
|
"pruning_threshold": memory_config.pruning_threshold,
|
||||||
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
"scene_id": str(memory_config.scene_id) if memory_config.scene_id else None,
|
||||||
"ontology_classes": memory_config.ontology_classes,
|
"ontology_class_infos": memory_config.ontology_class_infos,
|
||||||
}
|
}
|
||||||
config = PruningConfig(**pruning_config_dict)
|
config = PruningConfig(**pruning_config_dict)
|
||||||
|
|
||||||
@@ -232,9 +232,11 @@ async def run_pilot_extraction(
|
|||||||
"chunker_strategy": memory_config.chunker_strategy,
|
"chunker_strategy": memory_config.chunker_strategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 添加剪枝统计信息
|
# 添加剪枝统计信息(始终包含 pruning 字段,确保前端不会因字段缺失报错)
|
||||||
if pruning_stats:
|
preprocessing_summary["pruning"] = pruning_stats if pruning_stats else {
|
||||||
preprocessing_summary["pruning"] = pruning_stats
|
"enabled": memory_config.pruning_enabled,
|
||||||
|
"deleted_count": 0,
|
||||||
|
}
|
||||||
|
|
||||||
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary)
|
||||||
|
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class ToolService:
|
|||||||
|
|
||||||
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
|
def get_tool_info(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolInfo]:
|
||||||
"""获取工具详情"""
|
"""获取工具详情"""
|
||||||
config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
config = self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||||||
return self._config_to_info(config) if config else None
|
return self._config_to_info(config) if config else None
|
||||||
|
|
||||||
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None):
|
||||||
@@ -237,7 +237,7 @@ class ToolService:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
|
def delete_tool(self, tool_id: str, tenant_id: uuid.UUID) -> bool:
|
||||||
"""删除工具"""
|
"""删除工具(逻辑删除)"""
|
||||||
config = self._get_tool_config(tool_id, tenant_id)
|
config = self._get_tool_config(tool_id, tenant_id)
|
||||||
if not config:
|
if not config:
|
||||||
return False
|
return False
|
||||||
@@ -246,14 +246,7 @@ class ToolService:
|
|||||||
raise ValueError("内置工具不允许删除")
|
raise ValueError("内置工具不允许删除")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 删除关联表记录
|
config.is_active = False
|
||||||
if config.tool_type == ToolType.CUSTOM.value:
|
|
||||||
self.db.query(CustomToolConfig).filter(CustomToolConfig.id == config.id).delete()
|
|
||||||
elif config.tool_type == ToolType.MCP.value:
|
|
||||||
self.db.query(MCPToolConfig).filter(MCPToolConfig.id == config.id).delete()
|
|
||||||
|
|
||||||
# 删除主表记录(ToolExecution会通过cascade自动删除)
|
|
||||||
self.db.delete(config)
|
|
||||||
self._clear_tool_cache(tool_id)
|
self._clear_tool_cache(tool_id)
|
||||||
self.db.commit()
|
self.db.commit()
|
||||||
return True
|
return True
|
||||||
@@ -262,6 +255,27 @@ class ToolService:
|
|||||||
logger.error(f"删除工具失败: {tool_id}, {e}")
|
logger.error(f"删除工具失败: {tool_id}, {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def set_tool_active(self, tool_id: str, tenant_id: uuid.UUID, is_active: bool) -> bool:
|
||||||
|
"""设置工具可用状态(启用/禁用)"""
|
||||||
|
# 直接查询,包含 is_active=False 的记录
|
||||||
|
config = self.db.query(ToolConfig).filter(
|
||||||
|
ToolConfig.id == uuid.UUID(tool_id),
|
||||||
|
ToolConfig.tenant_id == tenant_id
|
||||||
|
).first()
|
||||||
|
if not config:
|
||||||
|
return False
|
||||||
|
if config.tool_type == ToolType.BUILTIN.value:
|
||||||
|
raise ValueError("内置工具不允许修改可用状态")
|
||||||
|
try:
|
||||||
|
config.is_active = is_active
|
||||||
|
self._clear_tool_cache(tool_id)
|
||||||
|
self.db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
self.db.rollback()
|
||||||
|
logger.error(f"设置工具状态失败: {tool_id}, {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
self,
|
self,
|
||||||
tool_id: str,
|
tool_id: str,
|
||||||
@@ -378,7 +392,7 @@ class ToolService:
|
|||||||
Returns:
|
Returns:
|
||||||
方法列表或None
|
方法列表或None
|
||||||
"""
|
"""
|
||||||
config = self._get_tool_config(tool_id, tenant_id)
|
config = self._get_tool_config_all(tool_id, tenant_id)
|
||||||
if not config:
|
if not config:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -857,16 +871,20 @@ class ToolService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
def _get_tool_config(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||||
"""获取工具配置"""
|
"""获取工具配置(仅返回 is_active=True)"""
|
||||||
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id)
|
||||||
|
|
||||||
|
def _get_tool_config_all(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[ToolConfig]:
|
||||||
|
"""获取工具配置(返回所有)"""
|
||||||
|
return self.tool_repo.find_by_id_and_tenant_all(self.db, uuid.UUID(tool_id), tenant_id)
|
||||||
|
|
||||||
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]:
|
||||||
"""获取工具实例"""
|
"""获取工具实例(仅返回 is_active=True 的工具)"""
|
||||||
if tool_id in self._tool_cache:
|
if tool_id in self._tool_cache:
|
||||||
return self._tool_cache[tool_id]
|
return self._tool_cache[tool_id]
|
||||||
|
|
||||||
config = self._get_tool_config(tool_id, tenant_id)
|
config = self._get_tool_config(tool_id, tenant_id)
|
||||||
if not config:
|
if not config or not config.is_active:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -980,6 +998,7 @@ class ToolService:
|
|||||||
tags=config.tags or [],
|
tags=config.tags or [],
|
||||||
tenant_id=str(config.tenant_id) if config.tenant_id else None,
|
tenant_id=str(config.tenant_id) if config.tenant_id else None,
|
||||||
config_data=config_data,
|
config_data=config_data,
|
||||||
|
is_active=config.is_active,
|
||||||
created_at=config.created_at
|
created_at=config.created_at
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest, FileInput, FileType
|
from app.schemas import DraftRunRequest, FileInput
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
@@ -570,6 +570,9 @@ class WorkflowService:
|
|||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
feature_configs = config.features or {}
|
||||||
|
self._validate_file_upload(feature_configs, payload.files)
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
"message": payload.message, "variables": payload.variables,
|
"message": payload.message, "variables": payload.variables,
|
||||||
"conversation_id": payload.conversation_id,
|
"conversation_id": payload.conversation_id,
|
||||||
@@ -633,30 +636,33 @@ class WorkflowService:
|
|||||||
final_messages = result.get("messages", [])[init_message_length:]
|
final_messages = result.get("messages", [])[init_message_length:]
|
||||||
human_message = ""
|
human_message = ""
|
||||||
assistant_message = ""
|
assistant_message = ""
|
||||||
|
human_meta = {
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
for message in final_messages:
|
for message in final_messages:
|
||||||
if message["role"] == "user":
|
if message["role"] == "user":
|
||||||
if isinstance(message["content"], str):
|
if isinstance(message["content"], str):
|
||||||
human_message += message["content"]
|
human_message += message["content"]
|
||||||
elif isinstance(message["content"], list):
|
elif isinstance(message["content"], list):
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
if file.get("type") == FileType.IMAGE:
|
human_meta["files"].append({
|
||||||
human_message += f"})"
|
"type": file.get("type"),
|
||||||
else:
|
"url": file.get("url")
|
||||||
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
})
|
||||||
if message["role"] == "assistant":
|
if message["role"] == "assistant":
|
||||||
assistant_message = message["content"]
|
assistant_message = message["content"]
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
role="user",
|
role="user",
|
||||||
content=human_message,
|
content=human_message,
|
||||||
meta_data=None
|
meta_data=human_meta
|
||||||
)
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message,
|
content=assistant_message,
|
||||||
meta_data={"usage": token_usage}
|
meta_data={"usage": token_usage, "audio_url": None}
|
||||||
)
|
)
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
@@ -737,6 +743,8 @@ class WorkflowService:
|
|||||||
code=BizCode.CONFIG_MISSING,
|
code=BizCode.CONFIG_MISSING,
|
||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
feature_configs = config.features or {}
|
||||||
|
self._validate_file_upload(feature_configs, payload.files)
|
||||||
|
|
||||||
input_data = {
|
input_data = {
|
||||||
"message": payload.message, "variables": payload.variables,
|
"message": payload.message, "variables": payload.variables,
|
||||||
@@ -797,30 +805,33 @@ class WorkflowService:
|
|||||||
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
final_messages = event.get("data", {}).get("messages", [])[init_message_length:]
|
||||||
human_message = ""
|
human_message = ""
|
||||||
assistant_message = ""
|
assistant_message = ""
|
||||||
|
human_meta = {
|
||||||
|
"files": []
|
||||||
|
}
|
||||||
for message in final_messages:
|
for message in final_messages:
|
||||||
if message["role"] == "user":
|
if message["role"] == "user":
|
||||||
if isinstance(message["content"], str):
|
if isinstance(message["content"], str):
|
||||||
human_message += message["content"]
|
human_message += message["content"]
|
||||||
elif isinstance(message["content"], list):
|
elif isinstance(message["content"], list):
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
if file.get("type") == FileType.IMAGE:
|
human_meta["files"].append({
|
||||||
human_message += f"})"
|
"type": file.get("type"),
|
||||||
else:
|
"url": file.get("url")
|
||||||
human_message += f"[{file.get('type')}]({file.get('url', '')})"
|
})
|
||||||
if message["role"] == "assistant":
|
if message["role"] == "assistant":
|
||||||
assistant_message = message["content"]
|
assistant_message = message["content"]
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
role="user",
|
role="user",
|
||||||
content=human_message,
|
content=human_message,
|
||||||
meta_data=None
|
meta_data=human_meta
|
||||||
)
|
)
|
||||||
self.conversation_service.add_message(
|
self.conversation_service.add_message(
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
conversation_id=conversation_id_uuid,
|
conversation_id=conversation_id_uuid,
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=assistant_message,
|
content=assistant_message,
|
||||||
meta_data={"usage": token_usage}
|
meta_data={"usage": token_usage, "audio_url": None}
|
||||||
)
|
)
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
@@ -845,7 +856,10 @@ class WorkflowService:
|
|||||||
yield event
|
yield event
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"工作流流式执行失败: execution_id={execution.execution_id}, error={e}", exc_info=True)
|
logger.error(
|
||||||
|
f"Workflow streaming execution failed: execution_id={execution.execution_id}, error={e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
self.update_execution_status(
|
self.update_execution_status(
|
||||||
execution.execution_id,
|
execution.execution_id,
|
||||||
"failed",
|
"failed",
|
||||||
@@ -868,6 +882,80 @@ class WorkflowService:
|
|||||||
return node.get("config", {}).get("variables", [])
|
return node.get("config", {}).get("variables", [])
|
||||||
raise BusinessException("workflow config error - start node not found")
|
raise BusinessException("workflow config error - start node not found")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_memory_enable(config: dict) -> bool:
|
||||||
|
nodes = config.get("nodes", [])
|
||||||
|
for node in nodes:
|
||||||
|
if node.get("type") in [NodeType.MEMORY_READ, NodeType.MEMORY_WRITE]:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_file_upload(
|
||||||
|
features_config: dict[str, Any],
|
||||||
|
files: Optional[list[FileInput]]
|
||||||
|
) -> None:
|
||||||
|
"""校验上传文件是否符合 file_upload 配置"""
|
||||||
|
if not files:
|
||||||
|
return
|
||||||
|
fu = features_config.get("file_upload")
|
||||||
|
if fu is None:
|
||||||
|
return
|
||||||
|
if not (isinstance(fu, dict) and fu.get("enabled")):
|
||||||
|
raise BusinessException(
|
||||||
|
"The application does not have file upload functionality enabled",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
max_count = fu.get("max_file_count", 5)
|
||||||
|
if len(files) > max_count:
|
||||||
|
raise BusinessException(
|
||||||
|
f"File count exceeds limit (maximum {max_count} files)",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
# 校验传输方式
|
||||||
|
allowed_methods = fu.get("allowed_transfer_methods", ["local_file", "remote_url"])
|
||||||
|
for f in files:
|
||||||
|
if f.transfer_method.value not in allowed_methods:
|
||||||
|
raise BusinessException(
|
||||||
|
f"Unsupport file transfer method:{f.transfer_method.value},"
|
||||||
|
f"allowed method:{', '.join(allowed_methods)}",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
# 各类型对应的开关和大小限制配置键
|
||||||
|
type_cfg = {
|
||||||
|
"image": ("image_enabled", "image_max_size_mb", 20, "image"),
|
||||||
|
"audio": ("audio_enabled", "audio_max_size_mb", 50, "audio"),
|
||||||
|
"document": ("document_enabled", "document_max_size_mb", 100, "document"),
|
||||||
|
"video": ("video_enabled", "video_max_size_mb", 500, "video"),
|
||||||
|
}
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
ftype = str(f.type) # 如 "image", "audio", "document", "video"
|
||||||
|
cfg = type_cfg.get(ftype)
|
||||||
|
if cfg is None:
|
||||||
|
continue
|
||||||
|
enabled_key, size_key, default_max_mb, label = cfg
|
||||||
|
|
||||||
|
# 校验类型开关
|
||||||
|
if not fu.get(enabled_key):
|
||||||
|
raise BusinessException(
|
||||||
|
f"The application has not enabled {label} file upload",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
# 校验文件大小(仅当内容已加载时)
|
||||||
|
content = f.get_content()
|
||||||
|
if content is not None:
|
||||||
|
max_mb = fu.get(size_key, default_max_mb)
|
||||||
|
size_mb = len(content) / (1024 * 1024)
|
||||||
|
if size_mb > max_mb:
|
||||||
|
raise BusinessException(
|
||||||
|
f"{label} File size exceeds the limit (maximum {max_mb} MB, current {size_mb:.1f} MB)",
|
||||||
|
BizCode.BAD_REQUEST
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
|||||||
@@ -1158,13 +1158,11 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s
|
|||||||
try:
|
try:
|
||||||
_r = get_sync_redis_client()
|
_r = get_sync_redis_client()
|
||||||
if _r is not None:
|
if _r is not None:
|
||||||
from datetime import timedelta as _td
|
|
||||||
from datetime import timezone as _tz
|
from datetime import timezone as _tz
|
||||||
_CST = _tz(_td(hours=8))
|
_now_utc = datetime.now(_tz.utc).isoformat()
|
||||||
_now_cst = datetime.now(_CST).replace(tzinfo=None).isoformat()
|
|
||||||
_r.set(
|
_r.set(
|
||||||
f"write_message:last_done:{end_user_id}",
|
f"write_message:last_done:{end_user_id}",
|
||||||
_now_cst,
|
_now_utc,
|
||||||
ex=86400 * 30,
|
ex=86400 * 30,
|
||||||
)
|
)
|
||||||
except Exception as _e:
|
except Exception as _e:
|
||||||
@@ -1294,9 +1292,9 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 2. 查询所有app下的end_user_id(去重)
|
# 2. 查询所有app下的end_user_id(去重)
|
||||||
app_ids = [app.id for app in apps]
|
# app_ids = [app.id for app in apps]
|
||||||
end_users = db.query(EndUser.id).filter(
|
end_users = db.query(EndUser.id).filter(
|
||||||
EndUser.app_id.in_(app_ids)
|
EndUser.workspace_id == workspace_id
|
||||||
).distinct().all()
|
).distinct().all()
|
||||||
|
|
||||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||||
@@ -1435,9 +1433,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]:
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# 2. 查询所有app下的end_user_id(去重)
|
# 2. 查询所有app下的end_user_id(去重)
|
||||||
app_ids = [app.id for app in apps]
|
# app_ids = [app.id for app in apps]
|
||||||
end_users = db.query(EndUser.id).filter(
|
end_users = db.query(EndUser.id).filter(
|
||||||
EndUser.app_id.in_(app_ids)
|
EndUser.workspace_id == workspace_id
|
||||||
).distinct().all()
|
).distinct().all()
|
||||||
|
|
||||||
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
# 3. 遍历所有end_user,查询每个宿主的记忆总量并累加
|
||||||
|
|||||||
@@ -100,7 +100,8 @@ def agent_config_4_app_release(release: AppRelease) -> AgentConfig:
|
|||||||
memory=config_dict.get("memory"),
|
memory=config_dict.get("memory"),
|
||||||
variables=config_dict.get("variables", []),
|
variables=config_dict.get("variables", []),
|
||||||
tools=config_dict.get("tools", []),
|
tools=config_dict.get("tools", []),
|
||||||
skills=config_dict.get("skills", {})
|
skills=config_dict.get("skills", {}),
|
||||||
|
features=config_dict.get("features", {})
|
||||||
)
|
)
|
||||||
|
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|||||||
50
api/migrations/versions/12114b3e953c_202603131647.py
Normal file
50
api/migrations/versions/12114b3e953c_202603131647.py
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
"""202603131647
|
||||||
|
|
||||||
|
Revision ID: 12114b3e953c
|
||||||
|
Revises: cd3a402c2f6c
|
||||||
|
Create Date: 2026-03-13 08:47:30.455956
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import text
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '12114b3e953c'
|
||||||
|
down_revision: Union[str, None] = 'ef9d172cb753'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
conn = op.get_bind()
|
||||||
|
print("Step 1: 添加 workspace_id 列...")
|
||||||
|
op.add_column('end_users', sa.Column('workspace_id', sa.UUID(), nullable=True))
|
||||||
|
print("Step 2: 回填 workspace_id...")
|
||||||
|
conn.execute(text("""
|
||||||
|
UPDATE end_users
|
||||||
|
SET workspace_id = apps.workspace_id
|
||||||
|
FROM apps
|
||||||
|
WHERE end_users.app_id = apps.id
|
||||||
|
"""))
|
||||||
|
# Step 3: 设置 workspace_id 为 NOT NULL
|
||||||
|
print("Step 3: 设置 workspace_id 为 NOT NULL...")
|
||||||
|
op.alter_column('end_users', 'workspace_id', nullable=False)
|
||||||
|
op.alter_column('end_users', 'app_id', existing_type=sa.UUID(), nullable=True)
|
||||||
|
# Step 4: 添加外键约束
|
||||||
|
print("Step 4: 添加外键约束...")
|
||||||
|
op.create_foreign_key('fk_end_users_workspace_id','end_users', 'workspaces',
|
||||||
|
['workspace_id'], ['id']
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_constraint('fk_end_users_workspace_id', 'end_users', type_='foreignkey')
|
||||||
|
op.alter_column('end_users', 'app_id', existing_type=sa.UUID(), nullable=False)
|
||||||
|
op.drop_column('end_users', 'workspace_id')
|
||||||
|
# ### end Alembic commands ###
|
||||||
156
api/migrations/versions/74b51dfece29_20260311000.py
Normal file
156
api/migrations/versions/74b51dfece29_20260311000.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""20260311000
|
||||||
|
|
||||||
|
Revision ID: 74b51dfece29
|
||||||
|
Revises: f017efe4831c
|
||||||
|
Create Date: 2026-03-19 10:15:42.488027
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '74b51dfece29'
|
||||||
|
down_revision: Union[str, None] = 'f017efe4831c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# 先删除旧的触发器(如果存在)
|
||||||
|
op.execute("DROP TRIGGER IF EXISTS tr_documents_update_stats ON documents;")
|
||||||
|
|
||||||
|
# 创建或更新 knowledges 统计信息的函数
|
||||||
|
op.execute("""
|
||||||
|
CREATE OR REPLACE FUNCTION update_knowledge_stats()
|
||||||
|
RETURNS TRIGGER AS $$
|
||||||
|
DECLARE
|
||||||
|
-- 声明变量用于存储当前处理的知识库ID
|
||||||
|
current_kb_id UUID;
|
||||||
|
-- 声明变量用于存储文件夹知识库ID(如果存在)
|
||||||
|
folder_kb_id UUID;
|
||||||
|
-- 声明变量用于存储递归查询结果
|
||||||
|
folder_ids UUID[];
|
||||||
|
BEGIN
|
||||||
|
-- 处理 documents 表的插入、更新或删除
|
||||||
|
IF TG_TABLE_NAME = 'documents' THEN
|
||||||
|
-- 1. 更新 knowledges 表的 doc_num
|
||||||
|
UPDATE knowledges SET doc_num = (
|
||||||
|
SELECT COUNT(*) FROM documents
|
||||||
|
WHERE kb_id = knowledges.id AND status = 1
|
||||||
|
)
|
||||||
|
WHERE id = NEW.kb_id OR id = OLD.kb_id;
|
||||||
|
|
||||||
|
-- 2. 更新 knowledges 表的 chunk_num
|
||||||
|
UPDATE knowledges SET chunk_num = (
|
||||||
|
SELECT COALESCE(SUM(chunk_num), 0) FROM documents
|
||||||
|
WHERE kb_id = knowledges.id AND status = 1
|
||||||
|
)
|
||||||
|
WHERE id = NEW.kb_id OR id = OLD.kb_id;
|
||||||
|
|
||||||
|
-- 通过 knowledge_shares 表同步统计信息
|
||||||
|
-- 1. 使用 source_kb_id 的 doc_num 更新 target_kb_id 的 doc_num
|
||||||
|
UPDATE knowledges AS target
|
||||||
|
SET doc_num = source.doc_num
|
||||||
|
FROM knowledge_shares ks
|
||||||
|
JOIN knowledges AS source ON source.id = ks.source_kb_id
|
||||||
|
WHERE ks.target_kb_id = target.id
|
||||||
|
AND (source.id = NEW.kb_id OR source.id = OLD.kb_id);
|
||||||
|
|
||||||
|
-- 2. 使用 source_kb_id 的 chunk_num 更新 target_kb_id 的 chunk_num
|
||||||
|
UPDATE knowledges AS target
|
||||||
|
SET chunk_num = source.chunk_num
|
||||||
|
FROM knowledge_shares ks
|
||||||
|
JOIN knowledges AS source ON source.id = ks.source_kb_id
|
||||||
|
WHERE ks.target_kb_id = target.id
|
||||||
|
AND (source.id = NEW.kb_id OR source.id = OLD.kb_id);
|
||||||
|
|
||||||
|
-- 处理文件夹知识库的统计更新
|
||||||
|
-- 获取当前处理的知识库ID(可能是NEW或OLD中的kb_id)
|
||||||
|
IF NEW.kb_id IS NOT NULL THEN
|
||||||
|
current_kb_id := NEW.kb_id;
|
||||||
|
ELSIF OLD.kb_id IS NOT NULL THEN
|
||||||
|
current_kb_id := OLD.kb_id;
|
||||||
|
ELSE
|
||||||
|
RETURN NULL;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- 查找当前知识库的父文件夹(如果有)
|
||||||
|
SELECT id INTO folder_kb_id FROM knowledges
|
||||||
|
WHERE id IN (
|
||||||
|
SELECT parent_id FROM knowledges WHERE id = current_kb_id
|
||||||
|
) AND type = 'Folder';
|
||||||
|
|
||||||
|
-- 如果存在父文件夹,递归处理所有父文件夹
|
||||||
|
IF folder_kb_id IS NOT NULL THEN
|
||||||
|
-- 使用递归CTE获取所有父文件夹ID(包括多级嵌套)
|
||||||
|
WITH RECURSIVE folder_hierarchy AS (
|
||||||
|
-- 基础查询:获取直接父文件夹
|
||||||
|
SELECT id FROM knowledges
|
||||||
|
WHERE id = folder_kb_id AND type = 'Folder'
|
||||||
|
UNION ALL
|
||||||
|
-- 递归查询:获取父文件夹的父文件夹
|
||||||
|
SELECT k.id FROM knowledges k
|
||||||
|
JOIN folder_hierarchy fh ON k.id = k.parent_id
|
||||||
|
WHERE k.type = 'Folder'
|
||||||
|
)
|
||||||
|
-- 将结果存入数组以便处理
|
||||||
|
SELECT array_agg(id) INTO folder_ids FROM folder_hierarchy;
|
||||||
|
|
||||||
|
-- 遍历所有父文件夹并更新统计信息
|
||||||
|
FOR i IN 1..array_length(folder_ids, 1) LOOP
|
||||||
|
-- 更新文件夹的doc_num(汇总所有子知识库的doc_num)
|
||||||
|
UPDATE knowledges SET doc_num = (
|
||||||
|
-- 汇总直接子知识库的doc_num
|
||||||
|
SELECT COALESCE(SUM(child.doc_num), 0)
|
||||||
|
FROM knowledges child
|
||||||
|
WHERE child.parent_id = folder_ids[i] AND child.status = 1
|
||||||
|
-- 加上直接属于该文件夹的文档数(如果有)
|
||||||
|
UNION ALL
|
||||||
|
SELECT COALESCE(COUNT(*), 0)
|
||||||
|
FROM documents
|
||||||
|
WHERE kb_id = folder_ids[i] AND status = 1
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE id = folder_ids[i];
|
||||||
|
|
||||||
|
-- 更新文件夹的chunk_num(汇总所有子知识库的chunk_num)
|
||||||
|
UPDATE knowledges SET chunk_num = (
|
||||||
|
-- 汇总直接子知识库的chunk_num
|
||||||
|
SELECT COALESCE(SUM(child.chunk_num), 0)
|
||||||
|
FROM knowledges child
|
||||||
|
WHERE child.parent_id = folder_ids[i] AND child.status = 1
|
||||||
|
-- 加上直接属于该文件夹的文档的chunk_num(如果有)
|
||||||
|
UNION ALL
|
||||||
|
SELECT COALESCE(SUM(d.chunk_num), 0)
|
||||||
|
FROM documents d
|
||||||
|
WHERE d.kb_id = folder_ids[i] AND d.status = 1
|
||||||
|
LIMIT 1
|
||||||
|
)
|
||||||
|
WHERE id = folder_ids[i];
|
||||||
|
END LOOP;
|
||||||
|
END IF;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN NULL;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# documents 表上的触发器(插入、更新、删除后)
|
||||||
|
op.execute("""
|
||||||
|
CREATE TRIGGER tr_documents_update_stats
|
||||||
|
AFTER INSERT OR UPDATE OR DELETE ON documents
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION update_knowledge_stats();
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# 删除触发器
|
||||||
|
op.execute("DROP TRIGGER IF EXISTS tr_documents_update_stats ON documents;")
|
||||||
|
# 删除函数
|
||||||
|
op.execute("DROP FUNCTION IF EXISTS update_knowledge_stats();")
|
||||||
|
|
||||||
34
api/migrations/versions/818c6c535e14_202603161825.py
Normal file
34
api/migrations/versions/818c6c535e14_202603161825.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""202603161825
|
||||||
|
|
||||||
|
Revision ID: 818c6c535e14
|
||||||
|
Revises: 12114b3e953c
|
||||||
|
Create Date: 2026-03-16 18:33:41.883671
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '818c6c535e14'
|
||||||
|
down_revision: Union[str, None] = '12114b3e953c'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('agent_configs', sa.Column('features', postgresql.JSON(astext_type=sa.Text()), nullable=True, comment='功能特性配置'))
|
||||||
|
op.add_column('tool_configs', sa.Column('is_active', sa.Boolean(), server_default='true', nullable=False, comment='是否可用,False表示已删除'))
|
||||||
|
op.create_index(op.f('ix_tool_configs_is_active'), 'tool_configs', ['is_active'], unique=False)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_index(op.f('ix_tool_configs_is_active'), table_name='tool_configs')
|
||||||
|
op.drop_column('tool_configs', 'is_active')
|
||||||
|
op.drop_column('agent_configs', 'features')
|
||||||
|
# ### end Alembic commands ###
|
||||||
30
api/migrations/versions/f017efe4831c_202603181652.py
Normal file
30
api/migrations/versions/f017efe4831c_202603181652.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""202603181652
|
||||||
|
|
||||||
|
Revision ID: f017efe4831c
|
||||||
|
Revises: 818c6c535e14
|
||||||
|
Create Date: 2026-03-18 16:52:21.639695
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'f017efe4831c'
|
||||||
|
down_revision: Union[str, None] = '818c6c535e14'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column('workflow_configs', sa.Column('features', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column('workflow_configs', 'features')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default():
|
|||||||
"""测试获取不存在的节点输出(使用默认值)"""
|
"""测试获取不存在的节点输出(使用默认值)"""
|
||||||
pool = VariablePool()
|
pool = VariablePool()
|
||||||
|
|
||||||
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
|
result = pool.get_node_output("nonexistent_node", default=None, strict=False)
|
||||||
|
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|||||||
Submodule redbear-mem-benchmark updated: c3bbc6931c...e853d99ff0
@@ -46,6 +46,7 @@
|
|||||||
"lexical": "^0.39.0",
|
"lexical": "^0.39.0",
|
||||||
"mammoth": "^1.12.0",
|
"mammoth": "^1.12.0",
|
||||||
"mermaid": "^11.12.1",
|
"mermaid": "^11.12.1",
|
||||||
|
"pdfjs-dist": "4.10.38",
|
||||||
"react": "^18.2.0",
|
"react": "^18.2.0",
|
||||||
"react-dom": "^18.2.0",
|
"react-dom": "^18.2.0",
|
||||||
"react-i18next": "^15.0.0",
|
"react-i18next": "^15.0.0",
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 13:59:45
|
* @Date: 2026-02-03 13:59:45
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-13 17:07:54
|
* @Last Modified time: 2026-03-18 20:01:29
|
||||||
*/
|
*/
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
import type { ApplicationModalData } from '@/views/ApplicationManagement/types'
|
import type { ApplicationModalData } from '@/views/ApplicationManagement/types'
|
||||||
@@ -137,7 +137,7 @@ export const getExperienceConfig = (share_token: string) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
// Export application
|
// Export application
|
||||||
export const appExport = (app_id: string, appName: string, data?: { release_version: string }) => {
|
export const appExport = (app_id: string, appName: string, data?: { release_id: string }) => {
|
||||||
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
|
return request.getDownloadFile(`/apps/${app_id}/export`, `${appName}.yml`, data)
|
||||||
}
|
}
|
||||||
// Import application
|
// Import application
|
||||||
|
|||||||
@@ -2,10 +2,12 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-06 21:11:51
|
* @Date: 2026-02-06 21:11:51
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-13 17:11:14
|
* @Last Modified time: 2026-03-17 18:39:09
|
||||||
*/
|
*/
|
||||||
import { type FC, useRef, useState } from 'react'
|
import { type FC, useRef, useState } from 'react'
|
||||||
import RecordRTC from 'recordrtc'
|
import RecordRTC from 'recordrtc'
|
||||||
|
import { App } from 'antd'
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
|
||||||
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
|
import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage'
|
||||||
import { request } from '@/utils/request'
|
import { request } from '@/utils/request'
|
||||||
@@ -19,14 +21,20 @@ interface AudioRecorderProps {
|
|||||||
action?: string;
|
action?: string;
|
||||||
/** Additional config passed to the upload request */
|
/** Additional config passed to the upload request */
|
||||||
requestConfig?: Record<string, any>;
|
requestConfig?: Record<string, any>;
|
||||||
|
disabled?: boolean;
|
||||||
|
maxSize?: number;
|
||||||
}
|
}
|
||||||
|
|
||||||
const AudioRecorder: FC<AudioRecorderProps> = ({
|
const AudioRecorder: FC<AudioRecorderProps> = ({
|
||||||
onRecordingComplete,
|
onRecordingComplete,
|
||||||
className = '',
|
className = '',
|
||||||
action = fileUploadUrlWithoutApiPrefix,
|
action = fileUploadUrlWithoutApiPrefix,
|
||||||
requestConfig = {}
|
requestConfig = {},
|
||||||
|
disabled = false,
|
||||||
|
maxSize,
|
||||||
}) => {
|
}) => {
|
||||||
|
const { message } = App.useApp()
|
||||||
|
const { t } = useTranslation();
|
||||||
// Whether the recorder is currently capturing audio
|
// Whether the recorder is currently capturing audio
|
||||||
const [isRecording, setIsRecording] = useState(false)
|
const [isRecording, setIsRecording] = useState(false)
|
||||||
// Holds the RecordRTC instance across renders
|
// Holds the RecordRTC instance across renders
|
||||||
@@ -34,6 +42,7 @@ const AudioRecorder: FC<AudioRecorderProps> = ({
|
|||||||
|
|
||||||
/** Request microphone access and start recording */
|
/** Request microphone access and start recording */
|
||||||
const startRecording = async () => {
|
const startRecording = async () => {
|
||||||
|
if (disabled) return
|
||||||
try {
|
try {
|
||||||
const stream = await navigator.mediaDevices.getUserMedia({ audio: true })
|
const stream = await navigator.mediaDevices.getUserMedia({ audio: true })
|
||||||
recorderRef.current = new RecordRTC(stream, {
|
recorderRef.current = new RecordRTC(stream, {
|
||||||
@@ -49,10 +58,17 @@ const AudioRecorder: FC<AudioRecorderProps> = ({
|
|||||||
|
|
||||||
/** Stop recording, upload the audio blob, then invoke the completion callback */
|
/** Stop recording, upload the audio blob, then invoke the completion callback */
|
||||||
const stopRecording = () => {
|
const stopRecording = () => {
|
||||||
|
if (disabled) return
|
||||||
if (recorderRef.current) {
|
if (recorderRef.current) {
|
||||||
recorderRef.current.stopRecording(() => {
|
recorderRef.current.stopRecording(() => {
|
||||||
const blob = recorderRef.current!.getBlob()
|
const blob = recorderRef.current!.getBlob()
|
||||||
const url = recorderRef.current!.toURL()
|
const url = recorderRef.current!.toURL()
|
||||||
|
|
||||||
|
if (maxSize && blob.size > maxSize * 1024 * 1024) {
|
||||||
|
message.error(t('common.fileSizeTip', { size: maxSize }));
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
const formData = new FormData()
|
const formData = new FormData()
|
||||||
formData.append('file', blob, `recording_${Date.now()}.webm`)
|
formData.append('file', blob, `recording_${Date.now()}.webm`)
|
||||||
request
|
request
|
||||||
@@ -76,7 +92,7 @@ const AudioRecorder: FC<AudioRecorderProps> = ({
|
|||||||
// swap background image to reflect current state
|
// swap background image to reflect current state
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
className={`rb:size-5.5 rb:cursor-pointer rb:bg-cover ${className} ${
|
className={`rb:size-5.5 rb:bg-cover ${disabled ? 'rb:opacity-65 rb:cursor-not-allowed' : 'rb:cursor-pointer'} ${className} ${
|
||||||
isRecording
|
isRecording
|
||||||
? `rb:bg-[url('@/assets/images/conversation/audio_ing.gif')]`
|
? `rb:bg-[url('@/assets/images/conversation/audio_ing.gif')]`
|
||||||
: `rb:bg-[url('@/assets/images/conversation/audio.svg')]`
|
: `rb:bg-[url('@/assets/images/conversation/audio.svg')]`
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-02 15:01:59
|
* @Date: 2026-02-02 15:01:59
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-12 14:59:38
|
* @Last Modified time: 2026-03-19 13:41:26
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -42,7 +42,8 @@ const ButtonCheckbox: FC<ButtonCheckboxProps> = ({
|
|||||||
icon,
|
icon,
|
||||||
checkedIcon,
|
checkedIcon,
|
||||||
children,
|
children,
|
||||||
cicle = false
|
cicle = false,
|
||||||
|
disabled,
|
||||||
}) => {
|
}) => {
|
||||||
// Listen to value changes and trigger side effects via onValueChange callback
|
// Listen to value changes and trigger side effects via onValueChange callback
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -63,13 +64,14 @@ const ButtonCheckbox: FC<ButtonCheckboxProps> = ({
|
|||||||
align="center"
|
align="center"
|
||||||
justify={cicle ? 'center' : 'start'}
|
justify={cicle ? 'center' : 'start'}
|
||||||
gap={4}
|
gap={4}
|
||||||
className={clsx("rb:flex rb:items-center rb:cursor-pointer rb:border rb:hover:bg-[#F6F6F6]", {
|
className={clsx("rb:flex rb:items-center rb:cursor-pointer rb:px-2! rb:border rb:hover:bg-[#F6F6F6]", {
|
||||||
'rb:size-7 rb:rounded-[14px] rb:border-[0.5px] rb:border-[#EBEBEB]': cicle,
|
'rb:size-7 rb:rounded-[14px] rb:border-[0.5px] rb:border-[#EBEBEB]': cicle,
|
||||||
'rb:rounded-lg rb:px-2 rb:text-[12px] rb:h-6': !cicle,
|
'rb:rounded-lg rb:text-[12px] rb:h-6': !cicle,
|
||||||
// Checked state: blue background and border
|
// Checked state: blue background and border
|
||||||
"rb:bg-[rgba(21,94,239,0.06)] rb:border-[rgba(21,94,239,0.25)] rb:hover:bg-[rgba(21,94,239,0.06)] rb:text-[#155EEF]": checked,
|
"rb:bg-[rgba(21,94,239,0.06)] rb:border-[rgba(21,94,239,0.25)] rb:hover:bg-[rgba(21,94,239,0.06)] rb:text-[#155EEF]": checked,
|
||||||
// Unchecked state: gray border and dark text
|
// Unchecked state: gray border and dark text
|
||||||
"rb:border-[#DFE4ED] rb:text-[#212332]": !checked,
|
"rb:border-[#DFE4ED] rb:text-[#212332]": !checked,
|
||||||
|
"rb:opacity-65 rb:cursor-not-allowed!": disabled
|
||||||
})}
|
})}
|
||||||
onClick={handleChange}
|
onClick={handleChange}
|
||||||
>
|
>
|
||||||
|
|||||||
@@ -2,13 +2,19 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2025-12-10 16:46:17
|
* @Date: 2025-12-10 16:46:17
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-02-06 21:05:52
|
* @Last Modified time: 2026-03-19 13:38:20
|
||||||
*/
|
*/
|
||||||
import { type FC, useRef, useEffect } from 'react'
|
import { type FC, useRef, useEffect, useState } from 'react'
|
||||||
import clsx from 'clsx'
|
import clsx from 'clsx'
|
||||||
import Markdown from '@/components/Markdown'
|
import Markdown from '@/components/Markdown'
|
||||||
import type { ChatContentProps } from './types'
|
import type { ChatContentProps } from './types'
|
||||||
import { Spin } from 'antd'
|
import { Spin, Divider, Space, Image, Flex } from 'antd'
|
||||||
|
import { SoundOutlined } from '@ant-design/icons'
|
||||||
|
|
||||||
|
|
||||||
|
const getFileUrl = (file: any) => {
|
||||||
|
return file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Chat Content Display Component
|
* Chat Content Display Component
|
||||||
@@ -28,15 +34,33 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
// Scroll container reference for controlling auto-scroll to bottom
|
// Scroll container reference for controlling auto-scroll to bottom
|
||||||
const scrollContainerRef = useRef<(HTMLDivElement | null)>(null)
|
const scrollContainerRef = useRef<(HTMLDivElement | null)>(null)
|
||||||
const prevDataLengthRef = useRef(data.length);
|
const prevDataLengthRef = useRef(data.length);
|
||||||
const isScrolledToBottomRef = useRef(true); // Track if user is scrolled to bottom
|
const isScrolledToBottomRef = useRef(true);
|
||||||
|
const audioRef = useRef<HTMLAudioElement | null>(null)
|
||||||
|
const [playingIndex, setPlayingIndex] = useState<number | null>(null)
|
||||||
|
|
||||||
|
const handlePlay = (index: number, audio_url: string) => {
|
||||||
|
if (playingIndex === index) {
|
||||||
|
audioRef.current?.pause()
|
||||||
|
setPlayingIndex(null)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (audioRef.current) {
|
||||||
|
audioRef.current.pause()
|
||||||
|
}
|
||||||
|
const audio = new Audio(audio_url)
|
||||||
|
audioRef.current = audio
|
||||||
|
audio.play()
|
||||||
|
setPlayingIndex(index)
|
||||||
|
audio.onended = () => setPlayingIndex(null)
|
||||||
|
}
|
||||||
|
|
||||||
// Track scroll position to determine if user is at bottom
|
// Track scroll position to determine if user is at bottom
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleScroll = () => {
|
const handleScroll = () => {
|
||||||
if (scrollContainerRef.current) {
|
if (scrollContainerRef.current) {
|
||||||
const { scrollTop, scrollHeight, clientHeight } = scrollContainerRef.current;
|
const { scrollTop, scrollHeight, clientHeight } = scrollContainerRef.current;
|
||||||
// Consider user is at bottom if within 20px of the bottom
|
// Consider user is at bottom if within 100px of the bottom
|
||||||
isScrolledToBottomRef.current = scrollHeight - scrollTop - clientHeight < 20;
|
isScrolledToBottomRef.current = scrollHeight - scrollTop - clientHeight < 100;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -64,11 +88,16 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
// Auto-scroll if data length changed OR user is currently at bottom
|
// Auto-scroll if data length changed OR user is currently at bottom
|
||||||
if (data.length !== prevDataLengthRef.current || isScrolledToBottomRef.current) {
|
if (data.length !== prevDataLengthRef.current || isScrolledToBottomRef.current) {
|
||||||
scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight;
|
scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight;
|
||||||
|
isScrolledToBottomRef.current = true;
|
||||||
}
|
}
|
||||||
prevDataLengthRef.current = data.length;
|
prevDataLengthRef.current = data.length;
|
||||||
}
|
}
|
||||||
}, 0);
|
}, 0);
|
||||||
}, [data])
|
}, [data])
|
||||||
|
|
||||||
|
const handleDownload = (file: any) => {
|
||||||
|
window.open(getFileUrl(file), '_blank')
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<div ref={scrollContainerRef} className={clsx("rb:relative rb:overflow-y-auto", classNames)}>
|
<div ref={scrollContainerRef} className={clsx("rb:relative rb:overflow-y-auto", classNames)}>
|
||||||
{data.length === 0
|
{data.length === 0
|
||||||
@@ -89,6 +118,44 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
{labelFormat(item)}
|
{labelFormat(item)}
|
||||||
</div>
|
</div>
|
||||||
}
|
}
|
||||||
|
{item.meta_data?.files && item.meta_data?.files.length > 0 && <Flex gap={8} vertical align="end">
|
||||||
|
{item.meta_data?.files?.map((file) => {
|
||||||
|
if (file.type.includes('image')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className={`rb:inline-block rb:group rb:relative rb:rounded-lg ${contentClassNames}`}>
|
||||||
|
<Image src={getFileUrl(file)} alt={file.name} className="rb:w-full rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('video')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className="rb:inline-block rb:group rb:relative rb:rounded-lg">
|
||||||
|
<video src={getFileUrl(file)} controls className="rb:max-w-80 rb:rounded-lg rb:object-cover rb:cursor-pointer" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if (file.type.includes('audio')) {
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className="rb:inline-flex rb:items-center rb:group rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:py-2 rb:px-2.5 rb:gap-2">
|
||||||
|
<audio src={getFileUrl(file)} controls className="rb:max-w-80" />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<div key={file.url || file.uid} className="rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:p-1! rb:cursor-pointer" onClick={() => handleDownload(file)}>
|
||||||
|
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
|
||||||
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
|
||||||
|
></div>}
|
||||||
|
{(file.type.includes('pdf')) && <div
|
||||||
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
|
||||||
|
></div>}
|
||||||
|
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) && <div
|
||||||
|
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
|
||||||
|
></div>}
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
})}
|
||||||
|
</Flex>}
|
||||||
{/* Message bubble */}
|
{/* Message bubble */}
|
||||||
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-130 rb:wrap-break-word', contentClassNames, {
|
<div className={clsx('rb:border rb:text-left rb:rounded-lg rb:mt-1.5 rb:leading-4.5 rb:p-[10px_12px_2px_12px] rb:inline-block rb:max-w-130 rb:wrap-break-word', contentClassNames, {
|
||||||
// Error message style (content is null and not assistant message)
|
// Error message style (content is null and not assistant message)
|
||||||
@@ -101,6 +168,19 @@ const ChatContent: FC<ChatContentProps> = ({
|
|||||||
{item.subContent && renderRuntime && renderRuntime(item, index)}
|
{item.subContent && renderRuntime && renderRuntime(item, index)}
|
||||||
{/* Render message content using Markdown component */}
|
{/* Render message content using Markdown component */}
|
||||||
<Markdown content={renderRuntime ? item.content ?? '' : item.content ?? errorDesc ?? ''} />
|
<Markdown content={renderRuntime ? item.content ?? '' : item.content ?? errorDesc ?? ''} />
|
||||||
|
|
||||||
|
{item.meta_data?.audio_url && <>
|
||||||
|
<Divider className="rb:my-3!" />
|
||||||
|
<Space size={12} className="rb:pb-2 rb:pl-1">
|
||||||
|
{playingIndex !== index
|
||||||
|
? <SoundOutlined className="rb:cursor-pointer rb:hover:text-[#155EEF]! rb:size-5.5" onClick={() => handlePlay(index, item.meta_data?.audio_url!)} />
|
||||||
|
: <div
|
||||||
|
className="rb:size-5.5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/audio_ing.gif')]"
|
||||||
|
onClick={() => handlePlay(index, item.meta_data?.audio_url!)}
|
||||||
|
/>
|
||||||
|
}
|
||||||
|
</Space>
|
||||||
|
</>}
|
||||||
</div>
|
</div>
|
||||||
{/* Bottom label (such as timestamp, username, etc.) */}
|
{/* Bottom label (such as timestamp, username, etc.) */}
|
||||||
{labelPosition === 'bottom' &&
|
{labelPosition === 'bottom' &&
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user