Compare commits
349 Commits
release/v0
...
v0.2.8
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6056952936 | ||
|
|
0f092e08f4 | ||
|
|
8e7603bcc4 | ||
|
|
a079358028 | ||
|
|
fa29a39920 | ||
|
|
2146c555d2 | ||
|
|
240f1d431b | ||
|
|
726148d7ee | ||
|
|
0f1b1d7d10 | ||
|
|
11aa2e1f9e | ||
|
|
ca654cca74 | ||
|
|
bd1f649bd0 | ||
|
|
ea00747c66 | ||
|
|
3db031891e | ||
|
|
fb6ca3909a | ||
|
|
929afb1770 | ||
|
|
6235584b2e | ||
|
|
0b1ea33b41 | ||
|
|
3929f811b8 | ||
|
|
551a2b59a5 | ||
|
|
9a765ac71e | ||
|
|
83e26732de | ||
|
|
52fdfc7744 | ||
|
|
4e544325a0 | ||
|
|
99a2f396fd | ||
|
|
0157c9d262 | ||
|
|
5ddacab162 | ||
|
|
a51e34852c | ||
|
|
36f670b2e9 | ||
|
|
cbcbc8822c | ||
|
|
aa2d1e7a35 | ||
|
|
39b2f3ba0e | ||
|
|
43064ab71b | ||
|
|
4144f0b9b5 | ||
|
|
08f0be17ce | ||
|
|
2915e464bf | ||
|
|
152559ae46 | ||
|
|
1f531f1ace | ||
|
|
7ec947189c | ||
|
|
b4615bacdc | ||
|
|
e849fed5c1 | ||
|
|
0f5cae4590 | ||
|
|
1c3029f360 | ||
|
|
e2411e0bdd | ||
|
|
7af88b19cf | ||
|
|
c3f8dbd4bc | ||
|
|
c1e48fde86 | ||
|
|
f644c84fbb | ||
|
|
d0afce27c4 | ||
|
|
b84aba71e7 | ||
|
|
2e481df465 | ||
|
|
a322ec4fd5 | ||
|
|
bdbf9c0609 | ||
|
|
ef7d59e442 | ||
|
|
27b782e12a | ||
|
|
37a22fbfa9 | ||
|
|
d798d101f7 | ||
|
|
825f225f63 | ||
|
|
4d5e2958dc | ||
|
|
6105d46198 | ||
|
|
7aec157859 | ||
|
|
13abb03d87 | ||
|
|
e8947ad0bb | ||
|
|
7056865726 | ||
|
|
c2c832f8c9 | ||
|
|
6bc4f04293 | ||
|
|
9d150ab353 | ||
|
|
f045b59b2d | ||
|
|
d584b47280 | ||
|
|
3e995cd971 | ||
|
|
b018e35ada | ||
|
|
86a0aa1f9f | ||
|
|
d523e4f3c6 | ||
|
|
186d097e00 | ||
|
|
c5cfe557da | ||
|
|
f786a66a3c | ||
|
|
ebd51928d7 | ||
|
|
2258b5c43c | ||
|
|
8c804a1011 | ||
|
|
1a4c2d7cd0 | ||
|
|
83fcabadae | ||
|
|
33d522b387 | ||
|
|
5997458aaf | ||
|
|
68f9471caf | ||
|
|
ecbb61db27 | ||
|
|
b42815ee7a | ||
|
|
49d7398e14 | ||
|
|
91589c1497 | ||
|
|
18ca83d763 | ||
|
|
4bbc561625 | ||
|
|
f52b681133 | ||
|
|
f6efa0d711 | ||
|
|
0fccc91dac | ||
|
|
8d8c6c695a | ||
|
|
57342259ce | ||
|
|
be46ed8865 | ||
|
|
04b2205769 | ||
|
|
76ba357982 | ||
|
|
2c318f6e60 | ||
|
|
3df8af3852 | ||
|
|
8b9ab8a841 | ||
|
|
750dbcc7c3 | ||
|
|
291767031c | ||
|
|
22ffe6ef1d | ||
|
|
02df1a70f3 | ||
|
|
8c5fa9c441 | ||
|
|
e6c558c2a0 | ||
|
|
1089a52ca0 | ||
|
|
c7fb9ab8e3 | ||
|
|
e24217a6ba | ||
|
|
f042f44501 | ||
|
|
56c98648f9 | ||
|
|
956efe6a09 | ||
|
|
bb64ad23dd | ||
|
|
a97326df74 | ||
|
|
1503f8781a | ||
|
|
163ddbb6ed | ||
|
|
7bbfd33ca0 | ||
|
|
0ea47ce890 | ||
|
|
38f891235c | ||
|
|
4d83c074d9 | ||
|
|
0e9672df80 | ||
|
|
abc7460539 | ||
|
|
4bb2ccfba7 | ||
|
|
969d428320 | ||
|
|
ff64522c50 | ||
|
|
65dc1a8f48 | ||
|
|
859b7f3c7f | ||
|
|
da3f875555 | ||
|
|
44d63a44da | ||
|
|
7e5e1609b0 | ||
|
|
d94adcb19c | ||
|
|
83894df260 | ||
|
|
7b99a32a1e | ||
|
|
06d1f54030 | ||
|
|
599ccb6bde | ||
|
|
db9050c302 | ||
|
|
71b3b665b5 | ||
|
|
3b8a806661 | ||
|
|
774719fb50 | ||
|
|
8ddacb7bc9 | ||
|
|
262a9ddc48 | ||
|
|
70f84b65ec | ||
|
|
ec5cb42f67 | ||
|
|
0802481fd2 | ||
|
|
548ba0ae36 | ||
|
|
376d5ca7d0 | ||
|
|
55438136b0 | ||
|
|
82db3517d7 | ||
|
|
130490c022 | ||
|
|
ff6459e439 | ||
|
|
dfcc85a466 | ||
|
|
be2ce854a1 | ||
|
|
e492dcd968 | ||
|
|
55bfee856d | ||
|
|
f951075551 | ||
|
|
964086a08a | ||
|
|
67501025b3 | ||
|
|
e1cc5c841a | ||
|
|
6b839bd5a8 | ||
|
|
1e63dd8d2d | ||
|
|
fab9272124 | ||
|
|
2f66fd9aae | ||
|
|
5616583fa1 | ||
|
|
3f0e991112 | ||
|
|
72bba0662f | ||
|
|
090f46006a | ||
|
|
abe0c7e7d1 | ||
|
|
6516f56ada | ||
|
|
ea391dc44e | ||
|
|
e21f713de0 | ||
|
|
3498e2e884 | ||
|
|
ea8edc5914 | ||
|
|
b62c40dba3 | ||
|
|
0832337839 | ||
|
|
b82f4491fb | ||
|
|
bdf0c256b3 | ||
|
|
3d91a9e926 | ||
|
|
779dbdea26 | ||
|
|
e8e342c206 | ||
|
|
78829d36cc | ||
|
|
f7c2e82dc0 | ||
|
|
396493ad2b | ||
|
|
b1a7b58f97 | ||
|
|
e81f39b50e | ||
|
|
3c99fb116c | ||
|
|
e7e136036c | ||
|
|
ca84fc6c9d | ||
|
|
a0c4515a81 | ||
|
|
4bf418a3d6 | ||
|
|
f033607c8b | ||
|
|
32d612fbeb | ||
|
|
9ce3a881f3 | ||
|
|
860cd31799 | ||
|
|
d674b48f7d | ||
|
|
1635f9dbef | ||
|
|
07c899f0a9 | ||
|
|
382e4c5377 | ||
|
|
fe6518d052 | ||
|
|
dc513dfbeb | ||
|
|
3d9bc7a986 | ||
|
|
75e36173cd | ||
|
|
8097f227ca | ||
|
|
3d79b72d70 | ||
|
|
6eb9b772e7 | ||
|
|
90c8ff35d1 | ||
|
|
ad87fd96db | ||
|
|
fd1debe681 | ||
|
|
c7cc0cd922 | ||
|
|
81a232177e | ||
|
|
73aee97be5 | ||
|
|
39f3a85bb1 | ||
|
|
098a2e54ae | ||
|
|
d575478b53 | ||
|
|
aab54ca1a8 | ||
|
|
d4f2094ee0 | ||
|
|
c354618e20 | ||
|
|
5141a91041 | ||
|
|
668539e737 | ||
|
|
967139cea4 | ||
|
|
6d8b1aede4 | ||
|
|
744ba31ba6 | ||
|
|
db8257b67a | ||
|
|
85770dc037 | ||
|
|
69f976a79a | ||
|
|
fd7e77eff8 | ||
|
|
05c2a093c0 | ||
|
|
b71bc1f875 | ||
|
|
cbc8714414 | ||
|
|
065f8db2f7 | ||
|
|
0ac7f83726 | ||
|
|
d03473da10 | ||
|
|
dac1c01a2c | ||
|
|
a7a2dabc5a | ||
|
|
83015a3404 | ||
|
|
b88e9c5f5e | ||
|
|
8380a8a811 | ||
|
|
6c69181290 | ||
|
|
0694075447 | ||
|
|
d66b9dd8cb | ||
|
|
7267198a8c | ||
|
|
0f36c5c872 | ||
|
|
6a67f028ce | ||
|
|
5d82786c20 | ||
|
|
e368f1c1d6 | ||
|
|
572ce7f9ec | ||
|
|
a4c942a21f | ||
|
|
4859ab3ba7 | ||
|
|
983b5f5087 | ||
|
|
75b87955dd | ||
|
|
110de0afbc | ||
|
|
2c074cd5c1 | ||
|
|
73e51a9b0b | ||
|
|
3a47039919 | ||
|
|
2961ea4e44 | ||
|
|
af2ffc9737 | ||
|
|
d7911244fc | ||
|
|
2a66775e45 | ||
|
|
6029a5a9a8 | ||
|
|
71d9ae15a1 | ||
|
|
f0c3d5f308 | ||
|
|
4706ea59fe | ||
|
|
5774a95f61 | ||
|
|
d660521c5c | ||
|
|
5db2c5092e | ||
|
|
59618457df | ||
|
|
c612dfbc1f | ||
|
|
8d053c97a7 | ||
|
|
a3e6f67ff7 | ||
|
|
01da2e3eee | ||
|
|
168cce1678 | ||
|
|
7240dfe793 | ||
|
|
b9340ba02d | ||
|
|
4f5ee24bc5 | ||
|
|
6a1b8d3ee3 | ||
|
|
f1207dc8b9 | ||
|
|
86c51559bb | ||
|
|
8b0f806079 | ||
|
|
99e94b3567 | ||
|
|
cfd5c1bc93 | ||
|
|
45d9e45346 | ||
|
|
fcb3845543 | ||
|
|
97eabc0c36 | ||
|
|
5328163973 | ||
|
|
7ff9dfee8c | ||
|
|
1e1675ec12 | ||
|
|
f941541304 | ||
|
|
3f7083c5b3 | ||
|
|
e81faebf69 | ||
|
|
8a4d58c520 | ||
|
|
2ac29ee89c | ||
|
|
252cdcd6f5 | ||
|
|
16e2c95965 | ||
|
|
10560fb34c | ||
|
|
58aa60ca0e | ||
|
|
d24b186d3e | ||
|
|
b4e81615b1 | ||
|
|
424d2033ea | ||
|
|
fd556f9b00 | ||
|
|
e2f5fa87b1 | ||
|
|
e4a2bd3b9b | ||
|
|
e3ada17a78 | ||
|
|
3e5a7adfe4 | ||
|
|
3237f4cd6e | ||
|
|
beea826377 | ||
|
|
7cdbbefc64 | ||
|
|
18780622b3 | ||
|
|
f405ac4d84 | ||
|
|
9fe47e2fb2 | ||
|
|
e4aaa18f61 | ||
|
|
5c3d9717dd | ||
|
|
ac86bbd60c | ||
|
|
33d12c43b2 | ||
|
|
107c676185 | ||
|
|
0f221b7ee6 | ||
|
|
e1939ef472 | ||
|
|
5438d35f17 | ||
|
|
9c26d1f4c8 | ||
|
|
4c2b31f31f | ||
|
|
4f88a13256 | ||
|
|
21ae448ed7 | ||
|
|
50466124c8 | ||
|
|
ece88a3879 | ||
|
|
cedc4a92cc | ||
|
|
c8065b0c60 | ||
|
|
476632294f | ||
|
|
349d46e043 | ||
|
|
00e0201bf9 | ||
|
|
b9ebe22df1 | ||
|
|
389dd8d402 | ||
|
|
966bd8528d | ||
|
|
8f789d47a2 | ||
|
|
94a40e49a0 | ||
|
|
8429279eea | ||
|
|
cef14cda9e | ||
|
|
c14f067afb | ||
|
|
2612abc9d0 | ||
|
|
d080b44ac3 | ||
|
|
c01ad5a19e | ||
|
|
f56bc0f85a | ||
|
|
9600d687fa | ||
|
|
41a0036bf6 | ||
|
|
08e4ad6a7c | ||
|
|
314e6e29d5 | ||
|
|
391cd602a2 | ||
|
|
247db844a4 | ||
|
|
5495d32822 | ||
|
|
a496991400 | ||
|
|
0ea83b4364 |
@@ -45,7 +45,8 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
|
|||||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||||
apt install -y libjemalloc-dev && \
|
apt install -y libjemalloc-dev && \
|
||||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
|
||||||
apt install -y ghostscript
|
apt install -y ghostscript && \
|
||||||
|
apt install -y libmagic1
|
||||||
|
|
||||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||||
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||||
|
|||||||
@@ -60,7 +60,12 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
|
|||||||
# are written from script.py.mako
|
# are written from script.py.mako
|
||||||
# output_encoding = utf-8
|
# output_encoding = utf-8
|
||||||
|
|
||||||
sqlalchemy.url = postgresql://user:password@localhost/dbname
|
# Database connection URL - DO NOT hardcode credentials here!
|
||||||
|
# Connection string is set dynamically from environment variables in migrations/env.py
|
||||||
|
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
|
||||||
|
# Example: postgresql://user:password@localhost:5432/dbname
|
||||||
|
; sqlalchemy.url = postgresql://user:password@host:port/dbname
|
||||||
|
sqlalchemy.url = driver://user:password@host:port/dbname
|
||||||
|
|
||||||
|
|
||||||
[post_write_hooks]
|
[post_write_hooks]
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
import os
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
import redis.asyncio as redis
|
import redis.asyncio as redis
|
||||||
from redis.asyncio import ConnectionPool
|
from redis.asyncio import ConnectionPool
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
# 设置日志记录器
|
# 设置日志记录器
|
||||||
|
|||||||
@@ -63,9 +63,9 @@ celery_app.conf.update(
|
|||||||
accept_content=['json'],
|
accept_content=['json'],
|
||||||
result_serializer='json',
|
result_serializer='json',
|
||||||
|
|
||||||
# 时区
|
# # 时区
|
||||||
timezone='Asia/Shanghai',
|
# timezone='Asia/Shanghai',
|
||||||
enable_utc=True,
|
# enable_utc=False,
|
||||||
|
|
||||||
# 任务追踪
|
# 任务追踪
|
||||||
task_track_started=True,
|
task_track_started=True,
|
||||||
@@ -96,6 +96,7 @@ celery_app.conf.update(
|
|||||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||||
|
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
|
||||||
|
|
||||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||||
@@ -113,6 +114,9 @@ celery_app.conf.update(
|
|||||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
|
||||||
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
|
||||||
|
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,7 +133,7 @@ implicit_emotions_update_schedule = crontab(
|
|||||||
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
|
||||||
)
|
)
|
||||||
|
|
||||||
#构建定时任务配置
|
# 构建定时任务配置
|
||||||
beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
"run-workspace-reflection": {
|
"run-workspace-reflection": {
|
||||||
"task": "app.tasks.workspace_reflection_task",
|
"task": "app.tasks.workspace_reflection_task",
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from . import (
|
|||||||
file_controller,
|
file_controller,
|
||||||
file_storage_controller,
|
file_storage_controller,
|
||||||
home_page_controller,
|
home_page_controller,
|
||||||
|
i18n_controller,
|
||||||
implicit_memory_controller,
|
implicit_memory_controller,
|
||||||
knowledge_controller,
|
knowledge_controller,
|
||||||
knowledgeshare_controller,
|
knowledgeshare_controller,
|
||||||
@@ -94,5 +95,6 @@ manager_router.include_router(memory_working_controller.router)
|
|||||||
manager_router.include_router(file_storage_controller.router)
|
manager_router.include_router(file_storage_controller.router)
|
||||||
manager_router.include_router(ontology_controller.router)
|
manager_router.include_router(ontology_controller.router)
|
||||||
manager_router.include_router(skill_controller.router)
|
manager_router.include_router(skill_controller.router)
|
||||||
|
manager_router.include_router(i18n_controller.router)
|
||||||
|
|
||||||
__all__ = ["manager_router"]
|
__all__ = ["manager_router"]
|
||||||
|
|||||||
@@ -1,10 +1,12 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import io
|
||||||
from typing import Optional, Annotated
|
from typing import Optional, Annotated
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
@@ -25,6 +27,7 @@ from app.services.app_service import AppService
|
|||||||
from app.services.app_statistics_service import AppStatisticsService
|
from app.services.app_statistics_service import AppStatisticsService
|
||||||
from app.services.workflow_import_service import WorkflowImportService
|
from app.services.workflow_import_service import WorkflowImportService
|
||||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||||
|
from app.services.app_dsl_service import AppDslService
|
||||||
|
|
||||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||||
logger = get_business_logger()
|
logger = get_business_logger()
|
||||||
@@ -50,6 +53,7 @@ def list_apps(
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
search: str | None = None,
|
search: str | None = None,
|
||||||
include_shared: bool = True,
|
include_shared: bool = True,
|
||||||
|
shared_only: bool = False,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
pagesize: int = 10,
|
pagesize: int = 10,
|
||||||
ids: Optional[str] = None,
|
ids: Optional[str] = None,
|
||||||
@@ -81,6 +85,7 @@ def list_apps(
|
|||||||
status=status,
|
status=status,
|
||||||
search=search,
|
search=search,
|
||||||
include_shared=include_shared,
|
include_shared=include_shared,
|
||||||
|
shared_only=shared_only,
|
||||||
page=page,
|
page=page,
|
||||||
pagesize=pagesize,
|
pagesize=pagesize,
|
||||||
)
|
)
|
||||||
@@ -90,6 +95,37 @@ def list_apps(
|
|||||||
return success(data=PageData(page=meta, items=items))
|
return success(data=PageData(page=meta, items=items))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def list_my_shared_out(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
shares = service.list_my_shared_out(workspace_id=workspace_id)
|
||||||
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def unshare_all_apps_to_workspace(
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Cancel all app shares from current workspace to a target workspace."""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
count = service.unshare_all_apps_to_workspace(
|
||||||
|
target_workspace_id=target_workspace_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}", summary="获取应用详情")
|
@router.get("/{app_id}", summary="获取应用详情")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_app(
|
def get_app(
|
||||||
@@ -158,6 +194,7 @@ def delete_app(
|
|||||||
def copy_app(
|
def copy_app(
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
new_name: Optional[str] = None,
|
new_name: Optional[str] = None,
|
||||||
|
payload: app_schema.CopyAppRequest = None,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user=Depends(get_current_user),
|
current_user=Depends(get_current_user),
|
||||||
):
|
):
|
||||||
@@ -169,6 +206,8 @@ def copy_app(
|
|||||||
- 不影响原应用
|
- 不影响原应用
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
# body takes precedence over query param for backward compatibility
|
||||||
|
new_name = (payload.new_name if payload else None) or new_name
|
||||||
logger.info(
|
logger.info(
|
||||||
"用户请求复制应用",
|
"用户请求复制应用",
|
||||||
extra={
|
extra={
|
||||||
@@ -218,6 +257,27 @@ def get_agent_config(
|
|||||||
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
return success(data=app_schema.AgentConfig.model_validate(cfg))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def get_opening(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
|
||||||
|
features = cfg.features or {}
|
||||||
|
if hasattr(features, "model_dump"):
|
||||||
|
features = features.model_dump()
|
||||||
|
opening = features.get("opening_statement", {})
|
||||||
|
return success(data=app_schema.OpeningResponse(
|
||||||
|
enabled=opening.get("enabled", False),
|
||||||
|
statement=opening.get("statement"),
|
||||||
|
suggested_questions=opening.get("suggested_questions", []),
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def publish_app(
|
def publish_app(
|
||||||
@@ -299,7 +359,8 @@ def share_app(
|
|||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
target_workspace_ids=payload.target_workspace_ids,
|
target_workspace_ids=payload.target_workspace_ids,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
workspace_id=workspace_id
|
workspace_id=workspace_id,
|
||||||
|
permission=payload.permission
|
||||||
)
|
)
|
||||||
|
|
||||||
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
data = [app_schema.AppShare.model_validate(s) for s in shares]
|
||||||
@@ -330,6 +391,32 @@ def unshare_app(
|
|||||||
return success(msg="应用分享已取消")
|
return success(msg="应用分享已取消")
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def update_share_permission(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
target_workspace_id: uuid.UUID,
|
||||||
|
payload: app_schema.UpdateSharePermissionRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""更新共享权限(readonly <-> editable)
|
||||||
|
|
||||||
|
- 只能修改自己工作空间应用的共享权限
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
share = service.update_share_permission(
|
||||||
|
app_id=app_id,
|
||||||
|
target_workspace_id=target_workspace_id,
|
||||||
|
permission=payload.permission,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=app_schema.AppShare.model_validate(share))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def list_app_shares(
|
def list_app_shares(
|
||||||
@@ -353,6 +440,46 @@ def list_app_shares(
|
|||||||
return success(data=data)
|
return success(data=data)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def remove_all_shared_apps_from_workspace(
|
||||||
|
source_workspace_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""Remove all shared apps from a specific source workspace (recipient operation)."""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
count = service.remove_all_shared_apps_from_workspace(
|
||||||
|
source_workspace_id=source_workspace_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
def remove_shared_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user=Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""被共享者从自己的工作空间移除共享应用
|
||||||
|
|
||||||
|
- 不会删除源应用,只删除共享记录
|
||||||
|
- 只能移除共享给自己工作空间的应用
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
service = app_service.AppService(db)
|
||||||
|
service.remove_shared_app(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg="已移除共享应用")
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
@router.post("/{app_id}/draft/run", summary="试运行 Agent(使用当前草稿配置)")
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
async def draft_run(
|
async def draft_run(
|
||||||
@@ -393,7 +520,7 @@ async def draft_run(
|
|||||||
# 提前验证和准备(在流式响应开始前完成)
|
# 提前验证和准备(在流式响应开始前完成)
|
||||||
from app.services.app_service import AppService
|
from app.services.app_service import AppService
|
||||||
from app.services.multi_agent_service import MultiAgentService
|
from app.services.multi_agent_service import MultiAgentService
|
||||||
from app.models import AgentConfig, ModelConfig
|
from app.models import AgentConfig, ModelConfig, AppRelease
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.services.draft_run_service import AgentRunService
|
from app.services.draft_run_service import AgentRunService
|
||||||
@@ -410,11 +537,12 @@ async def draft_run(
|
|||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
if payload.user_id is None:
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=str(current_user.id),
|
other_id=str(current_user.id),
|
||||||
original_user_id=str(current_user.id) # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
payload.user_id = str(new_end_user.id)
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
@@ -431,18 +559,29 @@ async def draft_run(
|
|||||||
service._check_agent_config(app_id)
|
service._check_agent_config(app_id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
agent_cfg = db.scalars(stmt).first()
|
is_shared = app.workspace_id != workspace_id
|
||||||
if not agent_cfg:
|
if is_shared:
|
||||||
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
agent_cfg = service._agent_config_from_release(release)
|
||||||
|
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
|
||||||
|
else:
|
||||||
|
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
|
||||||
|
agent_cfg = db.scalars(stmt).first()
|
||||||
|
if not agent_cfg:
|
||||||
|
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
|
||||||
# 3. 获取模型配置
|
# 3. 获取模型配置
|
||||||
model_config = None
|
model_config = None
|
||||||
if agent_cfg.default_model_config_id:
|
if agent_cfg.default_model_config_id:
|
||||||
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
|
||||||
if not model_config:
|
if not model_config:
|
||||||
from app.core.exceptions import ResourceNotFoundException
|
from app.core.exceptions import ResourceNotFoundException
|
||||||
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
@@ -598,7 +737,17 @@ async def draft_run(
|
|||||||
msg="多 Agent 任务执行成功"
|
msg="多 Agent 任务执行成功"
|
||||||
)
|
)
|
||||||
elif app.type == AppType.WORKFLOW: # 工作流
|
elif app.type == AppType.WORKFLOW: # 工作流
|
||||||
config = workflow_service.check_config(app_id)
|
# 共享应用:从最新发布版本读配置快照,而非草稿
|
||||||
|
is_shared = app.workspace_id != workspace_id
|
||||||
|
if is_shared:
|
||||||
|
if not app.current_release_id:
|
||||||
|
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
release = db.get(AppRelease, app.current_release_id)
|
||||||
|
if not release:
|
||||||
|
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
|
||||||
|
config = service._workflow_config_from_release(release)
|
||||||
|
else:
|
||||||
|
config = workflow_service.check_config(app_id)
|
||||||
# 3. 流式返回
|
# 3. 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -741,6 +890,16 @@ async def draft_run_compare(
|
|||||||
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
service._validate_app_accessible(app, workspace_id)
|
service._validate_app_accessible(app, workspace_id)
|
||||||
|
|
||||||
|
if payload.user_id is None:
|
||||||
|
# 先获取 app 的 workspace_id
|
||||||
|
end_user_repo = EndUserRepository(db)
|
||||||
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
|
app_id=app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
|
other_id=str(current_user.id),
|
||||||
|
)
|
||||||
|
payload.user_id = str(new_end_user.id)
|
||||||
|
|
||||||
# 2. 获取 Agent 配置
|
# 2. 获取 Agent 配置
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from app.models import AgentConfig
|
from app.models import AgentConfig
|
||||||
@@ -786,6 +945,13 @@ async def draft_run_compare(
|
|||||||
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
|
||||||
})
|
})
|
||||||
|
|
||||||
|
# 从 features 中读取功能开关(与 draft_run 保持一致)
|
||||||
|
features_config: dict = agent_cfg.features or {}
|
||||||
|
if hasattr(features_config, 'model_dump'):
|
||||||
|
features_config = features_config.model_dump()
|
||||||
|
web_search_feature = features_config.get("web_search", {})
|
||||||
|
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
|
||||||
|
|
||||||
# 流式返回
|
# 流式返回
|
||||||
if payload.stream:
|
if payload.stream:
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
@@ -797,11 +963,11 @@ async def draft_run_compare(
|
|||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id,
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
@@ -828,11 +994,11 @@ async def draft_run_compare(
|
|||||||
message=payload.message,
|
message=payload.message,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
conversation_id=payload.conversation_id,
|
conversation_id=payload.conversation_id,
|
||||||
user_id=payload.user_id or str(current_user.id),
|
user_id=payload.user_id,
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
web_search=True,
|
web_search=web_search,
|
||||||
memory=True,
|
memory=True,
|
||||||
parallel=payload.parallel,
|
parallel=payload.parallel,
|
||||||
timeout=payload.timeout or 60,
|
timeout=payload.timeout or 60,
|
||||||
@@ -1010,3 +1176,57 @@ def get_workspace_api_statistics(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return success(data=result)
|
return success(data=result)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def export_app(
|
||||||
|
app_id: uuid.UUID,
|
||||||
|
db: Annotated[Session, Depends(get_db)],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
|
release_id: Optional[uuid.UUID] = None
|
||||||
|
):
|
||||||
|
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
|
||||||
|
release_id: 指定发布版本id,不传则导出当前草稿配置。
|
||||||
|
"""
|
||||||
|
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
|
||||||
|
encoded = quote(filename, safe=".")
|
||||||
|
yaml_bytes = yaml_str.encode("utf-8")
|
||||||
|
file_stream = io.BytesIO(yaml_bytes)
|
||||||
|
file_stream.seek(0)
|
||||||
|
return StreamingResponse(
|
||||||
|
file_stream,
|
||||||
|
media_type="application/octet-stream; charset=utf-8",
|
||||||
|
headers={"Content-Disposition": f"attachment; filename={encoded}",
|
||||||
|
"Content-Length": str(len(yaml_bytes))}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import", summary="从 YAML 文件导入应用")
|
||||||
|
@cur_workspace_access_guard()
|
||||||
|
async def import_app(
|
||||||
|
file: UploadFile = File(...),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
|
||||||
|
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
|
||||||
|
"""
|
||||||
|
if not file.filename.lower().endswith((".yaml", ".yml")):
|
||||||
|
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
raw = (await file.read()).decode("utf-8")
|
||||||
|
dsl = yaml.safe_load(raw)
|
||||||
|
if not dsl or "app" not in dsl:
|
||||||
|
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
|
new_app, warnings = AppDslService(db).import_dsl(
|
||||||
|
dsl=dsl,
|
||||||
|
workspace_id=current_user.current_workspace_id,
|
||||||
|
tenant_id=current_user.tenant_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
|
||||||
|
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Callable
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -16,6 +17,7 @@ from app.core.exceptions import BusinessException
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.dependencies import get_current_user, oauth2_scheme
|
from app.dependencies import get_current_user, oauth2_scheme
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
|
||||||
# 获取专用日志器
|
# 获取专用日志器
|
||||||
auth_logger = get_auth_logger()
|
auth_logger = get_auth_logger()
|
||||||
@@ -26,7 +28,8 @@ router = APIRouter(tags=["Authentication"])
|
|||||||
@router.post("/token", response_model=ApiResponse)
|
@router.post("/token", response_model=ApiResponse)
|
||||||
async def login_for_access_token(
|
async def login_for_access_token(
|
||||||
form_data: TokenRequest,
|
form_data: TokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""用户登录获取token"""
|
"""用户登录获取token"""
|
||||||
auth_logger.info(f"用户登录请求: {form_data.email}")
|
auth_logger.info(f"用户登录请求: {form_data.email}")
|
||||||
@@ -40,10 +43,10 @@ async def login_for_access_token(
|
|||||||
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
|
||||||
|
|
||||||
if not invite_info.is_valid:
|
if not invite_info.is_valid:
|
||||||
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
|
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
|
||||||
|
|
||||||
if invite_info.email != form_data.email:
|
if invite_info.email != form_data.email:
|
||||||
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
|
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
|
||||||
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
|
||||||
try:
|
try:
|
||||||
# 尝试认证用户
|
# 尝试认证用户
|
||||||
@@ -69,7 +72,7 @@ async def login_for_access_token(
|
|||||||
elif e.code == BizCode.PASSWORD_ERROR:
|
elif e.code == BizCode.PASSWORD_ERROR:
|
||||||
# 用户存在但密码错误
|
# 用户存在但密码错误
|
||||||
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
|
||||||
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
|
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
|
||||||
else:
|
else:
|
||||||
# 其他认证失败情况,直接抛出
|
# 其他认证失败情况,直接抛出
|
||||||
raise
|
raise
|
||||||
@@ -82,7 +85,7 @@ async def login_for_access_token(
|
|||||||
except BusinessException as e:
|
except BusinessException as e:
|
||||||
|
|
||||||
# 其他认证失败情况,直接抛出
|
# 其他认证失败情况,直接抛出
|
||||||
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
|
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
|
||||||
|
|
||||||
# 创建 tokens
|
# 创建 tokens
|
||||||
access_token, access_token_id = security.create_access_token(subject=user.id)
|
access_token, access_token_id = security.create_access_token(subject=user.id)
|
||||||
@@ -110,14 +113,15 @@ async def login_for_access_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg="登录成功"
|
msg=t("auth.login.success")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=ApiResponse)
|
@router.post("/refresh", response_model=ApiResponse)
|
||||||
async def refresh_token(
|
async def refresh_token(
|
||||||
refresh_request: RefreshTokenRequest,
|
refresh_request: RefreshTokenRequest,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""刷新token"""
|
"""刷新token"""
|
||||||
auth_logger.info("收到token刷新请求")
|
auth_logger.info("收到token刷新请求")
|
||||||
@@ -125,18 +129,18 @@ async def refresh_token(
|
|||||||
# 验证 refresh token
|
# 验证 refresh token
|
||||||
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
userId = security.verify_token(refresh_request.refresh_token, "refresh")
|
||||||
if not userId:
|
if not userId:
|
||||||
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
|
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 检查用户是否存在
|
# 检查用户是否存在
|
||||||
user = auth_service.get_user_by_id(db, userId)
|
user = auth_service.get_user_by_id(db, userId)
|
||||||
if not user:
|
if not user:
|
||||||
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
|
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
|
||||||
|
|
||||||
# 检查 refresh token 黑名单
|
# 检查 refresh token 黑名单
|
||||||
if settings.ENABLE_SINGLE_SESSION:
|
if settings.ENABLE_SINGLE_SESSION:
|
||||||
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
|
||||||
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
|
||||||
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
|
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
|
||||||
|
|
||||||
# 生成新 tokens
|
# 生成新 tokens
|
||||||
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
|
||||||
@@ -167,7 +171,7 @@ async def refresh_token(
|
|||||||
expires_at=access_expires_at,
|
expires_at=access_expires_at,
|
||||||
refresh_expires_at=refresh_expires_at
|
refresh_expires_at=refresh_expires_at
|
||||||
),
|
),
|
||||||
msg="token刷新成功"
|
msg=t("auth.token.refresh_success")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -175,14 +179,15 @@ async def refresh_token(
|
|||||||
async def logout(
|
async def logout(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""登出当前用户:加入token黑名单并清理会话"""
|
"""登出当前用户:加入token黑名单并清理会话"""
|
||||||
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
auth_logger.info(f"用户 {current_user.username} 请求登出")
|
||||||
|
|
||||||
token_id = security.get_token_id(token)
|
token_id = security.get_token_id(token)
|
||||||
if not token_id:
|
if not token_id:
|
||||||
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
|
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
|
||||||
|
|
||||||
# 加入黑名单
|
# 加入黑名单
|
||||||
await SessionService.blacklist_token(token_id)
|
await SessionService.blacklist_token(token_id)
|
||||||
@@ -192,5 +197,5 @@ async def logout(
|
|||||||
await SessionService.clear_user_session(current_user.username)
|
await SessionService.clear_user_session(current_user.username)
|
||||||
|
|
||||||
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
auth_logger.info(f"用户 {current_user.username} 登出成功")
|
||||||
return success(msg="登出成功")
|
return success(msg=t("auth.logout.success"))
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
|
||||||
from fastapi.responses import FileResponse, RedirectResponse
|
from fastapi.responses import FileResponse, RedirectResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -47,6 +47,19 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _match_scheme(request: Request, url: str) -> str:
|
||||||
|
"""
|
||||||
|
将 presigned URL 的协议替换为与当前请求一致的协议(http/https)。
|
||||||
|
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
|
||||||
|
"""
|
||||||
|
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
|
||||||
|
if url.startswith("http://") and incoming_scheme == "https":
|
||||||
|
return "https://" + url[7:]
|
||||||
|
if url.startswith("https://") and incoming_scheme == "http":
|
||||||
|
return "http://" + url[8:]
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
@router.post("/files", response_model=ApiResponse)
|
@router.post("/files", response_model=ApiResponse)
|
||||||
async def upload_file(
|
async def upload_file(
|
||||||
file: UploadFile = File(...),
|
file: UploadFile = File(...),
|
||||||
@@ -78,7 +91,7 @@ async def upload_file(
|
|||||||
|
|
||||||
if file_size > settings.MAX_FILE_SIZE:
|
if file_size > settings.MAX_FILE_SIZE:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
|
||||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -159,7 +172,6 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
# Get share and release info from share_token
|
# Get share and release info from share_token
|
||||||
service = ReleaseShareService(db)
|
service = ReleaseShareService(db)
|
||||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
|
||||||
|
|
||||||
# Get share object to access app_id
|
# Get share object to access app_id
|
||||||
share = service.repo.get_by_share_token(share_data.share_token)
|
share = service.repo.get_by_share_token(share_data.share_token)
|
||||||
@@ -280,6 +292,7 @@ async def upload_file_with_share_token(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}", response_model=Any)
|
@router.get("/files/{file_id}", response_model=Any)
|
||||||
async def download_file(
|
async def download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
@@ -327,6 +340,7 @@ async def download_file(
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@@ -400,6 +414,7 @@ async def delete_file(
|
|||||||
|
|
||||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||||
async def get_file_url(
|
async def get_file_url(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = None,
|
expires: int = None,
|
||||||
permanent: bool = False,
|
permanent: bool = False,
|
||||||
@@ -463,6 +478,7 @@ async def get_file_url(
|
|||||||
else:
|
else:
|
||||||
# For remote storage (OSS/S3), get presigned URL
|
# For remote storage (OSS/S3), get presigned URL
|
||||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||||
|
url = _match_scheme(request, url)
|
||||||
|
|
||||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||||
return success(
|
return success(
|
||||||
@@ -482,8 +498,54 @@ async def get_file_url(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
|
||||||
|
async def get_permanent_file_url(
|
||||||
|
file_id: uuid.UUID,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
获取文件的永久公开 URL(无过期时间)。
|
||||||
|
|
||||||
|
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
|
||||||
|
- 远程存储(OSS/S3):返回 bucket 公读地址(需 bucket 已配置公共读权限)
|
||||||
|
"""
|
||||||
|
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||||
|
if not file_metadata:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
|
||||||
|
|
||||||
|
if file_metadata.status != "completed":
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"File upload not completed, status: {file_metadata.status}")
|
||||||
|
|
||||||
|
file_key = file_metadata.file_key
|
||||||
|
storage = storage_service.storage
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(storage, LocalStorage):
|
||||||
|
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
|
||||||
|
else:
|
||||||
|
url = await storage.get_permanent_url(file_key)
|
||||||
|
if not url:
|
||||||
|
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="Permanent URL not supported for current storage backend")
|
||||||
|
|
||||||
|
api_logger.info(f"Generated permanent URL: file_id={file_id}")
|
||||||
|
return success(
|
||||||
|
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
|
||||||
|
msg="Permanent file URL generated successfully"
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to generate permanent URL: {e}")
|
||||||
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to generate permanent URL: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/public/{file_id}", response_model=Any)
|
@router.get("/public/{file_id}", response_model=Any)
|
||||||
async def public_download_file(
|
async def public_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
expires: int = 0,
|
expires: int = 0,
|
||||||
signature: str = "",
|
signature: str = "",
|
||||||
@@ -555,6 +617,7 @@ async def public_download_file(
|
|||||||
# For remote storage, redirect to presigned URL
|
# For remote storage, redirect to presigned URL
|
||||||
try:
|
try:
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
@@ -566,6 +629,7 @@ async def public_download_file(
|
|||||||
|
|
||||||
@router.get("/permanent/{file_id}", response_model=Any)
|
@router.get("/permanent/{file_id}", response_model=Any)
|
||||||
async def permanent_download_file(
|
async def permanent_download_file(
|
||||||
|
request: Request,
|
||||||
file_id: uuid.UUID,
|
file_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||||
@@ -625,6 +689,7 @@ async def permanent_download_file(
|
|||||||
try:
|
try:
|
||||||
# Use a very long expiration (7 days max for most cloud providers)
|
# Use a very long expiration (7 days max for most cloud providers)
|
||||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||||
|
presigned_url = _match_scheme(request, presigned_url)
|
||||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||||
|
|||||||
833
api/app/controllers/i18n_controller.py
Normal file
833
api/app/controllers/i18n_controller.py
Normal file
@@ -0,0 +1,833 @@
|
|||||||
|
"""
|
||||||
|
I18n Management API Controller
|
||||||
|
|
||||||
|
This module provides management APIs for:
|
||||||
|
- Language management (list, get, add, update languages)
|
||||||
|
- Translation management (get, update, reload translations)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
from app.core.logging_config import get_api_logger
|
||||||
|
from app.core.response_utils import success
|
||||||
|
from app.db import get_db
|
||||||
|
from app.dependencies import get_current_user, get_current_superuser
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
from app.models.user_model import User
|
||||||
|
from app.schemas.i18n_schema import (
|
||||||
|
LanguageInfo,
|
||||||
|
LanguageListResponse,
|
||||||
|
LanguageCreateRequest,
|
||||||
|
LanguageUpdateRequest,
|
||||||
|
TranslationResponse,
|
||||||
|
TranslationUpdateRequest,
|
||||||
|
MissingTranslationsResponse,
|
||||||
|
ReloadResponse
|
||||||
|
)
|
||||||
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
|
||||||
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/i18n",
|
||||||
|
tags=["I18n Management"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Language Management APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/languages", response_model=ApiResponse)
|
||||||
|
def get_languages(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get list of all supported languages.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of language information including code, name, and status
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get languages request from user: {current_user.username}")
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Get available locales from translation service
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
|
||||||
|
# Build language info list
|
||||||
|
languages = []
|
||||||
|
for locale in available_locales:
|
||||||
|
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||||
|
|
||||||
|
# Get native names
|
||||||
|
native_names = {
|
||||||
|
"zh": "中文(简体)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "日本語",
|
||||||
|
"ko": "한국어",
|
||||||
|
"fr": "Français",
|
||||||
|
"de": "Deutsch",
|
||||||
|
"es": "Español"
|
||||||
|
}
|
||||||
|
|
||||||
|
language_info = LanguageInfo(
|
||||||
|
code=locale,
|
||||||
|
name=f"{locale.upper()}",
|
||||||
|
native_name=native_names.get(locale, locale),
|
||||||
|
is_enabled=is_enabled,
|
||||||
|
is_default=is_default
|
||||||
|
)
|
||||||
|
languages.append(language_info)
|
||||||
|
|
||||||
|
response = LanguageListResponse(languages=languages)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning {len(languages)} languages")
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/languages/{locale}", response_model=ApiResponse)
|
||||||
|
def get_language(
|
||||||
|
locale: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get information about a specific language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code (e.g., 'zh', 'en')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language information
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build language info
|
||||||
|
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
|
||||||
|
|
||||||
|
native_names = {
|
||||||
|
"zh": "中文(简体)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "日本語",
|
||||||
|
"ko": "한국어",
|
||||||
|
"fr": "Français",
|
||||||
|
"de": "Deutsch",
|
||||||
|
"es": "Español"
|
||||||
|
}
|
||||||
|
|
||||||
|
language_info = LanguageInfo(
|
||||||
|
code=locale,
|
||||||
|
name=f"{locale.upper()}",
|
||||||
|
native_name=native_names.get(locale, locale),
|
||||||
|
is_enabled=is_enabled,
|
||||||
|
is_default=is_default
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning language info for: {locale}")
|
||||||
|
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/languages", response_model=ApiResponse)
|
||||||
|
def add_language(
|
||||||
|
request: LanguageCreateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add a new language (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual language addition
|
||||||
|
requires creating translation files in the locales directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Language creation request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Add language request: code={request.code}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if language already exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if request.code in available_locales:
|
||||||
|
api_logger.warning(f"Language already exists: {request.code}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=t("i18n.language.already_exists", locale=request.code)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual language addition requires creating translation files
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Language addition validated: {request.code}. "
|
||||||
|
"Translation files need to be created manually."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
msg=t(
|
||||||
|
"i18n.language.add_instructions",
|
||||||
|
locale=request.code,
|
||||||
|
dir=settings.I18N_CORE_LOCALES_DIR
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/languages/{locale}", response_model=ApiResponse)
|
||||||
|
def update_language(
|
||||||
|
locale: str,
|
||||||
|
request: LanguageUpdateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update language configuration (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual configuration
|
||||||
|
changes require updating environment variables or config files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
request: Language update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Update language request: locale={locale}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if language exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual configuration changes require updating settings
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Language update validated: {locale}. "
|
||||||
|
"Configuration changes require environment variable updates."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(msg=t("i18n.language.update_instructions", locale=locale))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Translation Management APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/translations", response_model=ApiResponse)
|
||||||
|
def get_all_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get all translations for all or specific locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All translations organized by locale and namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get all translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
if locale:
|
||||||
|
# Get translations for specific locale
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
translations = {
|
||||||
|
locale: translation_service._cache.get(locale, {})
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Get all translations
|
||||||
|
translations = translation_service._cache
|
||||||
|
|
||||||
|
response = TranslationResponse(translations=translations)
|
||||||
|
|
||||||
|
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/{locale}", response_model=ApiResponse)
|
||||||
|
def get_locale_translations(
|
||||||
|
locale: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get all translations for a specific locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
All translations for the locale organized by namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get locale translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
translations = translation_service._cache.get(locale, {})
|
||||||
|
|
||||||
|
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
|
||||||
|
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
|
||||||
|
def get_namespace_translations(
|
||||||
|
locale: str,
|
||||||
|
namespace: str,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get translations for a specific namespace in a locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
namespace: Translation namespace (e.g., 'common', 'auth')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translations for the specified namespace
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get namespace translations request: locale={locale}, "
|
||||||
|
f"namespace={namespace}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get namespace translations
|
||||||
|
locale_translations = translation_service._cache.get(locale, {})
|
||||||
|
namespace_translations = locale_translations.get(namespace, {})
|
||||||
|
|
||||||
|
if not namespace_translations:
|
||||||
|
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Returning translations for namespace: {namespace} in locale: {locale}"
|
||||||
|
)
|
||||||
|
return success(
|
||||||
|
data={
|
||||||
|
"locale": locale,
|
||||||
|
"namespace": namespace,
|
||||||
|
"translations": namespace_translations
|
||||||
|
},
|
||||||
|
msg=t("common.success.retrieved")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
|
||||||
|
def update_translation(
|
||||||
|
locale: str,
|
||||||
|
key: str,
|
||||||
|
request: TranslationUpdateRequest,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update a single translation (admin only).
|
||||||
|
|
||||||
|
Note: This endpoint validates the request but actual translation updates
|
||||||
|
require modifying translation files in the locales directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Language code
|
||||||
|
key: Translation key (format: "namespace.key.subkey")
|
||||||
|
request: Translation update request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Update translation request: locale={locale}, key={key}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
# Check if locale exists
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
if locale not in available_locales:
|
||||||
|
api_logger.warning(f"Language not found: {locale}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=t("i18n.language.not_found", locale=locale)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate key format
|
||||||
|
if "." not in key:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=t("i18n.translation.invalid_key_format", key=key)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Note: Actual translation updates require modifying JSON files
|
||||||
|
# This endpoint serves as a validation and documentation point
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Translation update validated: {locale}/{key}. "
|
||||||
|
"Translation files need to be updated manually."
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(
|
||||||
|
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/translations/missing", response_model=ApiResponse)
|
||||||
|
def get_missing_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get list of missing translations.
|
||||||
|
|
||||||
|
Compares translations across locales to find missing keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to check (defaults to checking all non-default locales)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of missing translation keys
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get missing translations request: locale={locale}, user={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
|
||||||
|
# Get default locale translations as reference
|
||||||
|
default_translations = translation_service._cache.get(default_locale, {})
|
||||||
|
|
||||||
|
# Collect all keys from default locale
|
||||||
|
def collect_keys(data, prefix=""):
|
||||||
|
keys = []
|
||||||
|
for key, value in data.items():
|
||||||
|
full_key = f"{prefix}.{key}" if prefix else key
|
||||||
|
if isinstance(value, dict):
|
||||||
|
keys.extend(collect_keys(value, full_key))
|
||||||
|
else:
|
||||||
|
keys.append(full_key)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
default_keys = set()
|
||||||
|
for namespace, translations in default_translations.items():
|
||||||
|
namespace_keys = collect_keys(translations, namespace)
|
||||||
|
default_keys.update(namespace_keys)
|
||||||
|
|
||||||
|
# Find missing keys in target locale(s)
|
||||||
|
missing_by_locale = {}
|
||||||
|
|
||||||
|
target_locales = [locale] if locale else [
|
||||||
|
loc for loc in available_locales if loc != default_locale
|
||||||
|
]
|
||||||
|
|
||||||
|
for target_locale in target_locales:
|
||||||
|
if target_locale not in available_locales:
|
||||||
|
continue
|
||||||
|
|
||||||
|
target_translations = translation_service._cache.get(target_locale, {})
|
||||||
|
target_keys = set()
|
||||||
|
|
||||||
|
for namespace, translations in target_translations.items():
|
||||||
|
namespace_keys = collect_keys(translations, namespace)
|
||||||
|
target_keys.update(namespace_keys)
|
||||||
|
|
||||||
|
missing_keys = default_keys - target_keys
|
||||||
|
if missing_keys:
|
||||||
|
missing_by_locale[target_locale] = sorted(list(missing_keys))
|
||||||
|
|
||||||
|
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
|
||||||
|
|
||||||
|
total_missing = sum(len(keys) for keys in missing_by_locale.values())
|
||||||
|
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
|
||||||
|
|
||||||
|
return success(data=response.dict(), msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/reload", response_model=ApiResponse)
|
||||||
|
def reload_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Trigger hot reload of translation files (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to reload (defaults to reloading all locales)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Reload status and statistics
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Reload translations request: locale={locale or 'all'}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if not settings.I18N_ENABLE_HOT_RELOAD:
|
||||||
|
api_logger.warning("Hot reload is disabled in configuration")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=t("i18n.reload.disabled")
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Reload translations
|
||||||
|
translation_service.reload(locale)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
available_locales = translation_service.get_available_locales()
|
||||||
|
reloaded_locales = [locale] if locale else available_locales
|
||||||
|
|
||||||
|
response = ReloadResponse(
|
||||||
|
success=True,
|
||||||
|
reloaded_locales=reloaded_locales,
|
||||||
|
total_locales=len(available_locales)
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=response.dict(), msg=t("i18n.reload.success"))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"Failed to reload translations: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=t("i18n.reload.failed", error=str(e))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Performance Monitoring APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/metrics", response_model=ApiResponse)
|
||||||
|
def get_metrics(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get i18n performance metrics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Performance metrics including:
|
||||||
|
- Request counts
|
||||||
|
- Missing translations
|
||||||
|
- Timing statistics
|
||||||
|
- Locale usage
|
||||||
|
- Error counts
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
metrics = translation_service.get_metrics_summary()
|
||||||
|
|
||||||
|
api_logger.info("Returning i18n metrics")
|
||||||
|
return success(data=metrics, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/cache", response_model=ApiResponse)
|
||||||
|
def get_cache_stats(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get cache statistics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache statistics including:
|
||||||
|
- Hit/miss rates
|
||||||
|
- LRU cache performance
|
||||||
|
- Loaded locales
|
||||||
|
- Memory usage
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get cache stats request: admin={current_user.username}")
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
cache_stats = translation_service.get_cache_stats()
|
||||||
|
memory_usage = translation_service.get_memory_usage()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"cache": cache_stats,
|
||||||
|
"memory": memory_usage
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info("Returning cache statistics")
|
||||||
|
return success(data=data, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/metrics/prometheus")
|
||||||
|
def get_prometheus_metrics(
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get metrics in Prometheus format (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prometheus-formatted metrics as plain text
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
from app.i18n.metrics import get_metrics
|
||||||
|
metrics = get_metrics()
|
||||||
|
prometheus_output = metrics.export_prometheus()
|
||||||
|
|
||||||
|
from fastapi.responses import PlainTextResponse
|
||||||
|
return PlainTextResponse(content=prometheus_output)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/metrics/reset", response_model=ApiResponse)
|
||||||
|
def reset_metrics(
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Reset all metrics (admin only).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(f"Reset metrics request: admin={current_user.username}")
|
||||||
|
|
||||||
|
from app.i18n.metrics import get_metrics
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.reset()
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_service.cache.reset_stats()
|
||||||
|
|
||||||
|
api_logger.info("Metrics reset completed")
|
||||||
|
return success(msg=t("i18n.metrics.reset_success"))
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# Missing Translation Logging and Reporting APIs
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
@router.get("/logs/missing", response_model=ApiResponse)
|
||||||
|
def get_missing_translation_logs(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
limit: Optional[int] = 100,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get missing translation logs (admin only).
|
||||||
|
|
||||||
|
Returns logged missing translations with context information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
limit: Maximum number of entries to return (default: 100)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Missing translation logs with context
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Get missing translation logs request: locale={locale}, "
|
||||||
|
f"limit={limit}, admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Get missing translations
|
||||||
|
missing_translations = translation_logger.get_missing_translations(locale)
|
||||||
|
|
||||||
|
# Get missing with context
|
||||||
|
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
|
||||||
|
|
||||||
|
# Get statistics
|
||||||
|
statistics = translation_logger.get_statistics()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"missing_translations": missing_translations,
|
||||||
|
"recent_context": missing_with_context,
|
||||||
|
"statistics": statistics
|
||||||
|
}
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Returning {statistics['total_missing']} missing translations"
|
||||||
|
)
|
||||||
|
return success(data=data, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/logs/missing/report", response_model=ApiResponse)
|
||||||
|
def generate_missing_translation_report(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Generate a comprehensive missing translation report (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Comprehensive report with missing translations and statistics
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Generate missing translation report request: locale={locale}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Generate report
|
||||||
|
report = translation_logger.generate_report(locale)
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"Generated report with {report['total_missing']} missing translations"
|
||||||
|
)
|
||||||
|
return success(data=report, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logs/missing/export", response_model=ApiResponse)
|
||||||
|
def export_missing_translations(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Export missing translations to JSON file (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale filter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Export status and file path
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Export missing translations request: locale={locale}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Generate filename with timestamp
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
locale_suffix = f"_{locale}" if locale else "_all"
|
||||||
|
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
|
||||||
|
|
||||||
|
# Export to file
|
||||||
|
translation_logger.export_to_json(output_file)
|
||||||
|
|
||||||
|
api_logger.info(f"Missing translations exported to: {output_file}")
|
||||||
|
return success(
|
||||||
|
data={"file_path": output_file},
|
||||||
|
msg=t("i18n.logs.export_success", file=output_file)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/logs/missing", response_model=ApiResponse)
|
||||||
|
def clear_missing_translation_logs(
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
t: Callable = Depends(get_translator),
|
||||||
|
current_user: User = Depends(get_current_superuser)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Clear missing translation logs (admin only).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Optional locale to clear (clears all if not specified)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Success message
|
||||||
|
"""
|
||||||
|
api_logger.info(
|
||||||
|
f"Clear missing translation logs request: locale={locale or 'all'}, "
|
||||||
|
f"admin={current_user.username}"
|
||||||
|
)
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
translation_logger = translation_service.translation_logger
|
||||||
|
|
||||||
|
# Clear logs
|
||||||
|
translation_logger.clear(locale)
|
||||||
|
|
||||||
|
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
|
||||||
|
return success(msg=t("i18n.logs.clear_success"))
|
||||||
@@ -19,7 +19,7 @@ from app.models import mcp_market_config_model
|
|||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.schemas import mcp_market_config_schema
|
from app.schemas import mcp_market_config_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services import mcp_market_config_service
|
from app.services import mcp_market_config_service, mcp_market_service
|
||||||
|
|
||||||
# Obtain a dedicated API logger
|
# Obtain a dedicated API logger
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -55,6 +55,12 @@ async def get_mcp_servers(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail="The paging parameter must be greater than 0"
|
detail="The paging parameter must be greater than 0"
|
||||||
)
|
)
|
||||||
|
if page * pagesize > 100:
|
||||||
|
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
|
||||||
|
)
|
||||||
|
|
||||||
# 2. Query mcp market config information from the database
|
# 2. Query mcp market config information from the database
|
||||||
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
|
||||||
@@ -64,14 +70,16 @@ async def get_mcp_servers(
|
|||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(
|
api_logger.warning(
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Execute paged query
|
# 3. Execute paged query
|
||||||
api = MCPApi()
|
|
||||||
token = db_mcp_market_config.token
|
token = db_mcp_market_config.token
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="MCP market config token is not configured"
|
||||||
|
)
|
||||||
|
api = MCPApi()
|
||||||
api.login(token)
|
api.login(token)
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
@@ -115,6 +123,17 @@ async def get_mcp_servers(
|
|||||||
"has_next": True if page * pagesize < total else False
|
"has_next": True if page * pagesize < total else False
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
# 5. Update mck_market.mcp_count
|
||||||
|
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
|
||||||
|
if not db_mcp_market:
|
||||||
|
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="The mcp market does not exist or access is denied"
|
||||||
|
)
|
||||||
|
db_mcp_market.mcp_count = total
|
||||||
|
db.commit()
|
||||||
|
db.refresh(db_mcp_market)
|
||||||
return success(data=result, msg="Query of mcp servers list successful")
|
return success(data=result, msg="Query of mcp servers list successful")
|
||||||
|
|
||||||
|
|
||||||
@@ -140,14 +159,16 @@ async def get_operational_mcp_servers(
|
|||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(
|
api_logger.warning(
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Execute paged query
|
# 2. Execute paged query
|
||||||
api = MCPApi()
|
|
||||||
token = db_mcp_market_config.token
|
token = db_mcp_market_config.token
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="MCP market config token is not configured"
|
||||||
|
)
|
||||||
|
api = MCPApi()
|
||||||
api.login(token)
|
api.login(token)
|
||||||
|
|
||||||
url = f'{api.mcp_base_url}/operational'
|
url = f'{api.mcp_base_url}/operational'
|
||||||
@@ -198,14 +219,16 @@ async def get_mcp_server(
|
|||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(
|
api_logger.warning(
|
||||||
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Get detailed information for a specific MCP Server
|
# 2. Get detailed information for a specific MCP Server
|
||||||
api = MCPApi()
|
|
||||||
token = db_mcp_market_config.token
|
token = db_mcp_market_config.token
|
||||||
|
if not token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="MCP market config token is not configured"
|
||||||
|
)
|
||||||
|
api = MCPApi()
|
||||||
api.login(token)
|
api.login(token)
|
||||||
|
|
||||||
result = api.get_mcp_server(server_id=server_id)
|
result = api.get_mcp_server(server_id=server_id)
|
||||||
@@ -226,7 +249,26 @@ async def create_mcp_market_config(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
|
||||||
# 1. Check if the mcp market name already exists
|
# 1. Validate token can access ModelScope MCP market
|
||||||
|
if not create_data.token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Token is required to access ModelScope MCP market"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
api = MCPApi()
|
||||||
|
api.login(create_data.token)
|
||||||
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
|
cookies = api.get_cookies(create_data.token)
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||||
|
raise_for_http_status(r)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||||
|
)
|
||||||
|
# 2. Check if the mcp market name already exists
|
||||||
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
|
||||||
if db_mcp_market_config_exist:
|
if db_mcp_market_config_exist:
|
||||||
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
|
||||||
@@ -234,6 +276,30 @@ async def create_mcp_market_config(
|
|||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
|
||||||
)
|
)
|
||||||
|
# 2. verify token
|
||||||
|
create_data.status = 1
|
||||||
|
try:
|
||||||
|
api = MCPApi()
|
||||||
|
token = create_data.token
|
||||||
|
api.login(token)
|
||||||
|
|
||||||
|
body = {
|
||||||
|
'filter': {},
|
||||||
|
'page_number': 1,
|
||||||
|
'page_size': 20,
|
||||||
|
'search': ""
|
||||||
|
}
|
||||||
|
cookies = api.get_cookies(token)
|
||||||
|
r = api.session.put(
|
||||||
|
url=api.mcp_base_url,
|
||||||
|
headers=api.builder_headers(api.headers),
|
||||||
|
json=body,
|
||||||
|
cookies=cookies)
|
||||||
|
raise_for_http_status(r)
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
api_logger.error(f"Failed to get MCP servers: {str(e)}")
|
||||||
|
create_data.status = 0
|
||||||
|
# 3. create mcp_market_config
|
||||||
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
|
||||||
@@ -262,10 +328,7 @@ async def get_mcp_market_config(
|
|||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||||
@@ -295,10 +358,7 @@ async def get_mcp_market_config_by_mcp_market_id(
|
|||||||
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
|
||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or access is denied"
|
|
||||||
)
|
|
||||||
|
|
||||||
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||||
@@ -324,12 +384,25 @@ async def update_mcp_market_config(
|
|||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(
|
api_logger.warning(
|
||||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Update fields (only update non-null fields)
|
# 2. Validate new token if provided
|
||||||
|
if update_data.token is not None:
|
||||||
|
try:
|
||||||
|
api = MCPApi()
|
||||||
|
api.login(update_data.token)
|
||||||
|
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
|
||||||
|
cookies = api.get_cookies(update_data.token)
|
||||||
|
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
|
||||||
|
raise_for_http_status(r)
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Update fields (only update non-null fields)
|
||||||
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
|
||||||
update_dict = update_data.dict(exclude_unset=True)
|
update_dict = update_data.dict(exclude_unset=True)
|
||||||
updated_fields = []
|
updated_fields = []
|
||||||
@@ -344,7 +417,7 @@ async def update_mcp_market_config(
|
|||||||
if updated_fields:
|
if updated_fields:
|
||||||
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
|
||||||
|
|
||||||
# 3. Save to database
|
# 4. Save to database
|
||||||
try:
|
try:
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(db_mcp_market_config)
|
db.refresh(db_mcp_market_config)
|
||||||
@@ -357,7 +430,7 @@ async def update_mcp_market_config(
|
|||||||
detail=f"The mcp market config update failed: {str(e)}"
|
detail=f"The mcp market config update failed: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. Return the updated mcp market config
|
# 5. Return the updated mcp market config
|
||||||
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
|
||||||
msg="The mcp market config information updated successfully")
|
msg="The mcp market config information updated successfully")
|
||||||
|
|
||||||
@@ -381,10 +454,7 @@ async def delete_mcp_market_config(
|
|||||||
if not db_mcp_market_config:
|
if not db_mcp_market_config:
|
||||||
api_logger.warning(
|
api_logger.warning(
|
||||||
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
|
||||||
raise HTTPException(
|
return success(msg='The mcp market config does not exist or access is denied')
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="The mcp market config does not exist or you do not have permission to access it"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Deleting mcp market config
|
# 2. Deleting mcp market config
|
||||||
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -149,6 +150,21 @@ async def get_workspace_end_users(
|
|||||||
|
|
||||||
return {uid: {"total": 0} for uid in end_user_ids}
|
return {uid: {"total": 0} for uid in end_user_ids}
|
||||||
|
|
||||||
|
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||||
|
try:
|
||||||
|
from app.celery_app import celery_app as _celery_app
|
||||||
|
_celery_app.send_task(
|
||||||
|
"app.tasks.init_implicit_emotions_for_users",
|
||||||
|
kwargs={"end_user_ids": end_user_ids},
|
||||||
|
)
|
||||||
|
_celery_app.send_task(
|
||||||
|
"app.tasks.init_interest_distribution_for_users",
|
||||||
|
kwargs={"end_user_ids": end_user_ids},
|
||||||
|
)
|
||||||
|
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||||
|
|
||||||
# 并发执行配置查询和记忆数量查询
|
# 并发执行配置查询和记忆数量查询
|
||||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||||
get_memory_configs(),
|
get_memory_configs(),
|
||||||
@@ -178,6 +194,14 @@ async def get_workspace_end_users(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
|
||||||
|
|
||||||
|
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||||
|
try:
|
||||||
|
from app.tasks import init_community_clustering_for_users
|
||||||
|
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
|
||||||
|
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
|
||||||
return success(data=result, msg="宿主列表获取成功")
|
return success(data=result, msg="宿主列表获取成功")
|
||||||
|
|
||||||
@@ -387,14 +411,15 @@ def get_current_user_rag_total_num(
|
|||||||
@router.get("/rag_content", response_model=ApiResponse)
|
@router.get("/rag_content", response_model=ApiResponse)
|
||||||
def get_rag_content(
|
def get_rag_content(
|
||||||
end_user_id: str = Query(..., description="宿主ID"),
|
end_user_id: str = Query(..., description="宿主ID"),
|
||||||
limit: int = Query(15, description="返回记录数"),
|
page: int = Query(1, gt=0, description="页码,从1开始"),
|
||||||
|
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取当前宿主知识库中的chunk内容
|
获取当前宿主知识库中的chunk内容(分页)
|
||||||
"""
|
"""
|
||||||
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
|
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
|
||||||
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
return success(data=data, msg="宿主RAGchunk数据获取成功")
|
||||||
|
|
||||||
|
|
||||||
@@ -407,25 +432,17 @@ async def get_chunk_summary_tag(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取chunk总结、提取的标签和人物形象
|
读取RAG摘要、标签和人物形象(纯读库,不触发生成)。
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
{
|
{
|
||||||
"summary": "chunk内容的总结",
|
"summary": "用户摘要",
|
||||||
"tags": [
|
"tags": [{"tag": "标签1", "frequency": 5}, ...],
|
||||||
{"tag": "标签1", "frequency": 5},
|
"personas": ["产品设计师", ...],
|
||||||
{"tag": "标签2", "frequency": 3},
|
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||||
...
|
|
||||||
],
|
|
||||||
"personas": [
|
|
||||||
"产品设计师",
|
|
||||||
"旅行爱好者",
|
|
||||||
"摄影发烧友",
|
|
||||||
...
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk摘要、标签和人物形象")
|
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG摘要/标签/人物形象")
|
||||||
|
|
||||||
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
data = await memory_dashboard_service.get_chunk_summary_and_tags(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -435,8 +452,7 @@ async def get_chunk_summary_tag(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
|
return success(data=data, msg="获取成功")
|
||||||
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/chunk_insight", response_model=ApiResponse)
|
@router.get("/chunk_insight", response_model=ApiResponse)
|
||||||
@@ -447,14 +463,18 @@ async def get_chunk_insight(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取chunk的洞察内容
|
读取RAG洞察报告(纯读库,不触发生成)。
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
{
|
{
|
||||||
"insight": "对chunk内容的深度洞察分析"
|
"insight": "总体概述",
|
||||||
|
"behavior_pattern": "行为模式",
|
||||||
|
"key_findings": "关键发现",
|
||||||
|
"growth_trajectory": "成长轨迹",
|
||||||
|
"generated": true/false // false表示尚未生产,请调用 /generate_rag_profile
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id} 的chunk洞察")
|
api_logger.info(f"用户 {current_user.username} 读取宿主 {end_user_id} 的RAG洞察")
|
||||||
|
|
||||||
data = await memory_dashboard_service.get_chunk_insight(
|
data = await memory_dashboard_service.get_chunk_insight(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -463,8 +483,37 @@ async def get_chunk_insight(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info("成功获取chunk洞察")
|
return success(data=data, msg="获取成功")
|
||||||
return success(data=data, msg="chunk洞察获取成功")
|
|
||||||
|
|
||||||
|
class GenerateRagProfileRequest(BaseModel):
|
||||||
|
end_user_id: str = Field(..., description="宿主ID")
|
||||||
|
limit: int = Field(15, description="参与生成的chunk数量上限")
|
||||||
|
max_tags: int = Field(10, description="最大标签数量")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/generate_rag_profile", response_model=ApiResponse)
|
||||||
|
async def generate_rag_profile(
|
||||||
|
body: GenerateRagProfileRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
生产接口:为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
|
||||||
|
每次请求都会重新生成,覆盖已有数据。
|
||||||
|
"""
|
||||||
|
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
|
||||||
|
|
||||||
|
data = await memory_dashboard_service.generate_rag_profile(
|
||||||
|
end_user_id=body.end_user_id,
|
||||||
|
limit=body.limit,
|
||||||
|
max_tags=body.max_tags,
|
||||||
|
db=db,
|
||||||
|
current_user=current_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"RAG画像生产完成: {data}")
|
||||||
|
return success(data=data, msg="RAG画像生产完成")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/dashboard_data", response_model=ApiResponse)
|
@router.get("/dashboard_data", response_model=ApiResponse)
|
||||||
@@ -553,9 +602,12 @@ async def dashboard_data(
|
|||||||
)
|
)
|
||||||
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
|
||||||
# total_app: 统计当前空间下的所有app数量
|
# total_app: 统计当前空间下的所有app数量
|
||||||
from app.repositories import app_repository
|
# 包含自有app + 被分享给本工作空间的app
|
||||||
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
|
from app.services import app_service as _app_svc
|
||||||
neo4j_data["total_app"] = len(apps_orm)
|
_, total_app = _app_svc.AppService(db).list_apps(
|
||||||
|
workspace_id=workspace_id, include_shared=True, pagesize=1
|
||||||
|
)
|
||||||
|
neo4j_data["total_app"] = total_app
|
||||||
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
api_logger.warning(f"获取记忆总量失败: {str(e)}")
|
||||||
|
|||||||
@@ -1,3 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Memory Reflection Controller
|
||||||
|
|
||||||
|
This module provides REST API endpoints for managing memory reflection configurations
|
||||||
|
and operations. It handles reflection engine setup, configuration management, and
|
||||||
|
execution of self-reflection processes across memory systems.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Reflection configuration management (save, retrieve, update)
|
||||||
|
- Workspace-wide reflection execution across multiple applications
|
||||||
|
- Individual configuration-based reflection runs
|
||||||
|
- Multi-language support for reflection outputs
|
||||||
|
- Integration with Neo4j memory storage and LLM models
|
||||||
|
- Comprehensive error handling and logging
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -28,9 +44,13 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
|
# Load environment variables for configuration
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize API logger for request tracking and debugging
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Configure router with prefix and tags for API organization
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory",
|
prefix="/memory",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -43,7 +63,38 @@ async def save_reflection_config(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Save reflection configuration to data_comfig table"""
|
"""
|
||||||
|
Save reflection configuration to memory config table
|
||||||
|
|
||||||
|
Persists reflection engine configuration settings to the data_config table,
|
||||||
|
including reflection parameters, model settings, and evaluation criteria.
|
||||||
|
Validates configuration parameters and ensures data consistency.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Memory reflection configuration data including:
|
||||||
|
- config_id: Configuration identifier to update
|
||||||
|
- reflection_enabled: Whether reflection is enabled
|
||||||
|
- reflection_period_in_hours: Reflection execution interval
|
||||||
|
- reflexion_range: Scope of reflection (partial/all)
|
||||||
|
- baseline: Reflection strategy (time/fact/hybrid)
|
||||||
|
- reflection_model_id: LLM model for reflection operations
|
||||||
|
- memory_verify: Enable memory verification checks
|
||||||
|
- quality_assessment: Enable quality assessment evaluation
|
||||||
|
current_user: Authenticated user saving the configuration
|
||||||
|
db: Database session for data operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with saved reflection configuration data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 400: If config_id is missing or parameters are invalid
|
||||||
|
HTTPException 500: If configuration save operation fails
|
||||||
|
|
||||||
|
Database Operations:
|
||||||
|
- Updates memory_config table with reflection settings
|
||||||
|
- Commits transaction and refreshes entity
|
||||||
|
- Maintains configuration consistency
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
config_id = request.config_id
|
config_id = request.config_id
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
@@ -54,6 +105,7 @@ async def save_reflection_config(
|
|||||||
)
|
)
|
||||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
|
# Update reflection configuration in database
|
||||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||||
db,
|
db,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
@@ -66,6 +118,7 @@ async def save_reflection_config(
|
|||||||
quality_assessment=request.quality_assessment
|
quality_assessment=request.quality_assessment
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Commit transaction and refresh entity
|
||||||
db.commit()
|
db.commit()
|
||||||
db.refresh(memory_config)
|
db.refresh(memory_config)
|
||||||
|
|
||||||
@@ -102,13 +155,55 @@ async def start_workspace_reflection(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""启动工作空间中所有匹配应用的反思功能"""
|
"""
|
||||||
|
Start reflection functionality for all matching applications in workspace
|
||||||
|
|
||||||
|
Initiates reflection processes across all applications within the user's current
|
||||||
|
workspace that have valid memory configurations. Processes each application's
|
||||||
|
configurations and associated end users, executing reflection operations
|
||||||
|
with proper error isolation and transaction management.
|
||||||
|
|
||||||
|
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
|
||||||
|
that reflection failures for individual users don't affect other operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
current_user: Authenticated user initiating workspace reflection
|
||||||
|
db: Database session for configuration queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with reflection results for all processed applications:
|
||||||
|
- app_id: Application identifier
|
||||||
|
- config_id: Memory configuration identifier
|
||||||
|
- end_user_id: End user identifier
|
||||||
|
- reflection_result: Individual reflection operation result
|
||||||
|
|
||||||
|
Processing Logic:
|
||||||
|
1. Retrieve all applications in the current workspace
|
||||||
|
2. Filter applications with valid memory configurations
|
||||||
|
3. For each configuration, find matching releases
|
||||||
|
4. Execute reflection for each end user with isolated transactions
|
||||||
|
5. Aggregate results with error handling per user
|
||||||
|
|
||||||
|
Error Handling:
|
||||||
|
- Individual user reflection failures are isolated
|
||||||
|
- Failed operations are logged and included in results
|
||||||
|
- Database transactions are isolated per user to prevent cascading failures
|
||||||
|
- Comprehensive error reporting for debugging
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 500: If workspace reflection initialization fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Uses independent database sessions for each user operation
|
||||||
|
- Prevents transaction failures from affecting other users
|
||||||
|
- Comprehensive logging for operation tracking
|
||||||
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
try:
|
try:
|
||||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||||
|
|
||||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
# Use independent database session to get workspace app details, avoiding transaction failures
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
with get_db_context() as query_db:
|
with get_db_context() as query_db:
|
||||||
service = WorkspaceAppService(query_db)
|
service = WorkspaceAppService(query_db)
|
||||||
@@ -116,8 +211,9 @@ async def start_workspace_reflection(
|
|||||||
|
|
||||||
reflection_results = []
|
reflection_results = []
|
||||||
|
|
||||||
|
# Process each application in the workspace
|
||||||
for data in result['apps_detailed_info']:
|
for data in result['apps_detailed_info']:
|
||||||
# 跳过没有配置的应用
|
# Skip applications without configurations
|
||||||
if not data['memory_configs']:
|
if not data['memory_configs']:
|
||||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||||
continue
|
continue
|
||||||
@@ -126,22 +222,22 @@ async def start_workspace_reflection(
|
|||||||
memory_configs = data['memory_configs']
|
memory_configs = data['memory_configs']
|
||||||
end_users = data['end_users']
|
end_users = data['end_users']
|
||||||
|
|
||||||
# 为每个配置和用户组合执行反思
|
# Execute reflection for each configuration and user combination
|
||||||
for config in memory_configs:
|
for config in memory_configs:
|
||||||
config_id_str = str(config['config_id'])
|
config_id_str = str(config['config_id'])
|
||||||
|
|
||||||
# 找到匹配此配置的所有release
|
# Find all releases matching this configuration
|
||||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||||
|
|
||||||
if not matching_releases:
|
if not matching_releases:
|
||||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
# Execute reflection for each user - using independent database sessions
|
||||||
for user in end_users:
|
for user in end_users:
|
||||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||||
|
|
||||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
# Create independent database session for each user to avoid transaction failure impact
|
||||||
with get_db_context() as user_db:
|
with get_db_context() as user_db:
|
||||||
try:
|
try:
|
||||||
reflection_service = MemoryReflectionService(user_db)
|
reflection_service = MemoryReflectionService(user_db)
|
||||||
@@ -184,14 +280,51 @@ async def start_reflection_configs(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""通过config_id查询memory_config表中的反思配置信息"""
|
"""
|
||||||
|
Query reflection configuration information by config_id
|
||||||
|
|
||||||
|
Retrieves detailed reflection configuration settings from the memory_config
|
||||||
|
table for a specific configuration ID. Provides comprehensive reflection
|
||||||
|
parameters including model settings, evaluation criteria, and operational flags.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: Configuration identifier (UUID or integer) to query
|
||||||
|
current_user: Authenticated user making the request
|
||||||
|
db: Database session for data operations
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with detailed reflection configuration:
|
||||||
|
- config_id: Resolved configuration identifier
|
||||||
|
- reflection_enabled: Whether reflection is enabled for this config
|
||||||
|
- reflection_period_in_hours: Reflection execution interval
|
||||||
|
- reflexion_range: Scope of reflection operations (partial/all)
|
||||||
|
- baseline: Reflection strategy (time/fact/hybrid)
|
||||||
|
- reflection_model_id: LLM model identifier for reflection
|
||||||
|
- memory_verify: Memory verification flag
|
||||||
|
- quality_assessment: Quality assessment flag
|
||||||
|
|
||||||
|
Database Operations:
|
||||||
|
- Queries memory_config table by resolved config_id
|
||||||
|
- Retrieves all reflection-related configuration fields
|
||||||
|
- Resolves configuration ID for consistent formatting
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 404: If configuration with specified ID is not found
|
||||||
|
HTTPException 500: If configuration query operation fails
|
||||||
|
|
||||||
|
ID Resolution:
|
||||||
|
- Supports both UUID and integer config_id formats
|
||||||
|
- Automatically resolves to appropriate internal format
|
||||||
|
- Maintains consistency across different ID representations
|
||||||
|
"""
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
try:
|
try:
|
||||||
config_id=resolve_config_id(config_id,db)
|
config_id=resolve_config_id(config_id,db)
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
memory_config_id = resolve_config_id(result.config_id, db)
|
memory_config_id = resolve_config_id(result.config_id, db)
|
||||||
# 构建返回数据
|
|
||||||
|
# Build response data with comprehensive configuration details
|
||||||
reflection_config = {
|
reflection_config = {
|
||||||
"config_id": memory_config_id,
|
"config_id": memory_config_id,
|
||||||
"reflection_enabled": result.enable_self_reflexion,
|
"reflection_enabled": result.enable_self_reflexion,
|
||||||
@@ -205,9 +338,11 @@ async def start_reflection_configs(
|
|||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
return success(data=reflection_config, msg="反思配置查询成功")
|
return success(data=reflection_config, msg="反思配置查询成功")
|
||||||
|
|
||||||
|
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
|
||||||
|
return success(data=reflection_config, msg="Reflection configuration query successful")
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
# 重新抛出HTTP异常
|
# Re-raise HTTP exceptions without modification
|
||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"查询反思配置失败: {str(e)}")
|
api_logger.error(f"查询反思配置失败: {str(e)}")
|
||||||
@@ -223,13 +358,66 @@ async def reflection_run(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Activate the reflection function for all matching applications in the workspace"""
|
"""
|
||||||
# 使用集中化的语言校验
|
Execute reflection engine with specified configuration
|
||||||
|
|
||||||
|
Runs the reflection engine using configuration parameters from the database.
|
||||||
|
Validates model availability, sets up the reflection engine with proper
|
||||||
|
configuration, and executes the reflection process with multi-language support.
|
||||||
|
|
||||||
|
This endpoint provides a test run capability for reflection configurations,
|
||||||
|
allowing users to validate their reflection settings and see results before
|
||||||
|
deploying to production environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_id: Configuration identifier (UUID or integer) for reflection settings
|
||||||
|
language_type: Language preference header for output localization (optional)
|
||||||
|
current_user: Authenticated user executing the reflection
|
||||||
|
db: Database session for configuration queries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response with reflection execution results including:
|
||||||
|
- baseline: Reflection strategy used
|
||||||
|
- source_data: Input data processed
|
||||||
|
- memory_verifies: Memory verification results (if enabled)
|
||||||
|
- quality_assessments: Quality assessment results (if enabled)
|
||||||
|
- reflexion_data: Generated reflection insights and solutions
|
||||||
|
|
||||||
|
Configuration Validation:
|
||||||
|
- Verifies configuration exists in database
|
||||||
|
- Validates LLM model availability
|
||||||
|
- Falls back to default model if specified model is unavailable
|
||||||
|
- Ensures all required parameters are properly set
|
||||||
|
|
||||||
|
Reflection Engine Setup:
|
||||||
|
- Creates ReflectionConfig with database parameters
|
||||||
|
- Initializes Neo4j connector for memory access
|
||||||
|
- Sets up ReflectionEngine with validated model
|
||||||
|
- Configures language preferences for output
|
||||||
|
|
||||||
|
Error Handling:
|
||||||
|
- Model validation with fallback to default
|
||||||
|
- Configuration validation and error reporting
|
||||||
|
- Comprehensive logging for debugging
|
||||||
|
- Graceful handling of missing configurations
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 404: If configuration is not found
|
||||||
|
HTTPException 500: If reflection execution fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Direct database query for configuration retrieval
|
||||||
|
- Model validation to prevent runtime failures
|
||||||
|
- Efficient reflection engine initialization
|
||||||
|
- Language-aware output processing
|
||||||
|
"""
|
||||||
|
# Use centralized language validation for consistent localization
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||||
config_id = resolve_config_id(config_id, db)
|
config_id = resolve_config_id(config_id, db)
|
||||||
# 使用MemoryConfigRepository查询反思配置
|
|
||||||
|
# Query reflection configuration using MemoryConfigRepository
|
||||||
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||||
if not result:
|
if not result:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -239,7 +427,7 @@ async def reflection_run(
|
|||||||
|
|
||||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||||
|
|
||||||
# 验证模型ID是否存在
|
# Validate model ID existence
|
||||||
model_id = result.reflection_model_id
|
model_id = result.reflection_model_id
|
||||||
if model_id:
|
if model_id:
|
||||||
try:
|
try:
|
||||||
@@ -250,6 +438,7 @@ async def reflection_run(
|
|||||||
# 可以设置为None,让反思引擎使用默认模型
|
# 可以设置为None,让反思引擎使用默认模型
|
||||||
model_id = None
|
model_id = None
|
||||||
|
|
||||||
|
# Create reflection configuration with database parameters
|
||||||
config = ReflectionConfig(
|
config = ReflectionConfig(
|
||||||
enabled=result.enable_self_reflexion,
|
enabled=result.enable_self_reflexion,
|
||||||
iteration_period=result.iteration_period,
|
iteration_period=result.iteration_period,
|
||||||
@@ -262,11 +451,13 @@ async def reflection_run(
|
|||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
language_type=language_type
|
language_type=language_type
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Initialize Neo4j connector and reflection engine
|
||||||
connector = Neo4jConnector()
|
connector = Neo4jConnector()
|
||||||
engine = ReflectionEngine(
|
engine = ReflectionEngine(
|
||||||
config=config,
|
config=config,
|
||||||
neo4j_connector=connector,
|
neo4j_connector=connector,
|
||||||
llm_client=model_id # 传入验证后的 model_id
|
llm_client=model_id # Pass validated model_id
|
||||||
)
|
)
|
||||||
|
|
||||||
result=await (engine.reflection_run())
|
result=await (engine.reflection_run())
|
||||||
|
|||||||
@@ -1,3 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Memory Short Term Controller
|
||||||
|
|
||||||
|
This module provides REST API endpoints for managing short-term and long-term memory
|
||||||
|
data retrieval and analysis. It handles memory system statistics, data aggregation,
|
||||||
|
and provides comprehensive memory insights for end users.
|
||||||
|
|
||||||
|
Key Features:
|
||||||
|
- Short-term memory data retrieval and statistics
|
||||||
|
- Long-term memory data aggregation
|
||||||
|
- Entity count integration
|
||||||
|
- Multi-language response support
|
||||||
|
- Memory system analytics and reporting
|
||||||
|
"""
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@@ -13,9 +28,13 @@ from app.models.user_model import User
|
|||||||
from app.services.memory_short_service import LongService, ShortService
|
from app.services.memory_short_service import LongService, ShortService
|
||||||
from app.services.memory_storage_service import search_entity
|
from app.services.memory_storage_service import search_entity
|
||||||
|
|
||||||
|
# Load environment variables for configuration
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
# Initialize API logger for request tracking and debugging
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
|
|
||||||
|
# Configure router with prefix and tags for API organization
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/memory/short",
|
prefix="/memory/short",
|
||||||
tags=["Memory"],
|
tags=["Memory"],
|
||||||
@@ -27,24 +46,73 @@ async def short_term_configs(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
# 使用集中化的语言校验
|
"""
|
||||||
|
Retrieve comprehensive short-term and long-term memory statistics
|
||||||
|
|
||||||
|
Provides a comprehensive overview of memory system data for a specific end user,
|
||||||
|
including short-term memory entries, long-term memory aggregations, entity counts,
|
||||||
|
and retrieval statistics. Supports multi-language responses based on request headers.
|
||||||
|
|
||||||
|
This endpoint serves as a central dashboard for memory system analytics, combining
|
||||||
|
data from multiple memory subsystems to provide a holistic view of user memory state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: Unique identifier for the end user whose memory data to retrieve
|
||||||
|
language_type: Language preference header for response localization (optional)
|
||||||
|
current_user: Authenticated user making the request (injected by dependency)
|
||||||
|
db: Database session for data operations (injected by dependency)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Success response containing comprehensive memory statistics:
|
||||||
|
- short_term: List of short-term memory entries with detailed data
|
||||||
|
- long_term: List of long-term memory aggregations and summaries
|
||||||
|
- entity: Count of entities associated with the end user
|
||||||
|
- retrieval_number: Total count of short-term memory retrievals
|
||||||
|
- long_term_number: Total count of long-term memory entries
|
||||||
|
|
||||||
|
Response Structure:
|
||||||
|
{
|
||||||
|
"code": 200,
|
||||||
|
"msg": "Short-term memory system data retrieved successfully",
|
||||||
|
"data": {
|
||||||
|
"short_term": [...], # Short-term memory entries
|
||||||
|
"long_term": [...], # Long-term memory data
|
||||||
|
"entity": 42, # Entity count
|
||||||
|
"retrieval_number": 156, # Short-term retrieval count
|
||||||
|
"long_term_number": 23 # Long-term memory count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If end_user_id is invalid or data retrieval fails
|
||||||
|
|
||||||
|
Performance Notes:
|
||||||
|
- Combines multiple service calls for comprehensive data
|
||||||
|
- Entity search is performed asynchronously for better performance
|
||||||
|
- Response time depends on memory data volume for the specified user
|
||||||
|
"""
|
||||||
|
# Use centralized language validation for consistent localization
|
||||||
language = get_language_from_header(language_type)
|
language = get_language_from_header(language_type)
|
||||||
|
|
||||||
# 获取短期记忆数据
|
# Retrieve short-term memory data and statistics
|
||||||
short_term=ShortService(end_user_id, db)
|
short_term = ShortService(end_user_id, db)
|
||||||
short_result=short_term.get_short_databasets()
|
short_result = short_term.get_short_databasets() # Get short-term memory entries
|
||||||
short_count=short_term.get_short_count()
|
short_count = short_term.get_short_count() # Get short-term retrieval count
|
||||||
|
|
||||||
long_term=LongService(end_user_id, db)
|
# Retrieve long-term memory data and aggregations
|
||||||
long_result=long_term.get_long_databasets()
|
long_term = LongService(end_user_id, db)
|
||||||
|
long_result = long_term.get_long_databasets() # Get long-term memory entries
|
||||||
|
|
||||||
|
# Get entity count for the specified end user
|
||||||
entity_result = await search_entity(end_user_id)
|
entity_result = await search_entity(end_user_id)
|
||||||
|
|
||||||
|
# Compile comprehensive memory statistics response
|
||||||
result = {
|
result = {
|
||||||
'short_term': short_result,
|
'short_term': short_result, # Short-term memory entries
|
||||||
'long_term': long_result,
|
'long_term': long_result, # Long-term memory data
|
||||||
'entity': entity_result.get('num', 0),
|
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
|
||||||
"retrieval_number":short_count,
|
"retrieval_number": short_count, # Short-term retrieval statistics
|
||||||
"long_term_number":len(long_result)
|
"long_term_number": len(long_result) # Long-term memory entry count
|
||||||
}
|
}
|
||||||
|
|
||||||
return success(data=result, msg="短期记忆系统数据获取成功")
|
return success(data=result, msg="短期记忆系统数据获取成功")
|
||||||
@@ -8,6 +8,7 @@ from app.core.response_utils import success
|
|||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
from app.dependencies import get_current_user
|
from app.dependencies import get_current_user
|
||||||
from app.models import User
|
from app.models import User
|
||||||
|
from app.schemas import conversation_schema
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
|
|
||||||
@@ -32,35 +33,47 @@ def get_memory_count(
|
|||||||
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
|
||||||
def get_conversations(
|
def get_conversations(
|
||||||
end_user_id: uuid.UUID,
|
end_user_id: uuid.UUID,
|
||||||
|
page: int = 1,
|
||||||
|
pagesize: int = 20,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Retrieve all conversations for the current user in a specific group.
|
Retrieve conversations for the current user in a specific group with pagination.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id (UUID): The group identifier.
|
end_user_id (UUID): The group identifier.
|
||||||
|
page (int): Page number (1-based). Defaults to 1.
|
||||||
|
pagesize (int): Number of items per page. Defaults to 20.
|
||||||
current_user (User, optional): The authenticated user.
|
current_user (User, optional): The authenticated user.
|
||||||
db (Session, optional): SQLAlchemy session.
|
db (Session, optional): SQLAlchemy session.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ApiResponse: Contains a list of conversation IDs.
|
ApiResponse: Contains a paginated list of conversations.
|
||||||
|
|
||||||
Notes:
|
|
||||||
- Initializes the ConversationService with the current DB session.
|
|
||||||
- Returns only conversation IDs for lightweight response.
|
|
||||||
- Logs can be added to trace requests in production.
|
|
||||||
"""
|
"""
|
||||||
|
page = max(1, page)
|
||||||
|
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
|
||||||
conversation_service = ConversationService(db)
|
conversation_service = ConversationService(db)
|
||||||
conversations = conversation_service.get_user_conversations(
|
conversations, total = conversation_service.get_user_conversations(
|
||||||
end_user_id
|
end_user_id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size
|
||||||
)
|
)
|
||||||
return success(data=[
|
return success(data={
|
||||||
{
|
"items": [
|
||||||
"id": conversation.id,
|
{
|
||||||
"title": conversation.title
|
"id": conversation.id,
|
||||||
} for conversation in conversations
|
"title": conversation.title
|
||||||
], msg="get conversations success")
|
} for conversation in conversations
|
||||||
|
],
|
||||||
|
"total": total,
|
||||||
|
"page": {
|
||||||
|
"page": page,
|
||||||
|
"pagesize": page_size,
|
||||||
|
"total": total,
|
||||||
|
"hasnext": (page * page_size) < total
|
||||||
|
},
|
||||||
|
}, msg="get conversations success")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
|
||||||
@@ -90,11 +103,7 @@ def get_messages(
|
|||||||
conversation_id,
|
conversation_id,
|
||||||
)
|
)
|
||||||
messages = [
|
messages = [
|
||||||
{
|
conversation_schema.Message.model_validate(message)
|
||||||
"role": message.role,
|
|
||||||
"content": message.content,
|
|
||||||
"created_at": int(message.created_at.timestamp() * 1000),
|
|
||||||
}
|
|
||||||
for message in messages_obj
|
for message in messages_obj
|
||||||
]
|
]
|
||||||
return success(data=messages, msg="get conversation history success")
|
return success(data=messages, msg="get conversation history success")
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from app.core.logging_config import get_business_logger
|
|||||||
from app.core.response_utils import success, fail
|
from app.core.response_utils import success, fail
|
||||||
from app.db import get_db, get_db_read
|
from app.db import get_db, get_db_read
|
||||||
from app.dependencies import get_share_user_id, ShareTokenData
|
from app.dependencies import get_share_user_id, ShareTokenData
|
||||||
from app.models.app_model import App
|
|
||||||
from app.models.app_model import AppType
|
from app.models.app_model import AppType
|
||||||
from app.repositories import knowledge_repository
|
from app.repositories import knowledge_repository
|
||||||
from app.repositories.end_user_repository import EndUserRepository
|
from app.repositories.end_user_repository import EndUserRepository
|
||||||
@@ -22,6 +21,7 @@ from app.schemas import release_share_schema, conversation_schema
|
|||||||
from app.schemas.response_schema import PageData, PageMeta
|
from app.schemas.response_schema import PageData, PageMeta
|
||||||
from app.services import workspace_service
|
from app.services import workspace_service
|
||||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||||
|
from app.services.app_service import AppService
|
||||||
from app.services.auth_service import create_access_token
|
from app.services.auth_service import create_access_token
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.release_share_service import ReleaseShareService
|
from app.services.release_share_service import ReleaseShareService
|
||||||
@@ -215,8 +215,11 @@ def list_conversations(
|
|||||||
service = SharedChatService(db)
|
service = SharedChatService(db)
|
||||||
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
share, release = service.get_release_by_share_token(share_data.share_token, password)
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=app.workspace_id,
|
||||||
other_id=other_id
|
other_id=other_id
|
||||||
)
|
)
|
||||||
logger.debug(new_end_user.id)
|
logger.debug(new_end_user.id)
|
||||||
@@ -308,25 +311,29 @@ async def chat(
|
|||||||
|
|
||||||
# Store end_user_id in database with original user_id
|
# Store end_user_id in database with original user_id
|
||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
|
app_service = AppService(db)
|
||||||
|
app = app_service._get_app_or_404(share.app_id)
|
||||||
|
workspace_id = app.workspace_id
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=share.app_id,
|
app_id=share.app_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=user_id # Save original user_id to other_id
|
original_user_id=user_id
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
|
|
||||||
appid = share.app_id
|
# appid = share.app_id
|
||||||
"""获取存储类型和工作空间的ID"""
|
"""获取存储类型和工作空间的ID"""
|
||||||
|
|
||||||
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
|
||||||
app = db.query(App).filter(
|
# app = db.query(App).filter(
|
||||||
App.id == appid,
|
# App.id == appid,
|
||||||
App.is_active.is_(True)
|
# App.is_active.is_(True)
|
||||||
).first()
|
# ).first()
|
||||||
if not app:
|
# if not app:
|
||||||
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
|
||||||
|
|
||||||
workspace_id = app.workspace_id
|
# workspace_id = app.workspace_id
|
||||||
|
|
||||||
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
# 直接从 workspace 获取 storage_type(公开分享场景无需权限检查)
|
||||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||||
@@ -610,11 +617,11 @@ async def chat(
|
|||||||
|
|
||||||
# 多 Agent 非流式返回
|
# 多 Agent 非流式返回
|
||||||
result = await app_chat_service.workflow_chat(
|
result = await app_chat_service.workflow_chat(
|
||||||
|
|
||||||
message=payload.message,
|
message=payload.message,
|
||||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||||
user_id=end_user_id, # 转换为字符串
|
user_id=end_user_id, # 转换为字符串
|
||||||
variables=payload.variables,
|
variables=payload.variables,
|
||||||
|
files=payload.files,
|
||||||
config=config,
|
config=config,
|
||||||
web_search=payload.web_search,
|
web_search=payload.web_search,
|
||||||
memory=payload.memory,
|
memory=payload.memory,
|
||||||
@@ -654,17 +661,21 @@ async def config_query(
|
|||||||
workflow_service = WorkflowService(db)
|
workflow_service = WorkflowService(db)
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": workflow_service.get_start_node_variables(release.config)
|
"variables": workflow_service.get_start_node_variables(release.config),
|
||||||
|
"memory": workflow_service.is_memory_enable(release.config),
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.AGENT:
|
elif release.app.type == AppType.AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": release.config.get("variables")
|
"variables": release.config.get("variables"),
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
elif release.app.type == AppType.MULTI_AGENT:
|
elif release.app.type == AppType.MULTI_AGENT:
|
||||||
content = {
|
content = {
|
||||||
"app_type": release.app.type,
|
"app_type": release.app.type,
|
||||||
"variables": []
|
"variables": [],
|
||||||
|
"features": release.config.get("features")
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||||
|
|||||||
@@ -95,8 +95,8 @@ async def chat(
|
|||||||
end_user_repo = EndUserRepository(db)
|
end_user_repo = EndUserRepository(db)
|
||||||
new_end_user = end_user_repo.get_or_create_end_user(
|
new_end_user = end_user_repo.get_or_create_end_user(
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
other_id=other_id,
|
other_id=other_id,
|
||||||
original_user_id=other_id # Save original user_id to other_id
|
|
||||||
)
|
)
|
||||||
end_user_id = str(new_end_user.id)
|
end_user_id = str(new_end_user.id)
|
||||||
web_search = True
|
web_search = True
|
||||||
@@ -280,6 +280,7 @@ async def chat(
|
|||||||
memory=memory,
|
memory=memory,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
user_rag_memory_id=user_rag_memory_id,
|
user_rag_memory_id=user_rag_memory_id,
|
||||||
|
files=payload.files,
|
||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=app.current_release.id
|
release_id=app.current_release.id
|
||||||
|
|||||||
@@ -3,8 +3,11 @@ from typing import Optional
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.core.error_codes import BizCode
|
||||||
from app.schemas.tool_schema import (
|
from app.schemas.tool_schema import (
|
||||||
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
|
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
|
||||||
|
CustomToolTestRequest, ToolActiveUpdate
|
||||||
)
|
)
|
||||||
|
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
@@ -14,6 +17,7 @@ from app.models import User
|
|||||||
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
from app.models.tool_model import ToolType, ToolStatus, AuthType
|
||||||
from app.services.tool_service import ToolService
|
from app.services.tool_service import ToolService
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
|
from app.core.exceptions import BusinessException
|
||||||
|
|
||||||
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
router = APIRouter(prefix="/tools", tags=["Tool System"])
|
||||||
|
|
||||||
@@ -72,6 +76,8 @@ async def get_tool_methods(
|
|||||||
if methods is None:
|
if methods is None:
|
||||||
raise HTTPException(status_code=404, detail="工具不存在")
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
return success(data=methods, msg="获取工具方法成功")
|
return success(data=methods, msg="获取工具方法成功")
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -97,7 +103,13 @@ async def create_tool(
|
|||||||
):
|
):
|
||||||
"""创建工具"""
|
"""创建工具"""
|
||||||
try:
|
try:
|
||||||
tool_id = service.create_tool(
|
# 将 MCP 来源字段合并进 config
|
||||||
|
if request.tool_type == ToolType.MCP:
|
||||||
|
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
|
||||||
|
val = getattr(request, key, None)
|
||||||
|
if val is not None:
|
||||||
|
request.config[key] = val
|
||||||
|
tool_id = await service.create_tool(
|
||||||
name=request.name,
|
name=request.name,
|
||||||
tool_type=request.tool_type,
|
tool_type=request.tool_type,
|
||||||
tenant_id=current_user.tenant_id,
|
tenant_id=current_user.tenant_id,
|
||||||
@@ -107,8 +119,12 @@ async def create_tool(
|
|||||||
tags=request.tags
|
tags=request.tags
|
||||||
)
|
)
|
||||||
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
return success(data={"tool_id": tool_id}, msg="工具创建成功")
|
||||||
|
except BusinessException as e:
|
||||||
|
raise HTTPException(status_code=400, detail=e.message)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -137,6 +153,8 @@ async def update_tool(
|
|||||||
return success(msg="工具更新成功")
|
return success(msg="工具更新成功")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -147,7 +165,7 @@ async def delete_tool(
|
|||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
service: ToolService = Depends(get_tool_service)
|
service: ToolService = Depends(get_tool_service)
|
||||||
):
|
):
|
||||||
"""删除工具"""
|
"""删除工具(逻辑删除,is_active=False)"""
|
||||||
try:
|
try:
|
||||||
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
|
||||||
if not success_flag:
|
if not success_flag:
|
||||||
@@ -159,6 +177,32 @@ async def delete_tool(
|
|||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/{tool_id}/active", response_model=ApiResponse)
|
||||||
|
async def set_tool_active(
|
||||||
|
tool_id: str,
|
||||||
|
request: ToolActiveUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
service: ToolService = Depends(get_tool_service)
|
||||||
|
):
|
||||||
|
"""设置工具可用状态(启用/禁用)
|
||||||
|
|
||||||
|
- is_active=true: 启用工具
|
||||||
|
- is_active=false: 禁用工具(等同于删除,但可恢复)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
|
||||||
|
if not success_flag:
|
||||||
|
raise HTTPException(status_code=404, detail="工具不存在")
|
||||||
|
action = "启用" if request.is_active else "禁用"
|
||||||
|
return success(msg=f"工具已{action}")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/execution/execute", response_model=ApiResponse)
|
@router.post("/execution/execute", response_model=ApiResponse)
|
||||||
async def execute_tool(
|
async def execute_tool(
|
||||||
request: ToolExecuteRequest,
|
request: ToolExecuteRequest,
|
||||||
@@ -187,6 +231,8 @@ async def execute_tool(
|
|||||||
},
|
},
|
||||||
msg="工具执行完成"
|
msg="工具执行完成"
|
||||||
)
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -216,8 +262,10 @@ async def sync_mcp_tools(
|
|||||||
try:
|
try:
|
||||||
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
|
||||||
if not result.get("success", False):
|
if not result.get("success", False):
|
||||||
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
|
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
|
||||||
return success(data=result, msg="MCP工具列表同步完成")
|
return success(data=result, msg="MCP工具列表同步完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@@ -240,8 +288,10 @@ async def test_tool_connection(
|
|||||||
# 普通连接测试
|
# 普通连接测试
|
||||||
result = await service.test_connection(tool_id, current_user.tenant_id)
|
result = await service.test_connection(tool_id, current_user.tenant_id)
|
||||||
if result["success"] is False:
|
if result["success"] is False:
|
||||||
raise HTTPException(status_code=400, detail=result["message"])
|
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
|
||||||
return success(data=result, msg="连接测试完成")
|
return success(data=result, msg="连接测试完成")
|
||||||
|
except BusinessException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
import uuid
|
import uuid
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
@@ -19,6 +20,7 @@ from app.services import user_service
|
|||||||
from app.core.logging_config import get_api_logger
|
from app.core.logging_config import get_api_logger
|
||||||
from app.core.response_utils import success
|
from app.core.response_utils import success
|
||||||
from app.core.security import verify_password
|
from app.core.security import verify_password
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
|
||||||
# 获取API专用日志器
|
# 获取API专用日志器
|
||||||
api_logger = get_api_logger()
|
api_logger = get_api_logger()
|
||||||
@@ -33,7 +35,8 @@ router = APIRouter(
|
|||||||
def create_superuser(
|
def create_superuser(
|
||||||
user: user_schema.UserCreate,
|
user: user_schema.UserCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_superuser: User = Depends(get_current_superuser)
|
current_superuser: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建超级管理员(仅超级管理员可访问)"""
|
"""创建超级管理员(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
|
||||||
@@ -42,7 +45,7 @@ def create_superuser(
|
|||||||
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="超级管理员创建成功")
|
return success(data=result_schema, msg=t("users.create.superuser_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{user_id}", response_model=ApiResponse)
|
@router.delete("/{user_id}", response_model=ApiResponse)
|
||||||
@@ -50,6 +53,7 @@ def delete_user(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""停用用户(软删除)"""
|
"""停用用户(软删除)"""
|
||||||
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -57,13 +61,14 @@ def delete_user(
|
|||||||
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
db=db, user_id_to_deactivate=user_id, current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
|
||||||
return success(msg="用户停用成功")
|
return success(msg=t("users.delete.deactivate_success"))
|
||||||
|
|
||||||
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
@router.post("/{user_id}/activate", response_model=ApiResponse)
|
||||||
def activate_user(
|
def activate_user(
|
||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""激活用户"""
|
"""激活用户"""
|
||||||
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -74,13 +79,14 @@ def activate_user(
|
|||||||
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="用户激活成功")
|
return success(data=result_schema, msg=t("users.activate.success"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=ApiResponse)
|
@router.get("", response_model=ApiResponse)
|
||||||
def get_current_user_info(
|
def get_current_user_info(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前用户信息"""
|
"""获取当前用户信息"""
|
||||||
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
api_logger.info(f"当前用户信息请求: {current_user.username}")
|
||||||
@@ -105,7 +111,7 @@ def get_current_user_info(
|
|||||||
break
|
break
|
||||||
|
|
||||||
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
|
||||||
return success(data=result_schema, msg="用户信息获取成功")
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/superusers", response_model=ApiResponse)
|
@router.get("/superusers", response_model=ApiResponse)
|
||||||
@@ -113,6 +119,7 @@ def get_tenant_superusers(
|
|||||||
include_inactive: bool = False,
|
include_inactive: bool = False,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
|
||||||
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
|
||||||
@@ -125,7 +132,7 @@ def get_tenant_superusers(
|
|||||||
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
|
||||||
|
|
||||||
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
|
||||||
return success(data=superusers_schema, msg="租户超管列表获取成功")
|
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -134,6 +141,7 @@ def get_user_info_by_id(
|
|||||||
user_id: uuid.UUID,
|
user_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""根据用户ID获取用户信息"""
|
"""根据用户ID获取用户信息"""
|
||||||
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
|
||||||
@@ -144,7 +152,7 @@ def get_user_info_by_id(
|
|||||||
api_logger.info(f"用户信息获取成功: {result.username}")
|
api_logger.info(f"用户信息获取成功: {result.username}")
|
||||||
|
|
||||||
result_schema = user_schema.User.model_validate(result)
|
result_schema = user_schema.User.model_validate(result)
|
||||||
return success(data=result_schema, msg="用户信息获取成功")
|
return success(data=result_schema, msg=t("users.info.get_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-password", response_model=ApiResponse)
|
@router.put("/change-password", response_model=ApiResponse)
|
||||||
@@ -152,6 +160,7 @@ async def change_password(
|
|||||||
request: ChangePasswordRequest,
|
request: ChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""修改当前用户密码"""
|
"""修改当前用户密码"""
|
||||||
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
api_logger.info(f"用户密码修改请求: {current_user.username}")
|
||||||
@@ -164,7 +173,7 @@ async def change_password(
|
|||||||
current_user=current_user
|
current_user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
api_logger.info(f"用户密码修改成功: {current_user.username}")
|
||||||
return success(msg="密码修改成功")
|
return success(msg=t("auth.password.change_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/admin/change-password", response_model=ApiResponse)
|
@router.put("/admin/change-password", response_model=ApiResponse)
|
||||||
@@ -172,6 +181,7 @@ async def admin_change_password(
|
|||||||
request: AdminChangePasswordRequest,
|
request: AdminChangePasswordRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""超级管理员修改指定用户的密码"""
|
"""超级管理员修改指定用户的密码"""
|
||||||
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
|
||||||
@@ -186,16 +196,17 @@ async def admin_change_password(
|
|||||||
# 根据是否生成了随机密码来构造响应
|
# 根据是否生成了随机密码来构造响应
|
||||||
if request.new_password:
|
if request.new_password:
|
||||||
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
|
||||||
return success(msg="密码修改成功")
|
return success(msg=t("auth.password.change_success"))
|
||||||
else:
|
else:
|
||||||
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
|
||||||
return success(data=generated_password, msg="密码重置成功")
|
return success(data=generated_password, msg=t("auth.password.reset_success"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/verify_pwd", response_model=ApiResponse)
|
@router.post("/verify_pwd", response_model=ApiResponse)
|
||||||
def verify_pwd(
|
def verify_pwd(
|
||||||
request: VerifyPasswordRequest,
|
request: VerifyPasswordRequest,
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""验证当前用户密码"""
|
"""验证当前用户密码"""
|
||||||
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
api_logger.info(f"用户验证密码请求: {current_user.username}")
|
||||||
@@ -203,8 +214,8 @@ def verify_pwd(
|
|||||||
is_valid = verify_password(request.password, current_user.hashed_password)
|
is_valid = verify_password(request.password, current_user.hashed_password)
|
||||||
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
|
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
|
||||||
return success(data={"valid": is_valid}, msg="验证完成")
|
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("/send-email-code", response_model=ApiResponse)
|
@router.post("/send-email-code", response_model=ApiResponse)
|
||||||
@@ -212,6 +223,7 @@ async def send_email_code(
|
|||||||
request: SendEmailCodeRequest,
|
request: SendEmailCodeRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""发送邮箱验证码"""
|
"""发送邮箱验证码"""
|
||||||
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
|
||||||
@@ -219,7 +231,7 @@ async def send_email_code(
|
|||||||
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
|
||||||
|
|
||||||
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
|
||||||
return success(msg="验证码已发送到您的邮箱,请查收")
|
return success(msg=t("users.email.code_sent"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/change-email", response_model=ApiResponse)
|
@router.put("/change-email", response_model=ApiResponse)
|
||||||
@@ -227,6 +239,7 @@ async def change_email(
|
|||||||
request: VerifyEmailCodeRequest,
|
request: VerifyEmailCodeRequest,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""验证验证码并修改邮箱"""
|
"""验证验证码并修改邮箱"""
|
||||||
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
|
||||||
@@ -239,4 +252,51 @@ async def change_email(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
|
||||||
return success(msg="邮箱修改成功")
|
return success(msg=t("users.email.change_success"))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/language", response_model=ApiResponse)
|
||||||
|
def get_current_user_language(
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
"""获取当前用户的语言偏好"""
|
||||||
|
api_logger.info(f"获取用户语言偏好: {current_user.username}")
|
||||||
|
|
||||||
|
language = user_service.get_user_language_preference(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
|
||||||
|
return success(
|
||||||
|
data=user_schema.LanguagePreferenceResponse(language=language),
|
||||||
|
msg=t("users.language.get_success")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/language", response_model=ApiResponse)
|
||||||
|
def update_current_user_language(
|
||||||
|
request: user_schema.LanguagePreferenceRequest,
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
"""设置当前用户的语言偏好"""
|
||||||
|
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
|
||||||
|
|
||||||
|
updated_user = user_service.update_user_language_preference(
|
||||||
|
db=db,
|
||||||
|
user_id=current_user.id,
|
||||||
|
language=request.language,
|
||||||
|
current_user=current_user
|
||||||
|
)
|
||||||
|
|
||||||
|
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
|
||||||
|
return success(
|
||||||
|
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
|
||||||
|
msg=t("users.language.update_success")
|
||||||
|
)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from app.services.user_memory_service import (
|
|||||||
UserMemoryService,
|
UserMemoryService,
|
||||||
analytics_memory_types,
|
analytics_memory_types,
|
||||||
analytics_graph_data,
|
analytics_graph_data,
|
||||||
|
analytics_community_graph_data,
|
||||||
)
|
)
|
||||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||||
from app.schemas.response_schema import ApiResponse
|
from app.schemas.response_schema import ApiResponse
|
||||||
@@ -295,6 +296,42 @@ async def get_graph_data_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/analytics/community_graph", response_model=ApiResponse)
|
||||||
|
async def get_community_graph_data_api(
|
||||||
|
end_user_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db),
|
||||||
|
) -> dict:
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
|
||||||
|
if workspace_id is None:
|
||||||
|
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
|
||||||
|
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
|
||||||
|
f"workspace={workspace_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
|
||||||
|
|
||||||
|
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||||
|
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
|
||||||
|
return success(data=result, msg=result.get("message", "查询成功"))
|
||||||
|
|
||||||
|
api_logger.info(
|
||||||
|
f"成功获取社区图谱: end_user_id={end_user_id}, "
|
||||||
|
f"nodes={result['statistics']['total_nodes']}, "
|
||||||
|
f"edges={result['statistics']['total_edges']}"
|
||||||
|
)
|
||||||
|
return success(data=result, msg="查询成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
|
||||||
|
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
@router.get("/read_end_user/profile", response_model=ApiResponse)
|
||||||
async def get_end_user_profile(
|
async def get_end_user_profile(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
|
|||||||
@@ -14,6 +14,12 @@ from app.dependencies import (
|
|||||||
get_current_user,
|
get_current_user,
|
||||||
workspace_access_guard,
|
workspace_access_guard,
|
||||||
)
|
)
|
||||||
|
from app.i18n.dependencies import get_current_language, get_translator
|
||||||
|
from app.i18n.serializers import (
|
||||||
|
WorkspaceSerializer,
|
||||||
|
WorkspaceMemberSerializer,
|
||||||
|
WorkspaceInviteSerializer
|
||||||
|
)
|
||||||
from app.models.tenant_model import Tenants
|
from app.models.tenant_model import Tenants
|
||||||
from app.models.user_model import User
|
from app.models.user_model import User
|
||||||
from app.models.workspace_model import InviteStatus
|
from app.models.workspace_model import InviteStatus
|
||||||
@@ -65,7 +71,9 @@ def get_workspaces(
|
|||||||
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
include_current: bool = Query(True, description="是否包含当前工作空间"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
current_tenant: Tenants = Depends(get_current_tenant)
|
current_tenant: Tenants = Depends(get_current_tenant),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前租户下用户参与的所有工作空间
|
"""获取当前租户下用户参与的所有工作空间
|
||||||
|
|
||||||
@@ -88,8 +96,13 @@ def get_workspaces(
|
|||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
|
||||||
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
|
|
||||||
return success(data=workspaces_schema, msg="工作空间列表获取成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
|
||||||
|
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
|
||||||
|
|
||||||
|
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.post("", response_model=ApiResponse)
|
@router.post("", response_model=ApiResponse)
|
||||||
@@ -98,6 +111,8 @@ def create_workspace(
|
|||||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_superuser),
|
current_user: User = Depends(get_current_superuser),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建新的工作空间"""
|
"""创建新的工作空间"""
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -118,8 +133,13 @@ def create_workspace(
|
|||||||
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
|
||||||
f"创建者: {current_user.username}, language={language}"
|
f"创建者: {current_user.username}, language={language}"
|
||||||
)
|
)
|
||||||
result_schema = WorkspaceResponse.model_validate(result)
|
|
||||||
return success(data=result_schema, msg="工作空间创建成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||||
|
result_i18n = serializer.serialize(result_data, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.created"))
|
||||||
|
|
||||||
@router.put("", response_model=ApiResponse)
|
@router.put("", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
@@ -127,6 +147,8 @@ def update_workspace(
|
|||||||
workspace: WorkspaceUpdate,
|
workspace: WorkspaceUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""更新工作空间"""
|
"""更新工作空间"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -139,14 +161,21 @@ def update_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
|
||||||
result_schema = WorkspaceResponse.model_validate(result)
|
|
||||||
return success(data=result_schema, msg="工作空间更新成功")
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceSerializer()
|
||||||
|
result_data = WorkspaceResponse.model_validate(result).model_dump()
|
||||||
|
result_i18n = serializer.serialize(result_data, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.updated"))
|
||||||
|
|
||||||
@router.get("/members", response_model=ApiResponse)
|
@router.get("/members", response_model=ApiResponse)
|
||||||
@cur_workspace_access_guard()
|
@cur_workspace_access_guard()
|
||||||
def get_cur_workspace_members(
|
def get_cur_workspace_members(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间成员列表(关系序列化)"""
|
"""获取工作空间成员列表(关系序列化)"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
|
||||||
@@ -157,8 +186,14 @@ def get_cur_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
|
||||||
|
|
||||||
|
# 转换为表格项并使用序列化器添加国际化字段
|
||||||
table_items = _convert_members_to_table_items(members)
|
table_items = _convert_members_to_table_items(members)
|
||||||
return success(data=table_items, msg="工作空间成员列表获取成功")
|
serializer = WorkspaceMemberSerializer()
|
||||||
|
members_data = [item.model_dump() for item in table_items]
|
||||||
|
members_i18n = serializer.serialize_list(members_data, language)
|
||||||
|
|
||||||
|
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/members", response_model=ApiResponse)
|
@router.put("/members", response_model=ApiResponse)
|
||||||
@@ -168,6 +203,7 @@ def update_workspace_members(
|
|||||||
updates: List[WorkspaceMemberUpdate],
|
updates: List[WorkspaceMemberUpdate],
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
|
||||||
@@ -178,7 +214,7 @@ def update_workspace_members(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
|
||||||
return success(msg="成员角色更新成功")
|
return success(msg=t("workspace.members.role_updated"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
@router.delete("/members/{member_id}", response_model=ApiResponse)
|
||||||
@@ -187,6 +223,7 @@ def delete_workspace_member(
|
|||||||
member_id: uuid.UUID,
|
member_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
|
||||||
@@ -198,7 +235,7 @@ def delete_workspace_member(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
|
||||||
return success(msg="成员删除成功")
|
return success(msg=t("workspace.members.deleted"))
|
||||||
|
|
||||||
|
|
||||||
# 创建空间协作邀请
|
# 创建空间协作邀请
|
||||||
@@ -208,6 +245,8 @@ def create_workspace_invite(
|
|||||||
invite_data: WorkspaceInviteCreate,
|
invite_data: WorkspaceInviteCreate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""创建工作空间邀请"""
|
"""创建工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -220,7 +259,12 @@ def create_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
|
||||||
return success(data=result, msg="邀请创建成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.created"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/invites", response_model=ApiResponse)
|
@router.get("/invites", response_model=ApiResponse)
|
||||||
@@ -232,6 +276,8 @@ def get_workspace_invites(
|
|||||||
offset: int = Query(0, ge=0),
|
offset: int = Query(0, ge=0),
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请列表"""
|
"""获取工作空间邀请列表"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -246,18 +292,30 @@ def get_workspace_invites(
|
|||||||
offset=offset
|
offset=offset
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
|
||||||
return success(data=invites, msg="邀请列表获取成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
invites_i18n = serializer.serialize_list(invites, language)
|
||||||
|
|
||||||
|
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
|
||||||
def get_workspace_invite_info(
|
def get_workspace_invite_info(
|
||||||
token: str,
|
token: str,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取工作空间邀请用户信息(无需认证)"""
|
"""获取工作空间邀请用户信息(无需认证)"""
|
||||||
result = workspace_service.validate_invite_token(db=db, token=token)
|
result = workspace_service.validate_invite_token(db=db, token=token)
|
||||||
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
|
||||||
return success(data=result, msg="邀请验证成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.validated"))
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
|
||||||
@@ -267,6 +325,8 @@ def revoke_workspace_invite(
|
|||||||
invite_id: uuid.UUID,
|
invite_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""撤销工作空间邀请"""
|
"""撤销工作空间邀请"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -279,7 +339,12 @@ def revoke_workspace_invite(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
|
||||||
return success(data=result, msg="邀请撤销成功")
|
|
||||||
|
# 使用序列化器添加国际化字段
|
||||||
|
serializer = WorkspaceInviteSerializer()
|
||||||
|
result_i18n = serializer.serialize(result, language)
|
||||||
|
|
||||||
|
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
|
||||||
|
|
||||||
# ==================== 公开邀请接口(无需认证) ====================
|
# ==================== 公开邀请接口(无需认证) ====================
|
||||||
|
|
||||||
@@ -302,6 +367,7 @@ def switch_workspace(
|
|||||||
workspace_id: uuid.UUID,
|
workspace_id: uuid.UUID,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""切换工作空间"""
|
"""切换工作空间"""
|
||||||
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
|
||||||
@@ -312,7 +378,7 @@ def switch_workspace(
|
|||||||
user=current_user,
|
user=current_user,
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
api_logger.info(f"成功切换工作空间为 {workspace_id}")
|
||||||
return success(msg="工作空间切换成功")
|
return success(msg=t("workspace.switched"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/storage", response_model=ApiResponse)
|
@router.get("/storage", response_model=ApiResponse)
|
||||||
@@ -320,6 +386,7 @@ def switch_workspace(
|
|||||||
def get_workspace_storage_type(
|
def get_workspace_storage_type(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的存储类型"""
|
"""获取当前工作空间的存储类型"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -331,7 +398,7 @@ def get_workspace_storage_type(
|
|||||||
user=current_user
|
user=current_user
|
||||||
)
|
)
|
||||||
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
|
||||||
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
|
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/workspace_models", response_model=ApiResponse)
|
@router.get("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -339,6 +406,8 @@ def get_workspace_storage_type(
|
|||||||
def workspace_models_configs(
|
def workspace_models_configs(
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
language: str = Depends(get_current_language),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""获取当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -354,14 +423,14 @@ def workspace_models_configs(
|
|||||||
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="工作空间不存在或无权访问"
|
detail=t("workspace.not_found")
|
||||||
)
|
)
|
||||||
|
|
||||||
api_logger.info(
|
api_logger.info(
|
||||||
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
f"成功获取工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
|
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
|
||||||
|
|
||||||
|
|
||||||
@router.put("/workspace_models", response_model=ApiResponse)
|
@router.put("/workspace_models", response_model=ApiResponse)
|
||||||
@@ -370,6 +439,7 @@ def update_workspace_models_configs(
|
|||||||
models_update: WorkspaceModelsUpdate,
|
models_update: WorkspaceModelsUpdate,
|
||||||
db: Session = Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
|
t: callable = Depends(get_translator)
|
||||||
):
|
):
|
||||||
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
"""更新当前工作空间的模型配置(llm, embedding, rerank)"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
@@ -386,5 +456,5 @@ def update_workspace_models_configs(
|
|||||||
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
f"成功更新工作空间 {workspace_id} 的模型配置: "
|
||||||
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
|
||||||
)
|
)
|
||||||
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")
|
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Annotated, Any, Dict, Optional
|
from typing import Annotated, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
@@ -98,6 +97,7 @@ class Settings:
|
|||||||
|
|
||||||
# File Upload
|
# File Upload
|
||||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||||
|
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
|
||||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||||
|
|
||||||
@@ -115,6 +115,7 @@ class Settings:
|
|||||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||||
|
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
|
||||||
|
|
||||||
# VOLC ASR settings
|
# VOLC ASR settings
|
||||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||||
@@ -162,6 +163,44 @@ class Settings:
|
|||||||
# This controls the language used for memory summary titles and other generated content
|
# This controls the language used for memory summary titles and other generated content
|
||||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# ========================================================================
|
||||||
|
# Internationalization (i18n) Configuration
|
||||||
|
# ========================================================================
|
||||||
|
# Default language for API responses
|
||||||
|
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# Supported languages (comma-separated)
|
||||||
|
I18N_SUPPORTED_LANGUAGES: list[str] = [
|
||||||
|
lang.strip()
|
||||||
|
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
|
||||||
|
if lang.strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
# Core locales directory (community edition)
|
||||||
|
# Use absolute path to work from any working directory
|
||||||
|
I18N_CORE_LOCALES_DIR: str = os.getenv(
|
||||||
|
"I18N_CORE_LOCALES_DIR",
|
||||||
|
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Premium locales directory (enterprise edition, optional)
|
||||||
|
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
|
||||||
|
|
||||||
|
# Enable translation cache
|
||||||
|
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
|
||||||
|
|
||||||
|
# LRU cache size for hot translations
|
||||||
|
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
|
||||||
|
|
||||||
|
# Enable hot reload of translation files
|
||||||
|
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
|
||||||
|
|
||||||
|
# Fallback language when translation is missing
|
||||||
|
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
|
||||||
|
|
||||||
|
# Log missing translations
|
||||||
|
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
|
||||||
|
|
||||||
# Logging settings
|
# Logging settings
|
||||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
|||||||
@@ -1,16 +1,45 @@
|
|||||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||||
|
from app.schemas.memory_agent_schema import AgentMemoryDataset
|
||||||
|
|
||||||
|
|
||||||
def content_input_node(state: ReadState) -> ReadState:
|
def content_input_node(state: ReadState) -> ReadState:
|
||||||
"""开始节点 - 提取内容并保持状态信息"""
|
"""
|
||||||
|
Start node - Extract content and maintain state information
|
||||||
|
|
||||||
|
Extracts the content from the first message in the state and returns it
|
||||||
|
as the data field while preserving all other state information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing messages and other state data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with extracted content in data field
|
||||||
|
"""
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# 返回内容并保持所有状态信息
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
|
|
||||||
|
|
||||||
def content_input_write(state: WriteState) -> WriteState:
|
def content_input_write(state: WriteState) -> WriteState:
|
||||||
"""开始节点 - 提取内容并保持状态信息"""
|
"""
|
||||||
|
Start node - Extract content and maintain state information for write operations
|
||||||
|
|
||||||
|
Extracts the content from the first message in the state for write operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: WriteState containing messages and other state data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
WriteState: Updated state with extracted content in data field
|
||||||
|
"""
|
||||||
|
|
||||||
content = state['messages'][0].content if state.get('messages') else ''
|
content = state['messages'][0].content if state.get('messages') else ''
|
||||||
# 返回内容并保持所有状态信息
|
# Return content and maintain all state information
|
||||||
|
for pronoun in AgentMemoryDataset.PRONOUN:
|
||||||
|
content = content.replace(pronoun, AgentMemoryDataset.NAME)
|
||||||
|
|
||||||
return {"data": content}
|
return {"data": content}
|
||||||
@@ -19,19 +19,39 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class ProblemNodeService(LLMServiceMixin):
|
class ProblemNodeService(LLMServiceMixin):
|
||||||
"""问题处理节点服务类"""
|
"""
|
||||||
|
Problem processing node service class
|
||||||
|
|
||||||
|
Handles problem decomposition and extension operations using LLM services.
|
||||||
|
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
problem_service = ProblemNodeService()
|
problem_service = ProblemNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||||
"""问题分解节点"""
|
"""
|
||||||
|
Problem decomposition node
|
||||||
|
|
||||||
|
Breaks down complex user queries into smaller, more manageable sub-problems.
|
||||||
|
Uses LLM to analyze the input and generate structured problem decomposition
|
||||||
|
with question types and reasoning.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user input and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with problem decomposition results
|
||||||
|
"""
|
||||||
# 从状态中获取数据
|
# 从状态中获取数据
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
@@ -64,7 +84,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
# 添加更详细的日志记录
|
# 添加更详细的日志记录
|
||||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||||
|
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if not structured or not hasattr(structured, 'root'):
|
if not structured or not hasattr(structured, 'root'):
|
||||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||||
split_result = json.dumps([], ensure_ascii=False)
|
split_result = json.dumps([], ensure_ascii=False)
|
||||||
@@ -106,7 +126,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提供更详细的错误信息
|
# Provide more detailed error information
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
@@ -116,7 +136,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||||
|
|
||||||
# 创建默认的空结果
|
# Create default empty result
|
||||||
result = {
|
result = {
|
||||||
"context": json.dumps([], ensure_ascii=False),
|
"context": json.dumps([], ensure_ascii=False),
|
||||||
"original": content,
|
"original": content,
|
||||||
@@ -130,13 +150,25 @@ async def Split_The_Problem(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 返回更新后的状态,包含spit_context字段
|
# Return updated state including spit_context field
|
||||||
return {"spit_data": result}
|
return {"spit_data": result}
|
||||||
|
|
||||||
|
|
||||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||||
"""问题扩展节点"""
|
"""
|
||||||
# 获取原始数据和分解结果
|
Problem extension node
|
||||||
|
|
||||||
|
Extends the decomposed problems from Split_The_Problem node by generating
|
||||||
|
additional related questions and organizing them by original question.
|
||||||
|
Uses LLM to create comprehensive question extensions for better memory retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing decomposed problems and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with extended problem results
|
||||||
|
"""
|
||||||
|
# Get original data and decomposition results
|
||||||
start = time.time()
|
start = time.time()
|
||||||
content = state.get('data', '')
|
content = state.get('data', '')
|
||||||
data = state.get('spit_data', '')['context']
|
data = state.get('spit_data', '')['context']
|
||||||
@@ -182,7 +214,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||||
|
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if not response_content or not hasattr(response_content, 'root'):
|
if not response_content or not hasattr(response_content, 'root'):
|
||||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||||
aggregated_dict = {}
|
aggregated_dict = {}
|
||||||
@@ -216,7 +248,7 @@ async def Problem_Extension(state: ReadState) -> ReadState:
|
|||||||
exc_info=True
|
exc_info=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# 提供更详细的错误信息
|
# Provide more detailed error information
|
||||||
error_details = {
|
error_details = {
|
||||||
"error_type": type(e).__name__,
|
"error_type": type(e).__name__,
|
||||||
"error_message": str(e),
|
"error_message": str(e),
|
||||||
|
|||||||
@@ -29,6 +29,18 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
|
"""
|
||||||
|
Configure RAG (Retrieval-Augmented Generation) settings
|
||||||
|
|
||||||
|
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||||
|
weights, and reranker settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing user_rag_memory_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: RAG configuration dictionary
|
||||||
|
"""
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
"knowledge_bases": [
|
"knowledge_bases": [
|
||||||
@@ -48,6 +60,19 @@ async def rag_config(state):
|
|||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state, question):
|
async def rag_knowledge(state, question):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge using RAG approach
|
||||||
|
|
||||||
|
Performs knowledge retrieval from configured knowledge bases using the
|
||||||
|
provided question and returns formatted results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing configuration
|
||||||
|
question: Question to search for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||||
|
"""
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
@@ -68,12 +93,24 @@ async def rag_knowledge(state, question):
|
|||||||
|
|
||||||
|
|
||||||
async def llm_infomation(state: ReadState) -> ReadState:
|
async def llm_infomation(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Get LLM configuration information from state
|
||||||
|
|
||||||
|
Retrieves model configuration details including model ID and tenant ID
|
||||||
|
from the memory configuration in the current state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing memory configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Model configuration as Pydantic model
|
||||||
|
"""
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
model_id = memory_config.llm_model_id
|
model_id = memory_config.llm_model_id
|
||||||
tenant_id = memory_config.tenant_id
|
tenant_id = memory_config.tenant_id
|
||||||
|
|
||||||
# 使用现有的 memory_config 而不是重新查询数据库
|
# Use existing memory_config instead of re-querying database
|
||||||
# 或者使用线程安全的数据库访问
|
# or use thread-safe database access
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||||
@@ -82,16 +119,20 @@ async def llm_infomation(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
async def clean_databases(data) -> str:
|
async def clean_databases(data) -> str:
|
||||||
"""
|
"""
|
||||||
简化的数据库搜索结果清理函数
|
Simplified database search result cleaning function
|
||||||
|
|
||||||
|
Processes and cleans search results from various sources including
|
||||||
|
reranked results and time-based search results. Extracts text content
|
||||||
|
from structured data and returns as formatted string.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 搜索结果数据
|
data: Search result data (can be string, dict, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的内容字符串
|
str: Cleaned content string
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 解析JSON字符串
|
# Parse JSON string
|
||||||
if isinstance(data, str):
|
if isinstance(data, str):
|
||||||
try:
|
try:
|
||||||
data = json.loads(data)
|
data = json.loads(data)
|
||||||
@@ -101,24 +142,24 @@ async def clean_databases(data) -> str:
|
|||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
# 获取结果数据
|
# Get result data
|
||||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||||
results = data.get('results', data)
|
results = data.get('results', data)
|
||||||
if not isinstance(results, dict):
|
if not isinstance(results, dict):
|
||||||
return str(results)
|
return str(results)
|
||||||
|
|
||||||
# 收集所有内容
|
# Collect all content
|
||||||
content_list = []
|
content_list = []
|
||||||
|
|
||||||
# 处理重排序结果
|
# Process reranked results
|
||||||
reranked = results.get('reranked_results', {})
|
reranked = results.get('reranked_results', {})
|
||||||
if reranked:
|
if reranked:
|
||||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||||
items = reranked.get(category, [])
|
items = reranked.get(category, [])
|
||||||
if isinstance(items, list):
|
if isinstance(items, list):
|
||||||
content_list.extend(items)
|
content_list.extend(items)
|
||||||
# 处理时间搜索结果
|
# Process time search results
|
||||||
time_search = results.get('time_search', {})
|
time_search = results.get('time_search', {})
|
||||||
if time_search:
|
if time_search:
|
||||||
if isinstance(time_search, dict):
|
if isinstance(time_search, dict):
|
||||||
@@ -128,7 +169,7 @@ async def clean_databases(data) -> str:
|
|||||||
elif isinstance(time_search, list):
|
elif isinstance(time_search, list):
|
||||||
content_list.extend(time_search)
|
content_list.extend(time_search)
|
||||||
|
|
||||||
# 提取文本内容
|
# Extract text content
|
||||||
text_parts = []
|
text_parts = []
|
||||||
for item in content_list:
|
for item in content_list:
|
||||||
if isinstance(item, dict):
|
if isinstance(item, dict):
|
||||||
@@ -146,10 +187,19 @@ async def clean_databases(data) -> str:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||||
'''
|
"""
|
||||||
|
Retrieve information using simplified search approach
|
||||||
|
|
||||||
模型信息
|
Processes extended problems from previous nodes and performs retrieval
|
||||||
'''
|
using either RAG or hybrid search based on storage type. Handles concurrent
|
||||||
|
processing of multiple questions and deduplicates results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing problem extensions and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with retrieval results and intermediate outputs
|
||||||
|
"""
|
||||||
|
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
@@ -163,7 +213,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
problem_list.append(data)
|
problem_list.append(data)
|
||||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# Create async task to process individual questions
|
||||||
async def process_question_nodes(idx, question):
|
async def process_question_nodes(idx, question):
|
||||||
try:
|
try:
|
||||||
# Prepare search parameters based on storage type
|
# Prepare search parameters based on storage type
|
||||||
@@ -209,7 +259,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 并发处理所有问题
|
# Process all questions concurrently
|
||||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
databases_data = {
|
databases_data = {
|
||||||
@@ -257,7 +307,20 @@ async def retrieve_nodes(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def retrieve(state: ReadState) -> ReadState:
|
async def retrieve(state: ReadState) -> ReadState:
|
||||||
# 从state中获取end_user_id
|
"""
|
||||||
|
Advanced retrieve function using LangChain agents and tools
|
||||||
|
|
||||||
|
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
|
||||||
|
to perform sophisticated information retrieval. Supports both RAG and traditional
|
||||||
|
memory storage approaches with concurrent processing and result deduplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing problem extensions and configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state with retrieval results and intermediate outputs
|
||||||
|
"""
|
||||||
|
# Get end_user_id from state
|
||||||
import time
|
import time
|
||||||
start = time.time()
|
start = time.time()
|
||||||
problem_extension = state.get('problem_extension', '')['context']
|
problem_extension = state.get('problem_extension', '')['context']
|
||||||
@@ -299,21 +362,21 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建异步任务处理单个问题
|
# Create async task to process individual questions
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
# 在模块级别定义信号量,限制最大并发数
|
# Define semaphore at module level to limit maximum concurrency
|
||||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
|
||||||
|
|
||||||
async def process_question(idx, question):
|
async def process_question(idx, question):
|
||||||
async with SEMAPHORE: # 限制并发
|
async with SEMAPHORE: # Limit concurrency
|
||||||
try:
|
try:
|
||||||
if storage_type == "rag" and user_rag_memory_id:
|
if storage_type == "rag" and user_rag_memory_id:
|
||||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
|
||||||
question)
|
question)
|
||||||
else:
|
else:
|
||||||
cleaned_query = question
|
cleaned_query = question
|
||||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
# Use asyncio to run synchronous agent.invoke in thread pool
|
||||||
import asyncio
|
import asyncio
|
||||||
response = await asyncio.get_event_loop().run_in_executor(
|
response = await asyncio.get_event_loop().run_in_executor(
|
||||||
None,
|
None,
|
||||||
@@ -362,7 +425,7 @@ async def retrieve(state: ReadState) -> ReadState:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# 并发处理所有问题
|
# Process all questions concurrently
|
||||||
import asyncio
|
import asyncio
|
||||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||||
databases_anser = await asyncio.gather(*tasks)
|
databases_anser = await asyncio.gather(*tasks)
|
||||||
|
|||||||
@@ -23,18 +23,39 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class SummaryNodeService(LLMServiceMixin):
|
class SummaryNodeService(LLMServiceMixin):
|
||||||
"""总结节点服务类"""
|
"""
|
||||||
|
Summary node service class
|
||||||
|
|
||||||
|
Handles summary generation operations using LLM services. Inherits from
|
||||||
|
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||||
|
generating summaries from retrieved information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
summary_service = SummaryNodeService()
|
summary_service = SummaryNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def rag_config(state):
|
async def rag_config(state):
|
||||||
|
"""
|
||||||
|
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
|
||||||
|
|
||||||
|
Creates configuration for knowledge base retrieval including similarity thresholds,
|
||||||
|
weights, and reranker settings specifically for summary generation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing user_rag_memory_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: RAG configuration dictionary with knowledge base settings
|
||||||
|
"""
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
kb_config = {
|
kb_config = {
|
||||||
"knowledge_bases": [
|
"knowledge_bases": [
|
||||||
@@ -54,6 +75,23 @@ async def rag_config(state):
|
|||||||
|
|
||||||
|
|
||||||
async def rag_knowledge(state, question):
|
async def rag_knowledge(state, question):
|
||||||
|
"""
|
||||||
|
Retrieve knowledge using RAG approach for summary generation
|
||||||
|
|
||||||
|
Performs knowledge retrieval from configured knowledge bases using the
|
||||||
|
provided question and returns formatted results for summary processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state containing configuration
|
||||||
|
question: Question to search for in knowledge base
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
|
||||||
|
- retrieval_knowledge: List of retrieved knowledge chunks
|
||||||
|
- clean_content: Formatted content string
|
||||||
|
- cleaned_query: Processed query string
|
||||||
|
- raw_results: Raw retrieval results
|
||||||
|
"""
|
||||||
kb_config = await rag_config(state)
|
kb_config = await rag_config(state)
|
||||||
end_user_id = state.get('end_user_id', '')
|
end_user_id = state.get('end_user_id', '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
@@ -74,6 +112,18 @@ async def rag_knowledge(state, question):
|
|||||||
|
|
||||||
|
|
||||||
async def summary_history(state: ReadState) -> ReadState:
|
async def summary_history(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Retrieve conversation history for summary context
|
||||||
|
|
||||||
|
Gets the conversation history for the current user to provide context
|
||||||
|
for summary generation operations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing end_user_id
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Conversation history data
|
||||||
|
"""
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
|
||||||
return history
|
return history
|
||||||
@@ -82,11 +132,26 @@ async def summary_history(state: ReadState) -> ReadState:
|
|||||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
|
||||||
search_mode) -> str:
|
search_mode) -> str:
|
||||||
"""
|
"""
|
||||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
Enhanced summary_llm function with better error handling and data validation
|
||||||
|
|
||||||
|
Generates summaries using LLM with structured output. Includes fallback mechanisms
|
||||||
|
for handling LLM failures and provides robust error recovery.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing current context
|
||||||
|
history: Conversation history for context
|
||||||
|
retrieve_info: Retrieved information to summarize
|
||||||
|
template_name: Jinja2 template name for prompt generation
|
||||||
|
operation_name: Type of operation (summary, input_summary, retrieve_summary)
|
||||||
|
response_model: Pydantic model for structured output
|
||||||
|
search_mode: Search mode flag ("0" for simple, "1" for complex)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Generated summary text or fallback message
|
||||||
"""
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
|
|
||||||
# 构建系统提示词
|
# Build system prompt
|
||||||
if str(search_mode) == "0":
|
if str(search_mode) == "0":
|
||||||
system_prompt = await summary_service.template_service.render_template(
|
system_prompt = await summary_service.template_service.render_template(
|
||||||
template_name=template_name,
|
template_name=template_name,
|
||||||
@@ -103,7 +168,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
retrieve_info=retrieve_info
|
retrieve_info=retrieve_info
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
# 使用优化的LLM服务进行结构化输出
|
# Use optimized LLM service for structured output
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
structured = await summary_service.call_llm_structured(
|
structured = await summary_service.call_llm_structured(
|
||||||
state=state,
|
state=state,
|
||||||
@@ -112,23 +177,23 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
response_model=response_model,
|
response_model=response_model,
|
||||||
fallback_value=None
|
fallback_value=None
|
||||||
)
|
)
|
||||||
# 验证结构化响应
|
# Validate structured response
|
||||||
if structured is None:
|
if structured is None:
|
||||||
logger.warning("LLM返回None,使用默认回答")
|
logger.warning("LLM返回None,使用默认回答")
|
||||||
return "信息不足,无法回答"
|
return "信息不足,无法回答"
|
||||||
|
|
||||||
# 根据操作类型提取答案
|
# Extract answer based on operation type
|
||||||
if operation_name == "summary":
|
if operation_name == "summary":
|
||||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
# 处理RetrieveSummaryResponse
|
# Handle RetrieveSummaryResponse
|
||||||
if hasattr(structured, 'data') and structured.data:
|
if hasattr(structured, 'data') and structured.data:
|
||||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||||
else:
|
else:
|
||||||
logger.warning("结构化响应缺少data字段")
|
logger.warning("结构化响应缺少data字段")
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
# 验证答案不为空
|
# Validate answer is not empty
|
||||||
if not aimessages or aimessages.strip() == "":
|
if not aimessages or aimessages.strip() == "":
|
||||||
aimessages = "信息不足,无法回答"
|
aimessages = "信息不足,无法回答"
|
||||||
|
|
||||||
@@ -137,7 +202,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||||
|
|
||||||
# 尝试非结构化输出作为fallback
|
# Try unstructured output as fallback
|
||||||
try:
|
try:
|
||||||
logger.info("尝试非结构化输出作为fallback")
|
logger.info("尝试非结构化输出作为fallback")
|
||||||
response = await summary_service.call_llm_simple(
|
response = await summary_service.call_llm_simple(
|
||||||
@@ -148,9 +213,9 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
)
|
)
|
||||||
|
|
||||||
if response and response.strip():
|
if response and response.strip():
|
||||||
# 简单清理响应
|
# Simple response cleaning
|
||||||
cleaned_response = response.strip()
|
cleaned_response = response.strip()
|
||||||
# 移除可能的JSON标记
|
# Remove possible JSON markers
|
||||||
if cleaned_response.startswith('```'):
|
if cleaned_response.startswith('```'):
|
||||||
lines = cleaned_response.split('\n')
|
lines = cleaned_response.split('\n')
|
||||||
cleaned_response = '\n'.join(lines[1:-1])
|
cleaned_response = '\n'.join(lines[1:-1])
|
||||||
@@ -165,6 +230,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o
|
|||||||
|
|
||||||
|
|
||||||
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
||||||
|
"""
|
||||||
|
Save summary results to Redis session storage
|
||||||
|
|
||||||
|
Stores the generated summary and user query in Redis for session management
|
||||||
|
and conversation history tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user and query information
|
||||||
|
aimessages: Generated summary message to save
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Updated state after saving to Redis
|
||||||
|
"""
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
end_user_id = state.get("end_user_id", '')
|
end_user_id = state.get("end_user_id", '')
|
||||||
await SessionService(store).save_session(
|
await SessionService(store).save_session(
|
||||||
@@ -179,6 +257,20 @@ async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
|
||||||
|
"""
|
||||||
|
Format summary results for different output types
|
||||||
|
|
||||||
|
Creates structured output formats for both input summary and retrieval summary
|
||||||
|
operations, including metadata and intermediate results for frontend display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing storage and user information
|
||||||
|
aimessages: Generated summary message
|
||||||
|
raw_results: Raw search/retrieval results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (input_summary, retrieve_summary) formatted result dictionaries
|
||||||
|
"""
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
data = state.get("data", '')
|
data = state.get("data", '')
|
||||||
@@ -217,6 +309,19 @@ async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState
|
|||||||
|
|
||||||
|
|
||||||
async def Input_Summary(state: ReadState) -> ReadState:
|
async def Input_Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate quick input summary from retrieved information
|
||||||
|
|
||||||
|
Performs fast retrieval and generates a quick summary response for user queries.
|
||||||
|
This function prioritizes speed by only searching summary nodes and provides
|
||||||
|
immediate feedback to users.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing user query, storage configuration, and context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing summary results with status and metadata
|
||||||
|
"""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
memory_config = state.get('memory_config', None)
|
memory_config = state.get('memory_config', None)
|
||||||
@@ -266,6 +371,19 @@ async def Input_Summary(state: ReadState) -> ReadState:
|
|||||||
|
|
||||||
|
|
||||||
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
async def Retrieve_Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate comprehensive summary from retrieved expansion issues
|
||||||
|
|
||||||
|
Processes retrieved expansion issues and generates a detailed summary using LLM.
|
||||||
|
This function handles complex retrieval results and provides comprehensive answers
|
||||||
|
based on expanded query results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing retrieve data with expansion issues
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing comprehensive summary results
|
||||||
|
"""
|
||||||
retrieve = state.get("retrieve", '')
|
retrieve = state.get("retrieve", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
import json
|
import json
|
||||||
@@ -299,13 +417,26 @@ async def Retrieve_Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# 修复协程调用 - 先await,然后访问返回值
|
# Fixed coroutine call - await first, then access return value
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary(state: ReadState) -> ReadState:
|
async def Summary(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate final comprehensive summary from verified data
|
||||||
|
|
||||||
|
Creates the final summary using verified expansion issues and conversation history.
|
||||||
|
This function processes verified data to generate the most comprehensive and
|
||||||
|
accurate response to user queries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing verified data and query information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing final summary results
|
||||||
|
"""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
query = state.get("data", '')
|
query = state.get("data", '')
|
||||||
verify = state.get("verify", '')
|
verify = state.get("verify", '')
|
||||||
@@ -336,13 +467,26 @@ async def Summary(state: ReadState) -> ReadState:
|
|||||||
duration = 0.0
|
duration = 0.0
|
||||||
log_time('Retrieval summary', duration)
|
log_time('Retrieval summary', duration)
|
||||||
|
|
||||||
# 修复协程调用 - 先await,然后访问返回值
|
# Fixed coroutine call - await first, then access return value
|
||||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||||
summary = summary_result[1]
|
summary = summary_result[1]
|
||||||
return {"summary": summary}
|
return {"summary": summary}
|
||||||
|
|
||||||
|
|
||||||
async def Summary_fails(state: ReadState) -> ReadState:
|
async def Summary_fails(state: ReadState) -> ReadState:
|
||||||
|
"""
|
||||||
|
Generate fallback summary when normal summary process fails
|
||||||
|
|
||||||
|
Provides a fallback summary generation mechanism when the standard summary
|
||||||
|
process encounters errors or fails to produce satisfactory results. Uses
|
||||||
|
a specialized failure template to handle edge cases.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing verified data and failure context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ReadState: Dictionary containing fallback summary results
|
||||||
|
"""
|
||||||
storage_type = state.get("storage_type", '')
|
storage_type = state.get("storage_type", '')
|
||||||
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
user_rag_memory_id = state.get("user_rag_memory_id", '')
|
||||||
history = await summary_history(state)
|
history = await summary_history(state)
|
||||||
|
|||||||
@@ -18,24 +18,46 @@ logger = get_agent_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class VerificationNodeService(LLMServiceMixin):
|
class VerificationNodeService(LLMServiceMixin):
|
||||||
"""验证节点服务类"""
|
"""
|
||||||
|
Verification node service class
|
||||||
|
|
||||||
|
Handles data verification operations using LLM services. Inherits from
|
||||||
|
LLMServiceMixin to provide structured LLM calling capabilities for
|
||||||
|
verifying and validating retrieved information.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
template_service: Service for rendering Jinja2 templates
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template_service = TemplateService(template_root)
|
self.template_service = TemplateService(template_root)
|
||||||
|
|
||||||
|
|
||||||
# 创建全局服务实例
|
# Create global service instance
|
||||||
verification_service = VerificationNodeService()
|
verification_service = VerificationNodeService()
|
||||||
|
|
||||||
|
|
||||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||||
"""处理验证结果并生成输出格式"""
|
"""
|
||||||
|
Process verification results and generate output format
|
||||||
|
|
||||||
|
Transforms VerificationResult objects into structured output format suitable
|
||||||
|
for frontend consumption. Handles conversion of VerificationItem objects to
|
||||||
|
dictionary format and adds metadata for tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: ReadState containing storage and user configuration
|
||||||
|
messages_deal: VerificationResult containing verification outcomes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Formatted verification result with status and metadata
|
||||||
|
"""
|
||||||
storage_type = state.get('storage_type', '')
|
storage_type = state.get('storage_type', '')
|
||||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||||
data = state.get('data', '')
|
data = state.get('data', '')
|
||||||
|
|
||||||
# 将 VerificationItem 对象转换为字典列表
|
# Convert VerificationItem objects to dictionary list
|
||||||
verified_data = []
|
verified_data = []
|
||||||
if messages_deal.expansion_issue:
|
if messages_deal.expansion_issue:
|
||||||
for item in messages_deal.expansion_issue:
|
for item in messages_deal.expansion_issue:
|
||||||
@@ -89,7 +111,7 @@ async def Verify(state: ReadState):
|
|||||||
|
|
||||||
logger.info("Verify: 开始渲染模板")
|
logger.info("Verify: 开始渲染模板")
|
||||||
|
|
||||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
# Generate JSON schema to guide LLM output format
|
||||||
json_schema = VerificationResult.model_json_schema()
|
json_schema = VerificationResult.model_json_schema()
|
||||||
|
|
||||||
system_prompt = await verification_service.template_service.render_template(
|
system_prompt = await verification_service.template_service.render_template(
|
||||||
@@ -104,8 +126,8 @@ async def Verify(state: ReadState):
|
|||||||
# 使用优化的LLM服务,添加超时保护
|
# 使用优化的LLM服务,添加超时保护
|
||||||
logger.info("Verify: 开始调用 LLM")
|
logger.info("Verify: 开始调用 LLM")
|
||||||
try:
|
try:
|
||||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
|
||||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
|
||||||
|
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
structured = await asyncio.wait_for(
|
structured = await asyncio.wait_for(
|
||||||
@@ -122,7 +144,7 @@ async def Verify(state: ReadState):
|
|||||||
"reason": "验证失败或超时"
|
"reason": "验证失败或超时"
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
timeout=150.0 # 150秒超时
|
timeout=150.0 # 150 second timeout
|
||||||
)
|
)
|
||||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -33,7 +33,19 @@ from app.core.memory.agent.langgraph_graph.routing.routers import (
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def make_read_graph():
|
async def make_read_graph():
|
||||||
"""创建并返回 LangGraph 工作流"""
|
"""
|
||||||
|
Create and return a LangGraph workflow for memory reading operations
|
||||||
|
|
||||||
|
Builds a state graph workflow that handles memory retrieval, problem analysis,
|
||||||
|
verification, and summarization. The workflow includes nodes for content input,
|
||||||
|
problem splitting, retrieval, verification, and various summary operations.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StateGraph: Compiled LangGraph workflow for memory reading
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If workflow creation fails
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
# Build workflow graph
|
# Build workflow graph
|
||||||
workflow = StateGraph(ReadState)
|
workflow = StateGraph(ReadState)
|
||||||
@@ -48,7 +60,7 @@ async def make_read_graph():
|
|||||||
workflow.add_node("Summary", Summary)
|
workflow.add_node("Summary", Summary)
|
||||||
workflow.add_node("Summary_fails", Summary_fails)
|
workflow.add_node("Summary_fails", Summary_fails)
|
||||||
|
|
||||||
# 添加边
|
# Add edges to define workflow flow
|
||||||
workflow.add_edge(START, "content_input")
|
workflow.add_edge(START, "content_input")
|
||||||
workflow.add_conditional_edges("content_input", Split_continue)
|
workflow.add_conditional_edges("content_input", Split_continue)
|
||||||
workflow.add_edge("Input_Summary", END)
|
workflow.add_edge("Input_Summary", END)
|
||||||
@@ -63,7 +75,7 @@ async def make_read_graph():
|
|||||||
'''-----'''
|
'''-----'''
|
||||||
# workflow.add_edge("Retrieve", END)
|
# workflow.add_edge("Retrieve", END)
|
||||||
|
|
||||||
# 编译工作流
|
# Compile workflow
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
@@ -72,108 +84,3 @@ async def make_read_graph():
|
|||||||
raise
|
raise
|
||||||
finally:
|
finally:
|
||||||
print("工作流创建完成")
|
print("工作流创建完成")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数 - 运行工作流"""
|
|
||||||
message = "昨天有什么好看的电影"
|
|
||||||
end_user_id = '88a459f5_text09' # 组ID
|
|
||||||
storage_type = 'neo4j' # 存储类型
|
|
||||||
search_switch = '1' # 搜索开关
|
|
||||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
|
||||||
|
|
||||||
# 获取数据库会话
|
|
||||||
db_session = next(get_db())
|
|
||||||
config_service = MemoryConfigService(db_session)
|
|
||||||
memory_config = config_service.load_memory_config(
|
|
||||||
config_id=17, # 改为整数
|
|
||||||
service_name="MemoryAgentService"
|
|
||||||
)
|
|
||||||
import time
|
|
||||||
start = time.time()
|
|
||||||
try:
|
|
||||||
async with make_read_graph() as graph:
|
|
||||||
config = {"configurable": {"thread_id": end_user_id}}
|
|
||||||
# 初始状态 - 包含所有必要字段
|
|
||||||
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
|
|
||||||
"end_user_id": end_user_id
|
|
||||||
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
|
|
||||||
"memory_config": memory_config}
|
|
||||||
# 获取节点更新信息
|
|
||||||
_intermediate_outputs = []
|
|
||||||
summary = ''
|
|
||||||
|
|
||||||
async for update_event in graph.astream(
|
|
||||||
initial_state,
|
|
||||||
stream_mode="updates",
|
|
||||||
config=config
|
|
||||||
):
|
|
||||||
for node_name, node_data in update_event.items():
|
|
||||||
print(f"处理节点: {node_name}")
|
|
||||||
|
|
||||||
# 处理不同Summary节点的返回结构
|
|
||||||
if 'Summary' in node_name:
|
|
||||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
|
||||||
summary = node_data['InputSummary']['summary_result']
|
|
||||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
|
||||||
summary = node_data['RetrieveSummary']['summary_result']
|
|
||||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
|
||||||
summary = node_data['summary']['summary_result']
|
|
||||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
|
||||||
summary = node_data['SummaryFails']['summary_result']
|
|
||||||
|
|
||||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
|
||||||
if spit_data and spit_data != [] and spit_data != {}:
|
|
||||||
_intermediate_outputs.append(spit_data)
|
|
||||||
|
|
||||||
# Problem_Extension 节点
|
|
||||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
|
||||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
|
||||||
_intermediate_outputs.append(problem_extension)
|
|
||||||
|
|
||||||
# Retrieve 节点
|
|
||||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
|
||||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
|
||||||
_intermediate_outputs.extend(retrieve_node)
|
|
||||||
|
|
||||||
# Verify 节点
|
|
||||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
|
||||||
if verify_n and verify_n != [] and verify_n != {}:
|
|
||||||
_intermediate_outputs.append(verify_n)
|
|
||||||
|
|
||||||
# Summary 节点
|
|
||||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
|
||||||
if summary_n and summary_n != [] and summary_n != {}:
|
|
||||||
_intermediate_outputs.append(summary_n)
|
|
||||||
|
|
||||||
# # 过滤掉空值
|
|
||||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
|
||||||
#
|
|
||||||
# # 优化搜索结果
|
|
||||||
# print("=== 开始优化搜索结果 ===")
|
|
||||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
|
||||||
# result=reorder_output_results(optimized_outputs)
|
|
||||||
# # 保存优化后的结果到文件
|
|
||||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
|
||||||
# import json
|
|
||||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
|
||||||
#
|
|
||||||
print(f"=== 最终摘要 ===")
|
|
||||||
print(summary)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
traceback.print_exc()
|
|
||||||
finally:
|
|
||||||
db_session.close()
|
|
||||||
|
|
||||||
end = time.time()
|
|
||||||
print(100 * 'y')
|
|
||||||
print(f"总耗时: {end - start}s")
|
|
||||||
print(100 * 'y')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
|
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||||
|
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
counter = COUNTState(limit=3)
|
counter = COUNTState(limit=3)
|
||||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
|
||||||
|
|
||||||
|
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
|
|
||||||
@@ -25,6 +25,7 @@ def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summa
|
|||||||
return 'Input_Summary'
|
return 'Input_Summary'
|
||||||
return 'Split_The_Problem' # 默认情况
|
return 'Split_The_Problem' # 默认情况
|
||||||
|
|
||||||
|
|
||||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||||
"""
|
"""
|
||||||
Determine routing based on search_switch value.
|
Determine routing based on search_switch value.
|
||||||
@@ -43,8 +44,10 @@ def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
|||||||
elif search_switch == '1':
|
elif search_switch == '1':
|
||||||
return 'Retrieve_Summary'
|
return 'Retrieve_Summary'
|
||||||
return 'Retrieve_Summary' # Default based on business logic
|
return 'Retrieve_Summary' # Default based on business logic
|
||||||
|
|
||||||
|
|
||||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||||
status=state.get('verify', '')['status']
|
status = state.get('verify', '')['status']
|
||||||
# loop_count = counter.get_total()
|
# loop_count = counter.get_total()
|
||||||
if "success" in status:
|
if "success" in status:
|
||||||
# counter.reset()
|
# counter.reset()
|
||||||
@@ -53,7 +56,7 @@ def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "co
|
|||||||
# if loop_count < 2: # Maximum loop count is 3
|
# if loop_count < 2: # Maximum loop count is 3
|
||||||
# return "content_input"
|
# return "content_input"
|
||||||
# else:
|
# else:
|
||||||
# counter.reset()
|
# counter.reset()
|
||||||
return "Summary_fails"
|
return "Summary_fails"
|
||||||
else:
|
else:
|
||||||
# Add default return value to avoid returning None
|
# Add default return value to avoid returning None
|
||||||
|
|||||||
@@ -2,77 +2,104 @@ import json
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
from app.core.logging_config import get_agent_logger
|
from app.core.logging_config import get_agent_logger
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
|
||||||
|
|
||||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
|
||||||
from app.core.memory.agent.utils.redis_tool import count_store
|
from app.core.memory.agent.utils.redis_tool import count_store
|
||||||
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context, get_db
|
from app.db import get_db_context
|
||||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_konwledges_server import write_rag
|
from app.services.memory_konwledges_server import write_rag
|
||||||
from app.services.task_service import get_task_memory_write_result
|
from app.services.task_service import get_task_memory_write_result
|
||||||
from app.tasks import write_message_task
|
from app.tasks import write_message_task
|
||||||
from app.utils.config_utils import resolve_config_id
|
from app.utils.config_utils import resolve_config_id
|
||||||
|
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||||
|
|
||||||
|
|
||||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
"""
|
||||||
|
Write messages to RAG storage system
|
||||||
|
|
||||||
|
Combines user and AI messages into a single string format and stores them
|
||||||
|
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: User identifier for the conversation
|
||||||
|
user_message: User's input message content
|
||||||
|
ai_message: AI's response message content
|
||||||
|
user_rag_memory_id: RAG memory identifier for storage location
|
||||||
|
"""
|
||||||
|
# RAG mode: combine messages into string format (maintain original logic)
|
||||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
|
||||||
actual_config_id, long_term_messages=[]):
|
|
||||||
|
async def write(
|
||||||
|
storage_type,
|
||||||
|
end_user_id,
|
||||||
|
user_message,
|
||||||
|
ai_message,
|
||||||
|
user_rag_memory_id,
|
||||||
|
actual_end_user_id,
|
||||||
|
actual_config_id,
|
||||||
|
long_term_messages=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
写入记忆(支持结构化消息)
|
Write memory with structured message support
|
||||||
|
|
||||||
|
Handles memory writing operations for different storage types (Neo4j/RAG).
|
||||||
|
Supports both individual message pairs and batch long-term message processing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
storage_type: 存储类型 (neo4j/rag)
|
storage_type: Storage type identifier ("neo4j" or "rag")
|
||||||
end_user_id: 终端用户ID
|
end_user_id: Terminal user identifier
|
||||||
user_message: 用户消息内容
|
user_message: User message content
|
||||||
ai_message: AI 回复内容
|
ai_message: AI response content
|
||||||
user_rag_memory_id: RAG 记忆ID
|
user_rag_memory_id: RAG memory identifier
|
||||||
actual_end_user_id: 实际用户ID
|
actual_end_user_id: Actual user identifier for storage
|
||||||
actual_config_id: 配置ID
|
actual_config_id: Configuration identifier
|
||||||
|
long_term_messages: Optional list of structured messages for batch processing
|
||||||
|
|
||||||
逻辑说明:
|
Logic explanation:
|
||||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
|
||||||
- Neo4j 模式:使用结构化消息列表
|
- Neo4j mode: Uses structured message lists
|
||||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
|
||||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
|
||||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
3. Each message is converted to independent Chunk, preserving speaker field
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db = next(get_db())
|
if long_term_messages is None:
|
||||||
try:
|
long_term_messages = []
|
||||||
|
with get_db_context() as db:
|
||||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||||
# Neo4j 模式:使用结构化消息列表
|
# Neo4j mode: Use structured message lists
|
||||||
structured_messages = []
|
structured_messages = []
|
||||||
|
|
||||||
# 始终添加用户消息(如果不为空)
|
# Always add user message (if not empty)
|
||||||
if isinstance(user_message, str) and user_message.strip() != "":
|
if isinstance(user_message, str) and user_message.strip() != "":
|
||||||
structured_messages.append({"role": "user", "content": user_message})
|
structured_messages.append({"role": "user", "content": user_message})
|
||||||
|
|
||||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
# Only add assistant message when AI reply is not empty
|
||||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||||
|
|
||||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
# If long_term_messages provided, use it to replace structured_messages
|
||||||
if long_term_messages and isinstance(long_term_messages, list):
|
if long_term_messages and isinstance(long_term_messages, list):
|
||||||
structured_messages = long_term_messages
|
structured_messages = long_term_messages
|
||||||
elif long_term_messages and isinstance(long_term_messages, str):
|
elif long_term_messages and isinstance(long_term_messages, str):
|
||||||
# 如果是 JSON 字符串,先解析
|
# If it's a JSON string, parse it first
|
||||||
try:
|
try:
|
||||||
structured_messages = json.loads(long_term_messages)
|
structured_messages = json.loads(long_term_messages)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||||
|
|
||||||
# 如果没有消息,直接返回
|
# If no messages, return directly
|
||||||
if not structured_messages:
|
if not structured_messages:
|
||||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||||
return
|
return
|
||||||
@@ -80,29 +107,41 @@ async def write(storage_type, end_user_id, user_message, ai_message, user_rag_me
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||||
write_id = write_message_task.delay(
|
write_id = write_message_task.delay(
|
||||||
actual_end_user_id, # end_user_id: 用户ID
|
actual_end_user_id, # end_user_id: User ID
|
||||||
structured_messages, # message: JSON 字符串格式的消息列表
|
structured_messages, # message: JSON string format message list
|
||||||
str(actual_config_id), # config_id: 配置ID字符串
|
str(actual_config_id), # config_id: Configuration ID string
|
||||||
storage_type, # storage_type: "neo4j"
|
storage_type, # storage_type: "neo4j"
|
||||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
|
||||||
)
|
)
|
||||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||||
write_status = get_task_memory_write_result(str(write_id))
|
write_status = get_task_memory_write_result(str(write_id))
|
||||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
|
||||||
|
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
|
||||||
|
"""
|
||||||
|
Save long-term memory data to database
|
||||||
|
|
||||||
|
Handles the storage of long-term memory data based on different strategies
|
||||||
|
(chunk-based or aggregate-based) and manages the transition from short-term
|
||||||
|
to long-term memory storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
long_term_messages: Long-term message data to be saved
|
||||||
|
actual_config_id: Configuration identifier for memory settings
|
||||||
|
end_user_id: User identifier for memory association
|
||||||
|
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
|
||||||
|
scope: Scope/window size for memory processing
|
||||||
|
"""
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
repo = LongTermMemoryRepository(db_session)
|
repo = LongTermMemoryRepository(db_session)
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
result = write_store.get_session_by_userid(end_user_id)
|
result = write_store.get_session_by_userid(end_user_id)
|
||||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
data = await format_parsing(result, "dict")
|
data = await format_parsing(result, "dict")
|
||||||
chunk_data = data[:scope]
|
chunk_data = data[:scope]
|
||||||
if len(chunk_data)==scope:
|
if len(chunk_data) == scope:
|
||||||
repo.upsert(end_user_id, chunk_data)
|
repo.upsert(end_user_id, chunk_data)
|
||||||
logger.info(f'---------写入短长期-----------')
|
logger.info(f'---------写入短长期-----------')
|
||||||
else:
|
else:
|
||||||
@@ -112,18 +151,23 @@ async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,
|
|||||||
logger.info(f'写入短长期:')
|
logger.info(f'写入短长期:')
|
||||||
|
|
||||||
|
|
||||||
|
"""Window-based dialogue processing"""
|
||||||
|
|
||||||
'''根据窗口'''
|
|
||||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
|
||||||
'''
|
"""
|
||||||
根据窗口获取redis数据,写入neo4j:
|
Process dialogue based on window size and write to Neo4j
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
Manages conversation data based on a sliding window approach. When the window
|
||||||
memory_config: 内存配置对象
|
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
|
||||||
langchain_messages:原始数据LIST
|
|
||||||
scope:窗口大小
|
Args:
|
||||||
'''
|
end_user_id: Terminal user identifier
|
||||||
scope=scope
|
memory_config: Memory configuration object containing settings
|
||||||
|
langchain_messages: Original message data list
|
||||||
|
scope: Window size determining when to trigger long-term storage
|
||||||
|
"""
|
||||||
|
scope = scope
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||||
if is_end_user_id is not False:
|
if is_end_user_id is not False:
|
||||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||||
@@ -135,50 +179,72 @@ async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
|||||||
elif int(is_end_user_id) == int(scope):
|
elif int(is_end_user_id) == int(scope):
|
||||||
logger.info('写入长期记忆NEO4J')
|
logger.info('写入长期记忆NEO4J')
|
||||||
formatted_messages = (redis_messages)
|
formatted_messages = (redis_messages)
|
||||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
|
||||||
if hasattr(memory_config, 'config_id'):
|
if hasattr(memory_config, 'config_id'):
|
||||||
config_id = memory_config.config_id
|
config_id = memory_config.config_id
|
||||||
else:
|
else:
|
||||||
config_id = memory_config
|
config_id = memory_config
|
||||||
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
await write(
|
||||||
config_id, formatted_messages)
|
AgentMemory_Long_Term.STORAGE_NEO4J,
|
||||||
|
end_user_id,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
None,
|
||||||
|
end_user_id,
|
||||||
|
config_id,
|
||||||
|
formatted_messages
|
||||||
|
)
|
||||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
else:
|
else:
|
||||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||||
|
|
||||||
|
|
||||||
"""根据时间"""
|
"""Time-based memory processing"""
|
||||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
|
||||||
'''
|
|
||||||
根据时间获取redis数据,写入neo4j:
|
async def memory_long_term_storage(end_user_id, memory_config, time):
|
||||||
Args:
|
|
||||||
end_user_id: 终端用户ID
|
|
||||||
memory_config: 内存配置对象
|
|
||||||
'''
|
|
||||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
|
||||||
format_messages = (long_time_data)
|
|
||||||
messages=[]
|
|
||||||
memory_config=memory_config.config_id
|
|
||||||
for i in format_messages:
|
|
||||||
message=json.loads(i['Query'])
|
|
||||||
messages+= message
|
|
||||||
if format_messages!=[]:
|
|
||||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
|
||||||
memory_config, messages)
|
|
||||||
'''聚合判断'''
|
|
||||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
|
||||||
"""
|
"""
|
||||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
Process memory storage based on time intervals and write to Neo4j
|
||||||
|
|
||||||
|
Retrieves Redis data based on time intervals and writes it to Neo4j for
|
||||||
|
long-term storage. This function handles time-based memory consolidation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
end_user_id: 终端用户ID
|
end_user_id: Terminal user identifier
|
||||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
memory_config: Memory configuration object containing settings
|
||||||
memory_config: 内存配置对象
|
time: Time interval for data retrieval
|
||||||
"""
|
"""
|
||||||
|
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||||
|
format_messages = long_time_data
|
||||||
|
messages = []
|
||||||
|
memory_config = memory_config.config_id
|
||||||
|
for i in format_messages:
|
||||||
|
message = json.loads(i['Query'])
|
||||||
|
messages += message
|
||||||
|
if format_messages:
|
||||||
|
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||||
|
memory_config, messages)
|
||||||
|
|
||||||
|
|
||||||
|
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||||
|
"""
|
||||||
|
Aggregation judgment function: determine if input sentence and historical messages describe the same event
|
||||||
|
|
||||||
|
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
|
||||||
|
historical data or stored as separate events. This helps optimize memory storage and retrieval.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: Terminal user identifier
|
||||||
|
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||||
|
memory_config: Memory configuration object containing LLM settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Aggregation judgment result containing is_same_event flag and processed output
|
||||||
|
"""
|
||||||
|
history = None
|
||||||
try:
|
try:
|
||||||
# 1. 获取历史会话数据(使用新方法)
|
# 1. Get historical session data (using new method)
|
||||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||||
history = await format_parsing(result)
|
history = await format_parsing(result)
|
||||||
if not result:
|
if not result:
|
||||||
|
|||||||
@@ -2,41 +2,53 @@ import asyncio
|
|||||||
import json
|
import json
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
from app.core.memory.src.search import (
|
from app.core.memory.src.search import (
|
||||||
search_by_temporal,
|
search_by_temporal,
|
||||||
search_by_keyword_temporal,
|
search_by_keyword_temporal,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def extract_tool_message_content(response):
|
def extract_tool_message_content(response):
|
||||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
"""
|
||||||
|
Extract ToolMessage content and tool names from agent response
|
||||||
|
|
||||||
|
Parses agent response messages to extract tool execution results and metadata.
|
||||||
|
Handles JSON parsing and provides structured access to tool output data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Agent response dictionary containing messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
|
||||||
|
- tool_name: Name of the executed tool
|
||||||
|
- content: Parsed tool execution result (JSON or raw text)
|
||||||
|
"""
|
||||||
messages = response.get('messages', [])
|
messages = response.get('messages', [])
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||||
# 这是一个ToolMessage
|
# This is a ToolMessage
|
||||||
tool_content = message.content
|
tool_content = message.content
|
||||||
tool_name = None
|
tool_name = None
|
||||||
|
|
||||||
# 尝试获取工具名称
|
# Try to get tool name
|
||||||
if hasattr(message, 'name'):
|
if hasattr(message, 'name'):
|
||||||
tool_name = message.name
|
tool_name = message.name
|
||||||
elif hasattr(message, 'tool_name'):
|
elif hasattr(message, 'tool_name'):
|
||||||
tool_name = message.tool_name
|
tool_name = message.tool_name
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 解析JSON内容
|
# Parse JSON content
|
||||||
parsed_content = json.loads(tool_content)
|
parsed_content = json.loads(tool_content)
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': parsed_content
|
'content': parsed_content
|
||||||
}
|
}
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# 如果不是JSON格式,直接返回内容
|
# If not JSON format, return content directly
|
||||||
return {
|
return {
|
||||||
'tool_name': tool_name,
|
'tool_name': tool_name,
|
||||||
'content': tool_content
|
'content': tool_content
|
||||||
@@ -46,26 +58,49 @@ def extract_tool_message_content(response):
|
|||||||
|
|
||||||
|
|
||||||
class TimeRetrievalInput(BaseModel):
|
class TimeRetrievalInput(BaseModel):
|
||||||
"""时间检索工具的输入模式"""
|
"""
|
||||||
|
Input schema for time retrieval tool
|
||||||
|
|
||||||
|
Defines the expected input parameters for time-based retrieval operations.
|
||||||
|
Used for validation and documentation of tool parameters.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
context: User input query content for search
|
||||||
|
end_user_id: Group ID for filtering search results, defaults to test user
|
||||||
|
"""
|
||||||
context: str = Field(description="用户输入的查询内容")
|
context: str = Field(description="用户输入的查询内容")
|
||||||
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
end_user_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||||
|
|
||||||
|
|
||||||
def create_time_retrieval_tool(end_user_id: str):
|
def create_time_retrieval_tool(end_user_id: str):
|
||||||
"""
|
"""
|
||||||
创建一个带有特定end_user_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
|
||||||
|
|
||||||
|
Creates a specialized time-based retrieval tool that searches for statements within
|
||||||
|
specified time ranges. Includes field cleaning functionality to remove unnecessary
|
||||||
|
metadata from search results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: User identifier for scoping search results
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured TimeRetrievalWithGroupId tool function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_temporal_result_fields(data):
|
def clean_temporal_result_fields(data):
|
||||||
"""
|
"""
|
||||||
清理时间搜索结果中不需要的字段,并修改结构
|
Clean unnecessary fields from temporal search results and modify structure
|
||||||
|
|
||||||
|
Removes metadata fields that are not needed for end-user consumption and
|
||||||
|
restructures the response format for better usability.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要清理的数据
|
data: Data to be cleaned (dict, list, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的数据
|
Cleaned data with unnecessary fields removed
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# List of fields to filter out
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||||
'valid_at', 'invalid_at', 'statement_ids'
|
'valid_at', 'invalid_at', 'statement_ids'
|
||||||
@@ -75,9 +110,9 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
|
||||||
cleaned_value = clean_temporal_result_fields(value)
|
cleaned_value = clean_temporal_result_fields(value)
|
||||||
# 进一步将内部的 statements 改为 time_search
|
# Further change internal statements to time_search
|
||||||
if 'statements' in cleaned_value:
|
if 'statements' in cleaned_value:
|
||||||
cleaned['results'] = {
|
cleaned['results'] = {
|
||||||
'time_search': cleaned_value['statements']
|
'time_search': cleaned_value['statements']
|
||||||
@@ -93,24 +128,33 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
|
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
|
||||||
|
end_user_id_param: str = None, clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
|
||||||
显式接收参数:
|
|
||||||
- context: 查询上下文内容
|
Performs time-based search operations with automatic metadata filtering. Supports
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
flexible date range specification and provides clean, user-friendly output.
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
|
||||||
- end_user_id_param: 组ID(可选,用于覆盖默认组ID)
|
Explicit parameters:
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
- context: Query context content
|
||||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_user_id_param: Group ID (optional, overrides default group ID)
|
||||||
|
- clean_output: Whether to clean metadata fields from output
|
||||||
|
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results with temporal data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 使用传入的参数或默认值
|
# Use passed parameters or default values
|
||||||
actual_end_user_id = end_user_id_param or end_user_id
|
actual_end_user_id = end_user_id_param or end_user_id
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 基本时间搜索
|
# Basic time search
|
||||||
results = await search_by_temporal(
|
results = await search_by_temporal(
|
||||||
end_user_id=actual_end_user_id,
|
end_user_id=actual_end_user_id,
|
||||||
start_date=actual_start_date,
|
start_date=actual_start_date,
|
||||||
@@ -118,7 +162,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=10
|
limit=10
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -129,22 +173,32 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
return asyncio.run(_async_search())
|
return asyncio.run(_async_search())
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
|
||||||
|
clean_output: bool = True) -> str:
|
||||||
"""
|
"""
|
||||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
|
||||||
显式接收参数:
|
|
||||||
- context: 查询内容
|
Performs combined keyword and temporal search operations with automatic metadata
|
||||||
- days_back: 向前搜索的天数,默认7天
|
filtering. Provides more targeted search results by combining content relevance
|
||||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
with time-based filtering.
|
||||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
|
||||||
- clean_output: 是否清理输出中的元数据字段
|
Explicit parameters:
|
||||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
- context: Query content for keyword matching
|
||||||
|
- days_back: Number of days to search backwards, default 7 days
|
||||||
|
- start_date: Start time (optional, format: YYYY-MM-DD)
|
||||||
|
- end_date: End time (optional, format: YYYY-MM-DD)
|
||||||
|
- clean_output: Whether to clean metadata fields from output
|
||||||
|
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results combining keyword and temporal data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||||
|
|
||||||
# 关键词时间搜索
|
# Keyword time search
|
||||||
results = await search_by_keyword_temporal(
|
results = await search_by_keyword_temporal(
|
||||||
query_text=context,
|
query_text=context,
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
@@ -153,7 +207,7 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
limit=15
|
limit=15
|
||||||
)
|
)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_temporal_result_fields(results)
|
cleaned_results = clean_temporal_result_fields(results)
|
||||||
else:
|
else:
|
||||||
@@ -168,43 +222,52 @@ def create_time_retrieval_tool(end_user_id: str):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
|
||||||
|
|
||||||
|
Creates an advanced hybrid search tool that combines multiple search strategies
|
||||||
|
(keyword, vector, hybrid) with automatic result cleaning and formatting.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: Memory configuration object containing LLM and search settings
|
||||||
**search_params: 搜索参数,包含end_user_id, limit, include等
|
**search_params: Search parameters including end_user_id, limit, include, etc.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured HybridSearch tool function with async capabilities
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def clean_result_fields(data):
|
def clean_result_fields(data):
|
||||||
"""
|
"""
|
||||||
递归清理结果中不需要的字段
|
Recursively clean unnecessary fields from results
|
||||||
|
|
||||||
|
Removes metadata fields that are not needed for end-user consumption,
|
||||||
|
improving readability and reducing response size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
data: Data to be cleaned (can be dict, list, or other types)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
清理后的数据
|
Cleaned data with unnecessary fields removed
|
||||||
"""
|
"""
|
||||||
# 需要过滤的字段列表
|
# List of fields to filter out
|
||||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
|
||||||
fields_to_remove = {
|
fields_to_remove = {
|
||||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
|
||||||
}
|
}
|
||||||
|
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
# 对字典进行清理
|
# Clean dictionary
|
||||||
cleaned = {}
|
cleaned = {}
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in fields_to_remove:
|
if key not in fields_to_remove:
|
||||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
|
||||||
return cleaned
|
return cleaned
|
||||||
elif isinstance(data, list):
|
elif isinstance(data, list):
|
||||||
# 对列表中的每个元素进行清理
|
# Clean each element in list
|
||||||
return [clean_result_fields(item) for item in data]
|
return [clean_result_fields(item) for item in data]
|
||||||
else:
|
else:
|
||||||
# 其他类型直接返回
|
# Return other types directly
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -216,49 +279,55 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
rerank_alpha: float = 0.6,
|
rerank_alpha: float = 0.6,
|
||||||
use_forgetting_rerank: bool = False,
|
use_forgetting_rerank: bool = False,
|
||||||
use_llm_rerank: bool = False,
|
use_llm_rerank: bool = False,
|
||||||
clean_output: bool = True # 新增:是否清理输出字段
|
clean_output: bool = True # New: whether to clean output fields
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
|
||||||
|
|
||||||
|
Provides comprehensive search capabilities combining multiple search strategies
|
||||||
|
with intelligent result ranking and automatic metadata filtering for clean output.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: 查询内容
|
context: Query content for search
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: Result quantity limit
|
||||||
end_user_id: 组ID,用于过滤搜索结果
|
end_user_id: Group ID for filtering search results
|
||||||
rerank_alpha: 重排序权重参数
|
rerank_alpha: Reranking weight parameter for result scoring
|
||||||
use_forgetting_rerank: 是否使用遗忘重排序
|
use_forgetting_rerank: Whether to use forgetting-based reranking
|
||||||
use_llm_rerank: 是否使用LLM重排序
|
use_llm_rerank: Whether to use LLM-based reranking
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: Whether to clean metadata fields from output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted comprehensive search results
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 导入run_hybrid_search函数
|
# Import run_hybrid_search function
|
||||||
from app.core.memory.src.search import run_hybrid_search
|
from app.core.memory.src.search import run_hybrid_search
|
||||||
|
|
||||||
# 合并参数,优先使用传入的参数
|
# Merge parameters, prioritize passed parameters
|
||||||
final_params = {
|
final_params = {
|
||||||
"query_text": context,
|
"query_text": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
"end_user_id": end_user_id or search_params.get("end_user_id"),
|
||||||
"limit": limit or search_params.get("limit", 10),
|
"limit": limit or search_params.get("limit", 10),
|
||||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||||
"output_path": None, # 不保存到文件
|
"output_path": None, # Don't save to file
|
||||||
"memory_config": memory_config,
|
"memory_config": memory_config,
|
||||||
"rerank_alpha": rerank_alpha,
|
"rerank_alpha": rerank_alpha,
|
||||||
"use_forgetting_rerank": use_forgetting_rerank,
|
"use_forgetting_rerank": use_forgetting_rerank,
|
||||||
"use_llm_rerank": use_llm_rerank
|
"use_llm_rerank": use_llm_rerank
|
||||||
}
|
}
|
||||||
|
|
||||||
# 执行混合检索
|
# Execute hybrid retrieval
|
||||||
raw_results = await run_hybrid_search(**final_params)
|
raw_results = await run_hybrid_search(**final_params)
|
||||||
|
|
||||||
# 清理结果中不需要的字段
|
# Clean unnecessary fields from results
|
||||||
if clean_output:
|
if clean_output:
|
||||||
cleaned_results = clean_result_fields(raw_results)
|
cleaned_results = clean_result_fields(raw_results)
|
||||||
else:
|
else:
|
||||||
cleaned_results = raw_results
|
cleaned_results = raw_results
|
||||||
|
|
||||||
# 格式化返回结果
|
# Format return results
|
||||||
formatted_results = {
|
formatted_results = {
|
||||||
"search_query": context,
|
"search_query": context,
|
||||||
"search_type": search_type,
|
"search_type": search_type,
|
||||||
@@ -281,32 +350,46 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
|||||||
|
|
||||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||||
"""
|
"""
|
||||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
|
||||||
|
|
||||||
|
Creates a synchronous wrapper around the async hybrid search functionality,
|
||||||
|
making it compatible with synchronous tool execution environments.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_config: 内存配置对象
|
memory_config: Memory configuration object containing search settings
|
||||||
**search_params: 搜索参数
|
**search_params: Search parameters for configuration
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
function: Configured HybridSearchSync tool function
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def HybridSearchSync(
|
def HybridSearchSync(
|
||||||
context: str,
|
context: str,
|
||||||
search_type: str = "hybrid",
|
search_type: str = "hybrid",
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
end_user_id: str = None,
|
end_user_id: str = None,
|
||||||
clean_output: bool = True
|
clean_output: bool = True
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
|
||||||
|
|
||||||
|
Provides the same hybrid search capabilities as the async version but in a
|
||||||
|
synchronous execution context. Automatically handles async-to-sync conversion.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context: 查询内容
|
context: Query content for search
|
||||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
search_type: Search type ('keyword', 'embedding', 'hybrid')
|
||||||
limit: 结果数量限制
|
limit: Result quantity limit
|
||||||
end_user_id: 组ID,用于过滤搜索结果
|
end_user_id: Group ID for filtering search results
|
||||||
clean_output: 是否清理输出中的元数据字段
|
clean_output: Whether to clean metadata fields from output
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: JSON formatted search results
|
||||||
"""
|
"""
|
||||||
|
|
||||||
async def _async_search():
|
async def _async_search():
|
||||||
# 创建异步工具并执行
|
# Create async tool and execute
|
||||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||||
return await async_tool.ainvoke({
|
return await async_tool.ainvoke({
|
||||||
"context": context,
|
"context": context,
|
||||||
|
|||||||
@@ -1,20 +1,28 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage
|
from langchain_core.messages import HumanMessage, AIMessage
|
||||||
async def format_parsing(messages: list,type:str='string'):
|
|
||||||
|
|
||||||
|
async def format_parsing(messages: list, type: str = 'string'):
|
||||||
"""
|
"""
|
||||||
格式化解析消息列表
|
Format and parse message lists into different output types
|
||||||
|
|
||||||
|
Processes message lists from storage and converts them into either string format
|
||||||
|
or dictionary format based on the specified type parameter. Handles JSON parsing
|
||||||
|
and role-based message organization.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: List of message objects from storage containing message data
|
||||||
type: 返回类型 ('string' 或 'dict')
|
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
格式化后的消息列表
|
list: Formatted message list in the specified format
|
||||||
|
- 'string': List of formatted text messages with role prefixes
|
||||||
|
- 'dict': List of dictionaries mapping user messages to AI responses
|
||||||
"""
|
"""
|
||||||
result = []
|
result = []
|
||||||
user=[]
|
user = []
|
||||||
ai=[]
|
ai = []
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
hstory_messages = message['messages']
|
hstory_messages = message['messages']
|
||||||
@@ -24,25 +32,38 @@ async def format_parsing(messages: list,type:str='string'):
|
|||||||
role = content['role']
|
role = content['role']
|
||||||
content = content['content']
|
content = content['content']
|
||||||
if type == "string":
|
if type == "string":
|
||||||
if role == 'human' or role=="user":
|
if role == 'human' or role == "user":
|
||||||
content = '用户:' + content
|
content = '用户:' + content
|
||||||
else:
|
else:
|
||||||
content = 'AI:' + content
|
content = 'AI:' + content
|
||||||
result.append(content)
|
result.append(content)
|
||||||
if type == "dict" :
|
if type == "dict":
|
||||||
if role == 'human' or role=="user":
|
if role == 'human' or role == "user":
|
||||||
user.append( content)
|
user.append(content)
|
||||||
else:
|
else:
|
||||||
ai.append(content)
|
ai.append(content)
|
||||||
if type == "dict":
|
if type == "dict":
|
||||||
for key,values in zip(user,ai):
|
for key, values in zip(user, ai):
|
||||||
result.append({key:values})
|
result.append({key: values})
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def messages_parse(messages: list | dict):
|
async def messages_parse(messages: list | dict):
|
||||||
user=[]
|
"""
|
||||||
ai=[]
|
Parse messages from storage format into user-AI conversation pairs
|
||||||
database=[]
|
|
||||||
|
Extracts and organizes conversation data from stored message format,
|
||||||
|
separating user and AI messages and pairing them for database storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List or dictionary containing stored message data with Query fields
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of dictionaries containing user-AI message pairs for database storage
|
||||||
|
"""
|
||||||
|
user = []
|
||||||
|
ai = []
|
||||||
|
database = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
Query = message['Query']
|
Query = message['Query']
|
||||||
Query = json.loads(Query)
|
Query = json.loads(Query)
|
||||||
@@ -54,10 +75,23 @@ async def messages_parse(messages: list | dict):
|
|||||||
ai.append(data['content'])
|
ai.append(data['content'])
|
||||||
for key, values in zip(user, ai):
|
for key, values in zip(user, ai):
|
||||||
database.append({key, values})
|
database.append({key, values})
|
||||||
return database
|
return database
|
||||||
|
|
||||||
|
|
||||||
async def agent_chat_messages(user_content,ai_content):
|
async def agent_chat_messages(user_content, ai_content):
|
||||||
|
"""
|
||||||
|
Create structured chat message format for agent conversations
|
||||||
|
|
||||||
|
Formats user and AI content into a standardized message structure suitable
|
||||||
|
for agent processing and storage. Creates role-based message objects.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_content: User's message content string
|
||||||
|
ai_content: AI's response content string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List of structured message dictionaries with role and content fields
|
||||||
|
"""
|
||||||
messages = [
|
messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
|||||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
|
||||||
|
|
||||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||||
logger = get_agent_logger(__name__)
|
logger = get_agent_logger(__name__)
|
||||||
|
|
||||||
@@ -42,10 +41,26 @@ async def make_write_graph():
|
|||||||
|
|
||||||
yield graph
|
yield graph
|
||||||
|
|
||||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
|
||||||
|
end_user_id: str = '', scope: int = 6):
|
||||||
|
"""
|
||||||
|
Handle long-term memory storage with different strategies
|
||||||
|
|
||||||
|
Supports multiple storage strategies including chunk-based, time-based,
|
||||||
|
and aggregate judgment approaches for long-term memory persistence.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
|
||||||
|
langchain_messages: List of messages to store
|
||||||
|
memory_config: Memory configuration identifier
|
||||||
|
end_user_id: User group identifier
|
||||||
|
scope: Scope parameter for chunk-based storage (default: 6)
|
||||||
|
"""
|
||||||
|
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
|
||||||
|
aggregate_judgment
|
||||||
from app.core.memory.agent.utils.redis_tool import write_store
|
from app.core.memory.agent.utils.redis_tool import write_store
|
||||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
write_store.save_session_write(end_user_id, langchain_messages)
|
||||||
# 获取数据库会话
|
# 获取数据库会话
|
||||||
with get_db_context() as db_session:
|
with get_db_context() as db_session:
|
||||||
config_service = MemoryConfigService(db_session)
|
config_service = MemoryConfigService(db_session)
|
||||||
@@ -53,26 +68,39 @@ async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[
|
|||||||
config_id=memory_config, # 改为整数
|
config_id=memory_config, # 改为整数
|
||||||
service_name="MemoryAgentService"
|
service_name="MemoryAgentService"
|
||||||
)
|
)
|
||||||
if long_term_type=='chunk':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
|
||||||
'''方案一:对话窗口6轮对话'''
|
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
|
||||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
|
||||||
if long_term_type=='time':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
|
||||||
"""时间"""
|
"""Time-based strategy"""
|
||||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
|
||||||
if long_term_type=='aggregate':
|
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||||
"""方案三:聚合判断"""
|
"""Strategy 3: Aggregate judgment"""
|
||||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
|
||||||
|
"""
|
||||||
|
Write long-term memory with different storage types
|
||||||
|
|
||||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
Handles both RAG-based storage and traditional memory storage approaches.
|
||||||
|
For traditional storage, uses chunk-based strategy with paired user-AI messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
storage_type: Type of storage (RAG or traditional)
|
||||||
|
end_user_id: User group identifier
|
||||||
|
message_chat: User message content
|
||||||
|
aimessages: AI response messages
|
||||||
|
user_rag_memory_id: RAG memory identifier
|
||||||
|
actual_config_id: Actual configuration ID
|
||||||
|
"""
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||||
else:
|
else:
|
||||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
|
||||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ from langgraph.graph import add_messages
|
|||||||
|
|
||||||
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
|
||||||
|
|
||||||
|
|
||||||
class WriteState(TypedDict):
|
class WriteState(TypedDict):
|
||||||
'''
|
"""
|
||||||
Langgrapg Writing TypedDict
|
Langgrapg Writing TypedDict
|
||||||
'''
|
"""
|
||||||
messages: Annotated[list[AnyMessage], add_messages]
|
messages: Annotated[list[AnyMessage], add_messages]
|
||||||
end_user_id: str
|
end_user_id: str
|
||||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||||
@@ -20,6 +21,7 @@ class WriteState(TypedDict):
|
|||||||
data: str
|
data: str
|
||||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||||
|
|
||||||
|
|
||||||
class ReadState(TypedDict):
|
class ReadState(TypedDict):
|
||||||
"""
|
"""
|
||||||
LangGraph 工作流状态定义
|
LangGraph 工作流状态定义
|
||||||
@@ -43,18 +45,20 @@ class ReadState(TypedDict):
|
|||||||
config_id: str
|
config_id: str
|
||||||
data: str # 新增字段用于传递内容
|
data: str # 新增字段用于传递内容
|
||||||
spit_data: dict # 新增字段用于传递问题分解结果
|
spit_data: dict # 新增字段用于传递问题分解结果
|
||||||
problem_extension:dict
|
problem_extension: dict
|
||||||
storage_type: str
|
storage_type: str
|
||||||
user_rag_memory_id: str
|
user_rag_memory_id: str
|
||||||
llm_id: str
|
llm_id: str
|
||||||
embedding_id: str
|
embedding_id: str
|
||||||
memory_config: object # 新增字段用于传递内存配置对象
|
memory_config: object # 新增字段用于传递内存配置对象
|
||||||
retrieve:dict
|
retrieve: dict
|
||||||
RetrieveSummary: dict
|
RetrieveSummary: dict
|
||||||
InputSummary: dict
|
InputSummary: dict
|
||||||
verify: dict
|
verify: dict
|
||||||
SummaryFails: dict
|
SummaryFails: dict
|
||||||
summary: dict
|
summary: dict
|
||||||
|
|
||||||
|
|
||||||
class COUNTState:
|
class COUNTState:
|
||||||
"""
|
"""
|
||||||
工作流对话检索内容计数器
|
工作流对话检索内容计数器
|
||||||
@@ -99,6 +103,7 @@ class COUNTState:
|
|||||||
self.total = 0
|
self.total = 0
|
||||||
print("[COUNTState] 已重置为 0")
|
print("[COUNTState] 已重置为 0")
|
||||||
|
|
||||||
|
|
||||||
def deduplicate_entries(entries):
|
def deduplicate_entries(entries):
|
||||||
seen = set()
|
seen = set()
|
||||||
deduped = []
|
deduped = []
|
||||||
@@ -109,6 +114,7 @@ def deduplicate_entries(entries):
|
|||||||
deduped.append(entry)
|
deduped.append(entry)
|
||||||
return deduped
|
return deduped
|
||||||
|
|
||||||
|
|
||||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||||
grouped = defaultdict(list)
|
grouped = defaultdict(list)
|
||||||
for item in data:
|
for item in data:
|
||||||
|
|||||||
@@ -165,7 +165,9 @@ async def write(
|
|||||||
statement_chunk_edges=all_statement_chunk_edges,
|
statement_chunk_edges=all_statement_chunk_edges,
|
||||||
statement_entity_edges=all_statement_entity_edges,
|
statement_entity_edges=all_statement_entity_edges,
|
||||||
entity_edges=all_entity_entity_edges,
|
entity_edges=all_entity_entity_edges,
|
||||||
connector=neo4j_connector
|
connector=neo4j_connector,
|
||||||
|
config_id=config_id,
|
||||||
|
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
|
||||||
)
|
)
|
||||||
if success:
|
if success:
|
||||||
logger.info("Successfully saved all data to Neo4j")
|
logger.info("Successfully saved all data to Neo4j")
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from app.core.memory.storage_services.clustering_engine.label_propagation import LabelPropagationEngine
|
||||||
|
|
||||||
|
__all__ = ["LabelPropagationEngine"]
|
||||||
@@ -0,0 +1,508 @@
|
|||||||
|
"""标签传播聚类引擎
|
||||||
|
|
||||||
|
基于 ZEP 论文的动态标签传播算法,对 Neo4j 中的 ExtractedEntity 节点进行社区聚类。
|
||||||
|
|
||||||
|
支持两种模式:
|
||||||
|
- 全量初始化(full_clustering):首次运行,对所有实体做完整 LPA 迭代
|
||||||
|
- 增量更新(incremental_update):新实体到达时,只处理新实体及其邻居
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from math import sqrt
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from app.repositories.neo4j.community_repository import CommunityRepository
|
||||||
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 全量迭代最大轮数,防止不收敛
|
||||||
|
MAX_ITERATIONS = 10
|
||||||
|
# 社区摘要核心实体数量
|
||||||
|
CORE_ENTITY_LIMIT = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _cosine_similarity(v1: Optional[List[float]], v2: Optional[List[float]]) -> float:
|
||||||
|
"""计算两个向量的余弦相似度,任一为空则返回 0。"""
|
||||||
|
if not v1 or not v2 or len(v1) != len(v2):
|
||||||
|
return 0.0
|
||||||
|
dot = sum(a * b for a, b in zip(v1, v2))
|
||||||
|
norm1 = sqrt(sum(a * a for a in v1))
|
||||||
|
norm2 = sqrt(sum(b * b for b in v2))
|
||||||
|
if norm1 == 0 or norm2 == 0:
|
||||||
|
return 0.0
|
||||||
|
return dot / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
|
def _weighted_vote(
|
||||||
|
neighbors: List[Dict],
|
||||||
|
self_embedding: Optional[List[float]],
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
加权多数投票,选出得票最高的社区。
|
||||||
|
|
||||||
|
权重 = 语义相似度(name_embedding 余弦)* activation_value 加成
|
||||||
|
没有 community_id 的邻居不参与投票。
|
||||||
|
"""
|
||||||
|
votes: Dict[str, float] = {}
|
||||||
|
for nb in neighbors:
|
||||||
|
cid = nb.get("community_id")
|
||||||
|
if not cid:
|
||||||
|
continue
|
||||||
|
sem = _cosine_similarity(self_embedding, nb.get("name_embedding"))
|
||||||
|
act = nb.get("activation_value") or 0.5
|
||||||
|
# 语义相似度权重 0.6,激活值权重 0.4
|
||||||
|
weight = 0.6 * sem + 0.4 * act
|
||||||
|
votes[cid] = votes.get(cid, 0.0) + weight
|
||||||
|
|
||||||
|
if not votes:
|
||||||
|
return None
|
||||||
|
return max(votes, key=votes.__getitem__)
|
||||||
|
|
||||||
|
|
||||||
|
class LabelPropagationEngine:
|
||||||
|
"""标签传播聚类引擎"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
connector: Neo4jConnector,
|
||||||
|
config_id: Optional[str] = None,
|
||||||
|
llm_model_id: Optional[str] = None,
|
||||||
|
embedding_model_id: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self.connector = connector
|
||||||
|
self.repo = CommunityRepository(connector)
|
||||||
|
self.config_id = config_id
|
||||||
|
self.llm_model_id = llm_model_id
|
||||||
|
self.embedding_model_id = embedding_model_id
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# 公开接口
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def run(
|
||||||
|
self,
|
||||||
|
end_user_id: str,
|
||||||
|
new_entity_ids: Optional[List[str]] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
统一入口:自动判断全量还是增量。
|
||||||
|
|
||||||
|
- 若该用户尚无 Community 节点 → 全量初始化
|
||||||
|
- 否则 → 增量更新(仅处理 new_entity_ids)
|
||||||
|
"""
|
||||||
|
has_communities = await self.repo.has_communities(end_user_id)
|
||||||
|
if not has_communities:
|
||||||
|
logger.info(f"[Clustering] 用户 {end_user_id} 首次聚类,执行全量初始化")
|
||||||
|
await self.full_clustering(end_user_id)
|
||||||
|
else:
|
||||||
|
if new_entity_ids:
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 增量更新,新实体数: {len(new_entity_ids)}"
|
||||||
|
)
|
||||||
|
await self.incremental_update(new_entity_ids, end_user_id)
|
||||||
|
|
||||||
|
async def full_clustering(self, end_user_id: str) -> None:
|
||||||
|
"""
|
||||||
|
全量标签传播初始化。
|
||||||
|
|
||||||
|
1. 拉取所有实体,初始化每个实体为独立社区
|
||||||
|
2. 迭代:每轮对所有实体做邻居投票,更新社区标签
|
||||||
|
3. 直到标签不再变化或达到 MAX_ITERATIONS
|
||||||
|
4. 将最终标签写入 Neo4j
|
||||||
|
"""
|
||||||
|
entities = await self.repo.get_all_entities(end_user_id)
|
||||||
|
if not entities:
|
||||||
|
logger.info(f"[Clustering] 用户 {end_user_id} 无实体,跳过全量聚类")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化:每个实体持有自己 id 作为社区标签
|
||||||
|
labels: Dict[str, str] = {e["id"]: e["id"] for e in entities}
|
||||||
|
embeddings: Dict[str, Optional[List[float]]] = {
|
||||||
|
e["id"]: e.get("name_embedding") for e in entities
|
||||||
|
}
|
||||||
|
|
||||||
|
# 预加载所有实体的邻居,避免迭代内 O(iterations * |E|) 次 Neo4j 往返
|
||||||
|
logger.info(f"[Clustering] 预加载 {len(entities)} 个实体的邻居图...")
|
||||||
|
neighbors_cache: Dict[str, List[Dict]] = await self.repo.get_all_entity_neighbors_batch(end_user_id)
|
||||||
|
logger.info(f"[Clustering] 邻居预加载完成,覆盖实体数: {len(neighbors_cache)}")
|
||||||
|
|
||||||
|
for iteration in range(MAX_ITERATIONS):
|
||||||
|
changed = 0
|
||||||
|
# 随机顺序(Python dict 在 3.7+ 保持插入顺序,这里直接遍历)
|
||||||
|
for entity in entities:
|
||||||
|
eid = entity["id"]
|
||||||
|
# 直接从缓存取邻居,不再发起 Neo4j 查询
|
||||||
|
neighbors = neighbors_cache.get(eid, [])
|
||||||
|
|
||||||
|
# 将邻居的当前内存标签注入(覆盖 Neo4j 中的旧值)
|
||||||
|
enriched = []
|
||||||
|
for nb in neighbors:
|
||||||
|
nb_copy = dict(nb)
|
||||||
|
nb_copy["community_id"] = labels.get(nb["id"], nb.get("community_id"))
|
||||||
|
enriched.append(nb_copy)
|
||||||
|
|
||||||
|
new_label = _weighted_vote(enriched, embeddings.get(eid))
|
||||||
|
if new_label and new_label != labels[eid]:
|
||||||
|
labels[eid] = new_label
|
||||||
|
changed += 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量迭代 {iteration + 1}/{MAX_ITERATIONS},"
|
||||||
|
f"标签变化数: {changed}"
|
||||||
|
)
|
||||||
|
if changed == 0:
|
||||||
|
logger.info("[Clustering] 标签已收敛,提前结束迭代")
|
||||||
|
break
|
||||||
|
|
||||||
|
# 将最终标签写入 Neo4j
|
||||||
|
await self._flush_labels(labels, end_user_id)
|
||||||
|
pre_merge_count = len(set(labels.values()))
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量迭代完成,共 {pre_merge_count} 个社区,"
|
||||||
|
f"{len(labels)} 个实体,开始后处理合并"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全量初始化后做一轮社区合并(基于 name_embedding 余弦相似度)
|
||||||
|
all_community_ids = list(set(labels.values()))
|
||||||
|
await self._evaluate_merge(all_community_ids, end_user_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 全量聚类完成,合并前 {pre_merge_count} 个社区,"
|
||||||
|
f"{len(labels)} 个实体"
|
||||||
|
)
|
||||||
|
# 为所有社区生成元数据
|
||||||
|
# 注意:_evaluate_merge 后部分社区已被合并消解,需重新从 Neo4j 查询实际存活的社区
|
||||||
|
# 不能复用 labels.values(),那里包含已被 dissolve 的旧社区 ID
|
||||||
|
surviving_communities = await self.repo.get_all_entities(end_user_id)
|
||||||
|
surviving_community_ids = list({
|
||||||
|
e.get("community_id") for e in surviving_communities
|
||||||
|
if e.get("community_id")
|
||||||
|
})
|
||||||
|
logger.info(f"[Clustering] 合并后实际存活社区数: {len(surviving_community_ids)}")
|
||||||
|
for cid in surviving_community_ids:
|
||||||
|
await self._generate_community_metadata(cid, end_user_id)
|
||||||
|
|
||||||
|
async def incremental_update(
|
||||||
|
self, new_entity_ids: List[str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
增量更新:只处理新实体及其邻居,不重跑全图。
|
||||||
|
|
||||||
|
1. 对每个新实体查询邻居
|
||||||
|
2. 加权多数投票决定社区归属
|
||||||
|
3. 若邻居无社区 → 创建新社区
|
||||||
|
4. 若邻居分属多个社区 → 评估是否合并
|
||||||
|
"""
|
||||||
|
for entity_id in new_entity_ids:
|
||||||
|
await self._process_single_entity(entity_id, end_user_id)
|
||||||
|
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
# 内部方法
|
||||||
|
# ──────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _process_single_entity(
|
||||||
|
self, entity_id: str, end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""处理单个新实体的社区分配。"""
|
||||||
|
neighbors = await self.repo.get_entity_neighbors(entity_id, end_user_id)
|
||||||
|
|
||||||
|
# 查询自身 embedding(从邻居查询结果中无法获取,需单独查)
|
||||||
|
self_embedding = await self._get_entity_embedding(entity_id, end_user_id)
|
||||||
|
|
||||||
|
if not neighbors:
|
||||||
|
# 孤立实体:创建单成员社区
|
||||||
|
new_cid = self._new_community_id()
|
||||||
|
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
|
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 统计邻居社区分布
|
||||||
|
community_ids_in_neighbors = set(
|
||||||
|
nb["community_id"] for nb in neighbors if nb.get("community_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
target_cid = _weighted_vote(neighbors, self_embedding)
|
||||||
|
|
||||||
|
if target_cid is None:
|
||||||
|
# 邻居都没有社区,连同新实体一起创建新社区
|
||||||
|
new_cid = self._new_community_id()
|
||||||
|
await self.repo.upsert_community(new_cid, end_user_id)
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
|
||||||
|
for nb in neighbors:
|
||||||
|
await self.repo.assign_entity_to_community(
|
||||||
|
nb["id"], new_cid, end_user_id
|
||||||
|
)
|
||||||
|
await self.repo.refresh_member_count(new_cid, end_user_id)
|
||||||
|
logger.debug(
|
||||||
|
f"[Clustering] 新实体 {entity_id} 与 {len(neighbors)} 个无社区邻居 → 新社区 {new_cid}"
|
||||||
|
)
|
||||||
|
await self._generate_community_metadata(new_cid, end_user_id)
|
||||||
|
else:
|
||||||
|
# 加入得票最多的社区
|
||||||
|
await self.repo.assign_entity_to_community(entity_id, target_cid, end_user_id)
|
||||||
|
await self.repo.refresh_member_count(target_cid, end_user_id)
|
||||||
|
logger.debug(f"[Clustering] 新实体 {entity_id} → 社区 {target_cid}")
|
||||||
|
|
||||||
|
# 若邻居分属多个社区,评估合并
|
||||||
|
if len(community_ids_in_neighbors) > 1:
|
||||||
|
await self._evaluate_merge(
|
||||||
|
list(community_ids_in_neighbors), end_user_id
|
||||||
|
)
|
||||||
|
await self._generate_community_metadata(target_cid, end_user_id)
|
||||||
|
|
||||||
|
async def _evaluate_merge(
|
||||||
|
self, community_ids: List[str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
评估多个社区是否应合并。
|
||||||
|
|
||||||
|
策略:计算各社区成员 embedding 的平均向量,若两两余弦相似度 > 0.75 则合并。
|
||||||
|
合并时保留成员数最多的社区,其余成员迁移过来。
|
||||||
|
|
||||||
|
全量场景(社区数 > 20)使用批量查询,避免 N 次数据库往返。
|
||||||
|
"""
|
||||||
|
MERGE_THRESHOLD = 0.85
|
||||||
|
BATCH_THRESHOLD = 20 # 超过此数量走批量查询
|
||||||
|
|
||||||
|
community_embeddings: Dict[str, Optional[List[float]]] = {}
|
||||||
|
community_sizes: Dict[str, int] = {}
|
||||||
|
|
||||||
|
if len(community_ids) > BATCH_THRESHOLD:
|
||||||
|
# 批量查询:一次拉取所有社区成员
|
||||||
|
all_members = await self.repo.get_all_community_members_batch(
|
||||||
|
community_ids, end_user_id
|
||||||
|
)
|
||||||
|
for cid in community_ids:
|
||||||
|
members = all_members.get(cid, [])
|
||||||
|
community_sizes[cid] = len(members)
|
||||||
|
valid_embeddings = [
|
||||||
|
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||||
|
]
|
||||||
|
if valid_embeddings:
|
||||||
|
dim = len(valid_embeddings[0])
|
||||||
|
community_embeddings[cid] = [
|
||||||
|
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
community_embeddings[cid] = None
|
||||||
|
else:
|
||||||
|
# 增量场景:逐个查询
|
||||||
|
for cid in community_ids:
|
||||||
|
members = await self.repo.get_community_members(cid, end_user_id)
|
||||||
|
community_sizes[cid] = len(members)
|
||||||
|
valid_embeddings = [
|
||||||
|
m["name_embedding"] for m in members if m.get("name_embedding")
|
||||||
|
]
|
||||||
|
if valid_embeddings:
|
||||||
|
dim = len(valid_embeddings[0])
|
||||||
|
community_embeddings[cid] = [
|
||||||
|
sum(e[i] for e in valid_embeddings) / len(valid_embeddings)
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
community_embeddings[cid] = None
|
||||||
|
|
||||||
|
# 找出应合并的社区对
|
||||||
|
to_merge: List[tuple] = []
|
||||||
|
cids = list(community_ids)
|
||||||
|
for i in range(len(cids)):
|
||||||
|
for j in range(i + 1, len(cids)):
|
||||||
|
sim = _cosine_similarity(
|
||||||
|
community_embeddings[cids[i]],
|
||||||
|
community_embeddings[cids[j]],
|
||||||
|
)
|
||||||
|
if sim > MERGE_THRESHOLD:
|
||||||
|
to_merge.append((cids[i], cids[j]))
|
||||||
|
|
||||||
|
logger.info(f"[Clustering] 发现 {len(to_merge)} 对可合并社区")
|
||||||
|
|
||||||
|
# 执行合并:逐对处理,每次合并后重新计算合并社区的平均向量
|
||||||
|
# 避免 union-find 链式传递导致语义不相关的社区被间接合并
|
||||||
|
# (A≈B、B≈C 不代表 A≈C,不能因传递性把 A/B/C 全部合并)
|
||||||
|
merged_into: Dict[str, str] = {} # dissolve → keep 的最终映射
|
||||||
|
|
||||||
|
def get_root(x: str) -> str:
|
||||||
|
"""路径压缩,找到 x 当前所属的根社区。"""
|
||||||
|
while x in merged_into:
|
||||||
|
merged_into[x] = merged_into.get(merged_into[x], merged_into[x])
|
||||||
|
x = merged_into[x]
|
||||||
|
return x
|
||||||
|
|
||||||
|
for c1, c2 in to_merge:
|
||||||
|
root1, root2 = get_root(c1), get_root(c2)
|
||||||
|
if root1 == root2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 用合并后的最新平均向量重新验证相似度
|
||||||
|
# 防止链式传递:A≈B 合并后 B 的向量已更新,C 必须和新 B 相似才能合并
|
||||||
|
current_sim = _cosine_similarity(
|
||||||
|
community_embeddings.get(root1),
|
||||||
|
community_embeddings.get(root2),
|
||||||
|
)
|
||||||
|
if current_sim <= MERGE_THRESHOLD:
|
||||||
|
# 合并后向量已漂移,不再满足阈值,跳过
|
||||||
|
logger.debug(
|
||||||
|
f"[Clustering] 跳过合并 {root1} ↔ {root2},"
|
||||||
|
f"当前相似度 {current_sim:.3f} ≤ {MERGE_THRESHOLD}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
keep = root1 if community_sizes.get(root1, 0) >= community_sizes.get(root2, 0) else root2
|
||||||
|
dissolve = root2 if keep == root1 else root1
|
||||||
|
merged_into[dissolve] = keep
|
||||||
|
|
||||||
|
members = await self.repo.get_community_members(dissolve, end_user_id)
|
||||||
|
for m in members:
|
||||||
|
await self.repo.assign_entity_to_community(m["id"], keep, end_user_id)
|
||||||
|
|
||||||
|
# 合并后重新计算 keep 的平均向量(加权平均)
|
||||||
|
keep_emb = community_embeddings.get(keep)
|
||||||
|
dissolve_emb = community_embeddings.get(dissolve)
|
||||||
|
keep_size = community_sizes.get(keep, 0)
|
||||||
|
dissolve_size = community_sizes.get(dissolve, 0)
|
||||||
|
total_size = keep_size + dissolve_size
|
||||||
|
if keep_emb and dissolve_emb and total_size > 0:
|
||||||
|
dim = len(keep_emb)
|
||||||
|
community_embeddings[keep] = [
|
||||||
|
(keep_emb[i] * keep_size + dissolve_emb[i] * dissolve_size) / total_size
|
||||||
|
for i in range(dim)
|
||||||
|
]
|
||||||
|
community_embeddings[dissolve] = None
|
||||||
|
|
||||||
|
community_sizes[keep] = total_size
|
||||||
|
community_sizes[dissolve] = 0
|
||||||
|
await self.repo.refresh_member_count(keep, end_user_id)
|
||||||
|
logger.info(
|
||||||
|
f"[Clustering] 社区合并: {dissolve} → {keep},"
|
||||||
|
f"相似度={current_sim:.3f},迁移 {len(members)} 个成员"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _flush_labels(
|
||||||
|
self, labels: Dict[str, str], end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""将内存中的标签批量写入 Neo4j。"""
|
||||||
|
# 先创建所有唯一社区节点
|
||||||
|
unique_communities = set(labels.values())
|
||||||
|
for cid in unique_communities:
|
||||||
|
await self.repo.upsert_community(cid, end_user_id)
|
||||||
|
|
||||||
|
# 再批量分配实体
|
||||||
|
for entity_id, community_id in labels.items():
|
||||||
|
await self.repo.assign_entity_to_community(
|
||||||
|
entity_id, community_id, end_user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# 刷新成员数
|
||||||
|
for cid in unique_communities:
|
||||||
|
await self.repo.refresh_member_count(cid, end_user_id)
|
||||||
|
|
||||||
|
async def _get_entity_embedding(
|
||||||
|
self, entity_id: str, end_user_id: str
|
||||||
|
) -> Optional[List[float]]:
|
||||||
|
"""查询单个实体的 name_embedding。"""
|
||||||
|
try:
|
||||||
|
result = await self.connector.execute_query(
|
||||||
|
"MATCH (e:ExtractedEntity {id: $eid, end_user_id: $uid}) "
|
||||||
|
"RETURN e.name_embedding AS name_embedding",
|
||||||
|
eid=entity_id,
|
||||||
|
uid=end_user_id,
|
||||||
|
)
|
||||||
|
return result[0]["name_embedding"] if result else None
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _generate_community_metadata(
|
||||||
|
self, community_id: str, end_user_id: str
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
为社区生成并写入元数据:名称、摘要、核心实体。
|
||||||
|
|
||||||
|
- core_entities:按 activation_value 排序取 top-N 实体名称列表(无需 LLM)
|
||||||
|
- name / summary:若有 llm_model_id 则调用 LLM 生成,否则用实体名称拼接兜底
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 先检查属性是否已完整,完整则跳过,避免重复生成
|
||||||
|
check_embedding = bool(self.embedding_model_id)
|
||||||
|
if await self.repo.is_community_complete(community_id, end_user_id, check_embedding=check_embedding):
|
||||||
|
logger.debug(f"[Clustering] 社区 {community_id} 属性已完整,跳过生成")
|
||||||
|
return
|
||||||
|
|
||||||
|
members = await self.repo.get_community_members(community_id, end_user_id)
|
||||||
|
if not members:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 核心实体:按 activation_value 降序取 top-N
|
||||||
|
sorted_members = sorted(
|
||||||
|
members,
|
||||||
|
key=lambda m: m.get("activation_value") or 0,
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
|
||||||
|
all_names = [m["name"] for m in members if m.get("name")]
|
||||||
|
|
||||||
|
name = "、".join(core_entities[:3]) if core_entities else community_id[:8]
|
||||||
|
summary = f"包含实体:{', '.join(all_names)}"
|
||||||
|
|
||||||
|
# 若有 LLM 配置,调用 LLM 生成更好的名称和摘要
|
||||||
|
if self.llm_model_id:
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
entity_list_str = "、".join(all_names)
|
||||||
|
prompt = (
|
||||||
|
f"以下是一组语义相关的实体:{entity_list_str}\n\n"
|
||||||
|
f"请为这组实体所代表的主题:\n"
|
||||||
|
f"1. 起一个简洁的中文名称(不超过10个字)\n"
|
||||||
|
f"2. 写一句话摘要(不超过50个字)\n\n"
|
||||||
|
f"严格按以下格式输出,不要有其他内容:\n"
|
||||||
|
f"名称:<名称>\n摘要:<摘要>"
|
||||||
|
)
|
||||||
|
with get_db_context() as db:
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client(self.llm_model_id)
|
||||||
|
response = await llm_client.chat([{"role": "user", "content": prompt}])
|
||||||
|
text = response.content if hasattr(response, "content") else str(response)
|
||||||
|
|
||||||
|
for line in text.strip().splitlines():
|
||||||
|
if line.startswith("名称:"):
|
||||||
|
name = line[3:].strip()
|
||||||
|
elif line.startswith("摘要:"):
|
||||||
|
summary = line[3:].strip()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] LLM 生成社区元数据失败,使用兜底值: {e}")
|
||||||
|
|
||||||
|
# 生成 summary_embedding
|
||||||
|
summary_embedding: Optional[List[float]] = None
|
||||||
|
if self.embedding_model_id and summary:
|
||||||
|
try:
|
||||||
|
from app.db import get_db_context
|
||||||
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
|
|
||||||
|
with get_db_context() as db:
|
||||||
|
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
|
||||||
|
vectors = await embedder.response([summary])
|
||||||
|
if vectors:
|
||||||
|
summary_embedding = vectors[0]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[Clustering] 社区 {community_id} 生成 summary_embedding 失败: {e}")
|
||||||
|
|
||||||
|
await self.repo.update_community_metadata(
|
||||||
|
community_id=community_id,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
name=name,
|
||||||
|
summary=summary,
|
||||||
|
core_entities=core_entities,
|
||||||
|
summary_embedding=summary_embedding,
|
||||||
|
)
|
||||||
|
logger.debug(f"[Clustering] 社区 {community_id} 元数据已更新: name={name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Clustering] _generate_community_metadata failed for {community_id}: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _new_community_id() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
@@ -33,6 +33,7 @@ class DialogExtractionResponse(BaseModel):
|
|||||||
|
|
||||||
- is_related:对话与场景的相关性判定。
|
- is_related:对话与场景的相关性判定。
|
||||||
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
- times / ids / amounts / contacts / addresses / keywords:重要信息片段,用来在不相关对话中保留关键消息。
|
||||||
|
- preserve_keywords:情绪/兴趣/爱好/个人观点相关词,包含这些词的消息必须强制保留。
|
||||||
"""
|
"""
|
||||||
is_related: bool = Field(...)
|
is_related: bool = Field(...)
|
||||||
times: List[str] = Field(default_factory=list)
|
times: List[str] = Field(default_factory=list)
|
||||||
@@ -41,6 +42,7 @@ class DialogExtractionResponse(BaseModel):
|
|||||||
contacts: List[str] = Field(default_factory=list)
|
contacts: List[str] = Field(default_factory=list)
|
||||||
addresses: List[str] = Field(default_factory=list)
|
addresses: List[str] = Field(default_factory=list)
|
||||||
keywords: List[str] = Field(default_factory=list)
|
keywords: List[str] = Field(default_factory=list)
|
||||||
|
preserve_keywords: List[str] = Field(default_factory=list, description="情绪/兴趣/爱好/个人观点相关词,包含这些词的消息强制保留")
|
||||||
|
|
||||||
|
|
||||||
class MessageImportanceResponse(BaseModel):
|
class MessageImportanceResponse(BaseModel):
|
||||||
@@ -86,26 +88,17 @@ class SemanticPruner:
|
|||||||
self._detailed_prune_logging = True # 是否启用详细日志
|
self._detailed_prune_logging = True # 是否启用详细日志
|
||||||
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志
|
||||||
|
|
||||||
# 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则)
|
# 加载统一填充词库
|
||||||
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(
|
self.scene_config: ScenePatterns = SceneConfigRegistry.get_config(self.config.pruning_scene)
|
||||||
self.config.pruning_scene,
|
|
||||||
fallback_to_generic=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 判断是否为内置专门场景
|
# 本体类型列表(用于注入提示词,所有场景均支持)
|
||||||
self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene)
|
|
||||||
|
|
||||||
# 自定义场景的本体类型列表(用于注入提示词)
|
|
||||||
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
self._ontology_classes = getattr(self.config, "ontology_classes", None) or []
|
||||||
|
|
||||||
if self._is_builtin_scene:
|
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene}")
|
||||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置")
|
if self._ontology_classes:
|
||||||
|
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
||||||
else:
|
else:
|
||||||
self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入")
|
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
||||||
if self._ontology_classes:
|
|
||||||
self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}")
|
|
||||||
else:
|
|
||||||
self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词")
|
|
||||||
|
|
||||||
# Load Jinja2 template
|
# Load Jinja2 template
|
||||||
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
self.template = prompt_env.get_template("extracat_Pruning.jinja2")
|
||||||
@@ -117,98 +110,18 @@ class SemanticPruner:
|
|||||||
# 运行日志:收集关键终端输出,便于写入 JSON
|
# 运行日志:收集关键终端输出,便于写入 JSON
|
||||||
self.run_logs: List[str] = []
|
self.run_logs: List[str] = []
|
||||||
|
|
||||||
def _is_important_message(self, message: ConversationMessage) -> bool:
|
# _is_important_message 和 _importance_score 已移除:
|
||||||
"""基于启发式规则识别重要信息消息,优先保留。
|
# 重要性判断完全由 extracat_Pruning.jinja2 提示词 + LLM 的 preserve_tokens 机制承担。
|
||||||
|
# LLM 根据注入的本体工程类型语义识别需要保护的内容,无需硬编码正则规则。
|
||||||
改进版:使用场景特定的模式进行识别
|
|
||||||
- 根据 pruning_scene 动态加载对应的识别规则
|
|
||||||
- 支持教育、在线服务、外呼三个场景的特定模式
|
|
||||||
"""
|
|
||||||
text = message.msg.strip()
|
|
||||||
if not text:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 使用场景特定的模式
|
|
||||||
all_patterns = (
|
|
||||||
self.scene_config.high_priority_patterns +
|
|
||||||
self.scene_config.medium_priority_patterns +
|
|
||||||
self.scene_config.low_priority_patterns
|
|
||||||
)
|
|
||||||
|
|
||||||
for pattern, _ in all_patterns:
|
|
||||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查是否为问句(以问号结尾或包含疑问词)
|
|
||||||
if text.endswith("?") or text.endswith("?"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查是否包含问句关键词
|
|
||||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 检查是否包含决策性关键词
|
|
||||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _importance_score(self, message: ConversationMessage) -> int:
|
|
||||||
"""为重要消息打分,用于在保留比例内优先保留更关键的内容。
|
|
||||||
|
|
||||||
改进版:使用场景特定的权重体系(0-10分)
|
|
||||||
- 根据场景动态调整不同信息类型的权重
|
|
||||||
- 高优先级模式:4-6分
|
|
||||||
- 中优先级模式:2-3分
|
|
||||||
- 低优先级模式:1分
|
|
||||||
"""
|
|
||||||
text = message.msg.strip()
|
|
||||||
score = 0
|
|
||||||
|
|
||||||
# 使用场景特定的权重
|
|
||||||
for pattern, weight in self.scene_config.high_priority_patterns:
|
|
||||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
|
||||||
score += weight
|
|
||||||
|
|
||||||
for pattern, weight in self.scene_config.medium_priority_patterns:
|
|
||||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
|
||||||
score += weight
|
|
||||||
|
|
||||||
for pattern, weight in self.scene_config.low_priority_patterns:
|
|
||||||
if re.search(pattern, text, flags=re.IGNORECASE):
|
|
||||||
score += weight
|
|
||||||
|
|
||||||
# 问句加分
|
|
||||||
if text.endswith("?") or text.endswith("?"):
|
|
||||||
score += 2
|
|
||||||
|
|
||||||
# 包含问句关键词加分
|
|
||||||
if any(keyword in text for keyword in self.scene_config.question_keywords):
|
|
||||||
score += 1
|
|
||||||
|
|
||||||
# 包含决策性关键词加分
|
|
||||||
if any(keyword in text for keyword in self.scene_config.decision_keywords):
|
|
||||||
score += 2
|
|
||||||
|
|
||||||
# 长度加分(较长的消息通常包含更多信息)
|
|
||||||
if len(text) > 50:
|
|
||||||
score += 1
|
|
||||||
if len(text) > 100:
|
|
||||||
score += 1
|
|
||||||
|
|
||||||
return min(score, 10) # 最高10分
|
|
||||||
|
|
||||||
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
def _is_filler_message(self, message: ConversationMessage) -> bool:
|
||||||
"""检测典型寒暄/口头禅/确认类短消息。
|
"""检测典型寒暄/口头禅/确认类短消息。
|
||||||
|
|
||||||
改进版:更严格的填充消息判断,避免误删场景相关内容
|
判断顺序:
|
||||||
满足以下之一视为填充消息:
|
1. 空消息
|
||||||
- 纯标点或空白
|
2. 场景特定填充词库精确匹配
|
||||||
- 在场景特定填充词库中(精确匹配)
|
3. 常见寒暄精确匹配
|
||||||
- 纯表情符号
|
4. 纯表情/标点
|
||||||
- 常见寒暄(精确匹配短语)
|
|
||||||
|
|
||||||
注意:不再使用长度判断,避免误删短但重要的消息
|
|
||||||
"""
|
"""
|
||||||
t = message.msg.strip()
|
t = message.msg.strip()
|
||||||
if not t:
|
if not t:
|
||||||
@@ -234,20 +147,6 @@ class SemanticPruner:
|
|||||||
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
if re.fullmatch(r"(\[[^\]]+\])+", t):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# 检查是否为纯emoji(Unicode表情)
|
|
||||||
emoji_pattern = re.compile(
|
|
||||||
"["
|
|
||||||
"\U0001F600-\U0001F64F" # 表情符号
|
|
||||||
"\U0001F300-\U0001F5FF" # 符号和象形文字
|
|
||||||
"\U0001F680-\U0001F6FF" # 交通和地图符号
|
|
||||||
"\U0001F1E0-\U0001F1FF" # 旗帜
|
|
||||||
"\U00002702-\U000027B0"
|
|
||||||
"\U000024C2-\U0001F251"
|
|
||||||
"]+", flags=re.UNICODE
|
|
||||||
)
|
|
||||||
if emoji_pattern.fullmatch(t):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# 纯标点符号
|
# 纯标点符号
|
||||||
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
if re.fullmatch(r"[。!?,.!?…·\s]+", t):
|
||||||
return True
|
return True
|
||||||
@@ -432,14 +331,12 @@ class SemanticPruner:
|
|||||||
|
|
||||||
rendered = self.template.render(
|
rendered = self.template.render(
|
||||||
pruning_scene=self.config.pruning_scene,
|
pruning_scene=self.config.pruning_scene,
|
||||||
is_builtin_scene=self._is_builtin_scene,
|
|
||||||
ontology_classes=self._ontology_classes,
|
ontology_classes=self._ontology_classes,
|
||||||
dialog_text=dialog_text,
|
dialog_text=dialog_text,
|
||||||
language=self.language
|
language=self.language
|
||||||
)
|
)
|
||||||
log_template_rendering("extracat_Pruning.jinja2", {
|
log_template_rendering("extracat_Pruning.jinja2", {
|
||||||
"pruning_scene": self.config.pruning_scene,
|
"pruning_scene": self.config.pruning_scene,
|
||||||
"is_builtin_scene": self._is_builtin_scene,
|
|
||||||
"ontology_classes_count": len(self._ontology_classes),
|
"ontology_classes_count": len(self._ontology_classes),
|
||||||
"language": self.language
|
"language": self.language
|
||||||
})
|
})
|
||||||
@@ -504,62 +401,56 @@ class SemanticPruner:
|
|||||||
# 相关对话不剪枝
|
# 相关对话不剪枝
|
||||||
return dialog
|
return dialog
|
||||||
|
|
||||||
# 在不相关对话中,识别重要/不重要消息
|
# 在不相关对话中,LLM 已通过 preserve_tokens 标记需要保护的内容
|
||||||
tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords
|
preserve_tokens = (
|
||||||
|
extraction.times + extraction.ids + extraction.amounts +
|
||||||
|
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||||
|
extraction.preserve_keywords
|
||||||
|
)
|
||||||
msgs = dialog.context.msgs
|
msgs = dialog.context.msgs
|
||||||
imp_unrel_msgs: List[ConversationMessage] = []
|
|
||||||
unimp_unrel_msgs: List[ConversationMessage] = []
|
# 分类:填充 / 其他可删(LLM保护消息通过不加入任何桶来隐式保护)
|
||||||
|
filler_ids: set = set()
|
||||||
|
deletable: List[ConversationMessage] = []
|
||||||
|
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
if self._msg_matches_tokens(m, tokens) or self._is_important_message(m):
|
if self._msg_matches_tokens(m, preserve_tokens):
|
||||||
imp_unrel_msgs.append(m)
|
pass # 保护消息:不加入任何桶,不会被删除
|
||||||
|
elif self._is_filler_message(m):
|
||||||
|
filler_ids.add(id(m))
|
||||||
else:
|
else:
|
||||||
unimp_unrel_msgs.append(m)
|
deletable.append(m)
|
||||||
# 计算总删除目标数量
|
|
||||||
|
# 计算删除目标
|
||||||
total_unrel = len(msgs)
|
total_unrel = len(msgs)
|
||||||
delete_target = int(total_unrel * proportion)
|
delete_target = int(total_unrel * proportion)
|
||||||
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
if proportion > 0 and total_unrel > 0 and delete_target == 0:
|
||||||
delete_target = 1
|
delete_target = 1
|
||||||
imp_del_cap = min(int(len(imp_unrel_msgs) * proportion), len(imp_unrel_msgs))
|
max_deletable = min(len(filler_ids) + len(deletable), max(0, total_unrel - 1))
|
||||||
unimp_del_cap = len(unimp_unrel_msgs)
|
|
||||||
max_capacity = max(0, len(msgs) - 1)
|
|
||||||
max_deletable = min(imp_del_cap + unimp_del_cap, max_capacity)
|
|
||||||
delete_target = min(delete_target, max_deletable)
|
delete_target = min(delete_target, max_deletable)
|
||||||
# 删除配额分配
|
|
||||||
del_unimp = min(delete_target, unimp_del_cap)
|
|
||||||
rem = delete_target - del_unimp
|
|
||||||
del_imp = min(rem, imp_del_cap)
|
|
||||||
|
|
||||||
# 选取删除集合
|
# 优先删填充,再删其他可删消息(按出现顺序)
|
||||||
unimp_delete_ids = []
|
to_delete_ids: set = set()
|
||||||
imp_delete_ids = []
|
|
||||||
if del_unimp > 0:
|
|
||||||
# 按出现顺序选取前 del_unimp 条不重要消息进行删除(确定性、可复现)
|
|
||||||
unimp_delete_ids = [id(m) for m in unimp_unrel_msgs[:del_unimp]]
|
|
||||||
if del_imp > 0:
|
|
||||||
imp_sorted = sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))
|
|
||||||
imp_delete_ids = [id(m) for m in imp_sorted[:del_imp]]
|
|
||||||
|
|
||||||
# 统计实际删除数量(重要/不重要)
|
|
||||||
actual_unimp_deleted = 0
|
|
||||||
actual_imp_deleted = 0
|
|
||||||
kept_msgs = []
|
|
||||||
delete_targets = set(unimp_delete_ids) | set(imp_delete_ids)
|
|
||||||
for m in msgs:
|
for m in msgs:
|
||||||
mid = id(m)
|
if len(to_delete_ids) >= delete_target:
|
||||||
if mid in delete_targets:
|
break
|
||||||
if mid in set(unimp_delete_ids) and actual_unimp_deleted < del_unimp:
|
if id(m) in filler_ids:
|
||||||
actual_unimp_deleted += 1
|
to_delete_ids.add(id(m))
|
||||||
continue
|
for m in deletable:
|
||||||
if mid in set(imp_delete_ids) and actual_imp_deleted < del_imp:
|
if len(to_delete_ids) >= delete_target:
|
||||||
actual_imp_deleted += 1
|
break
|
||||||
continue
|
to_delete_ids.add(id(m))
|
||||||
kept_msgs.append(m)
|
|
||||||
|
kept_msgs = [m for m in msgs if id(m) not in to_delete_ids]
|
||||||
if not kept_msgs and msgs:
|
if not kept_msgs and msgs:
|
||||||
kept_msgs = [msgs[0]]
|
kept_msgs = [msgs[0]]
|
||||||
|
|
||||||
deleted_total = actual_unimp_deleted + actual_imp_deleted
|
deleted_total = len(msgs) - len(kept_msgs)
|
||||||
|
protected_count = len(msgs) - len(filler_ids) - len(deletable)
|
||||||
self._log(
|
self._log(
|
||||||
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} 删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
f"[剪枝-对话] 对话ID={dialog.id} 总消息={len(msgs)} "
|
||||||
|
f"(保护={protected_count} 填充={len(filler_ids)} 可删={len(deletable)}) "
|
||||||
|
f"删除目标={delete_target} 实删={deleted_total} 保留={len(kept_msgs)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
dialog.context = ConversationContext(msgs=kept_msgs)
|
dialog.context = ConversationContext(msgs=kept_msgs)
|
||||||
@@ -595,50 +486,63 @@ class SemanticPruner:
|
|||||||
total_original_msgs = 0
|
total_original_msgs = 0
|
||||||
total_deleted_msgs = 0
|
total_deleted_msgs = 0
|
||||||
|
|
||||||
for d_idx, dd in enumerate(dialogs):
|
# 并发执行所有对话的 LLM 抽取(获取 preserve_keywords 等保护信息)
|
||||||
|
semaphore = asyncio.Semaphore(self.max_concurrent)
|
||||||
|
|
||||||
|
async def extract_with_semaphore(dd: DialogData) -> DialogExtractionResponse:
|
||||||
|
async with semaphore:
|
||||||
|
try:
|
||||||
|
return await self._extract_dialog_important(dd.content)
|
||||||
|
except Exception as e:
|
||||||
|
self._log(f"[剪枝-LLM] 对话抽取失败,使用降级策略: {str(e)[:100]}")
|
||||||
|
return DialogExtractionResponse(is_related=True)
|
||||||
|
|
||||||
|
extraction_tasks = [extract_with_semaphore(dd) for dd in dialogs]
|
||||||
|
extraction_results: List[DialogExtractionResponse] = await asyncio.gather(*extraction_tasks)
|
||||||
|
|
||||||
|
for d_idx, (dd, extraction) in enumerate(zip(dialogs, extraction_results)):
|
||||||
msgs = dd.context.msgs
|
msgs = dd.context.msgs
|
||||||
original_count = len(msgs)
|
original_count = len(msgs)
|
||||||
total_original_msgs += original_count
|
total_original_msgs += original_count
|
||||||
|
|
||||||
# ========== 问答对保护(已注释,暂不启用,留作观察) ==========
|
# 从 LLM 抽取结果中获取所有需要保留的 token
|
||||||
# qa_pairs = self._identify_qa_pairs(msgs)
|
preserve_tokens = (
|
||||||
# protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0)
|
extraction.times + extraction.ids + extraction.amounts +
|
||||||
# ========================================================
|
extraction.contacts + extraction.addresses + extraction.keywords +
|
||||||
|
extraction.preserve_keywords # 情绪/兴趣/爱好关键词
|
||||||
|
)
|
||||||
|
|
||||||
# 消息级分类:每条消息独立判断
|
# 判断是否需要详细日志
|
||||||
important_msgs = [] # 重要消息(保留)
|
|
||||||
unimportant_msgs = [] # 不重要消息(可删除)
|
|
||||||
filler_msgs = [] # 填充消息(优先删除)
|
|
||||||
|
|
||||||
# 判断是否需要详细日志(仅对前N条消息记录)
|
|
||||||
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog
|
||||||
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog:
|
||||||
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志")
|
||||||
|
|
||||||
|
if extraction.preserve_keywords:
|
||||||
|
self._log(f" 对话[{d_idx}] LLM抽取到情绪/兴趣保护词: {extraction.preserve_keywords}")
|
||||||
|
|
||||||
|
# 消息级分类:LLM保护 / 填充 / 其他可删
|
||||||
|
llm_protected_msgs = [] # LLM 保护消息(preserve_tokens 命中):绝对不可删除
|
||||||
|
filler_msgs = [] # 填充消息(优先删除)
|
||||||
|
deletable_msgs = [] # 其余消息(按比例删除)
|
||||||
|
|
||||||
for idx, m in enumerate(msgs):
|
for idx, m in enumerate(msgs):
|
||||||
msg_text = m.msg.strip()
|
msg_text = m.msg.strip()
|
||||||
|
|
||||||
# ========== 问答对保护判断(已注释) ==========
|
if self._msg_matches_tokens(m, preserve_tokens):
|
||||||
# if idx in protected_indices:
|
llm_protected_msgs.append((idx, m))
|
||||||
# important_msgs.append((idx, m))
|
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||||
# self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)")
|
self._log(f" [{idx}] '{msg_text[:30]}...' → 保护(LLM,不可删)")
|
||||||
# ==========================================
|
elif self._is_filler_message(m):
|
||||||
|
|
||||||
# 填充消息(寒暄、表情等)
|
|
||||||
if self._is_filler_message(m):
|
|
||||||
filler_msgs.append((idx, m))
|
filler_msgs.append((idx, m))
|
||||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
self._log(f" [{idx}] '{msg_text[:30]}...' → 填充")
|
||||||
# 重要信息(学号、成绩、时间、金额等)
|
|
||||||
elif self._is_important_message(m):
|
|
||||||
important_msgs.append((idx, m))
|
|
||||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
|
||||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)")
|
|
||||||
# 其他消息
|
|
||||||
else:
|
else:
|
||||||
unimportant_msgs.append((idx, m))
|
deletable_msgs.append((idx, m))
|
||||||
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
if should_log_details or idx < self._max_debug_msgs_per_dialog:
|
||||||
self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要")
|
self._log(f" [{idx}] '{msg_text[:30]}...' → 可删")
|
||||||
|
|
||||||
|
# important_msgs 仅用于日志统计
|
||||||
|
important_msgs = llm_protected_msgs
|
||||||
|
|
||||||
# 计算删除配额
|
# 计算删除配额
|
||||||
delete_target = int(original_count * proportion)
|
delete_target = int(original_count * proportion)
|
||||||
@@ -649,37 +553,23 @@ class SemanticPruner:
|
|||||||
max_deletable = max(0, original_count - 1)
|
max_deletable = max(0, original_count - 1)
|
||||||
delete_target = min(delete_target, max_deletable)
|
delete_target = min(delete_target, max_deletable)
|
||||||
|
|
||||||
# 删除策略:优先删除填充消息,再删除不重要消息
|
# 删除策略:优先删填充消息,再按出现顺序删其余可删消息
|
||||||
to_delete_indices = set()
|
to_delete_indices = set()
|
||||||
deleted_details = [] # 记录删除的消息详情
|
deleted_details = []
|
||||||
|
|
||||||
# 第一步:删除填充消息
|
# 第一步:删除填充消息
|
||||||
filler_to_delete = min(len(filler_msgs), delete_target)
|
for idx, msg in filler_msgs:
|
||||||
for i in range(filler_to_delete):
|
if len(to_delete_indices) >= delete_target:
|
||||||
idx, msg = filler_msgs[i]
|
break
|
||||||
to_delete_indices.add(idx)
|
to_delete_indices.add(idx)
|
||||||
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'")
|
||||||
|
|
||||||
# 第二步:如果还需要删除,删除不重要消息
|
# 第二步:如果还需要删除,按出现顺序删可删消息
|
||||||
remaining_quota = delete_target - len(to_delete_indices)
|
for idx, msg in deletable_msgs:
|
||||||
if remaining_quota > 0:
|
if len(to_delete_indices) >= delete_target:
|
||||||
unimp_to_delete = min(len(unimportant_msgs), remaining_quota)
|
break
|
||||||
for i in range(unimp_to_delete):
|
to_delete_indices.add(idx)
|
||||||
idx, msg = unimportant_msgs[i]
|
deleted_details.append(f"[{idx}] 可删: '{msg.msg[:50]}'")
|
||||||
to_delete_indices.add(idx)
|
|
||||||
deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'")
|
|
||||||
|
|
||||||
# 第三步:如果还需要删除,按重要性分数删除重要消息
|
|
||||||
remaining_quota = delete_target - len(to_delete_indices)
|
|
||||||
if remaining_quota > 0 and important_msgs:
|
|
||||||
# 按重要性分数排序(分数低的优先删除)
|
|
||||||
imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1]))
|
|
||||||
imp_to_delete = min(len(imp_sorted), remaining_quota)
|
|
||||||
for i in range(imp_to_delete):
|
|
||||||
idx, msg = imp_sorted[i]
|
|
||||||
to_delete_indices.add(idx)
|
|
||||||
score = self._importance_score(msg)
|
|
||||||
deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'")
|
|
||||||
|
|
||||||
# 执行删除
|
# 执行删除
|
||||||
kept_msgs = []
|
kept_msgs = []
|
||||||
@@ -707,7 +597,7 @@ class SemanticPruner:
|
|||||||
|
|
||||||
self._log(
|
self._log(
|
||||||
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} "
|
||||||
f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) "
|
f"(保护={len(important_msgs)} 填充={len(filler_msgs)} 可删={len(deletable_msgs)}) "
|
||||||
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
f"删除={deleted_count} 保留={len(kept_msgs)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,66 +1,25 @@
|
|||||||
"""
|
"""
|
||||||
场景特定配置 - 为不同场景提供定制化的剪枝规则
|
场景特定配置 - 统一填充词库
|
||||||
|
|
||||||
功能:
|
重要性判断已完全交由 extracat_Pruning.jinja2 提示词 + LLM preserve_tokens 机制承担。
|
||||||
- 场景特定的重要信息识别模式
|
本模块仅保留统一填充词库(filler_phrases),用于识别无意义寒暄/表情/口头禅。
|
||||||
- 场景特定的重要性评分权重
|
所有场景共用同一份词库,场景差异由 LLM 语义判断处理。
|
||||||
- 场景特定的填充词库
|
|
||||||
- 场景特定的问答对识别规则
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Dict, List, Set, Tuple
|
from typing import List, Set
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ScenePatterns:
|
class ScenePatterns:
|
||||||
"""场景特定的识别模式"""
|
"""场景特定的识别模式(仅保留填充词库)"""
|
||||||
|
|
||||||
# 重要信息的正则模式(优先级从高到低)
|
|
||||||
high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight)
|
|
||||||
medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
|
||||||
low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list)
|
|
||||||
|
|
||||||
# 填充词库(无意义对话)
|
|
||||||
filler_phrases: Set[str] = field(default_factory=set)
|
filler_phrases: Set[str] = field(default_factory=set)
|
||||||
|
|
||||||
# 问句关键词(用于识别问答对)
|
|
||||||
question_keywords: Set[str] = field(default_factory=set)
|
|
||||||
|
|
||||||
# 决策性/承诺性关键词
|
|
||||||
decision_keywords: Set[str] = field(default_factory=set)
|
|
||||||
|
|
||||||
|
|
||||||
class SceneConfigRegistry:
|
class SceneConfigRegistry:
|
||||||
"""场景配置注册表 - 管理所有场景的特定配置"""
|
"""场景配置注册表 - 所有场景共用统一填充词库"""
|
||||||
|
|
||||||
# 基础通用模式(所有场景共享)
|
BASE_FILLERS: Set[str] = {
|
||||||
BASE_HIGH_PRIORITY = [
|
|
||||||
(r"订单号|工单|申请号|编号|ID|账号|账户", 5),
|
|
||||||
(r"金额|费用|价格|¥|¥|\d+元", 5),
|
|
||||||
(r"\d{11}", 4), # 手机号
|
|
||||||
(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱
|
|
||||||
]
|
|
||||||
|
|
||||||
BASE_MEDIUM_PRIORITY = [
|
|
||||||
(r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期
|
|
||||||
(r"\d{4}年\d{1,2}月\d{1,2}日", 3),
|
|
||||||
(r"电话|手机号|微信|QQ|联系方式", 3),
|
|
||||||
(r"地址|地点|位置", 2),
|
|
||||||
(r"时间|日期|有效期|截止", 2),
|
|
||||||
(r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重)
|
|
||||||
(r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3),
|
|
||||||
(r"今年|去年|明年", 3),
|
|
||||||
]
|
|
||||||
|
|
||||||
BASE_LOW_PRIORITY = [
|
|
||||||
(r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM
|
|
||||||
(r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点
|
|
||||||
(r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充)
|
|
||||||
(r"AM|PM|am|pm", 1),
|
|
||||||
]
|
|
||||||
|
|
||||||
BASE_FILLERS = {
|
|
||||||
# 基础寒暄
|
# 基础寒暄
|
||||||
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
"你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦",
|
||||||
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
"好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢",
|
||||||
@@ -69,7 +28,26 @@ class SceneConfigRegistry:
|
|||||||
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
"哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia",
|
||||||
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
"额", "呃", "啊", "诶", "唉", "哎", "嗯哼",
|
||||||
# 确认词
|
# 确认词
|
||||||
"是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了",
|
"是的", "对", "对的", "没错", "好嘞", "收到", "明白", "了解", "知道了",
|
||||||
|
# 服务类套话
|
||||||
|
"请问", "请稍等", "稍等", "马上", "立即",
|
||||||
|
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
||||||
|
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
||||||
|
"感谢您的耐心等待", "抱歉让您久等了",
|
||||||
|
"已记录", "已反馈", "已转接", "已升级",
|
||||||
|
"祝您生活愉快", "欢迎下次咨询",
|
||||||
|
# 外呼套话
|
||||||
|
"喂", "hello", "打扰了", "不好意思",
|
||||||
|
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
||||||
|
"我是", "我们是", "我们公司", "我们这边",
|
||||||
|
"了解一下", "介绍一下", "简单说一下",
|
||||||
|
"考虑考虑", "想一想", "再说", "再看看",
|
||||||
|
"不需要", "不感兴趣", "没兴趣", "不用了",
|
||||||
|
"没问题", "那就这样", "再联系", "回头聊", "有需要再说",
|
||||||
|
# 教育场景套话
|
||||||
|
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
||||||
|
"举手", "请坐", "很好", "不错", "继续",
|
||||||
|
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
||||||
# 标点和符号
|
# 标点和符号
|
||||||
"。。。", "...", "???", "???", "!!!", "!!!",
|
"。。。", "...", "???", "???", "!!!", "!!!",
|
||||||
# 表情符号
|
# 表情符号
|
||||||
@@ -82,245 +60,7 @@ class SceneConfigRegistry:
|
|||||||
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
"emmm", "emm", "em", "mmp", "wtf", "omg",
|
||||||
}
|
}
|
||||||
|
|
||||||
BASE_QUESTION_KEYWORDS = {
|
|
||||||
"什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗"
|
|
||||||
}
|
|
||||||
|
|
||||||
BASE_DECISION_KEYWORDS = {
|
|
||||||
"必须", "一定", "务必", "需要", "要求", "规定", "应该",
|
|
||||||
"承诺", "保证", "确保", "负责", "同意", "答应"
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_education_config(cls) -> ScenePatterns:
|
def get_config(cls, scene: str = "") -> ScenePatterns:
|
||||||
"""教育场景配置"""
|
"""所有场景统一返回同一份填充词库"""
|
||||||
return ScenePatterns(
|
return ScenePatterns(filler_phrases=cls.BASE_FILLERS)
|
||||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
|
||||||
# 成绩相关(最高优先级)
|
|
||||||
(r"成绩|分数|得分|满分|及格|不及格", 6),
|
|
||||||
(r"GPA|绩点|学分|平均分", 6),
|
|
||||||
(r"\d+分|\d+\.?\d*分", 5), # 具体分数
|
|
||||||
(r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等
|
|
||||||
|
|
||||||
# 学籍信息
|
|
||||||
(r"学号|学生证|教师工号|工号", 5),
|
|
||||||
(r"班级|年级|专业|院系", 4),
|
|
||||||
|
|
||||||
# 课程相关
|
|
||||||
(r"课程|科目|学科|必修|选修", 4),
|
|
||||||
(r"教材|课本|教科书|参考书", 4),
|
|
||||||
(r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等
|
|
||||||
|
|
||||||
# 学科内容(新增)
|
|
||||||
(r"微积分|导数|积分|函数|极限|微分", 4),
|
|
||||||
(r"代数|几何|三角|概率|统计", 4),
|
|
||||||
(r"物理|化学|生物|历史|地理", 4),
|
|
||||||
(r"英语|语文|数学|政治|哲学", 4),
|
|
||||||
(r"定义|定理|公式|概念|原理|法则", 3),
|
|
||||||
(r"例题|解题|证明|推导|计算", 3),
|
|
||||||
],
|
|
||||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
|
||||||
# 教学活动
|
|
||||||
(r"作业|练习|习题|题目", 3),
|
|
||||||
(r"考试|测验|测试|考核|期中|期末", 3),
|
|
||||||
(r"上课|下课|课堂|讲课", 2),
|
|
||||||
(r"提问|回答|发言|讨论", 2),
|
|
||||||
(r"问一下|请教|咨询|询问", 2), # 新增:问询相关
|
|
||||||
(r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态
|
|
||||||
|
|
||||||
# 时间安排
|
|
||||||
(r"课表|课程表|时间表", 3),
|
|
||||||
(r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等
|
|
||||||
],
|
|
||||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
|
||||||
(r"老师|教师|同学|学生", 1),
|
|
||||||
(r"教室|实验室|图书馆", 1),
|
|
||||||
],
|
|
||||||
filler_phrases=cls.BASE_FILLERS | {
|
|
||||||
# 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义)
|
|
||||||
"老师好", "同学们好", "上课", "下课", "起立", "坐下",
|
|
||||||
"举手", "请坐", "很好", "不错", "继续",
|
|
||||||
"下一个", "下一题", "下一位", "还有吗", "还有问题吗",
|
|
||||||
},
|
|
||||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
|
||||||
"为啥", "咋", "咋办", "怎样", "如何做",
|
|
||||||
"能不能", "可不可以", "行不行", "对不对", "是不是",
|
|
||||||
},
|
|
||||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
|
||||||
"必考", "重点", "考点", "难点", "关键",
|
|
||||||
"记住", "背诵", "掌握", "理解", "复习",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_online_service_config(cls) -> ScenePatterns:
|
|
||||||
"""在线服务场景配置"""
|
|
||||||
return ScenePatterns(
|
|
||||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
|
||||||
# 工单相关(最高优先级)
|
|
||||||
(r"工单号|工单编号|ticket|TK\d+", 6),
|
|
||||||
(r"工单状态|处理中|已解决|已关闭|待处理", 5),
|
|
||||||
(r"优先级|紧急|高优先级|P0|P1|P2", 5),
|
|
||||||
|
|
||||||
# 产品信息
|
|
||||||
(r"产品型号|型号|SKU|产品编号", 5),
|
|
||||||
(r"序列号|SN|设备号", 5),
|
|
||||||
(r"版本号|软件版本|固件版本", 4),
|
|
||||||
|
|
||||||
# 问题描述
|
|
||||||
(r"故障|错误|异常|bug|问题", 4),
|
|
||||||
(r"错误代码|故障代码|error code", 5),
|
|
||||||
(r"无法|不能|失败|报错", 3),
|
|
||||||
],
|
|
||||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
|
||||||
# 服务相关
|
|
||||||
(r"退款|退货|换货|补发", 4),
|
|
||||||
(r"发票|收据|凭证", 3),
|
|
||||||
(r"物流|快递|运单号", 3),
|
|
||||||
(r"保修|质保|售后", 3),
|
|
||||||
|
|
||||||
# 时效相关
|
|
||||||
(r"SLA|响应时间|处理时长", 4),
|
|
||||||
(r"超时|延迟|等待", 2),
|
|
||||||
],
|
|
||||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
|
||||||
(r"客服|工程师|技术支持", 1),
|
|
||||||
(r"用户|客户|会员", 1),
|
|
||||||
],
|
|
||||||
filler_phrases=cls.BASE_FILLERS | {
|
|
||||||
# 在线服务特有填充词
|
|
||||||
"您好", "请问", "请稍等", "稍等", "马上", "立即",
|
|
||||||
"正在查询", "正在处理", "正在为您", "帮您查一下",
|
|
||||||
"还有其他问题吗", "还需要什么帮助", "很高兴为您服务",
|
|
||||||
"感谢您的耐心等待", "抱歉让您久等了",
|
|
||||||
"已记录", "已反馈", "已转接", "已升级",
|
|
||||||
"祝您生活愉快", "再见", "欢迎下次咨询",
|
|
||||||
},
|
|
||||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
|
||||||
"能否", "可否", "是否", "有没有", "能不能",
|
|
||||||
"怎么办", "如何处理", "怎么解决",
|
|
||||||
},
|
|
||||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
|
||||||
"立即处理", "马上解决", "尽快", "优先",
|
|
||||||
"升级", "转接", "派单", "跟进",
|
|
||||||
"补偿", "赔偿", "退款", "换货",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_outbound_config(cls) -> ScenePatterns:
|
|
||||||
"""外呼场景配置"""
|
|
||||||
return ScenePatterns(
|
|
||||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY + [
|
|
||||||
# 意向相关(最高优先级)
|
|
||||||
(r"意向|意愿|兴趣|感兴趣", 6),
|
|
||||||
(r"A类|B类|C类|D类|高意向|低意向", 6),
|
|
||||||
(r"成交|签约|下单|购买|确认", 6),
|
|
||||||
|
|
||||||
# 联系信息(外呼场景中更重要)
|
|
||||||
(r"预约|约定|安排|确定时间", 5),
|
|
||||||
(r"下次联系|回访|跟进", 5),
|
|
||||||
(r"方便|有空|可以|时间", 4),
|
|
||||||
|
|
||||||
# 通话状态
|
|
||||||
(r"接通|未接通|占线|关机|停机", 4),
|
|
||||||
(r"通话时长|通话时间", 3),
|
|
||||||
],
|
|
||||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [
|
|
||||||
# 客户信息
|
|
||||||
(r"姓名|称呼|先生|女士", 3),
|
|
||||||
(r"公司|单位|职位|职务", 3),
|
|
||||||
(r"需求|要求|期望", 3),
|
|
||||||
|
|
||||||
# 跟进状态
|
|
||||||
(r"跟进状态|进展|进度", 3),
|
|
||||||
(r"已联系|待联系|联系中", 2),
|
|
||||||
(r"拒绝|不感兴趣|考虑|再说", 3),
|
|
||||||
],
|
|
||||||
low_priority_patterns=cls.BASE_LOW_PRIORITY + [
|
|
||||||
(r"销售|客户经理|业务员", 1),
|
|
||||||
(r"产品|服务|方案", 1),
|
|
||||||
],
|
|
||||||
filler_phrases=cls.BASE_FILLERS | {
|
|
||||||
# 外呼场景特有填充词
|
|
||||||
"您好", "喂", "hello", "打扰了", "不好意思",
|
|
||||||
"方便接电话吗", "现在方便吗", "占用您一点时间",
|
|
||||||
"我是", "我们是", "我们公司", "我们这边",
|
|
||||||
"了解一下", "介绍一下", "简单说一下",
|
|
||||||
"考虑考虑", "想一想", "再说", "再看看",
|
|
||||||
"不需要", "不感兴趣", "没兴趣", "不用了",
|
|
||||||
"好的", "行", "可以", "没问题", "那就这样",
|
|
||||||
"再联系", "回头聊", "有需要再说",
|
|
||||||
},
|
|
||||||
question_keywords=cls.BASE_QUESTION_KEYWORDS | {
|
|
||||||
"有没有", "需不需要", "要不要", "考虑不考虑",
|
|
||||||
"了解吗", "知道吗", "听说过吗",
|
|
||||||
"方便吗", "有空吗", "在吗",
|
|
||||||
},
|
|
||||||
decision_keywords=cls.BASE_DECISION_KEYWORDS | {
|
|
||||||
"确定", "决定", "选择", "购买", "下单",
|
|
||||||
"预约", "安排", "约定", "确认",
|
|
||||||
"跟进", "回访", "联系", "沟通",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns:
|
|
||||||
"""根据场景名称获取配置
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene: 场景名称 ('education', 'online_service', 'outbound' 或其他)
|
|
||||||
fallback_to_generic: 如果场景不存在,是否降级到通用配置
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应场景的配置,如果场景不存在:
|
|
||||||
- fallback_to_generic=True: 返回通用配置(仅基础规则)
|
|
||||||
- fallback_to_generic=False: 抛出异常
|
|
||||||
"""
|
|
||||||
scene_map = {
|
|
||||||
'education': cls.get_education_config,
|
|
||||||
'online_service': cls.get_online_service_config,
|
|
||||||
'outbound': cls.get_outbound_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
if scene in scene_map:
|
|
||||||
return scene_map[scene]()
|
|
||||||
|
|
||||||
if fallback_to_generic:
|
|
||||||
# 返回通用配置(仅包含基础规则,不包含场景特定规则)
|
|
||||||
return cls.get_generic_config()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_generic_config(cls) -> ScenePatterns:
|
|
||||||
"""通用场景配置 - 仅包含基础规则,适用于未定义的场景
|
|
||||||
|
|
||||||
这是一个保守的配置,只使用最通用的规则,避免误删重要信息
|
|
||||||
"""
|
|
||||||
return ScenePatterns(
|
|
||||||
high_priority_patterns=cls.BASE_HIGH_PRIORITY,
|
|
||||||
medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY,
|
|
||||||
low_priority_patterns=cls.BASE_LOW_PRIORITY,
|
|
||||||
filler_phrases=cls.BASE_FILLERS,
|
|
||||||
question_keywords=cls.BASE_QUESTION_KEYWORDS,
|
|
||||||
decision_keywords=cls.BASE_DECISION_KEYWORDS
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_all_scenes(cls) -> List[str]:
|
|
||||||
"""获取所有预定义场景的列表"""
|
|
||||||
return ['education', 'online_service', 'outbound']
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_scene_supported(cls, scene: str) -> bool:
|
|
||||||
"""检查场景是否有专门的配置支持
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scene: 场景名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True: 有专门配置
|
|
||||||
False: 将使用通用配置
|
|
||||||
"""
|
|
||||||
return scene in cls.get_all_scenes()
|
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import List, Dict, Optional
|
|||||||
from app.core.logging_config import get_memory_logger
|
from app.core.logging_config import get_memory_logger
|
||||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||||
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_prompt
|
||||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||||
from app.core.memory.models.message_models import DialogData, Statement
|
from app.core.memory.models.message_models import DialogData, Statement
|
||||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||||
@@ -14,15 +14,15 @@ from app.core.memory.utils.log.logging_utils import prompt_logger
|
|||||||
logger = get_memory_logger(__name__)
|
logger = get_memory_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TripletExtractor:
|
class TripletExtractor:
|
||||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
llm_client: OpenAIClient,
|
llm_client: OpenAIClient,
|
||||||
ontology_types: Optional[OntologyTypeList] = None,
|
ontology_types: Optional[OntologyTypeList] = None,
|
||||||
language: str = "zh"):
|
language: str = "zh"
|
||||||
|
):
|
||||||
"""Initialize the TripletExtractor with an LLM client
|
"""Initialize the TripletExtractor with an LLM client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -65,7 +65,8 @@ class TripletExtractor:
|
|||||||
|
|
||||||
# Create messages for LLM
|
# Create messages for LLM
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
{"role": "system",
|
||||||
|
"content": "You are an expert at extracting knowledge triplets and entities from text. Follow the provided instructions carefully and return valid JSON."},
|
||||||
{"role": "user", "content": prompt_content}
|
{"role": "user", "content": prompt_content}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -116,7 +117,8 @@ class TripletExtractor:
|
|||||||
logger.error(f"Error processing statement: {e}", exc_info=True)
|
logger.error(f"Error processing statement: {e}", exc_info=True)
|
||||||
return TripletExtractionResponse(triplets=[], entities=[])
|
return TripletExtractionResponse(triplets=[], entities=[])
|
||||||
|
|
||||||
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[str, TripletExtractionResponse]:
|
async def extract_triplets_from_statements(self, dialog_data: DialogData, limit_chunks: int = None) -> Dict[
|
||||||
|
str, TripletExtractionResponse]:
|
||||||
"""Extract triplets and entities from statements
|
"""Extract triplets and entities from statements
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""
|
"""
|
||||||
自我反思引擎实现
|
Self-Reflection Engine Implementation
|
||||||
|
|
||||||
该模块实现了记忆系统的自我反思功能,包括:
|
This module implements the self-reflection functionality of the memory system, including:
|
||||||
1. 基于时间的反思 - 根据时间周期触发反思
|
1. Time-based reflection - Triggers reflection based on time cycles
|
||||||
2. 基于事实的反思 - 检测记忆冲突并解决
|
2. Fact-based reflection - Detects and resolves memory conflicts
|
||||||
3. 综合反思 - 整合多种反思策略
|
3. Comprehensive reflection - Integrates multiple reflection strategies
|
||||||
4. 反思结果应用 - 更新记忆库
|
4. Reflection result application - Updates memory database
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -38,7 +38,7 @@ from app.schemas.memory_storage_schema import (
|
|||||||
)
|
)
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
# 配置日志
|
# Configure logging
|
||||||
_root_logger = logging.getLogger()
|
_root_logger = logging.getLogger()
|
||||||
if not _root_logger.handlers:
|
if not _root_logger.handlers:
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -49,35 +49,62 @@ else:
|
|||||||
_root_logger.setLevel(logging.INFO)
|
_root_logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
class TranslationResponse(BaseModel):
|
class TranslationResponse(BaseModel):
|
||||||
"""翻译响应模型"""
|
"""Translation response model for language conversion"""
|
||||||
data: str
|
data: str
|
||||||
|
|
||||||
class ReflectionRange(str, Enum):
|
class ReflectionRange(str, Enum):
|
||||||
"""反思范围枚举"""
|
"""
|
||||||
PARTIAL = "partial" # 从检索结果中反思
|
Reflection range enumeration
|
||||||
ALL = "all" # 从整个数据库中反思
|
|
||||||
|
Defines the scope of data to be included in reflection operations.
|
||||||
|
"""
|
||||||
|
PARTIAL = "partial" # Reflect from retrieval results
|
||||||
|
ALL = "all" # Reflect from entire database
|
||||||
|
|
||||||
|
|
||||||
class ReflectionBaseline(str, Enum):
|
class ReflectionBaseline(str, Enum):
|
||||||
"""反思基线枚举"""
|
"""
|
||||||
TIME = "TIME" # 基于时间的反思
|
Reflection baseline enumeration
|
||||||
FACT = "FACT" # 基于事实的反思
|
|
||||||
HYBRID = "HYBRID" # 混合反思
|
Defines the strategy or approach used for reflection operations.
|
||||||
|
"""
|
||||||
|
TIME = "TIME" # Time-based reflection
|
||||||
|
FACT = "FACT" # Fact-based reflection
|
||||||
|
HYBRID = "HYBRID" # Hybrid reflection combining multiple strategies
|
||||||
|
|
||||||
|
|
||||||
class ReflectionConfig(BaseModel):
|
class ReflectionConfig(BaseModel):
|
||||||
"""反思引擎配置"""
|
"""
|
||||||
|
Reflection engine configuration
|
||||||
|
|
||||||
|
Defines all configuration parameters for the reflection engine including
|
||||||
|
operation modes, model settings, and evaluation criteria.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether reflection engine is enabled
|
||||||
|
iteration_period: Reflection cycle period (e.g., "3" hours)
|
||||||
|
reflexion_range: Scope of reflection (PARTIAL or ALL)
|
||||||
|
baseline: Reflection strategy (TIME, FACT, or HYBRID)
|
||||||
|
model_id: LLM model identifier for reflection operations
|
||||||
|
end_user_id: User identifier for scoped operations
|
||||||
|
output_example: Example output format for guidance
|
||||||
|
memory_verify: Enable memory verification checks
|
||||||
|
quality_assessment: Enable quality assessment evaluation
|
||||||
|
violation_handling_strategy: Strategy for handling violations
|
||||||
|
language_type: Language type for output ("zh" or "en")
|
||||||
|
"""
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
iteration_period: str = "3" # 反思周期
|
iteration_period: str = "3" # Reflection cycle period
|
||||||
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
reflexion_range: ReflectionRange = ReflectionRange.PARTIAL
|
||||||
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
baseline: ReflectionBaseline = ReflectionBaseline.TIME
|
||||||
model_id: Optional[str] = None # 模型ID
|
model_id: Optional[str] = None # Model ID
|
||||||
end_user_id: Optional[str] = None
|
end_user_id: Optional[str] = None
|
||||||
output_example: Optional[str] = None # 输出示例
|
output_example: Optional[str] = None # Output example
|
||||||
|
|
||||||
# 评估相关字段
|
# Evaluation related fields
|
||||||
memory_verify: bool = True # 记忆验证
|
memory_verify: bool = True # Memory verification
|
||||||
quality_assessment: bool = True # 质量评估
|
quality_assessment: bool = True # Quality assessment
|
||||||
violation_handling_strategy: str = "warn" # 违规处理策略
|
violation_handling_strategy: str = "warn" # Violation handling strategy
|
||||||
language_type: str = "zh"
|
language_type: str = "zh"
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@@ -85,7 +112,21 @@ class ReflectionConfig(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ReflectionResult(BaseModel):
|
class ReflectionResult(BaseModel):
|
||||||
"""反思结果"""
|
"""
|
||||||
|
Reflection operation result
|
||||||
|
|
||||||
|
Contains comprehensive information about the outcome of a reflection operation
|
||||||
|
including success status, metrics, and execution details.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
success: Whether the reflection operation succeeded
|
||||||
|
message: Descriptive message about the operation result
|
||||||
|
conflicts_found: Number of conflicts detected during reflection
|
||||||
|
conflicts_resolved: Number of conflicts successfully resolved
|
||||||
|
memories_updated: Number of memory entries updated in database
|
||||||
|
execution_time: Total time taken for the reflection operation
|
||||||
|
details: Additional details about the operation (optional)
|
||||||
|
"""
|
||||||
success: bool
|
success: bool
|
||||||
message: str
|
message: str
|
||||||
conflicts_found: int = 0
|
conflicts_found: int = 0
|
||||||
@@ -97,9 +138,22 @@ class ReflectionResult(BaseModel):
|
|||||||
|
|
||||||
class ReflectionEngine:
|
class ReflectionEngine:
|
||||||
"""
|
"""
|
||||||
自我反思引擎
|
Self-Reflection Engine
|
||||||
|
|
||||||
负责执行记忆系统的自我反思,包括冲突检测、冲突解决和记忆更新。
|
Responsible for executing memory system self-reflection operations including
|
||||||
|
conflict detection, conflict resolution, and memory updates. Supports multiple
|
||||||
|
reflection strategies and provides comprehensive result tracking.
|
||||||
|
|
||||||
|
The engine can operate in different modes:
|
||||||
|
- Time-based: Reflects on memories within specific time periods
|
||||||
|
- Fact-based: Detects and resolves factual conflicts in memories
|
||||||
|
- Hybrid: Combines multiple reflection strategies
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
config: Reflection engine configuration
|
||||||
|
neo4j_connector: Neo4j database connector
|
||||||
|
llm_client: Language model client for analysis
|
||||||
|
Various function handlers for data processing and prompt rendering
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -115,18 +169,21 @@ class ReflectionEngine:
|
|||||||
update_query: Optional[str] = None
|
update_query: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
初始化反思引擎
|
Initialize reflection engine
|
||||||
|
|
||||||
|
Sets up the reflection engine with configuration and optional dependencies.
|
||||||
|
Uses lazy initialization to avoid circular imports and optimize startup time.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: 反思引擎配置
|
config: Reflection engine configuration object
|
||||||
neo4j_connector: Neo4j 连接器(可选)
|
neo4j_connector: Neo4j connector instance (optional, will be created if not provided)
|
||||||
llm_client: LLM 客户端(可选)
|
llm_client: LLM client instance (optional, will be created if not provided)
|
||||||
get_data_func: 获取数据的函数(可选)
|
get_data_func: Function for retrieving data (optional, uses default if not provided)
|
||||||
render_evaluate_prompt_func: 渲染评估提示词的函数(可选)
|
render_evaluate_prompt_func: Function for rendering evaluation prompts (optional)
|
||||||
render_reflexion_prompt_func: 渲染反思提示词的函数(可选)
|
render_reflexion_prompt_func: Function for rendering reflection prompts (optional)
|
||||||
conflict_schema: 冲突结果 Schema(可选)
|
conflict_schema: Schema for conflict result validation (optional)
|
||||||
reflexion_schema: 反思结果 Schema(可选)
|
reflexion_schema: Schema for reflection result validation (optional)
|
||||||
update_query: 更新查询语句(可选)
|
update_query: Query string for database updates (optional)
|
||||||
"""
|
"""
|
||||||
self.config = config
|
self.config = config
|
||||||
self.neo4j_connector = neo4j_connector
|
self.neo4j_connector = neo4j_connector
|
||||||
@@ -137,14 +194,20 @@ class ReflectionEngine:
|
|||||||
self.conflict_schema = conflict_schema
|
self.conflict_schema = conflict_schema
|
||||||
self.reflexion_schema = reflexion_schema
|
self.reflexion_schema = reflexion_schema
|
||||||
self.update_query = update_query
|
self.update_query = update_query
|
||||||
self._semaphore = asyncio.Semaphore(5) # 默认并发数为5
|
self._semaphore = asyncio.Semaphore(5) # Default concurrency limit of 5
|
||||||
|
|
||||||
|
|
||||||
# 延迟导入以避免循环依赖
|
# Lazy import to avoid circular dependencies
|
||||||
self._lazy_init_done = False
|
self._lazy_init_done = False
|
||||||
|
|
||||||
def _lazy_init(self):
|
def _lazy_init(self):
|
||||||
"""延迟初始化,避免循环导入"""
|
"""
|
||||||
|
Lazy initialization to avoid circular imports
|
||||||
|
|
||||||
|
Initializes dependencies only when needed, preventing circular import issues
|
||||||
|
and optimizing startup performance. Sets up default implementations for
|
||||||
|
any components not provided during construction.
|
||||||
|
"""
|
||||||
if self._lazy_init_done:
|
if self._lazy_init_done:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -158,7 +221,7 @@ class ReflectionEngine:
|
|||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||||
elif isinstance(self.llm_client, str):
|
elif isinstance(self.llm_client, str):
|
||||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
# If llm_client is a string (model_id), use it to initialize the client
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.services.memory_config_service import MemoryConfigService
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
@@ -172,10 +235,10 @@ class ReflectionEngine:
|
|||||||
model_config = config_service.get_model_config(model_id)
|
model_config = config_service.get_model_config(model_id)
|
||||||
|
|
||||||
extra_params={
|
extra_params={
|
||||||
"temperature": 0.2, # 降低温度提高响应速度和一致性
|
"temperature": 0.2, # Lower temperature for faster response and consistency
|
||||||
"max_tokens": 600, # 限制最大token数
|
"max_tokens": 600, # Limit maximum token count
|
||||||
"top_p": 0.8, # 优化采样参数
|
"top_p": 0.8, # Optimize sampling parameters
|
||||||
"stream": False, # 确保非流式输出以获得最快响应
|
"stream": False, # Ensure non-streaming output for fastest response
|
||||||
}
|
}
|
||||||
|
|
||||||
self.llm_client = OpenAIClient(RedBearModelConfig(
|
self.llm_client = OpenAIClient(RedBearModelConfig(
|
||||||
@@ -191,7 +254,7 @@ class ReflectionEngine:
|
|||||||
if self.get_data_func is None:
|
if self.get_data_func is None:
|
||||||
self.get_data_func = get_data
|
self.get_data_func = get_data
|
||||||
|
|
||||||
# 导入get_data_statement函数
|
# Import get_data_statement function
|
||||||
if not hasattr(self, 'get_data_statement'):
|
if not hasattr(self, 'get_data_statement'):
|
||||||
self.get_data_statement = get_data_statement
|
self.get_data_statement = get_data_statement
|
||||||
|
|
||||||
@@ -223,13 +286,20 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
async def execute_reflection(self, host_id) -> ReflectionResult:
|
async def execute_reflection(self, host_id) -> ReflectionResult:
|
||||||
"""
|
"""
|
||||||
执行完整的反思流程
|
Execute complete reflection workflow
|
||||||
|
|
||||||
|
Performs the full reflection process including data retrieval, conflict detection,
|
||||||
|
conflict resolution, and memory updates. This is the main entry point for
|
||||||
|
reflection operations.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: 主机ID
|
host_id: Host identifier for scoping reflection operations
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReflectionResult: 反思结果
|
ReflectionResult: Comprehensive result of the reflection operation including
|
||||||
|
success status, conflict metrics, and execution time
|
||||||
"""
|
"""
|
||||||
# 延迟初始化
|
# Lazy initialization
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
|
|
||||||
if not self.config.enabled:
|
if not self.config.enabled:
|
||||||
@@ -243,7 +313,7 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
print(self.config.baseline, self.config.memory_verify, self.config.quality_assessment)
|
||||||
try:
|
try:
|
||||||
# 1. 获取反思数据
|
# 1. Get reflection data
|
||||||
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
reflexion_data, statement_databasets = await self._get_reflexion_data(host_id)
|
||||||
if not reflexion_data:
|
if not reflexion_data:
|
||||||
return ReflectionResult(
|
return ReflectionResult(
|
||||||
@@ -252,7 +322,7 @@ class ReflectionEngine:
|
|||||||
execution_time=asyncio.get_event_loop().time() - start_time
|
execution_time=asyncio.get_event_loop().time() - start_time
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 检测冲突(基于事实的反思)
|
# 2. Detect conflicts (fact-based reflection)
|
||||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||||
conflict_list=[]
|
conflict_list=[]
|
||||||
for i in conflict_data:
|
for i in conflict_data:
|
||||||
@@ -261,7 +331,7 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
|
|
||||||
conflicts_found=0
|
conflicts_found=0
|
||||||
# 3. 解决冲突
|
# 3. Resolve conflicts
|
||||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||||
|
|
||||||
if not solved_data:
|
if not solved_data:
|
||||||
@@ -276,7 +346,7 @@ class ReflectionEngine:
|
|||||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||||
|
|
||||||
|
|
||||||
# 4. 应用反思结果(更新记忆库)
|
# 4. Apply reflection results (update memory database)
|
||||||
memories_updated=await self._apply_reflection_results(solved_data)
|
memories_updated=await self._apply_reflection_results(solved_data)
|
||||||
|
|
||||||
execution_time = asyncio.get_event_loop().time() - start_time
|
execution_time = asyncio.get_event_loop().time() - start_time
|
||||||
@@ -302,7 +372,19 @@ class ReflectionEngine:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def Translate(self, text):
|
async def Translate(self, text):
|
||||||
# 翻译中文为英文
|
"""
|
||||||
|
Translate Chinese text to English
|
||||||
|
|
||||||
|
Uses the configured LLM to translate Chinese text to English with structured output.
|
||||||
|
Provides consistent translation format for reflection results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: Chinese text to be translated
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Translated English text
|
||||||
|
"""
|
||||||
|
# Translate Chinese to English
|
||||||
translation_messages = [
|
translation_messages = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -316,6 +398,19 @@ class ReflectionEngine:
|
|||||||
)
|
)
|
||||||
return response.data
|
return response.data
|
||||||
async def extract_translation(self,data):
|
async def extract_translation(self,data):
|
||||||
|
"""
|
||||||
|
Extract and translate reflection data to English
|
||||||
|
|
||||||
|
Processes reflection data structure and translates all Chinese content to English.
|
||||||
|
Handles nested data structures including memory verifications, quality assessments,
|
||||||
|
and reflection data while preserving the original structure.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Dictionary containing reflection data with Chinese content
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Translated data structure with English content
|
||||||
|
"""
|
||||||
end_datas={}
|
end_datas={}
|
||||||
end_datas['source_data']=await self.Translate(data['source_data'])
|
end_datas['source_data']=await self.Translate(data['source_data'])
|
||||||
quality_assessments = []
|
quality_assessments = []
|
||||||
@@ -350,6 +445,18 @@ class ReflectionEngine:
|
|||||||
return end_datas
|
return end_datas
|
||||||
|
|
||||||
async def reflection_run(self):
|
async def reflection_run(self):
|
||||||
|
"""
|
||||||
|
Execute reflection workflow with comprehensive data processing
|
||||||
|
|
||||||
|
Performs a complete reflection operation including conflict detection, resolution,
|
||||||
|
and result formatting. Supports both Chinese and English output based on
|
||||||
|
configuration settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Comprehensive reflection results including source data, memory verifications,
|
||||||
|
quality assessments, and reflection data. Results are translated to English
|
||||||
|
if language_type is set to 'en'.
|
||||||
|
"""
|
||||||
self._lazy_init()
|
self._lazy_init()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
memory_verifies_flag = self.config.memory_verify
|
memory_verifies_flag = self.config.memory_verify
|
||||||
@@ -367,7 +474,7 @@ class ReflectionEngine:
|
|||||||
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
result_data['source_data'] = "我是 2023 年春天去北京工作的,后来基本一直都在北京上班,也没怎么换过城市。不过后来公司调整,2024 年上半年我被调到上海待了差不多半年,那段时间每天都是在上海办公室打卡。当时入职资料用的还是我之前的身份信息,身份证号是 11010119950308123X,银行卡是 6222023847595898,这些一直没变。对了,其实我 从 2023 年开始就一直在北京生活,从来没有长期离开过北京,上海那段更多算是远程配合"
|
||||||
# 2. 检测冲突(基于事实的反思)
|
# 2. 检测冲突(基于事实的反思)
|
||||||
conflict_data = await self._detect_conflicts(databasets, source_data)
|
conflict_data = await self._detect_conflicts(databasets, source_data)
|
||||||
# 遍历数据提取字段
|
# Traverse data to extract fields
|
||||||
quality_assessments = []
|
quality_assessments = []
|
||||||
memory_verifies = []
|
memory_verifies = []
|
||||||
for item in conflict_data:
|
for item in conflict_data:
|
||||||
@@ -375,9 +482,9 @@ class ReflectionEngine:
|
|||||||
memory_verifies.append(item['memory_verify'])
|
memory_verifies.append(item['memory_verify'])
|
||||||
result_data['memory_verifies'] = memory_verifies
|
result_data['memory_verifies'] = memory_verifies
|
||||||
result_data['quality_assessments'] = quality_assessments
|
result_data['quality_assessments'] = quality_assessments
|
||||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
conflicts_found = 0 # Initialize as integer 0 instead of empty string
|
||||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
# Clean conflict_data, and memory_verify and quality_assessment
|
||||||
cleaned_conflict_data = []
|
cleaned_conflict_data = []
|
||||||
for item in conflict_data:
|
for item in conflict_data:
|
||||||
cleaned_item = {
|
cleaned_item = {
|
||||||
@@ -389,7 +496,7 @@ class ReflectionEngine:
|
|||||||
for item in conflict_data:
|
for item in conflict_data:
|
||||||
cleaned_data = []
|
cleaned_data = []
|
||||||
for row in item.get("data", []):
|
for row in item.get("data", []):
|
||||||
# 删除 created_at / expired_at
|
# Remove created_at / expired_at
|
||||||
cleaned_row = {
|
cleaned_row = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in row.items()
|
for k, v in row.items()
|
||||||
@@ -402,7 +509,7 @@ class ReflectionEngine:
|
|||||||
}
|
}
|
||||||
cleaned_conflict_data_.append(cleaned_item)
|
cleaned_conflict_data_.append(cleaned_item)
|
||||||
print(cleaned_conflict_data_)
|
print(cleaned_conflict_data_)
|
||||||
# 3. 解决冲突
|
# 3. Resolve conflicts
|
||||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||||
if not solved_data:
|
if not solved_data:
|
||||||
return ReflectionResult(
|
return ReflectionResult(
|
||||||
@@ -413,7 +520,7 @@ class ReflectionEngine:
|
|||||||
)
|
)
|
||||||
reflexion_data = []
|
reflexion_data = []
|
||||||
|
|
||||||
# 遍历数据提取reflexion字段
|
# Traverse data to extract reflexion fields
|
||||||
for item in solved_data:
|
for item in solved_data:
|
||||||
if 'results' in item:
|
if 'results' in item:
|
||||||
for result in item['results']:
|
for result in item['results']:
|
||||||
@@ -431,15 +538,24 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
|
|
||||||
async def extract_fields_from_json(self):
|
async def extract_fields_from_json(self):
|
||||||
"""从example.json中提取source_data和databasets字段"""
|
"""
|
||||||
|
Extract source_data and databasets fields from example.json
|
||||||
|
|
||||||
|
Reads reflection example data from the example.json file and extracts
|
||||||
|
the source data and database statements for testing and demonstration purposes.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (source_data, databasets) extracted from the example file
|
||||||
|
Returns empty lists if file reading fails
|
||||||
|
"""
|
||||||
|
|
||||||
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
prompt_dir = os.path.join(os.path.dirname(__file__), "example")
|
||||||
try:
|
try:
|
||||||
# 读取JSON文件
|
# Read JSON file
|
||||||
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
with open(prompt_dir + '/example.json', 'r', encoding='utf-8') as f:
|
||||||
data = json.loads(f.read())
|
data = json.loads(f.read())
|
||||||
|
|
||||||
# 提取memory_verify下的字段
|
# Extract fields under memory_verify
|
||||||
memory_verify = data.get("memory_verify", {})
|
memory_verify = data.get("memory_verify", {})
|
||||||
source_data = memory_verify.get("source_data", [])
|
source_data = memory_verify.get("source_data", [])
|
||||||
databasets = memory_verify.get("databasets", [])
|
databasets = memory_verify.get("databasets", [])
|
||||||
@@ -451,15 +567,17 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
async def _get_reflexion_data(self, host_id: uuid.UUID) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
获取反思数据
|
Get reflection data from database
|
||||||
|
|
||||||
根据配置的反思范围获取需要反思的记忆数据。
|
Retrieves memory data for reflection based on the configured reflection range.
|
||||||
|
Supports both partial (from retrieval results) and full (entire database) modes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: 主机ID
|
host_id: Host UUID identifier for scoping data retrieval
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Any]: 反思数据列表
|
tuple: (reflexion_data, statement_data) containing memory data for reflection
|
||||||
|
Returns empty lists if query fails
|
||||||
"""
|
"""
|
||||||
|
|
||||||
print("=== 获取反思数据 ===")
|
print("=== 获取反思数据 ===")
|
||||||
@@ -484,26 +602,29 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
async def _detect_conflicts(self, data: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
检测冲突(基于事实的反思)
|
Detect conflicts (fact-based reflection)
|
||||||
|
|
||||||
使用 LLM 分析记忆数据,检测其中的冲突。
|
Uses LLM to analyze memory data and detect conflicts within the memories.
|
||||||
|
Performs comprehensive conflict detection including memory verification and
|
||||||
|
quality assessment based on configuration settings.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 待检测的记忆数据
|
data: Memory data to be analyzed for conflicts
|
||||||
|
statement_databasets: Statement database records for context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Any]: 冲突记忆列表
|
List[Any]: List of detected conflicts with detailed analysis
|
||||||
"""
|
"""
|
||||||
if not data:
|
if not data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 数据预处理:如果数据量太少,直接返回无冲突
|
# Data preprocessing: if data is too small, return no conflicts directly
|
||||||
if len(data) < 2:
|
if len(data) < 2:
|
||||||
logging.info("数据量不足,无需检测冲突")
|
logging.info("数据量不足,无需检测冲突")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 使用转换后的数据
|
# Use converted data
|
||||||
# print("转换后的数据:", data[:2] if len(data) > 2 else data) # 只打印前2条避免日志过长
|
# print("Converted data:", data[:2] if len(data) > 2 else data) # Only print first 2 to avoid long logs
|
||||||
memory_verify = self.config.memory_verify
|
memory_verify = self.config.memory_verify
|
||||||
|
|
||||||
logging.info("====== 冲突检测开始 ======")
|
logging.info("====== 冲突检测开始 ======")
|
||||||
@@ -512,7 +633,7 @@ class ReflectionEngine:
|
|||||||
language_type=self.config.language_type
|
language_type=self.config.language_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 渲染冲突检测提示词
|
# Render conflict detection prompt
|
||||||
rendered_prompt = await self.render_evaluate_prompt_func(
|
rendered_prompt = await self.render_evaluate_prompt_func(
|
||||||
data,
|
data,
|
||||||
self.conflict_schema,
|
self.conflict_schema,
|
||||||
@@ -526,7 +647,7 @@ class ReflectionEngine:
|
|||||||
messages = [{"role": "user", "content": rendered_prompt}]
|
messages = [{"role": "user", "content": rendered_prompt}]
|
||||||
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
logging.info(f"提示词长度: {len(rendered_prompt)}")
|
||||||
|
|
||||||
# 调用 LLM 进行冲突检测
|
# Call LLM for conflict detection
|
||||||
response = await self.llm_client.response_structured(
|
response = await self.llm_client.response_structured(
|
||||||
messages,
|
messages,
|
||||||
self.conflict_schema
|
self.conflict_schema
|
||||||
@@ -539,7 +660,7 @@ class ReflectionEngine:
|
|||||||
logging.error("LLM 冲突检测输出解析失败")
|
logging.error("LLM 冲突检测输出解析失败")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 标准化返回格式
|
# Standardize return format
|
||||||
if isinstance(response, BaseModel):
|
if isinstance(response, BaseModel):
|
||||||
return [response.model_dump()]
|
return [response.model_dump()]
|
||||||
elif hasattr(response, 'dict'):
|
elif hasattr(response, 'dict'):
|
||||||
@@ -553,15 +674,17 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
async def _resolve_conflicts(self, conflicts: List[Any], statement_databasets: List[Any]) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
解决冲突
|
Resolve detected conflicts
|
||||||
|
|
||||||
使用 LLM 对检测到的冲突进行反思和解决。
|
Uses LLM to perform reflection and resolution on detected conflicts.
|
||||||
|
Processes conflicts in parallel for efficiency while respecting concurrency limits.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
conflicts: 冲突列表
|
conflicts: List of conflicts to be resolved
|
||||||
|
statement_databasets: Statement database records for context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Any]: 解决方案列表
|
List[Any]: List of resolution solutions with reflection results
|
||||||
"""
|
"""
|
||||||
if not conflicts:
|
if not conflicts:
|
||||||
return []
|
return []
|
||||||
@@ -570,12 +693,12 @@ class ReflectionEngine:
|
|||||||
baseline = self.config.baseline
|
baseline = self.config.baseline
|
||||||
memory_verify = self.config.memory_verify
|
memory_verify = self.config.memory_verify
|
||||||
|
|
||||||
# 并行处理每个冲突
|
# Process each conflict in parallel
|
||||||
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
async def _resolve_one(conflict: Any) -> Optional[Dict[str, Any]]:
|
||||||
"""解决单个冲突"""
|
"""Resolve a single conflict"""
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
# 渲染反思提示词
|
# Render reflection prompt
|
||||||
rendered_prompt = await self.render_reflexion_prompt_func(
|
rendered_prompt = await self.render_reflexion_prompt_func(
|
||||||
[conflict],
|
[conflict],
|
||||||
self.reflexion_schema,
|
self.reflexion_schema,
|
||||||
@@ -587,7 +710,7 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
messages = [{"role": "user", "content": rendered_prompt}]
|
messages = [{"role": "user", "content": rendered_prompt}]
|
||||||
|
|
||||||
# 调用 LLM 进行反思
|
# Call LLM for reflection
|
||||||
response = await self.llm_client.response_structured(
|
response = await self.llm_client.response_structured(
|
||||||
messages,
|
messages,
|
||||||
self.reflexion_schema
|
self.reflexion_schema
|
||||||
@@ -596,7 +719,7 @@ class ReflectionEngine:
|
|||||||
if not response:
|
if not response:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 标准化返回格式
|
# Standardize return format
|
||||||
if isinstance(response, BaseModel):
|
if isinstance(response, BaseModel):
|
||||||
return response.model_dump()
|
return response.model_dump()
|
||||||
elif hasattr(response, 'dict'):
|
elif hasattr(response, 'dict'):
|
||||||
@@ -610,11 +733,11 @@ class ReflectionEngine:
|
|||||||
logging.warning(f"解决单个冲突失败: {e}")
|
logging.warning(f"解决单个冲突失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 并发执行所有冲突解决任务
|
# Execute all conflict resolution tasks concurrently
|
||||||
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
tasks = [_resolve_one(conflict) for conflict in conflicts]
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||||
|
|
||||||
# 过滤掉失败的结果
|
# Filter out failed results
|
||||||
solved = [r for r in results if r is not None]
|
solved = [r for r in results if r is not None]
|
||||||
|
|
||||||
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
logging.info(f"成功解决 {len(solved)}/{len(conflicts)} 个冲突")
|
||||||
@@ -626,15 +749,16 @@ class ReflectionEngine:
|
|||||||
solved_data: List[Dict[str, Any]]
|
solved_data: List[Dict[str, Any]]
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
应用反思结果(更新记忆库)
|
Apply reflection results (update memory database)
|
||||||
|
|
||||||
将解决冲突后的记忆更新到 Neo4j 数据库中。
|
Updates the Neo4j database with resolved conflicts and reflection results.
|
||||||
|
Processes the solved data and applies changes to the memory storage system.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
solved_data: 解决方案列表
|
solved_data: List of resolved conflict solutions with reflection data
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: 成功更新的记忆数量
|
int: Number of successfully updated memory entries
|
||||||
"""
|
"""
|
||||||
changes = extract_and_process_changes(solved_data)
|
changes = extract_and_process_changes(solved_data)
|
||||||
success_count = await neo4j_data(changes)
|
success_count = await neo4j_data(changes)
|
||||||
@@ -642,80 +766,86 @@ class ReflectionEngine:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 基于时间的反思方法
|
# Time-based reflection methods
|
||||||
async def time_based_reflection(
|
async def time_based_reflection(
|
||||||
self,
|
self,
|
||||||
host_id: uuid.UUID,
|
host_id: uuid.UUID,
|
||||||
time_period: Optional[str] = None
|
time_period: Optional[str] = None
|
||||||
) -> ReflectionResult:
|
) -> ReflectionResult:
|
||||||
"""
|
"""
|
||||||
基于时间的反思
|
Time-based reflection
|
||||||
|
|
||||||
根据时间周期触发反思,检查在指定时间段内的记忆。
|
Triggers reflection based on time cycles, checking memories within
|
||||||
|
specified time periods. Uses the configured iteration period if
|
||||||
|
no specific time period is provided.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: 主机ID
|
host_id: Host UUID identifier for scoping reflection
|
||||||
time_period: 时间周期(如"三小时"),如果不提供则使用配置中的值
|
time_period: Time period (e.g., "three hours"), uses config value if not provided
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReflectionResult: 反思结果
|
ReflectionResult: Comprehensive reflection operation result
|
||||||
"""
|
"""
|
||||||
period = time_period or self.config.iteration_period
|
period = time_period or self.config.iteration_period
|
||||||
logging.info(f"执行基于时间的反思,周期: {period}")
|
logging.info(f"执行基于时间的反思,周期: {period}")
|
||||||
|
|
||||||
# 使用标准反思流程
|
# Use standard reflection workflow
|
||||||
return await self.execute_reflection(host_id)
|
return await self.execute_reflection(host_id)
|
||||||
|
|
||||||
# 基于事实的反思方法
|
# Fact-based reflection methods
|
||||||
async def fact_based_reflection(
|
async def fact_based_reflection(
|
||||||
self,
|
self,
|
||||||
host_id: uuid.UUID
|
host_id: uuid.UUID
|
||||||
) -> ReflectionResult:
|
) -> ReflectionResult:
|
||||||
"""
|
"""
|
||||||
基于事实的反思
|
Fact-based reflection
|
||||||
|
|
||||||
检测记忆中的事实冲突并解决。
|
Detects and resolves factual conflicts within memories. Analyzes
|
||||||
|
memory data for inconsistencies and contradictions that need resolution.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: 主机ID
|
host_id: Host UUID identifier for scoping reflection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReflectionResult: 反思结果
|
ReflectionResult: Comprehensive reflection operation result
|
||||||
"""
|
"""
|
||||||
logging.info("执行基于事实的反思")
|
logging.info("执行基于事实的反思")
|
||||||
|
|
||||||
# 使用标准反思流程
|
# Use standard reflection workflow
|
||||||
return await self.execute_reflection(host_id)
|
return await self.execute_reflection(host_id)
|
||||||
|
|
||||||
# 综合反思方法
|
# Comprehensive reflection methods
|
||||||
async def comprehensive_reflection(
|
async def comprehensive_reflection(
|
||||||
self,
|
self,
|
||||||
host_id: uuid.UUID
|
host_id: uuid.UUID
|
||||||
) -> ReflectionResult:
|
) -> ReflectionResult:
|
||||||
"""
|
"""
|
||||||
综合反思
|
Comprehensive reflection
|
||||||
|
|
||||||
整合基于时间和基于事实的反思策略。
|
Integrates time-based and fact-based reflection strategies based on
|
||||||
|
the configured baseline. Supports hybrid approaches that combine
|
||||||
|
multiple reflection methodologies.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
host_id: 主机ID
|
host_id: Host UUID identifier for scoping reflection
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ReflectionResult: 反思结果
|
ReflectionResult: Comprehensive reflection operation result combining
|
||||||
|
multiple strategies if using hybrid baseline
|
||||||
"""
|
"""
|
||||||
logging.info("执行综合反思")
|
logging.info("执行综合反思")
|
||||||
|
|
||||||
# 根据配置的基线选择反思策略
|
# Choose reflection strategy based on configured baseline
|
||||||
if self.config.baseline == ReflectionBaseline.TIME:
|
if self.config.baseline == ReflectionBaseline.TIME:
|
||||||
return await self.time_based_reflection(host_id)
|
return await self.time_based_reflection(host_id)
|
||||||
elif self.config.baseline == ReflectionBaseline.FACT:
|
elif self.config.baseline == ReflectionBaseline.FACT:
|
||||||
return await self.fact_based_reflection(host_id)
|
return await self.fact_based_reflection(host_id)
|
||||||
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
elif self.config.baseline == ReflectionBaseline.HYBRID:
|
||||||
# 混合策略:先执行基于时间的反思,再执行基于事实的反思
|
# Hybrid strategy: execute time-based reflection first, then fact-based reflection
|
||||||
time_result = await self.time_based_reflection(host_id)
|
time_result = await self.time_based_reflection(host_id)
|
||||||
fact_result = await self.fact_based_reflection(host_id)
|
fact_result = await self.fact_based_reflection(host_id)
|
||||||
|
|
||||||
# 合并结果
|
# Merge results
|
||||||
return ReflectionResult(
|
return ReflectionResult(
|
||||||
success=time_result.success and fact_result.success,
|
success=time_result.success and fact_result.success,
|
||||||
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
message=f"时间反思: {time_result.message}; 事实反思: {fact_result.message}",
|
||||||
|
|||||||
@@ -2,9 +2,17 @@ import json
|
|||||||
|
|
||||||
|
|
||||||
def escape_lucene_query(query: str) -> str:
|
def escape_lucene_query(query: str) -> str:
|
||||||
"""Escape Lucene special characters in a free-text query.
|
"""
|
||||||
|
Escape special characters in Lucene queries
|
||||||
|
|
||||||
This prevents ParseException when using Neo4j full-text procedures.
|
Prevents ParseException when using Neo4j full-text search procedures.
|
||||||
|
Escapes all Lucene reserved special characters and operators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: Original query string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Escaped query string safe for Lucene search
|
||||||
"""
|
"""
|
||||||
if query is None:
|
if query is None:
|
||||||
return ""
|
return ""
|
||||||
@@ -22,11 +30,21 @@ def escape_lucene_query(query: str) -> str:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
def extract_plain_query(query_input: str) -> str:
|
def extract_plain_query(query_input: str) -> str:
|
||||||
"""Extract clean, plain-text query from various input forms.
|
"""
|
||||||
|
Extract clean plain-text query from various input forms
|
||||||
|
|
||||||
|
Handles the following cases:
|
||||||
- Strips surrounding quotes and whitespace
|
- Strips surrounding quotes and whitespace
|
||||||
- If input looks like JSON, prefers the 'original' field
|
- If input looks like JSON, prefers the 'original' field
|
||||||
- Fallbacks to the raw string when parsing fails
|
- Falls back to raw string when parsing fails
|
||||||
|
- Handles dictionary-type input
|
||||||
|
- Best-effort unescape common escape characters
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_input: Query input in various forms (string, dict, etc.)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Extracted plain-text query string
|
||||||
"""
|
"""
|
||||||
if query_input is None:
|
if query_input is None:
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -4,7 +4,13 @@ from datetime import datetime
|
|||||||
|
|
||||||
def validate_date_format(date_str: str) -> bool:
|
def validate_date_format(date_str: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Validate if the date string is in the format YYYY-MM-DD.
|
Validate if date string conforms to YYYY-MM-DD format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date_str: Date string to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if format is correct, False otherwise
|
||||||
"""
|
"""
|
||||||
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
pattern = r"^\d{4}-\d{1,2}-\d{1,2}$"
|
||||||
return bool(re.match(pattern, date_str))
|
return bool(re.match(pattern, date_str))
|
||||||
@@ -41,7 +47,20 @@ def normalize_date(date_str: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def preprocess_date_string(date_str: str) -> str:
|
def preprocess_date_string(date_str: str) -> str:
|
||||||
"""预处理日期字符串,处理特殊格式"""
|
"""
|
||||||
|
预处理日期字符串,处理特殊格式
|
||||||
|
|
||||||
|
处理以下特殊格式:
|
||||||
|
- 年份后直接跟月份没有分隔符的格式(如 "20259/28")
|
||||||
|
- 无分隔符的纯数字格式(如 "20251028", "251028")
|
||||||
|
- 混合分隔符,统一为 "-"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date_str: 原始日期字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 预处理后的日期字符串,格式为 "YYYY-MM-DD" 或 "YYYY-MM"
|
||||||
|
"""
|
||||||
|
|
||||||
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
# 处理类似 "20259/28" 的格式(年份后直接跟月份没有分隔)
|
||||||
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
match = re.match(r'^(\d{4,5})[/\.\-_]?(\d{1,2})[/\.\-_]?(\d{1,2})$', date_str)
|
||||||
@@ -78,7 +97,23 @@ def preprocess_date_string(date_str: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def fallback_parse(date_str: str) -> str:
|
def fallback_parse(date_str: str) -> str:
|
||||||
"""备选解析方案"""
|
"""
|
||||||
|
备选日期解析方案
|
||||||
|
|
||||||
|
当智能解析失败时,尝试使用预定义的日期格式进行解析。
|
||||||
|
支持多种常见的日期格式,包括:
|
||||||
|
- YYYY-MM-DD, YYYY/MM/DD, YYYY.MM.DD
|
||||||
|
- YYYYMMDD, YYMMDD
|
||||||
|
- MM-DD-YYYY, MM/DD/YYYY, MM.DD.YYYY
|
||||||
|
- DD-MM-YYYY, DD/MM/YYYY, DD.MM.YYYY
|
||||||
|
- YYYY-MM, YYYY/MM, YYYY.MM
|
||||||
|
|
||||||
|
Args:
|
||||||
|
date_str: 待解析的日期字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 标准化后的日期字符串(YYYY-MM-DD格式),解析失败时返回原字符串
|
||||||
|
"""
|
||||||
|
|
||||||
# 尝试常见的日期格式[citation:4][citation:5]
|
# 尝试常见的日期格式[citation:4][citation:5]
|
||||||
formats_to_try = [
|
formats_to_try = [
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{#
|
{#
|
||||||
对话级抽取与相关性判定模板(用于剪枝加速)
|
对话级抽取与相关性判定模板(用于剪枝加速)
|
||||||
输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language
|
输入:pruning_scene, ontology_classes, dialog_text, language
|
||||||
输出:严格 JSON(不要包含任何多余文本),字段:
|
输出:严格 JSON(不要包含任何多余文本),字段:
|
||||||
- is_related: bool,是否与所选场景相关
|
- is_related: bool,是否与所选场景相关
|
||||||
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
- times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等)
|
||||||
@@ -9,64 +9,71 @@
|
|||||||
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
- contacts: [string],联系方式(电话/手机号/邮箱/微信/QQ等)
|
||||||
- addresses: [string],地址/地点相关文本
|
- addresses: [string],地址/地点相关文本
|
||||||
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
- keywords: [string],其它有助于保留的重要关键词(与场景强相关的术语)
|
||||||
|
- preserve_keywords: [string],必须保留的情绪/兴趣/爱好/个人偏好相关词或短语片段
|
||||||
|
|
||||||
要求:
|
要求:
|
||||||
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
- 必须只输出上述 JSON,且键名一致;不得输出解释、前后缀;不得包含注释。
|
||||||
- times/ids/amounts/contacts/addresses/keywords 仅抽取原文片段或规范化后的简单字符串。
|
- times/ids/amounts/contacts/addresses/keywords/preserve_keywords 仅抽取原文片段或规范化后的简单字符串。
|
||||||
- 仅输出上述键;避免多余解释或字段。
|
- 仅输出上述键;避免多余解释或字段。
|
||||||
#}
|
#}
|
||||||
|
|
||||||
{# ── 内置场景的固定说明 ── #}
|
{# ── 确定场景说明 ── #}
|
||||||
{% set builtin_scene_instructions = {
|
{% if ontology_classes and ontology_classes | length > 0 %}
|
||||||
'education': {
|
{% if language == 'en' %}
|
||||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
{% set custom_types_str = ontology_classes | join(', ') %}
|
||||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
{% set instruction = 'Scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
||||||
},
|
|
||||||
'online_service': {
|
|
||||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
|
||||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
|
||||||
},
|
|
||||||
'outbound': {
|
|
||||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
|
||||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
|
||||||
}
|
|
||||||
} %}
|
|
||||||
|
|
||||||
{# ── 确定最终使用的场景说明 ── #}
|
|
||||||
{% if is_builtin_scene %}
|
|
||||||
{# 内置专门场景:使用固定说明 #}
|
|
||||||
{% set scene_key = pruning_scene %}
|
|
||||||
{% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %}
|
|
||||||
{% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %}
|
|
||||||
{% set custom_types_str = '' %}
|
|
||||||
{% else %}
|
|
||||||
{# 自定义场景:使用场景名称 + 本体类型列表构建说明 #}
|
|
||||||
{% if ontology_classes and ontology_classes | length > 0 %}
|
|
||||||
{% if language == 'en' %}
|
|
||||||
{% set custom_types_str = ontology_classes | join(', ') %}
|
|
||||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %}
|
|
||||||
{% else %}
|
|
||||||
{% set custom_types_str = ontology_classes | join('、') %}
|
|
||||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
|
||||||
{% endif %}
|
|
||||||
{% else %}
|
{% else %}
|
||||||
{# 无本体类型时退化为通用说明 #}
|
{% set custom_types_str = ontology_classes | join('、') %}
|
||||||
{% if language == 'en' %}
|
{% set instruction = '场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %}
|
||||||
{% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
{% endif %}
|
||||||
{% else %}
|
{% else %}
|
||||||
{% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
{% if language == 'en' %}
|
||||||
{% endif %}
|
|
||||||
{% set custom_types_str = '' %}
|
{% set custom_types_str = '' %}
|
||||||
|
{% set instruction = 'Scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %}
|
||||||
|
{% else %}
|
||||||
|
{% set custom_types_str = '' %}
|
||||||
|
{% set instruction = '场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% if language == "zh" %}
|
{% if language == "zh" %}
|
||||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
你是一个对话内容分析助手。请对下方对话全文进行一次性分析,完成两项任务:
|
||||||
|
1. 判断对话是否与指定场景相关;
|
||||||
|
2. 从对话中抽取所有需要保留的重要信息片段。
|
||||||
|
|
||||||
场景说明:{{ instruction }}
|
场景说明:{{ instruction }}
|
||||||
{% if not is_builtin_scene and custom_types_str %}
|
{% if custom_types_str %}
|
||||||
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
---
|
||||||
|
【必须保留的内容(不可删除)】
|
||||||
|
以下类型的内容无论是否与场景直接相关,都必须保留,请将其关键词/短语抽取到对应字段:
|
||||||
|
- 时间信息:日期、时间点、时间段、有效期 → times 字段
|
||||||
|
- 编号信息:学号、工号、订单号、申请号、账号、ID → ids 字段
|
||||||
|
- 金额信息:价格、费用、金额(含货币符号或单位) → amounts 字段
|
||||||
|
- 联系方式:电话、手机号、邮箱、微信、QQ → contacts 字段
|
||||||
|
- 地址信息:地点、地址、位置 → addresses 字段
|
||||||
|
- 场景关键词:与场景强相关的专业术语、事件名称 → keywords 字段
|
||||||
|
- **情绪与情感**:喜悦、悲伤、愤怒、焦虑、开心、难过、委屈、兴奋、害怕、担心、压力、感动等情绪表达 → preserve_keywords 字段
|
||||||
|
- **兴趣与爱好**:喜欢、热爱、爱好、擅长、享受、沉迷、着迷、讨厌某事物等个人偏好表达 → preserve_keywords 字段
|
||||||
|
- **个人观点与态度**:对某事物的明确看法、评价、立场 → preserve_keywords 字段
|
||||||
|
|
||||||
|
【可以删除的内容】
|
||||||
|
以下类型的内容属于低价值信息,可以在剪枝时删除:
|
||||||
|
- 纯寒暄问候:如"你好"、"在吗"、"拜拜"、"嗯"、"好的"、"哦"等无实质内容的短语
|
||||||
|
- 纯表情/符号:如"[微笑]"、"😊"、"哈哈"等
|
||||||
|
- 重复确认:如"对对对"、"是的是的"、"嗯嗯嗯"等无新增信息的重复
|
||||||
|
- 无意义填充:如"啊"、"呢"、"嘛"等语气词单独成句
|
||||||
|
|
||||||
|
**注意:即使消息很短,只要包含情绪、兴趣、爱好、个人观点等有价值信息,就必须保留,不得删除。**
|
||||||
|
例如:
|
||||||
|
- "我好开心呀" → 包含情绪(开心),必须保留,preserve_keywords 中加入"开心"
|
||||||
|
- "好喜欢打羽毛球呀" → 包含兴趣爱好(喜欢打羽毛球),必须保留,preserve_keywords 中加入"喜欢打羽毛球"
|
||||||
|
- "我好难过" → 包含情绪(难过),必须保留,preserve_keywords 中加入"难过"
|
||||||
|
- "太好啦!看到你开心,我也跟着心情亮起来" → 包含情绪,必须保留,preserve_keywords 中加入"开心"
|
||||||
|
|
||||||
|
---
|
||||||
对话全文:
|
对话全文:
|
||||||
"""
|
"""
|
||||||
{{ dialog_text }}
|
{{ dialog_text }}
|
||||||
@@ -80,15 +87,46 @@
|
|||||||
"amounts": [<string>...],
|
"amounts": [<string>...],
|
||||||
"contacts": [<string>...],
|
"contacts": [<string>...],
|
||||||
"addresses": [<string>...],
|
"addresses": [<string>...],
|
||||||
"keywords": [<string>...]
|
"keywords": [<string>...],
|
||||||
|
"preserve_keywords": [<string>...]
|
||||||
}
|
}
|
||||||
{% else %}
|
{% else %}
|
||||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
You are a dialogue content analysis assistant. Please analyze the full dialogue below in one pass and complete two tasks:
|
||||||
|
1. Determine whether the dialogue is relevant to the specified scene;
|
||||||
|
2. Extract all important information fragments that must be preserved.
|
||||||
|
|
||||||
Scenario Description: {{ instruction }}
|
Scenario Description: {{ instruction }}
|
||||||
{% if not is_builtin_scene and custom_types_str %}
|
{% if custom_types_str %}
|
||||||
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true).
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
---
|
||||||
|
[MUST PRESERVE (cannot be deleted)]
|
||||||
|
The following types of content must always be preserved regardless of scene relevance. Extract their keywords/phrases into the corresponding fields:
|
||||||
|
- Time information: dates, time points, durations, expiry dates → times field
|
||||||
|
- ID information: student IDs, employee IDs, order numbers, application numbers, account IDs → ids field
|
||||||
|
- Amount information: prices, fees, amounts (with currency symbols or units) → amounts field
|
||||||
|
- Contact information: phone numbers, emails, WeChat, QQ → contacts field
|
||||||
|
- Address information: locations, addresses, places → addresses field
|
||||||
|
- Scene keywords: professional terms and event names strongly related to the scene → keywords field
|
||||||
|
- **Emotions and feelings**: joy, sadness, anger, anxiety, happiness, sadness, excitement, fear, worry, stress, being moved, etc. → preserve_keywords field
|
||||||
|
- **Interests and hobbies**: likes, loves, hobbies, good at, enjoys, obsessed with, hates something, personal preferences → preserve_keywords field
|
||||||
|
- **Personal opinions and attitudes**: clear views, evaluations, or stances on something → preserve_keywords field
|
||||||
|
|
||||||
|
[CAN BE DELETED]
|
||||||
|
The following types of content are low-value and can be removed during pruning:
|
||||||
|
- Pure greetings: e.g., "hello", "are you there", "bye", "ok", "yeah" — short phrases with no substantive content
|
||||||
|
- Pure emojis/symbols: e.g., "[smile]", "😊", "haha"
|
||||||
|
- Repetitive confirmations: e.g., "yes yes yes", "right right", "uh huh" — repetitions with no new information
|
||||||
|
- Meaningless fillers: standalone interjections like "ah", "well", "hmm"
|
||||||
|
|
||||||
|
**Note: Even if a message is short, if it contains emotions, interests, hobbies, or personal opinions, it MUST be preserved.**
|
||||||
|
Examples:
|
||||||
|
- "I'm so happy!" → contains emotion (happy), must preserve; add "happy" to preserve_keywords
|
||||||
|
- "I love playing badminton!" → contains interest (love playing badminton), must preserve; add "love playing badminton" to preserve_keywords
|
||||||
|
- "I feel so sad" → contains emotion (sad), must preserve; add "sad" to preserve_keywords
|
||||||
|
|
||||||
|
---
|
||||||
Full Dialogue:
|
Full Dialogue:
|
||||||
"""
|
"""
|
||||||
{{ dialog_text }}
|
{{ dialog_text }}
|
||||||
@@ -102,6 +140,7 @@ Output strict JSON only (fixed keys, order doesn't matter):
|
|||||||
"amounts": [<string>...],
|
"amounts": [<string>...],
|
||||||
"contacts": [<string>...],
|
"contacts": [<string>...],
|
||||||
"addresses": [<string>...],
|
"addresses": [<string>...],
|
||||||
"keywords": [<string>...]
|
"keywords": [<string>...],
|
||||||
|
"preserve_keywords": [<string>...]
|
||||||
}
|
}
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|||||||
@@ -2,15 +2,15 @@ import os
|
|||||||
from jinja2 import Environment, FileSystemLoader
|
from jinja2 import Environment, FileSystemLoader
|
||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
|
||||||
# Setup Jinja2 environment
|
# Setup Jinja2 environment
|
||||||
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
prompt_dir = os.path.join(os.path.dirname(__file__), "prompts")
|
||||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||||
|
|
||||||
|
|
||||||
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
||||||
baseline: str = "TIME",
|
baseline: str = "TIME",
|
||||||
memory_verify: bool = False,quality_assessment:bool = False,
|
memory_verify: bool = False, quality_assessment: bool = False,
|
||||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
statement_databasets=None, language_type: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
Renders the evaluate prompt using the evaluate_optimized.jinja2 template.
|
||||||
|
|
||||||
@@ -23,6 +23,8 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
|||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as string
|
Rendered prompt content as string
|
||||||
"""
|
"""
|
||||||
|
if statement_databasets is None:
|
||||||
|
statement_databasets = []
|
||||||
template = prompt_env.get_template("evaluate.jinja2")
|
template = prompt_env.get_template("evaluate.jinja2")
|
||||||
|
|
||||||
# Convert Pydantic model to JSON schema if needed
|
# Convert Pydantic model to JSON schema if needed
|
||||||
@@ -46,7 +48,7 @@ async def render_evaluate_prompt(evaluate_data: List[Any], schema: Any,
|
|||||||
|
|
||||||
|
|
||||||
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: str, memory_verify: bool = False,
|
||||||
statement_databasets: List[str] = [],language_type:str = "zh") -> str:
|
statement_databasets=None, language_type: str = "zh") -> str:
|
||||||
"""
|
"""
|
||||||
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
Renders the reflexion prompt using the reflexion_optimized.jinja2 template.
|
||||||
|
|
||||||
@@ -58,6 +60,8 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
|||||||
Returns:
|
Returns:
|
||||||
Rendered prompt content as a string.
|
Rendered prompt content as a string.
|
||||||
"""
|
"""
|
||||||
|
if statement_databasets is None:
|
||||||
|
statement_databasets = []
|
||||||
template = prompt_env.get_template("reflexion.jinja2")
|
template = prompt_env.get_template("reflexion.jinja2")
|
||||||
|
|
||||||
# Convert Pydantic model to JSON schema if needed
|
# Convert Pydantic model to JSON schema if needed
|
||||||
@@ -69,7 +73,7 @@ async def render_reflexion_prompt(data: Dict[str, Any], schema: Any, baseline: s
|
|||||||
json_schema = schema
|
json_schema = schema
|
||||||
|
|
||||||
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
rendered_prompt = template.render(data=data, json_schema=json_schema,
|
||||||
baseline=baseline,memory_verify=memory_verify,
|
baseline=baseline, memory_verify=memory_verify,
|
||||||
statement_databasets=statement_databasets,language_type=language_type)
|
statement_databasets=statement_databasets, language_type=language_type)
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|||||||
@@ -1,23 +1,19 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import time
|
from typing import Any, Dict, Optional, TypeVar
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, TypeVar
|
from langchain_aws import ChatBedrock
|
||||||
|
from langchain_community.chat_models import ChatTongyi
|
||||||
|
from langchain_core.embeddings import Embeddings
|
||||||
|
from langchain_core.language_models import BaseLLM
|
||||||
|
from langchain_ollama import OllamaLLM
|
||||||
|
from langchain_openai import ChatOpenAI, OpenAI
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
import httpx
|
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.models.models_model import ModelProvider, ModelType
|
from app.models.models_model import ModelProvider, ModelType
|
||||||
from langchain_community.document_compressors import JinaRerank
|
|
||||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
||||||
from langchain_core.embeddings import Embeddings
|
|
||||||
from langchain_core.language_models import BaseLanguageModel, BaseLLM
|
|
||||||
from langchain_core.outputs import Generation, LLMResult
|
|
||||||
from langchain_core.retrievers import BaseRetriever
|
|
||||||
from langchain_core.runnables import RunnableSerializable
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
@@ -163,25 +159,17 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy
|
|||||||
|
|
||||||
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
# dashscope 的 omni 模型使用 OpenAI 兼容模式
|
||||||
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
if provider == ModelProvider.DASHSCOPE and config.is_omni:
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
|
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]:
|
||||||
if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] :
|
|
||||||
if type == ModelType.LLM:
|
if type == ModelType.LLM:
|
||||||
from langchain_openai import OpenAI
|
|
||||||
return OpenAI
|
return OpenAI
|
||||||
elif type == ModelType.CHAT:
|
elif type == ModelType.CHAT:
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
return ChatOpenAI
|
return ChatOpenAI
|
||||||
elif provider == ModelProvider.DASHSCOPE:
|
elif provider == ModelProvider.DASHSCOPE:
|
||||||
from langchain_community.chat_models import ChatTongyi
|
|
||||||
return ChatTongyi
|
return ChatTongyi
|
||||||
elif provider == ModelProvider.OLLAMA:
|
elif provider == ModelProvider.OLLAMA:
|
||||||
from langchain_ollama import OllamaLLM
|
|
||||||
return OllamaLLM
|
return OllamaLLM
|
||||||
elif provider == ModelProvider.BEDROCK:
|
elif provider == ModelProvider.BEDROCK:
|
||||||
from langchain_aws import ChatBedrock, ChatBedrockConverse
|
|
||||||
|
|
||||||
return ChatBedrock
|
return ChatBedrock
|
||||||
else:
|
else:
|
||||||
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED)
|
||||||
|
|||||||
@@ -94,72 +94,16 @@ def knowledge_retrieval(
|
|||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=kb_id)
|
||||||
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
if db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1:
|
||||||
# Process shared knowledge base
|
# Process shared knowledge base
|
||||||
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
rs, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db,
|
db=db,
|
||||||
knowledgeshare_id=db_knowledge.id)
|
db_knowledge=db_knowledge,
|
||||||
if knowledgeshare:
|
kb_config={**kb_config, "query": query}, # 或改为单独参数
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db,
|
file_names_filter=file_names_filter,
|
||||||
knowledge_id=knowledgeshare.source_kb_id)
|
chat_model=chat_model,
|
||||||
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
embedding_model=embedding_model,
|
||||||
continue
|
kb_ids=kb_ids,
|
||||||
else:
|
workspace_ids=workspace_ids,
|
||||||
continue
|
)
|
||||||
|
|
||||||
if str(db_knowledge.id) not in kb_ids:
|
|
||||||
kb_ids.append(str(db_knowledge.id))
|
|
||||||
if str(db_knowledge.workspace_id) not in workspace_ids:
|
|
||||||
workspace_ids.append(str(db_knowledge.workspace_id))
|
|
||||||
if not chat_model:
|
|
||||||
chat_model = Base(
|
|
||||||
key=db_knowledge.llm.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.llm.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.llm.api_keys[0].api_base
|
|
||||||
)
|
|
||||||
if not embedding_model:
|
|
||||||
embedding_model = OpenAIEmbed(
|
|
||||||
key=db_knowledge.embedding.api_keys[0].api_key,
|
|
||||||
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
|
||||||
base_url=db_knowledge.embedding.api_keys[0].api_base
|
|
||||||
)
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
# Retrieve according to the configured retrieval type
|
|
||||||
match kb_config["retrieve_type"]:
|
|
||||||
case "participle":
|
|
||||||
rs = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case "semantic":
|
|
||||||
rs = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
case _: # hybrid
|
|
||||||
rs1 = vector_service.search_by_vector(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["vector_similarity_weight"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
rs2 = vector_service.search_by_full_text(
|
|
||||||
query=query,
|
|
||||||
top_k=kb_config["top_k"],
|
|
||||||
score_threshold=kb_config["similarity_threshold"],
|
|
||||||
file_names_filter=file_names_filter
|
|
||||||
)
|
|
||||||
|
|
||||||
# Deduplication of merge results
|
|
||||||
seen_ids = set()
|
|
||||||
unique_rs = []
|
|
||||||
for doc in rs1 + rs2:
|
|
||||||
if doc.metadata["doc_id"] not in seen_ids:
|
|
||||||
seen_ids.add(doc.metadata["doc_id"])
|
|
||||||
unique_rs.append(doc)
|
|
||||||
rs = unique_rs
|
|
||||||
|
|
||||||
all_results.extend(rs)
|
all_results.extend(rs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -199,6 +143,115 @@ def knowledge_retrieval(
|
|||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
def _retrieve_for_knowledge(
|
||||||
|
db: Session,
|
||||||
|
db_knowledge,
|
||||||
|
kb_config: Dict[str, Any],
|
||||||
|
file_names_filter: list[str],
|
||||||
|
chat_model: Base | None,
|
||||||
|
embedding_model: OpenAIEmbed | None,
|
||||||
|
kb_ids: list[str],
|
||||||
|
workspace_ids: list[str],
|
||||||
|
) -> tuple[list[DocumentChunk], Base | None, OpenAIEmbed | None]:
|
||||||
|
"""
|
||||||
|
对单个知识库进行检索。
|
||||||
|
- 处理共享知识库
|
||||||
|
- 如果是 Folder,则递归检索其子知识库
|
||||||
|
- 返回本知识库(含子库)的检索结果和可能更新后的 chat_model/embedding_model
|
||||||
|
"""
|
||||||
|
results: list[DocumentChunk] = []
|
||||||
|
|
||||||
|
# 处理共享知识库
|
||||||
|
if db_knowledge.permission_id.lower() == knowledge_model.PermissionType.Share:
|
||||||
|
knowledgeshare = knowledgeshare_repository.get_knowledgeshare_by_id(db=db, knowledgeshare_id=db_knowledge.id)
|
||||||
|
if not knowledgeshare:
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db, knowledge_id=knowledgeshare.source_kb_id)
|
||||||
|
if not (db_knowledge and db_knowledge.chunk_num > 0 and db_knowledge.status == 1):
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# Folder 类型:递归处理子知识库
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
# 递归处理子知识库(子库如果还是 Folder,会继续往下)
|
||||||
|
child_results, chat_model, embedding_model = _retrieve_for_knowledge(
|
||||||
|
db=db,
|
||||||
|
db_knowledge=child,
|
||||||
|
kb_config=kb_config,
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
chat_model=chat_model,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
workspace_ids=workspace_ids,
|
||||||
|
)
|
||||||
|
results.extend(child_results)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
# 普通知识库,执行一次检索
|
||||||
|
if str(db_knowledge.id) not in kb_ids:
|
||||||
|
kb_ids.append(str(db_knowledge.id))
|
||||||
|
if str(db_knowledge.workspace_id) not in workspace_ids:
|
||||||
|
workspace_ids.append(str(db_knowledge.workspace_id))
|
||||||
|
|
||||||
|
if not chat_model:
|
||||||
|
chat_model = Base(
|
||||||
|
key=db_knowledge.llm.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.llm.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.llm.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
if not embedding_model:
|
||||||
|
embedding_model = OpenAIEmbed(
|
||||||
|
key=db_knowledge.embedding.api_keys[0].api_key,
|
||||||
|
model_name=db_knowledge.embedding.api_keys[0].model_name,
|
||||||
|
base_url=db_knowledge.embedding.api_keys[0].api_base,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
|
||||||
|
match kb_config["retrieve_type"]:
|
||||||
|
case "participle":
|
||||||
|
rs = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"], # 或者直接把 query 作为额外参数传进来
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case "semantic":
|
||||||
|
rs = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
rs1 = vector_service.search_by_vector(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["vector_similarity_weight"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
rs2 = vector_service.search_by_full_text(
|
||||||
|
query=kb_config["query"],
|
||||||
|
top_k=kb_config["top_k"],
|
||||||
|
score_threshold=kb_config["similarity_threshold"],
|
||||||
|
file_names_filter=file_names_filter,
|
||||||
|
)
|
||||||
|
# 合并去重
|
||||||
|
seen_ids = set()
|
||||||
|
unique_rs = []
|
||||||
|
for doc in rs1 + rs2:
|
||||||
|
if doc.metadata["doc_id"] not in seen_ids:
|
||||||
|
seen_ids.add(doc.metadata["doc_id"])
|
||||||
|
unique_rs.append(doc)
|
||||||
|
rs = unique_rs
|
||||||
|
|
||||||
|
results.extend(rs)
|
||||||
|
return results, chat_model, embedding_model
|
||||||
|
|
||||||
|
|
||||||
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
def rerank(db: Session, reranker_id: uuid, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,11 +4,12 @@ RAG chunk analysis utilities.
|
|||||||
|
|
||||||
from .chunk_summary import generate_chunk_summary
|
from .chunk_summary import generate_chunk_summary
|
||||||
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
|
from .chunk_tags import extract_chunk_tags, extract_chunk_persona
|
||||||
from .chunk_insight import generate_chunk_insight
|
from .chunk_insight import generate_chunk_insight, generate_chunk_insight_sections
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"generate_chunk_summary",
|
"generate_chunk_summary",
|
||||||
"extract_chunk_tags",
|
"extract_chunk_tags",
|
||||||
"extract_chunk_persona",
|
"extract_chunk_persona",
|
||||||
"generate_chunk_insight",
|
"generate_chunk_insight",
|
||||||
|
"generate_chunk_insight_sections",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,213 +1,207 @@
|
|||||||
"""
|
"""
|
||||||
Generate insights from RAG chunks.
|
Generate memory insight report for RAG chunks using memory_insight.jinja2 prompt template.
|
||||||
|
|
||||||
This module provides functionality to analyze chunk content and generate insights using LLM.
|
The memory_insight.jinja2 template produces a four-section report:
|
||||||
|
【总体概述】 → memory_insight
|
||||||
|
【行为模式】 → behavior_pattern
|
||||||
|
【关键发现】 → key_findings
|
||||||
|
【成长轨迹】 → growth_trajectory
|
||||||
|
|
||||||
|
generate_chunk_insight() returns the full raw text (stored in end_user.memory_insight).
|
||||||
|
generate_chunk_insight_sections() returns a dict with all four fields for richer storage.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any, Dict, List
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
class ChunkInsight(BaseModel):
|
# ── Domain analysis helpers (kept for building prompt inputs) ─────────────────
|
||||||
"""Pydantic model for chunk insight."""
|
|
||||||
insight: str = Field(..., description="对chunk内容的深度洞察分析")
|
|
||||||
|
|
||||||
|
async def _classify_domain(chunk: str, llm_client) -> str:
|
||||||
|
"""Classify a single chunk into a domain category."""
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
class DomainClassification(BaseModel):
|
class _Domain(BaseModel):
|
||||||
"""Pydantic model for domain classification."""
|
domain: str = Field(..., description="领域分类")
|
||||||
domain: str = Field(
|
|
||||||
...,
|
|
||||||
description="内容所属的领域分类",
|
|
||||||
examples=["技术", "商业", "教育", "生活", "娱乐", "健康", "其他"]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_chunk_domain(chunk: str) -> str:
|
|
||||||
"""
|
|
||||||
Classify a chunk into a specific domain.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk: Chunk content string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Domain name
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
llm_client = _get_llm_client()
|
prompt = (
|
||||||
|
"请将以下文本归类到最合适的领域(技术/商业/教育/生活/娱乐/健康/其他)。\n\n"
|
||||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
f"文本: {chunk[:500]}\n\n直接返回领域名称。"
|
||||||
|
|
||||||
可选领域及其关键词:
|
|
||||||
- 技术:编程、软件、硬件、算法、数据、网络、系统、开发、工程等
|
|
||||||
- 商业:市场、销售、管理、财务、投资、创业、营销、战略等
|
|
||||||
- 教育:学习、课程、培训、教学、知识、技能、考试、研究等
|
|
||||||
- 生活:日常、家庭、饮食、购物、旅行、休闲、娱乐等
|
|
||||||
- 娱乐:游戏、电影、音乐、体育、艺术、文化等
|
|
||||||
- 健康:医疗、养生、运动、心理、保健、疾病等
|
|
||||||
- 其他:无法归入以上类别的内容
|
|
||||||
|
|
||||||
文本内容: {chunk[:500]}...
|
|
||||||
|
|
||||||
请直接返回最合适的领域名称。"""
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": "你是一个专业的文本分类助手。请仔细分析文本内容,选择最合适的领域分类。"},
|
|
||||||
{"role": "user", "content": prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
classification = await llm_client.response_structured(
|
|
||||||
messages=messages,
|
|
||||||
response_model=DomainClassification
|
|
||||||
)
|
)
|
||||||
|
result = await llm_client.response_structured(
|
||||||
return classification.domain if classification else "其他"
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
response_model=_Domain,
|
||||||
except Exception as e:
|
)
|
||||||
business_logger.error(f"分类chunk领域失败: {str(e)}")
|
return result.domain if result else "其他"
|
||||||
|
except Exception:
|
||||||
return "其他"
|
return "其他"
|
||||||
|
|
||||||
|
|
||||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
async def _build_insight_inputs(
|
||||||
|
chunks: List[str],
|
||||||
|
max_chunks: int,
|
||||||
|
end_user_id: Optional[str],
|
||||||
|
) -> Dict[str, Optional[str]]:
|
||||||
"""
|
"""
|
||||||
Analyze the domain distribution of chunks.
|
Derive domain_distribution, active_periods, social_connections strings
|
||||||
|
to feed into the memory_insight.jinja2 template.
|
||||||
Args:
|
|
||||||
chunks: List of chunk content strings
|
|
||||||
max_chunks: Maximum number of chunks to analyze
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of domain -> percentage
|
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
llm_client = _get_llm_client(end_user_id)
|
||||||
return {}
|
chunks_sample = chunks[:max_chunks]
|
||||||
|
|
||||||
try:
|
# Domain distribution
|
||||||
# 限制分析的chunk数量
|
domain_counts: Counter = Counter()
|
||||||
chunks_to_analyze = chunks[:max_chunks]
|
for chunk in chunks_sample:
|
||||||
|
domain = await _classify_domain(chunk, llm_client)
|
||||||
|
domain_counts[domain] += 1
|
||||||
|
|
||||||
# 为每个chunk分类
|
total = sum(domain_counts.values()) or 1
|
||||||
domain_counts = Counter()
|
domain_distribution = ", ".join(
|
||||||
for chunk in chunks_to_analyze:
|
f"{d}({c / total:.0%})" for d, c in domain_counts.most_common(3)
|
||||||
domain = await classify_chunk_domain(chunk)
|
)
|
||||||
domain_counts[domain] += 1
|
|
||||||
|
|
||||||
# 计算百分比
|
return {
|
||||||
total = sum(domain_counts.values())
|
"domain_distribution": domain_distribution,
|
||||||
domain_distribution = {
|
"active_periods": None, # RAG模式暂无时间维度数据
|
||||||
domain: count / total
|
"social_connections": None, # RAG模式暂无社交关联数据
|
||||||
for domain, count in domain_counts.items()
|
}
|
||||||
}
|
|
||||||
|
|
||||||
# 按百分比降序排序
|
|
||||||
return dict(sorted(domain_distribution.items(), key=lambda x: x[1], reverse=True))
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
business_logger.error(f"分析领域分布失败: {str(e)}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
# ── Section parser ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_ZH_SECTIONS = {
|
||||||
|
"memory_insight": r"【总体概述】(.*?)(?=【|$)",
|
||||||
|
"behavior_pattern": r"【行为模式】(.*?)(?=【|$)",
|
||||||
|
"key_findings": r"【关键发现】(.*?)(?=【|$)",
|
||||||
|
"growth_trajectory": r"【成长轨迹】(.*?)(?=【|$)",
|
||||||
|
}
|
||||||
|
|
||||||
|
_EN_SECTIONS = {
|
||||||
|
"memory_insight": r"【Overview】(.*?)(?=【|$)",
|
||||||
|
"behavior_pattern": r"【Behavior Pattern】(.*?)(?=【|$)",
|
||||||
|
"key_findings": r"【Key Findings】(.*?)(?=【|$)",
|
||||||
|
"growth_trajectory": r"【Growth Trajectory】(.*?)(?=【|$)",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sections(text: str, language: str = "zh") -> Dict[str, str]:
|
||||||
|
"""Extract the four sections from the LLM output."""
|
||||||
|
patterns = _ZH_SECTIONS if language == "zh" else _EN_SECTIONS
|
||||||
|
result = {}
|
||||||
|
for key, pattern in patterns.items():
|
||||||
|
match = re.search(pattern, text, re.DOTALL)
|
||||||
|
result[key] = match.group(1).strip() if match else ""
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def generate_chunk_insight(
|
||||||
|
chunks: List[str],
|
||||||
|
max_chunks: int = 15,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
language: str = "zh",
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Generate insights from the given chunks.
|
Generate a memory insight report from RAG chunks.
|
||||||
|
|
||||||
Args:
|
Returns the full raw report text (suitable for end_user.memory_insight).
|
||||||
chunks: List of chunk content strings
|
Use generate_chunk_insight_sections() when you need all four dimensions.
|
||||||
max_chunks: Maximum number of chunks to analyze
|
"""
|
||||||
|
sections = await generate_chunk_insight_sections(
|
||||||
|
chunks=chunks,
|
||||||
|
max_chunks=max_chunks,
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
return sections.get("memory_insight") or sections.get("_raw", "洞察生成失败")
|
||||||
|
|
||||||
Returns:
|
|
||||||
A comprehensive insight report
|
async def generate_chunk_insight_sections(
|
||||||
|
chunks: List[str],
|
||||||
|
max_chunks: int = 15,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
language: str = "zh",
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Generate a four-section memory insight report from RAG chunks.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
memory_insight, behavior_pattern, key_findings, growth_trajectory
|
||||||
|
(plus '_raw' containing the full LLM output for debugging)
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
business_logger.warning("没有提供chunk内容用于生成洞察")
|
business_logger.warning("没有提供chunk内容用于生成洞察")
|
||||||
return "暂无足够数据生成洞察报告"
|
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||||
|
empty["_raw"] = "暂无足够数据生成洞察报告"
|
||||||
|
return empty
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 分析领域分布
|
from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt
|
||||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
|
||||||
|
|
||||||
# 2. 统计基本信息
|
# Build template inputs from chunk analysis
|
||||||
total_chunks = len(chunks)
|
inputs = await _build_insight_inputs(chunks, max_chunks, end_user_id)
|
||||||
avg_length = sum(len(chunk) for chunk in chunks) / total_chunks if total_chunks > 0 else 0
|
|
||||||
|
|
||||||
# 3. 构建洞察prompt
|
rendered_prompt = await render_memory_insight_prompt(
|
||||||
prompt_parts = []
|
domain_distribution=inputs["domain_distribution"],
|
||||||
|
active_periods=inputs["active_periods"],
|
||||||
|
social_connections=inputs["social_connections"],
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
if domain_dist:
|
messages = [{"role": "user", "content": rendered_prompt}]
|
||||||
top_domains = ", ".join([f"{k}({v:.0%})" for k, v in list(domain_dist.items())[:3]])
|
llm_client = _get_llm_client(end_user_id)
|
||||||
prompt_parts.append(f"- 内容领域分布: {top_domains}")
|
|
||||||
|
|
||||||
prompt_parts.append(f"- 内容规模: 共{total_chunks}个知识片段,平均长度{avg_length:.0f}字")
|
|
||||||
|
|
||||||
# 添加部分chunk内容作为参考
|
|
||||||
sample_chunks = chunks[:5]
|
|
||||||
sample_content = "\n".join([f"示例{i+1}: {chunk[:200]}..." for i, chunk in enumerate(sample_chunks)])
|
|
||||||
prompt_parts.append(f"\n内容示例:\n{sample_content}")
|
|
||||||
|
|
||||||
system_prompt = """你是一位专业的知识内容分析师。你的任务是根据提供的信息,生成一段简洁、有洞察力的分析报告。
|
|
||||||
|
|
||||||
重要规则:
|
|
||||||
1. 报告需要将所有要点流畅地串联成一个段落
|
|
||||||
2. 语言风格要专业、客观,同时易于理解
|
|
||||||
3. 不要添加任何额外的解释或标题,直接输出报告内容
|
|
||||||
4. 基于提供的数据和示例内容进行分析,不要编造信息
|
|
||||||
5. 重点关注内容的主题、特点和价值
|
|
||||||
6. 报告长度控制在150-200字
|
|
||||||
|
|
||||||
例如,如果输入是:
|
|
||||||
- 内容领域分布: 技术(60%), 商业(25%), 教育(15%)
|
|
||||||
- 内容规模: 共50个知识片段,平均长度320字
|
|
||||||
内容示例: [示例内容...]
|
|
||||||
|
|
||||||
你的输出应该类似:
|
|
||||||
"该知识库主要聚焦于技术领域(60%),涵盖商业(25%)和教育(15%)相关内容。共包含50个知识片段,平均每个片段约320字,内容详实。从示例来看,内容涉及[具体主题],体现了[特点],对[目标用户]具有较高的参考价值。"
|
|
||||||
"""
|
|
||||||
|
|
||||||
user_prompt = "\n".join(prompt_parts)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": system_prompt},
|
|
||||||
{"role": "user", "content": user_prompt}
|
|
||||||
]
|
|
||||||
|
|
||||||
# 调用LLM生成洞察
|
|
||||||
llm_client = _get_llm_client()
|
|
||||||
response = await llm_client.chat(messages=messages)
|
response = await llm_client.chat(messages=messages)
|
||||||
|
raw_text = response.content.strip() if response and response.content else ""
|
||||||
|
|
||||||
insight = response.content.strip()
|
sections = _parse_sections(raw_text, language=language)
|
||||||
business_logger.info(f"成功生成chunk洞察,分析了 {min(len(chunks), max_chunks)} 个片段")
|
sections["_raw"] = raw_text
|
||||||
|
|
||||||
return insight
|
business_logger.info(
|
||||||
|
f"成功生成chunk洞察四维度,分析了 {min(len(chunks), max_chunks)} 个片段"
|
||||||
|
)
|
||||||
|
return sections
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
business_logger.error(f"生成chunk洞察失败: {str(e)}")
|
business_logger.error(f"生成chunk洞察失败: {str(e)}")
|
||||||
return "洞察生成失败"
|
empty = {k: "" for k in ("memory_insight", "behavior_pattern", "key_findings", "growth_trajectory")}
|
||||||
|
empty["_raw"] = "洞察生成失败"
|
||||||
|
return empty
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试代码
|
|
||||||
test_chunks = [
|
|
||||||
"Python是一种高级编程语言,以其简洁的语法和强大的功能而闻名。它广泛应用于Web开发、数据分析、人工智能等领域。",
|
|
||||||
"机器学习算法可以从数据中自动学习模式,无需显式编程。常见的算法包括决策树、随机森林、神经网络等。",
|
|
||||||
"深度学习是机器学习的一个分支,使用多层神经网络来学习数据的层次化表示。它在图像识别、语音识别等任务中表现出色。",
|
|
||||||
"自然语言处理技术使计算机能够理解和生成人类语言。应用包括机器翻译、情感分析、文本摘要等。",
|
|
||||||
"数据科学结合了统计学、计算机科学和领域知识,用于从数据中提取有价值的洞察。"
|
|
||||||
]
|
|
||||||
|
|
||||||
print("开始生成chunk洞察...")
|
|
||||||
insight = asyncio.run(generate_chunk_insight(test_chunks))
|
|
||||||
print(f"\n生成的洞察:\n{insight}")
|
|
||||||
|
|||||||
@@ -1,11 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
Generate summary for RAG chunks.
|
Generate summary for RAG chunks using memory_summary.jinja2 prompt template.
|
||||||
|
|
||||||
This module provides functionality to summarize chunk content using LLM.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List
|
import os
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -14,94 +13,135 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
|
||||||
with get_db_context() as db:
|
|
||||||
factory = MemoryClientFactory(db)
|
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkSummary(BaseModel):
|
# ── Schema ──────────────────────────────────────────────────────────────────
|
||||||
"""Pydantic model for chunk summary."""
|
|
||||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
class MemorySummaryStatement(BaseModel):
|
||||||
|
"""Single labelled statement extracted by memory_summary.jinja2."""
|
||||||
|
statement: str = Field(..., description="提取的陈述内容")
|
||||||
|
label: Optional[str] = Field(None, description="陈述标签")
|
||||||
|
|
||||||
|
|
||||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
class MemorySummaryResponse(BaseModel):
|
||||||
"""
|
"""
|
||||||
Generate a summary for the given chunks.
|
Structured output expected from memory_summary.jinja2.
|
||||||
|
The template asks for a JSON array of labelled statements;
|
||||||
|
we wrap it in an object so response_structured can parse it.
|
||||||
|
"""
|
||||||
|
statements: List[MemorySummaryStatement] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="从chunk中提取的陈述列表"
|
||||||
|
)
|
||||||
|
summary: Optional[str] = Field(None, description="整体摘要文本(可选)")
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM client helper ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Core function ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def generate_chunk_summary(
|
||||||
|
chunks: List[str],
|
||||||
|
max_chunks: int = 10,
|
||||||
|
end_user_id: Optional[str] = None,
|
||||||
|
language: str = "zh",
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Generate a user summary from RAG chunks using the memory_summary.jinja2 template.
|
||||||
|
|
||||||
|
The template extracts labelled statements from the chunks; we then join them
|
||||||
|
into a coherent summary string that can be stored in end_user.user_summary.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
chunks: List of chunk content strings
|
chunks: List of chunk content strings
|
||||||
max_chunks: Maximum number of chunks to process (default: 10)
|
max_chunks: Maximum number of chunks to process
|
||||||
|
end_user_id: Optional end-user ID for model selection
|
||||||
|
language: Output language ("zh" or "en")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A concise summary of the chunks
|
Summary string (joined statements or fallback text)
|
||||||
"""
|
"""
|
||||||
if not chunks:
|
if not chunks:
|
||||||
business_logger.warning("没有提供chunk内容用于生成摘要")
|
business_logger.warning("没有提供chunk内容用于生成摘要")
|
||||||
return "暂无内容"
|
return "暂无内容"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 限制处理的chunk数量,避免token过多
|
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||||
|
|
||||||
chunks_to_process = chunks[:max_chunks]
|
chunks_to_process = chunks[:max_chunks]
|
||||||
|
chunk_texts = "\n\n".join(
|
||||||
# 合并chunk内容
|
[f"片段{i + 1}: {chunk}" for i, chunk in enumerate(chunks_to_process)]
|
||||||
combined_content = "\n\n".join([f"片段{i+1}: {chunk}" for i, chunk in enumerate(chunks_to_process)])
|
|
||||||
|
|
||||||
# 构建prompt
|
|
||||||
system_prompt = (
|
|
||||||
"你是一位专业的文本摘要助手。请基于提供的文本片段,生成简洁的摘要。要求:\n"
|
|
||||||
"- 摘要长度控制在100-150字;\n"
|
|
||||||
"- 提取核心信息和关键要点;\n"
|
|
||||||
"- 使用客观、清晰的语言;\n"
|
|
||||||
"- 避免冗余和重复;\n"
|
|
||||||
"- 如果内容涉及多个主题,按重要性排序呈现。"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
user_prompt = f"请为以下文本片段生成摘要:\n\n{combined_content}"
|
json_schema = MemorySummaryResponse.model_json_schema()
|
||||||
|
|
||||||
messages = [
|
rendered_prompt = await render_memory_summary_prompt(
|
||||||
{"role": "system", "content": system_prompt},
|
chunk_texts=chunk_texts,
|
||||||
{"role": "user", "content": user_prompt},
|
json_schema=json_schema,
|
||||||
]
|
max_words=200,
|
||||||
|
language=language,
|
||||||
|
)
|
||||||
|
|
||||||
# 调用LLM生成摘要
|
messages = [{"role": "user", "content": rendered_prompt}]
|
||||||
llm_client = _get_llm_client()
|
|
||||||
response = await llm_client.chat(messages=messages)
|
|
||||||
|
|
||||||
summary = response.content.strip()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
business_logger.info(f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段")
|
|
||||||
|
|
||||||
|
# Try structured output; fall back to plain chat only for LLMClientException
|
||||||
|
# (indicates the model/provider doesn't support structured output).
|
||||||
|
# All other exceptions are re-raised so config/schema errors stay visible.
|
||||||
|
try:
|
||||||
|
response: MemorySummaryResponse = await llm_client.response_structured(
|
||||||
|
messages=messages,
|
||||||
|
response_model=MemorySummaryResponse,
|
||||||
|
)
|
||||||
|
if response.summary:
|
||||||
|
summary = response.summary.strip()
|
||||||
|
elif response.statements:
|
||||||
|
summary = ";".join(s.statement for s in response.statements)
|
||||||
|
else:
|
||||||
|
summary = "暂无内容"
|
||||||
|
except Exception as e:
|
||||||
|
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||||
|
if isinstance(e, LLMClientException):
|
||||||
|
business_logger.warning(
|
||||||
|
f"结构化输出不可用,降级为普通对话: end_user_id={end_user_id}, reason={e}"
|
||||||
|
)
|
||||||
|
raw = await llm_client.chat(messages=messages)
|
||||||
|
summary = raw.content.strip() if raw and raw.content else "暂无内容"
|
||||||
|
else:
|
||||||
|
business_logger.error(f"生成摘要时发生非预期异常: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
business_logger.info(
|
||||||
|
f"成功生成chunk摘要,处理了 {len(chunks_to_process)} 个片段"
|
||||||
|
)
|
||||||
return summary
|
return summary
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
business_logger.error(f"生成chunk摘要失败: {str(e)}")
|
business_logger.error(f"生成chunk摘要失败: {str(e)}")
|
||||||
return "摘要生成失败"
|
return "摘要生成失败"
|
||||||
|
|
||||||
|
|
||||||
async def generate_chunk_summary_batch(chunks_list: List[List[str]]) -> List[str]:
|
|
||||||
"""
|
|
||||||
Generate summaries for multiple chunk lists in batch.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunks_list: List of chunk lists
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of summaries
|
|
||||||
"""
|
|
||||||
tasks = [generate_chunk_summary(chunks) for chunks in chunks_list]
|
|
||||||
return await asyncio.gather(*tasks)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试代码
|
|
||||||
test_chunks = [
|
|
||||||
"这是第一段测试内容,讲述了关于机器学习的基础知识。",
|
|
||||||
"第二段内容介绍了深度学习的应用场景和发展历史。",
|
|
||||||
"第三段讨论了自然语言处理技术的最新进展。"
|
|
||||||
]
|
|
||||||
|
|
||||||
print("开始生成chunk摘要...")
|
|
||||||
summary = asyncio.run(generate_chunk_summary(test_chunks))
|
|
||||||
print(f"\n生成的摘要:\n{summary}")
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
class ExtractedTags(BaseModel):
|
class ExtractedTags(BaseModel):
|
||||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
|||||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Extract meaningful tags from the given chunks.
|
Extract meaningful tags from the given chunks.
|
||||||
|
|
||||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
|||||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
|
|
||||||
# 为每个chunk单独提取标签,然后统计频率
|
# 为每个chunk单独提取标签,然后统计频率
|
||||||
all_tags = []
|
all_tags = []
|
||||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
|||||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extract persona (人物形象) from the given chunks.
|
Extract persona (人物形象) from the given chunks.
|
||||||
|
|
||||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 调用LLM提取人物形象
|
# 调用LLM提取人物形象
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
structured_response = await llm_client.response_structured(
|
structured_response = await llm_client.response_structured(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_model=ExtractedPersona
|
response_model=ExtractedPersona
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ file operations across different storage backends.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
|
|
||||||
class StorageBackend(ABC):
|
class StorageBackend(ABC):
|
||||||
@@ -42,6 +42,26 @@ class StorageBackend(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Upload a file from an async byte stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: Unique identifier for the file.
|
||||||
|
stream: Async iterator yielding bytes chunks.
|
||||||
|
content_type: Optional MIME type of the file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total bytes written.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
@@ -101,3 +121,18 @@ class StorageBackend(ABC):
|
|||||||
URL for accessing the file.
|
URL for accessing the file.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (no expiration).
|
||||||
|
|
||||||
|
Returns None by default; remote storage backends should override this
|
||||||
|
if the bucket is configured for public read access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_key: Unique identifier for the file in the storage system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent public URL, or None if not supported.
|
||||||
|
"""
|
||||||
|
return None
|
||||||
|
|||||||
@@ -85,6 +85,7 @@ class StorageFactory:
|
|||||||
access_key_id=settings.S3_ACCESS_KEY_ID,
|
access_key_id=settings.S3_ACCESS_KEY_ID,
|
||||||
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
|
secret_access_key=settings.S3_SECRET_ACCESS_KEY,
|
||||||
bucket_name=settings.S3_BUCKET_NAME,
|
bucket_name=settings.S3_BUCKET_NAME,
|
||||||
|
endpoint_url=settings.S3_ENDPOINT_URL,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiofiles.os
|
import aiofiles.os
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
from app.core.storage.base import StorageBackend
|
from app.core.storage.base import StorageBackend
|
||||||
from app.core.storage_exceptions import (
|
from app.core.storage_exceptions import (
|
||||||
@@ -179,6 +180,36 @@ class LocalStorage(StorageBackend):
|
|||||||
full_path = self._get_full_path(file_key)
|
full_path = self._get_full_path(file_key)
|
||||||
return full_path.exists()
|
return full_path.exists()
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
Upload a file from an async byte stream to the local file system.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total bytes written.
|
||||||
|
"""
|
||||||
|
full_path = self._get_full_path(file_key)
|
||||||
|
try:
|
||||||
|
full_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
total = 0
|
||||||
|
async with aiofiles.open(full_path, "wb") as f:
|
||||||
|
async for chunk in stream:
|
||||||
|
await f.write(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
logger.info(f"File stream uploaded successfully: {file_key}")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stream upload file {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
async def get_url(self, file_key: str, expires: int = 3600) -> str:
|
||||||
"""
|
"""
|
||||||
Get an access URL for the file.
|
Get an access URL for the file.
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on Aliyun Object
|
|||||||
Storage Service (OSS) using the oss2 SDK.
|
Storage Service (OSS) using the oss2 SDK.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import oss2
|
import oss2
|
||||||
from oss2.exceptions import NoSuchKey, OssError
|
from oss2.exceptions import NoSuchKey, OssError
|
||||||
@@ -125,10 +126,39 @@ class OSSStorage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Upload from async stream to OSS. Returns total bytes written."""
|
||||||
|
buf = io.BytesIO()
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
buf.write(chunk)
|
||||||
|
content = buf.getvalue()
|
||||||
|
headers = {"Content-Type": content_type} if content_type else None
|
||||||
|
self.bucket.put_object(file_key, content, headers=headers)
|
||||||
|
logger.info(f"File stream uploaded to OSS successfully: {file_key}")
|
||||||
|
return len(content)
|
||||||
|
except OssError as e:
|
||||||
|
logger.error(f"OSS error stream uploading file {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to OSS: {e.message}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to stream upload file to OSS {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to OSS: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Download a file from OSS.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_key: Unique identifier for the file in the storage system.
|
file_key: Unique identifier for the file in the storage system.
|
||||||
|
|
||||||
@@ -231,3 +261,13 @@ class OSSStorage(StorageBackend):
|
|||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
# Return a basic URL format as fallback
|
||||||
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.{endpoint}/{file_key}
|
||||||
|
"""
|
||||||
|
host = self.endpoint.replace("https://", "").replace("http://", "")
|
||||||
|
return f"https://{self.bucket_name}.{host}/{file_key}"
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides a storage backend that stores files on AWS S3
|
|||||||
using the boto3 SDK.
|
using the boto3 SDK.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional
|
from typing import AsyncIterator, Optional
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
from botocore.exceptions import ClientError, NoCredentialsError, BotoCoreError
|
||||||
@@ -35,6 +36,19 @@ class S3Storage(StorageBackend):
|
|||||||
bucket_name: The name of the S3 bucket.
|
bucket_name: The name of the S3 bucket.
|
||||||
region: The AWS region.
|
region: The AWS region.
|
||||||
"""
|
"""
|
||||||
|
AMAZON_S3_ENDPOINT_MAP = {
|
||||||
|
"us-east-1": "https://s3.us-east-1.amazonaws.com", # 特殊:无地域后缀
|
||||||
|
"us-east-2": "https://s3.us-east-2.amazonaws.com",
|
||||||
|
"us-west-1": "https://s3.us-west-1.amazonaws.com",
|
||||||
|
"us-west-2": "https://s3.us-west-2.amazonaws.com",
|
||||||
|
"ap-east-1": "https://s3.ap-east-1.amazonaws.com", # 香港
|
||||||
|
"ap-southeast-1": "https://s3.ap-southeast-1.amazonaws.com", # 新加坡
|
||||||
|
"ap-southeast-2": "https://s3.ap-southeast-2.amazonaws.com", # 悉尼
|
||||||
|
"ap-northeast-1": "https://s3.ap-northeast-1.amazonaws.com", # 东京
|
||||||
|
"eu-central-1": "https://s3.eu-central-1.amazonaws.com", # 法兰克福
|
||||||
|
"eu-west-1": "https://s3.eu-west-1.amazonaws.com", # 爱尔兰
|
||||||
|
# 可根据需要扩展其他地域
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -42,6 +56,7 @@ class S3Storage(StorageBackend):
|
|||||||
access_key_id: str,
|
access_key_id: str,
|
||||||
secret_access_key: str,
|
secret_access_key: str,
|
||||||
bucket_name: str,
|
bucket_name: str,
|
||||||
|
endpoint_url: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the S3Storage backend.
|
Initialize the S3Storage backend.
|
||||||
@@ -51,6 +66,7 @@ class S3Storage(StorageBackend):
|
|||||||
access_key_id: The AWS access key ID.
|
access_key_id: The AWS access key ID.
|
||||||
secret_access_key: The AWS secret access key.
|
secret_access_key: The AWS secret access key.
|
||||||
bucket_name: The name of the S3 bucket.
|
bucket_name: The name of the S3 bucket.
|
||||||
|
endpoint_url: The complete URL to use for the constructed client.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
StorageConfigError: If any required configuration is missing.
|
StorageConfigError: If any required configuration is missing.
|
||||||
@@ -69,10 +85,19 @@ class S3Storage(StorageBackend):
|
|||||||
self.region = region
|
self.region = region
|
||||||
self.bucket_name = bucket_name
|
self.bucket_name = bucket_name
|
||||||
|
|
||||||
|
if not endpoint_url:
|
||||||
|
# 优先匹配内置映射表(解决特殊地域)
|
||||||
|
if region in self.AMAZON_S3_ENDPOINT_MAP:
|
||||||
|
endpoint_url = self.AMAZON_S3_ENDPOINT_MAP[region]
|
||||||
|
# 兜底:通用拼接(适配未配置的新地域)
|
||||||
|
else:
|
||||||
|
endpoint_url = f"https://s3.{region}.amazonaws.com"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.client = boto3.client(
|
self.client = boto3.client(
|
||||||
"s3",
|
"s3",
|
||||||
region_name=region,
|
region_name=region,
|
||||||
|
endpoint_url=endpoint_url,
|
||||||
aws_access_key_id=access_key_id,
|
aws_access_key_id=access_key_id,
|
||||||
aws_secret_access_key=secret_access_key,
|
aws_secret_access_key=secret_access_key,
|
||||||
)
|
)
|
||||||
@@ -150,6 +175,62 @@ class S3Storage(StorageBackend):
|
|||||||
cause=e,
|
cause=e,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def upload_stream(
|
||||||
|
self,
|
||||||
|
file_key: str,
|
||||||
|
stream: AsyncIterator[bytes],
|
||||||
|
content_type: Optional[str] = None,
|
||||||
|
) -> int:
|
||||||
|
"""Upload from async stream to S3 via multipart upload. Returns total bytes written."""
|
||||||
|
extra_args = {"ContentType": content_type} if content_type else {}
|
||||||
|
mpu = self.client.create_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key, **extra_args
|
||||||
|
)
|
||||||
|
upload_id = mpu["UploadId"]
|
||||||
|
parts = []
|
||||||
|
part_number = 1
|
||||||
|
buf = io.BytesIO()
|
||||||
|
total = 0
|
||||||
|
min_part_size = 5 * 1024 * 1024 # S3 最小分片 5MB
|
||||||
|
try:
|
||||||
|
async for chunk in stream:
|
||||||
|
buf.write(chunk)
|
||||||
|
total += len(chunk)
|
||||||
|
if buf.tell() >= min_part_size:
|
||||||
|
buf.seek(0)
|
||||||
|
resp = self.client.upload_part(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id, PartNumber=part_number, Body=buf.read()
|
||||||
|
)
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
part_number += 1
|
||||||
|
buf = io.BytesIO()
|
||||||
|
# 上传剩余数据(最后一片可小于 5MB)
|
||||||
|
remaining = buf.getvalue()
|
||||||
|
if remaining:
|
||||||
|
resp = self.client.upload_part(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id, PartNumber=part_number, Body=remaining
|
||||||
|
)
|
||||||
|
parts.append({"PartNumber": part_number, "ETag": resp["ETag"]})
|
||||||
|
self.client.complete_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key,
|
||||||
|
UploadId=upload_id,
|
||||||
|
MultipartUpload={"Parts": parts}
|
||||||
|
)
|
||||||
|
logger.info(f"File stream uploaded to S3 successfully: {file_key}")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
self.client.abort_multipart_upload(
|
||||||
|
Bucket=self.bucket_name, Key=file_key, UploadId=upload_id
|
||||||
|
)
|
||||||
|
logger.error(f"Failed to stream upload file to S3 {file_key}: {e}")
|
||||||
|
raise StorageUploadError(
|
||||||
|
message=f"Failed to stream upload file to S3: {e}",
|
||||||
|
file_key=file_key,
|
||||||
|
cause=e,
|
||||||
|
)
|
||||||
|
|
||||||
async def download(self, file_key: str) -> bytes:
|
async def download(self, file_key: str) -> bytes:
|
||||||
"""
|
"""
|
||||||
Download a file from S3.
|
Download a file from S3.
|
||||||
@@ -297,3 +378,12 @@ class S3Storage(StorageBackend):
|
|||||||
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
|
||||||
# Return a basic URL format as fallback
|
# Return a basic URL format as fallback
|
||||||
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|
||||||
|
async def get_permanent_url(self, file_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Get a permanent public URL for the file (requires bucket public read).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key}
|
||||||
|
"""
|
||||||
|
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
|
||||||
|
|||||||
@@ -195,6 +195,6 @@ class MCPToolManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": "连接失败",
|
||||||
"message": "连接失败"
|
"message": str(e)
|
||||||
}
|
}
|
||||||
@@ -23,7 +23,7 @@ class SimpleMCPClient:
|
|||||||
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
def __init__(self, server_url: str, connection_config: Dict[str, Any] = None):
|
||||||
self.server_url = server_url
|
self.server_url = server_url
|
||||||
self.connection_config = connection_config or {}
|
self.connection_config = connection_config or {}
|
||||||
self.timeout = self.connection_config.get("timeout", 30)
|
self.timeout = self.connection_config.get("timeout", 10)
|
||||||
|
|
||||||
# 确定连接类型
|
# 确定连接类型
|
||||||
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
self.is_websocket = server_url.startswith(("ws://", "wss://"))
|
||||||
@@ -53,6 +53,7 @@ class SimpleMCPClient:
|
|||||||
else:
|
else:
|
||||||
await self._connect_http()
|
await self._connect_http()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
await self.disconnect()
|
||||||
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
logger.error(f"MCP连接失败: {self.server_url}, 错误: {e}")
|
||||||
raise MCPConnectionError(f"连接失败: {e}")
|
raise MCPConnectionError(f"连接失败: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -8,34 +8,60 @@ from typing import Any
|
|||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from app.core.workflow.adapters.base_converter import BaseConverter
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \
|
from app.core.workflow.adapters.errors import (
|
||||||
|
UnsupportVariableType,
|
||||||
|
UnknowModelWarning,
|
||||||
|
ExceptionDefineition,
|
||||||
ExceptionType
|
ExceptionType
|
||||||
from app.core.workflow.nodes.assigner import AssignerNodeConfig
|
)
|
||||||
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
from app.core.workflow.nodes.assigner.config import AssignmentItem
|
||||||
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig
|
||||||
from app.core.workflow.nodes.code import CodeNodeConfig
|
|
||||||
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
from app.core.workflow.nodes.code.config import InputVariable, OutputVariable
|
||||||
from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig
|
from app.core.workflow.nodes.configs import (
|
||||||
from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig
|
StartNodeConfig,
|
||||||
from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \
|
LLMNodeConfig,
|
||||||
|
AssignerNodeConfig,
|
||||||
|
CodeNodeConfig,
|
||||||
|
LoopNodeConfig,
|
||||||
|
IterationNodeConfig,
|
||||||
|
EndNodeConfig,
|
||||||
|
HttpRequestNodeConfig,
|
||||||
|
IfElseNodeConfig,
|
||||||
|
JinjaRenderNodeConfig,
|
||||||
|
KnowledgeRetrievalNodeConfig,
|
||||||
|
NoteNodeConfig,
|
||||||
|
ParameterExtractorNodeConfig,
|
||||||
|
QuestionClassifierNodeConfig,
|
||||||
|
VariableAggregatorNodeConfig
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.cycle_graph.config import (
|
||||||
|
ConditionDetail as LoopConditionDetail,
|
||||||
|
ConditionsConfig,
|
||||||
CycleVariable
|
CycleVariable
|
||||||
from app.core.workflow.nodes.end import EndNodeConfig
|
)
|
||||||
from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \
|
from app.core.workflow.nodes.enums import (
|
||||||
HttpContentType, HttpErrorHandle
|
ValueInputType,
|
||||||
from app.core.workflow.nodes.http_request import HttpRequestNodeConfig
|
ComparisonOperator,
|
||||||
from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \
|
AssignmentOperator,
|
||||||
HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig
|
HttpAuthType,
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
HttpContentType,
|
||||||
|
HttpErrorHandle,
|
||||||
|
NodeType
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.http_request.config import (
|
||||||
|
HttpAuthConfig,
|
||||||
|
HttpContentTypeConfig,
|
||||||
|
HttpFormData,
|
||||||
|
HttpTimeOutConfig,
|
||||||
|
HttpRetryConfig,
|
||||||
|
HttpErrorDefaultTamplete,
|
||||||
|
HttpErrorHandleConfig
|
||||||
|
)
|
||||||
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig
|
||||||
from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig
|
|
||||||
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig
|
||||||
from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig
|
|
||||||
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig
|
||||||
from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig
|
|
||||||
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig
|
||||||
from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig
|
|
||||||
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
from app.core.workflow.nodes.question_classifier.config import ClassifierConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig
|
|
||||||
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE
|
||||||
|
|
||||||
|
|
||||||
@@ -48,24 +74,24 @@ class DifyConverter(BaseConverter):
|
|||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.CONFIG_CONVERT_MAP = {
|
self.CONFIG_CONVERT_MAP = {
|
||||||
"start": self.convert_start_node_config,
|
NodeType.START: self.convert_start_node_config,
|
||||||
"llm": self.convert_llm_node_config,
|
NodeType.LLM: self.convert_llm_node_config,
|
||||||
"answer": self.convert_end_node_config,
|
NodeType.END: self.convert_end_node_config,
|
||||||
"if-else": self.convert_if_else_node_config,
|
NodeType.IF_ELSE: self.convert_if_else_node_config,
|
||||||
"loop": self.convert_loop_node_config,
|
NodeType.LOOP: self.convert_loop_node_config,
|
||||||
"iteration": self.convert_iteration_node_config,
|
NodeType.ITERATION: self.convert_iteration_node_config,
|
||||||
"assigner": self.convert_assigner_node_config,
|
NodeType.ASSIGNER: self.convert_assigner_node_config,
|
||||||
"code": self.convert_code_node_config,
|
NodeType.CODE: self.convert_code_node_config,
|
||||||
"http-request": self.convert_http_node_config,
|
NodeType.HTTP_REQUEST: self.convert_http_node_config,
|
||||||
"template-transform": self.convert_jinja_render_node_config,
|
NodeType.JINJARENDER: self.convert_jinja_render_node_config,
|
||||||
"knowledge-retrieval": self.convert_knowledge_node_config,
|
NodeType.KNOWLEDGE_RETRIEVAL: self.convert_knowledge_node_config,
|
||||||
"parameter-extractor": self.convert_parameter_extractor_node_config,
|
NodeType.PARAMETER_EXTRACTOR: self.convert_parameter_extractor_node_config,
|
||||||
"question-classifier": self.convert_question_classifier_node_config,
|
NodeType.QUESTION_CLASSIFIER: self.convert_question_classifier_node_config,
|
||||||
"variable-aggregator": self.convert_variable_aggregator_node_config,
|
NodeType.VAR_AGGREGATOR: self.convert_variable_aggregator_node_config,
|
||||||
"tool": self.convert_tool_node_config,
|
NodeType.TOOL: self.convert_tool_node_config,
|
||||||
"loop-start": lambda x: {},
|
NodeType.NOTES: self.convert_notes_config,
|
||||||
"iteration-start": lambda x: {},
|
NodeType.CYCLE_START: lambda x: {},
|
||||||
"loop-end": lambda x: {},
|
NodeType.BREAK: lambda x: {},
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_node_convert(self, node_type):
|
def get_node_convert(self, node_type):
|
||||||
@@ -732,3 +758,16 @@ class DifyConverter(BaseConverter):
|
|||||||
detail=f"Please reconfigure the tool node.",
|
detail=f"Please reconfigure the tool node.",
|
||||||
))
|
))
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def convert_notes_config(node: dict):
|
||||||
|
node_data = node["data"]
|
||||||
|
result = NoteNodeConfig.model_construct(
|
||||||
|
author=node_data.get("author", ""),
|
||||||
|
text=node_data.get("text", ""),
|
||||||
|
width=node_data.get("width", 80),
|
||||||
|
height=node_data.get("height", 80),
|
||||||
|
theme=node_data.get("theme", "blue"),
|
||||||
|
show_author=node_data.get("showAuthor", True)
|
||||||
|
).model_dump()
|
||||||
|
return result
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
def __init__(self, config: dict[str, Any]):
|
def __init__(self, config: dict[str, Any]):
|
||||||
DifyConverter.__init__(self)
|
DifyConverter.__init__(self)
|
||||||
BasePlatformAdapter.__init__(self, config)
|
BasePlatformAdapter.__init__(self, config)
|
||||||
|
|
||||||
def get_metadata(self) -> PlatformMetadata:
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
@@ -59,7 +59,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
||||||
)
|
)
|
||||||
|
|
||||||
def map_node_type(self, platform_node_type) -> str:
|
def map_node_type(self, platform_node_type) -> NodeType:
|
||||||
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -84,7 +84,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
require_fields = frozenset({'app', 'kind', 'version', 'workflow'})
|
||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
if self.config.get("app",{}).get("mode") == "workflow":
|
if self.config.get("app", {}).get("mode") == "workflow":
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefineition(
|
||||||
type=ExceptionType.PLATFORM,
|
type=ExceptionType.PLATFORM,
|
||||||
detail="workflow mode is not supported"
|
detail="workflow mode is not supported"
|
||||||
@@ -163,13 +163,14 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||||
node_data = node["data"]
|
node_data = node["data"]
|
||||||
try:
|
try:
|
||||||
|
node_type = self.map_node_type(node_data["type"])
|
||||||
return NodeDefinition(
|
return NodeDefinition(
|
||||||
id=node["id"],
|
id=node["id"],
|
||||||
type=self.map_node_type(node_data["type"]),
|
type=node_type,
|
||||||
name=node_data.get("title") or "notes",
|
name=node_data.get("title") or "notes",
|
||||||
cycle=node.get("parentId"),
|
cycle=node.get("parentId"),
|
||||||
description=None,
|
description=None,
|
||||||
config=self._convert_node_config(node),
|
config=self._convert_node_config(node_type, node),
|
||||||
position={
|
position={
|
||||||
"x": node["position"]["x"],
|
"x": node["position"]["x"],
|
||||||
"y": node["position"]["y"]
|
"y": node["position"]["y"]
|
||||||
@@ -183,17 +184,16 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"convert node error - {e}", exc_info=True)
|
logger.debug(f"convert node error - {e}", exc_info=True)
|
||||||
|
|
||||||
def _convert_node_config(self, node: dict):
|
def _convert_node_config(self, node_type: NodeType, node: dict):
|
||||||
node_data = node["data"]
|
|
||||||
node_type = node_data["type"]
|
|
||||||
try:
|
try:
|
||||||
|
node_data = node["data"]
|
||||||
converter = self.get_node_convert(node_type)
|
converter = self.get_node_convert(node_type)
|
||||||
if node_type not in self.CONFIG_CONVERT_MAP:
|
if node_type == NodeType.UNKNOWN:
|
||||||
self.errors.append(ExceptionDefineition(
|
self.errors.append(ExceptionDefineition(
|
||||||
type=ExceptionType.NODE,
|
type=ExceptionType.NODE,
|
||||||
node_id=node["id"],
|
node_id=node["id"],
|
||||||
node_name=node["data"]["title"],
|
node_name=node["data"]["title"],
|
||||||
detail=f"node type {node_type if node_type else 'notes'} is unsupported",
|
detail=f"node type {node_data.get('type')} is unsupported",
|
||||||
))
|
))
|
||||||
return converter(node)
|
return converter(node)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -214,7 +214,7 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
if source in self.branch_node_cache:
|
if source in self.branch_node_cache:
|
||||||
case_id = edge["sourceHandle"]
|
case_id = edge["sourceHandle"]
|
||||||
if case_id == "false":
|
if case_id == "false":
|
||||||
label = f'CASE{len(self.branch_node_cache[source])+1}'
|
label = f'CASE{len(self.branch_node_cache[source]) + 1}'
|
||||||
else:
|
else:
|
||||||
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}'
|
||||||
if source in self.error_branch_node_cache:
|
if source in self.error_branch_node_cache:
|
||||||
@@ -257,5 +257,3 @@ class DifyAdapter(BasePlatformAdapter, DifyConverter):
|
|||||||
|
|
||||||
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig:
|
||||||
return ExecutionConfig()
|
return ExecutionConfig()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,65 +4,145 @@
|
|||||||
# @Time : 2026/2/25 14:11
|
# @Time : 2026/2/25 14:11
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
from app.core.workflow.adapters.base_adapter import (
|
from app.core.workflow.adapters.base_adapter import (
|
||||||
PlatformMetadata,
|
PlatformMetadata,
|
||||||
PlatformType,
|
PlatformType,
|
||||||
BasePlatformAdapter,
|
BasePlatformAdapter,
|
||||||
WorkflowParserResult
|
WorkflowParserResult
|
||||||
)
|
)
|
||||||
from app.schemas.workflow_schema import ExecutionConfig
|
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType, UnsupportNodeType
|
||||||
|
from app.core.workflow.adapters.memory_bear.memory_bear_converter import MemoryBearConverter
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
from app.schemas.workflow_schema import ExecutionConfig, NodeDefinition, EdgeDefinition, VariableDefinition
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
VALID_NODE_TYPES = frozenset(t.value for t in NodeType if t != NodeType.UNKNOWN)
|
||||||
|
|
||||||
|
|
||||||
class MemoryBearAdapter(BasePlatformAdapter):
|
class MemoryBearAdapter(BasePlatformAdapter, MemoryBearConverter):
|
||||||
NODE_TYPE_MAPPING = {}
|
NODE_TYPE_MAPPING = {t.value: t for t in NodeType}
|
||||||
|
|
||||||
|
def __init__(self, config: dict[str, Any]):
|
||||||
|
MemoryBearConverter.__init__(self)
|
||||||
|
BasePlatformAdapter.__init__(self, config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_nodes(self):
|
def origin_nodes(self):
|
||||||
return self.config.get("workflow").get("nodes")
|
return self.config.get("workflow").get("nodes") or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_edges(self):
|
def origin_edges(self):
|
||||||
return self.config.get("workflow").get("edges")
|
return self.config.get("workflow").get("edges") or []
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_variables(self):
|
def origin_variables(self):
|
||||||
return self.config.get("workflow").get("variables")
|
return self.config.get("workflow").get("variables") or []
|
||||||
|
|
||||||
def get_metadata(self) -> PlatformMetadata:
|
def get_metadata(self) -> PlatformMetadata:
|
||||||
return PlatformMetadata(
|
return PlatformMetadata(
|
||||||
platform_name=PlatformType.MEMORY_BEAR,
|
platform_name=PlatformType.MEMORY_BEAR,
|
||||||
version="0.2.5",
|
version="0.2.5",
|
||||||
support_node_types=list(self.NODE_TYPE_MAPPING.keys())
|
support_node_types=list(VALID_NODE_TYPES)
|
||||||
)
|
)
|
||||||
|
|
||||||
def map_node_type(self, platform_node_type) -> str:
|
def map_node_type(self, platform_node_type: str) -> NodeType:
|
||||||
return platform_node_type
|
return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _valid_nodes(node: dict[str, Any]):
|
def _valid_node(node: dict[str, Any]) -> bool:
|
||||||
if "type" not in node["data"]:
|
|
||||||
return False
|
|
||||||
if "id" not in node or "type" not in node:
|
if "id" not in node or "type" not in node:
|
||||||
return False
|
return False
|
||||||
|
if not isinstance(node.get("config"), dict):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def validate_config(self) -> bool:
|
def validate_config(self) -> bool:
|
||||||
require_fields = frozenset({'app', 'workflow'})
|
require_fields = frozenset({'app', 'workflow'})
|
||||||
if not all(field in self.config for field in require_fields):
|
if not all(field in self.config for field in require_fields):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for node in self.origin_nodes:
|
for node in self.origin_nodes:
|
||||||
if not self._valid_nodes(node):
|
if not self._valid_node(node):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None:
|
||||||
|
node_id = node.get("id")
|
||||||
|
node_name = node.get("name")
|
||||||
|
try:
|
||||||
|
node_type = self.map_node_type(node["type"])
|
||||||
|
if node_type == NodeType.UNKNOWN:
|
||||||
|
self.errors.append(UnsupportNodeType(
|
||||||
|
node_id=node_id,
|
||||||
|
node_type=node["type"]
|
||||||
|
))
|
||||||
|
return None
|
||||||
|
|
||||||
|
config = node.get("config") or {}
|
||||||
|
converter = self.get_node_convert(node_type)
|
||||||
|
converter(node_id, node_name, config) # validates and appends errors if invalid
|
||||||
|
|
||||||
|
return NodeDefinition(**node)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.NODE,
|
||||||
|
node_id=node_id,
|
||||||
|
node_name=node_name,
|
||||||
|
detail=f"convert node error - {e}"
|
||||||
|
))
|
||||||
|
logger.debug(f"MemoryBear convert node error - {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_edge(self, edge: dict[str, Any], valid_node_ids: set) -> EdgeDefinition | None:
|
||||||
|
try:
|
||||||
|
if edge.get("source") not in valid_node_ids or edge.get("target") not in valid_node_ids:
|
||||||
|
self.warnings.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.EDGE,
|
||||||
|
detail=f"edge {edge.get('id')} skipped: source or target node not found"
|
||||||
|
))
|
||||||
|
return None
|
||||||
|
return EdgeDefinition(**edge)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.EDGE,
|
||||||
|
detail=f"convert edge error - {e}"
|
||||||
|
))
|
||||||
|
logger.debug(f"MemoryBear convert edge error - {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_variable(self, variable: dict[str, Any]) -> VariableDefinition | None:
|
||||||
|
try:
|
||||||
|
return VariableDefinition(**variable)
|
||||||
|
except Exception as e:
|
||||||
|
self.warnings.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.VARIABLE,
|
||||||
|
name=variable.get("name"),
|
||||||
|
detail=f"convert variable error - {e}"
|
||||||
|
))
|
||||||
|
logger.debug(f"MemoryBear convert variable error - {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
def parse_workflow(self) -> WorkflowParserResult:
|
def parse_workflow(self) -> WorkflowParserResult:
|
||||||
self.nodes = self.origin_nodes
|
for node in self.origin_nodes:
|
||||||
self.edges = self.origin_edges
|
converted = self._convert_node(node)
|
||||||
self.conv_variables = self.origin_variables
|
if converted:
|
||||||
|
self.nodes.append(converted)
|
||||||
|
|
||||||
|
valid_node_ids = {n.id for n in self.nodes}
|
||||||
|
|
||||||
|
for edge in self.origin_edges:
|
||||||
|
converted = self._convert_edge(edge, valid_node_ids)
|
||||||
|
if converted:
|
||||||
|
self.edges.append(converted)
|
||||||
|
|
||||||
|
for variable in self.origin_variables:
|
||||||
|
converted = self._convert_variable(variable)
|
||||||
|
if converted:
|
||||||
|
self.conv_variables.append(converted)
|
||||||
|
|
||||||
return WorkflowParserResult(
|
return WorkflowParserResult(
|
||||||
success=True,
|
success=not self.errors and not self.warnings,
|
||||||
platform=self.get_metadata(),
|
platform=self.get_metadata(),
|
||||||
execution_config=ExecutionConfig(),
|
execution_config=ExecutionConfig(),
|
||||||
origin_config=self.config,
|
origin_config=self.config,
|
||||||
@@ -72,5 +152,4 @@ class MemoryBearAdapter(BasePlatformAdapter):
|
|||||||
variables=self.conv_variables,
|
variables=self.conv_variables,
|
||||||
warnings=self.warnings,
|
warnings=self.warnings,
|
||||||
errors=self.errors,
|
errors=self.errors,
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -0,0 +1,85 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
from app.core.workflow.adapters.base_converter import BaseConverter
|
||||||
|
from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
from app.core.workflow.nodes.configs import (
|
||||||
|
StartNodeConfig,
|
||||||
|
EndNodeConfig,
|
||||||
|
LLMNodeConfig,
|
||||||
|
AgentNodeConfig,
|
||||||
|
IfElseNodeConfig,
|
||||||
|
KnowledgeRetrievalNodeConfig,
|
||||||
|
AssignerNodeConfig,
|
||||||
|
CodeNodeConfig,
|
||||||
|
HttpRequestNodeConfig,
|
||||||
|
JinjaRenderNodeConfig,
|
||||||
|
VariableAggregatorNodeConfig,
|
||||||
|
ParameterExtractorNodeConfig,
|
||||||
|
LoopNodeConfig,
|
||||||
|
IterationNodeConfig,
|
||||||
|
QuestionClassifierNodeConfig,
|
||||||
|
ToolNodeConfig,
|
||||||
|
MemoryReadNodeConfig,
|
||||||
|
MemoryWriteNodeConfig,
|
||||||
|
NoteNodeConfig,
|
||||||
|
)
|
||||||
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryBearConverter(BaseConverter):
|
||||||
|
errors: list
|
||||||
|
warnings: list
|
||||||
|
|
||||||
|
CONFIG_CLASS_MAP: dict[NodeType, type[BaseNodeConfig]] = {
|
||||||
|
NodeType.START: StartNodeConfig,
|
||||||
|
NodeType.END: EndNodeConfig,
|
||||||
|
NodeType.ANSWER: EndNodeConfig,
|
||||||
|
NodeType.LLM: LLMNodeConfig,
|
||||||
|
NodeType.AGENT: AgentNodeConfig,
|
||||||
|
NodeType.IF_ELSE: IfElseNodeConfig,
|
||||||
|
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNodeConfig,
|
||||||
|
NodeType.ASSIGNER: AssignerNodeConfig,
|
||||||
|
NodeType.CODE: CodeNodeConfig,
|
||||||
|
NodeType.HTTP_REQUEST: HttpRequestNodeConfig,
|
||||||
|
NodeType.JINJARENDER: JinjaRenderNodeConfig,
|
||||||
|
NodeType.VAR_AGGREGATOR: VariableAggregatorNodeConfig,
|
||||||
|
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNodeConfig,
|
||||||
|
NodeType.LOOP: LoopNodeConfig,
|
||||||
|
NodeType.ITERATION: IterationNodeConfig,
|
||||||
|
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNodeConfig,
|
||||||
|
NodeType.TOOL: ToolNodeConfig,
|
||||||
|
NodeType.MEMORY_READ: MemoryReadNodeConfig,
|
||||||
|
NodeType.MEMORY_WRITE: MemoryWriteNodeConfig,
|
||||||
|
NodeType.NOTES: NoteNodeConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_file(var):
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _convert_array_file(var):
|
||||||
|
return []
|
||||||
|
|
||||||
|
def config_validate(self, node_id: str, node_name: str, config_cls: type[BaseNodeConfig], value: dict):
|
||||||
|
try:
|
||||||
|
return config_cls.model_validate(value)
|
||||||
|
except Exception as e:
|
||||||
|
self.errors.append(ExceptionDefineition(
|
||||||
|
type=ExceptionType.CONFIG,
|
||||||
|
node_id=node_id,
|
||||||
|
node_name=node_name,
|
||||||
|
detail=str(e)
|
||||||
|
))
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_node_convert(self, node_type: NodeType):
|
||||||
|
config_cls = self.CONFIG_CLASS_MAP.get(node_type)
|
||||||
|
if not config_cls:
|
||||||
|
return lambda node_id, node_name, config: config
|
||||||
|
|
||||||
|
def validate(node_id: str, node_name: str, config: dict):
|
||||||
|
self.config_validate(node_id, node_name, config_cls, config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
return validate
|
||||||
@@ -5,7 +5,7 @@
|
|||||||
import re
|
import re
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
|
|
||||||
from app.core.logging_config import get_logger
|
from app.core.logging_config import get_logger
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
@@ -52,10 +52,11 @@ class OutputContent(BaseModel):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
_SCOPE: str | None = None
|
_SCOPE: str | None = PrivateAttr(default=None)
|
||||||
|
|
||||||
def get_scope(self) -> str:
|
def get_scope(self) -> str | None:
|
||||||
self._SCOPE = SCOPE_PATTERN.findall(self.literal)[0]
|
matches = SCOPE_PATTERN.findall(self.literal)
|
||||||
|
self._SCOPE = matches[0] if matches else None
|
||||||
return self._SCOPE
|
return self._SCOPE
|
||||||
|
|
||||||
def depends_on_scope(self, scope: str) -> bool:
|
def depends_on_scope(self, scope: str) -> bool:
|
||||||
@@ -68,6 +69,8 @@ class OutputContent(BaseModel):
|
|||||||
Returns:
|
Returns:
|
||||||
bool: True if this segment references the given scope.
|
bool: True if this segment references the given scope.
|
||||||
"""
|
"""
|
||||||
|
if not self.is_variable:
|
||||||
|
return False
|
||||||
if self._SCOPE:
|
if self._SCOPE:
|
||||||
return self._SCOPE == scope
|
return self._SCOPE == scope
|
||||||
return self.get_scope() == scope
|
return self.get_scope() == scope
|
||||||
@@ -152,7 +155,7 @@ class StreamOutputConfig(BaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# Case 1: resolve control branch dependency
|
# Case 1: resolve control branch dependency
|
||||||
if scope in self.control_nodes.keys():
|
if scope in self.control_nodes:
|
||||||
if status is None:
|
if status is None:
|
||||||
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
raise RuntimeError("[Stream Output] Control node activation status not provided")
|
||||||
if status in self.control_nodes[scope]:
|
if status in self.control_nodes[scope]:
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
@@ -15,6 +16,7 @@ from app.core.workflow.variable.base_variable import VariableType, FileObject
|
|||||||
from app.db import get_db_read
|
from app.db import get_db_read
|
||||||
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy
|
||||||
from app.schemas import FileInput
|
from app.schemas import FileInput
|
||||||
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -619,11 +621,12 @@ class BaseNode(ABC):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def process_message(
|
async def process_message(
|
||||||
provider: str,
|
api_config: ModelInfo,
|
||||||
is_omni: bool,
|
|
||||||
content: str | dict | FileObject,
|
content: str | dict | FileObject,
|
||||||
|
end_user_id: str,
|
||||||
enable_file=False
|
enable_file=False
|
||||||
) -> list | str | None:
|
) -> list | str | None:
|
||||||
|
provider = api_config.provider
|
||||||
if isinstance(content, dict):
|
if isinstance(content, dict):
|
||||||
content = FileObject(
|
content = FileObject(
|
||||||
type=content.get("type"),
|
type=content.get("type"),
|
||||||
@@ -642,16 +645,20 @@ class BaseNode(ABC):
|
|||||||
if content.content_cache.get(provider):
|
if content.content_cache.get(provider):
|
||||||
return content.content_cache[provider]
|
return content.content_cache[provider]
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
multimodel_service = MultimodalService(db, provider, is_omni=is_omni)
|
multimodel_service = MultimodalService(db, api_config=api_config)
|
||||||
message = await multimodel_service.process_files(
|
file_obj = FileInput(
|
||||||
[FileInput.model_construct(
|
type=content.type,
|
||||||
type=content.type,
|
url=content.url,
|
||||||
url=content.url,
|
transfer_method=content.transfer_method,
|
||||||
transfer_method=content.transfer_method,
|
origin_file_type=content.origin_file_type,
|
||||||
file_type=content.origin_file_type,
|
upload_file_id=uuid.UUID(content.file_id) if content.file_id else None,
|
||||||
upload_file_id=content.file_id
|
|
||||||
)]
|
|
||||||
)
|
)
|
||||||
|
file_obj.set_content(content.get_content())
|
||||||
|
message = await multimodel_service.process_files(
|
||||||
|
end_user_id,
|
||||||
|
[file_obj],
|
||||||
|
)
|
||||||
|
content.set_content(file_obj.get_content())
|
||||||
if message:
|
if message:
|
||||||
content.content_cache[provider] = message
|
content.content_cache[provider] = message
|
||||||
return message
|
return message
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from app.core.workflow.nodes.question_classifier.config import QuestionClassifie
|
|||||||
from app.core.workflow.nodes.start.config import StartNodeConfig
|
from app.core.workflow.nodes.start.config import StartNodeConfig
|
||||||
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
from app.core.workflow.nodes.tool.config import ToolNodeConfig
|
||||||
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
from app.core.workflow.nodes.variable_aggregator.config import VariableAggregatorNodeConfig
|
||||||
|
from app.core.workflow.nodes.notes.config import NoteNodeConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 基础类
|
# 基础类
|
||||||
@@ -47,5 +48,6 @@ __all__ = [
|
|||||||
"ToolNodeConfig",
|
"ToolNodeConfig",
|
||||||
"MemoryReadNodeConfig",
|
"MemoryReadNodeConfig",
|
||||||
"MemoryWriteNodeConfig",
|
"MemoryWriteNodeConfig",
|
||||||
"CodeNodeConfig"
|
"CodeNodeConfig",
|
||||||
|
"NoteNodeConfig"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from pydantic import Field, BaseModel, field_validator
|
|||||||
|
|
||||||
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpAuthType, HttpContentType, HttpErrorHandle
|
||||||
|
from app.core.workflow.variable.base_variable import FileObject
|
||||||
|
|
||||||
|
|
||||||
class HttpAuthConfig(BaseModel):
|
class HttpAuthConfig(BaseModel):
|
||||||
@@ -260,6 +261,11 @@ class HttpRequestNodeOutput(BaseModel):
|
|||||||
description="Http response headers"
|
description="Http response headers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
files: list[FileObject] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of files",
|
||||||
|
)
|
||||||
|
|
||||||
output: str = Field(
|
output: str = Field(
|
||||||
default="SUCCESS",
|
default="SUCCESS",
|
||||||
description="HTTP response body",
|
description="HTTP response body",
|
||||||
|
|||||||
@@ -1,24 +1,146 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import mimetypes
|
||||||
import uuid
|
import uuid
|
||||||
|
import imghdr
|
||||||
|
from email.message import Message
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
# import filetypes # TODO: File support (Feature)
|
|
||||||
from httpx import AsyncClient, Response, Timeout
|
from httpx import AsyncClient, Response, Timeout
|
||||||
|
import magic
|
||||||
|
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType
|
||||||
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.utils.file_processer import mime_to_file_type
|
||||||
|
from app.core.workflow.variable.base_variable import VariableType, FileObject
|
||||||
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable
|
||||||
|
from app.schemas import FileType, TransferMethod
|
||||||
|
|
||||||
logger = logging.getLogger(__file__)
|
logger = logging.getLogger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
class HttpResponse:
|
||||||
|
def __init__(self, response: httpx.Response):
|
||||||
|
self.response = response
|
||||||
|
self.headers = dict(response.headers)
|
||||||
|
|
||||||
|
self._is_file: bool | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_type(self) -> str:
|
||||||
|
return self.headers.get("content-type", "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content_disposition(self) -> Message | None:
|
||||||
|
content_disposition = self.headers.get("content-disposition", "")
|
||||||
|
if content_disposition:
|
||||||
|
msg = Message()
|
||||||
|
msg["content-disposition"] = content_disposition
|
||||||
|
return msg
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_file(self) -> bool:
|
||||||
|
if self._is_file is not None:
|
||||||
|
return self._is_file
|
||||||
|
content_type = self.content_type.split(";")[0].strip().lower()
|
||||||
|
|
||||||
|
parsed_content_disposition = self.content_disposition
|
||||||
|
if parsed_content_disposition:
|
||||||
|
disp_type = parsed_content_disposition.get_content_disposition()
|
||||||
|
filename = parsed_content_disposition.get_filename()
|
||||||
|
if disp_type == "attachment" or filename:
|
||||||
|
self._is_file = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
if content_type.startswith("text/") and "csv" not in content_type:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if content_type.startswith("application/"):
|
||||||
|
if any(
|
||||||
|
text_type in content_type
|
||||||
|
for text_type in {"json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql"}
|
||||||
|
):
|
||||||
|
self._is_file = False
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
content_sample = self.response.content[:1024]
|
||||||
|
content_sample.decode("utf-8")
|
||||||
|
text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ")
|
||||||
|
if any(marker in content_sample for marker in text_markers):
|
||||||
|
return False
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
self._is_file = True
|
||||||
|
return True
|
||||||
|
|
||||||
|
main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or ""))
|
||||||
|
if main_type:
|
||||||
|
self._is_file = main_type.split("/")[0] in ("application", "image", "audio", "video")
|
||||||
|
return self._is_file
|
||||||
|
self._is_file = any(media_type in content_type for media_type in ("image/", "audio/", "video/"))
|
||||||
|
return self._is_file
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_image(self):
|
||||||
|
if self.is_file:
|
||||||
|
kind = imghdr.what(None, h=self.response.content)
|
||||||
|
return kind is not None
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url(self) -> str:
|
||||||
|
return str(self.response.url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def body(self) -> str:
|
||||||
|
if self.is_file:
|
||||||
|
return f"{'!' if self.is_image else ''}[file]({self.url})"
|
||||||
|
return self.response.text
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_type(file_bytes) -> tuple[FileType | None, str | None]:
|
||||||
|
mime = magic.from_buffer(file_bytes, mime=True)
|
||||||
|
|
||||||
|
if mime.startswith("image"):
|
||||||
|
return FileType.IMAGE, mime
|
||||||
|
elif mime.startswith("video"):
|
||||||
|
return FileType.VIDEO, mime
|
||||||
|
elif mime.startswith("audio"):
|
||||||
|
return FileType.AUDIO, mime
|
||||||
|
elif mime in ["application/pdf",
|
||||||
|
"application/msword",
|
||||||
|
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||||
|
"application/vnd.ms-excel",
|
||||||
|
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||||
|
"text/plain"]:
|
||||||
|
return FileType.DOCUMENT, mime
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def files(self) -> list[FileObject]:
|
||||||
|
file_type, mime_type = self.get_file_type(self.response.content)
|
||||||
|
origin_file_type = mime_to_file_type(mime_type)
|
||||||
|
if self.is_file and file_type and origin_file_type:
|
||||||
|
file_obj = FileObject(
|
||||||
|
type=file_type,
|
||||||
|
url=self.url,
|
||||||
|
transfer_method=TransferMethod.REMOTE_URL.value,
|
||||||
|
origin_file_type=origin_file_type,
|
||||||
|
file_id=None,
|
||||||
|
is_file=True
|
||||||
|
)
|
||||||
|
file_obj.set_content(self.response.content)
|
||||||
|
return [
|
||||||
|
file_obj
|
||||||
|
]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class HttpRequestNode(BaseNode):
|
class HttpRequestNode(BaseNode):
|
||||||
"""
|
"""
|
||||||
HTTP Request Workflow Node.
|
HTTP Request Workflow Node.
|
||||||
@@ -44,6 +166,7 @@ class HttpRequestNode(BaseNode):
|
|||||||
"body": VariableType.STRING,
|
"body": VariableType.STRING,
|
||||||
"status_code": VariableType.NUMBER,
|
"status_code": VariableType.NUMBER,
|
||||||
"headers": VariableType.OBJECT,
|
"headers": VariableType.OBJECT,
|
||||||
|
"files": VariableType.ARRAY_FILE,
|
||||||
"output": VariableType.STRING
|
"output": VariableType.STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -232,10 +355,12 @@ class HttpRequestNode(BaseNode):
|
|||||||
)
|
)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
logger.info(f"Node {self.node_id}: HTTP request succeeded")
|
||||||
|
response = HttpResponse(resp)
|
||||||
return HttpRequestNodeOutput(
|
return HttpRequestNodeOutput(
|
||||||
body=resp.text,
|
body=response.body,
|
||||||
status_code=resp.status_code,
|
status_code=resp.status_code,
|
||||||
headers=resp.headers,
|
headers=resp.headers,
|
||||||
|
files=response.files
|
||||||
).model_dump()
|
).model_dump()
|
||||||
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
except (httpx.HTTPStatusError, httpx.RequestError) as e:
|
||||||
logger.error(f"HTTP request node exception: {e}")
|
logger.error(f"HTTP request node exception: {e}")
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator
|
from app.core.workflow.nodes.enums import ComparisonOperator, LogicOperator, ValueInputType
|
||||||
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
from app.core.workflow.nodes.if_else import IfElseNodeConfig
|
||||||
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
from app.core.workflow.nodes.operators import ConditionExpressionResolver, CompareOperatorInstance
|
||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
@@ -23,6 +23,26 @@ class IfElseNode(BaseNode):
|
|||||||
"output": VariableType.STRING
|
"output": VariableType.STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
result = []
|
||||||
|
for case in self.typed_config.cases:
|
||||||
|
expressions = []
|
||||||
|
for expression in case.expressions:
|
||||||
|
expressions.append({
|
||||||
|
"left": self.get_variable(expression.left, variable_pool, strict=False),
|
||||||
|
"right": expression.right
|
||||||
|
if expression.input_type == ValueInputType.CONSTANT
|
||||||
|
else self.get_variable(expression.right, variable_pool, strict=False),
|
||||||
|
"operator": expression.operator,
|
||||||
|
})
|
||||||
|
result.append({
|
||||||
|
"expressions": expressions,
|
||||||
|
"logical_operator": case.logical_operator,
|
||||||
|
})
|
||||||
|
return {
|
||||||
|
"cases": result
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
def _evaluate(operator, instance: CompareOperatorInstance) -> Any:
|
||||||
match operator:
|
match operator:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from typing import Any
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.models import RedBearRerank, RedBearModelConfig
|
from app.core.models import RedBearRerank, RedBearModelConfig
|
||||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
|
||||||
from app.core.workflow.engine.state_manager import WorkflowState
|
from app.core.workflow.engine.state_manager import WorkflowState
|
||||||
from app.core.workflow.engine.variable_pool import VariablePool
|
from app.core.workflow.engine.variable_pool import VariablePool
|
||||||
from app.core.workflow.nodes.base_node import BaseNode
|
from app.core.workflow.nodes.base_node import BaseNode
|
||||||
@@ -24,12 +24,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
|
||||||
super().__init__(node_config, workflow_config)
|
super().__init__(node_config, workflow_config)
|
||||||
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
|
||||||
|
self.vector_service: ElasticSearchVector | None = None
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
"output": VariableType.ARRAY_STRING
|
"output": VariableType.ARRAY_STRING
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"query": self._render_template(self.typed_config.query, variable_pool),
|
||||||
|
"knowledge_bases": [kb_config.model_dump(mode="json") for kb_config in self.typed_config.knowledge_bases],
|
||||||
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
def _build_kb_filter(kb_ids: list[uuid.UUID], permission: knowledge_model.PermissionType):
|
||||||
"""
|
"""
|
||||||
@@ -157,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
)
|
)
|
||||||
return reranker
|
return reranker
|
||||||
|
|
||||||
|
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
|
||||||
|
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
|
||||||
|
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
|
||||||
|
for child in children:
|
||||||
|
if not (child and child.chunk_num > 0 and child.status == 1):
|
||||||
|
continue
|
||||||
|
kb_config.kb_id = child.id
|
||||||
|
self.knowledge_retrieval(db, query, rs, child, kb_config)
|
||||||
|
return
|
||||||
|
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
||||||
|
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
||||||
|
match kb_config.retrieve_type:
|
||||||
|
case RetrieveType.PARTICIPLE:
|
||||||
|
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold))
|
||||||
|
case RetrieveType.SEMANTIC:
|
||||||
|
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight))
|
||||||
|
case RetrieveType.HYBRID:
|
||||||
|
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.vector_similarity_weight)
|
||||||
|
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
||||||
|
indices=indices,
|
||||||
|
score_threshold=kb_config.similarity_threshold)
|
||||||
|
|
||||||
|
# Deduplicate hybrid retrieval results
|
||||||
|
unique_rs = self._deduplicate_docs(rs1, rs2)
|
||||||
|
if not unique_rs:
|
||||||
|
return
|
||||||
|
if self.typed_config.reranker_id:
|
||||||
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
|
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
||||||
|
else:
|
||||||
|
rs.extend(sorted(
|
||||||
|
unique_rs,
|
||||||
|
key=lambda d: d.metadata.get("score", 0),
|
||||||
|
reverse=True
|
||||||
|
)[:kb_config.top_k])
|
||||||
|
case _:
|
||||||
|
raise RuntimeError("Unknown retrieval type")
|
||||||
|
|
||||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||||
"""
|
"""
|
||||||
Execute the knowledge retrieval workflow node.
|
Execute the knowledge retrieval workflow node.
|
||||||
@@ -185,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
|
|||||||
query = self._render_template(self.typed_config.query, variable_pool)
|
query = self._render_template(self.typed_config.query, variable_pool)
|
||||||
with get_db_read() as db:
|
with get_db_read() as db:
|
||||||
knowledge_bases = self.typed_config.knowledge_bases
|
knowledge_bases = self.typed_config.knowledge_bases
|
||||||
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
|
|
||||||
|
|
||||||
if not existing_ids:
|
|
||||||
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
|
|
||||||
|
|
||||||
rs = []
|
rs = []
|
||||||
for kb_config in knowledge_bases:
|
for kb_config in knowledge_bases:
|
||||||
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
|
||||||
if not db_knowledge:
|
if not db_knowledge:
|
||||||
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
raise RuntimeError("The knowledge base does not exist or access is denied.")
|
||||||
|
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
|
||||||
|
|
||||||
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
|
|
||||||
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
|
|
||||||
match kb_config.retrieve_type:
|
|
||||||
case RetrieveType.PARTICIPLE:
|
|
||||||
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold))
|
|
||||||
case RetrieveType.SEMANTIC:
|
|
||||||
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight))
|
|
||||||
case RetrieveType.HYBRID:
|
|
||||||
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.vector_similarity_weight)
|
|
||||||
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
|
|
||||||
indices=indices,
|
|
||||||
score_threshold=kb_config.similarity_threshold)
|
|
||||||
|
|
||||||
# Deduplicate hy brid retrieval results
|
|
||||||
unique_rs = self._deduplicate_docs(rs1, rs2)
|
|
||||||
if not unique_rs:
|
|
||||||
continue
|
|
||||||
if self.typed_config.reranker_id:
|
|
||||||
vector_service.reranker = self.get_reranker_model()
|
|
||||||
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
|
|
||||||
else:
|
|
||||||
rs.extend(sorted(
|
|
||||||
unique_rs,
|
|
||||||
key=lambda d: d.metadata.get("score", 0),
|
|
||||||
reverse=True
|
|
||||||
)[:kb_config.top_k])
|
|
||||||
case _:
|
|
||||||
raise RuntimeError("Unknown retrieval type")
|
|
||||||
if not rs:
|
if not rs:
|
||||||
return []
|
return []
|
||||||
if self.typed_config.reranker_id:
|
if self.typed_config.reranker_id:
|
||||||
vector_service.reranker = self.get_reranker_model()
|
self.vector_service.reranker = self.get_reranker_model()
|
||||||
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
|
||||||
else:
|
else:
|
||||||
final_rs = sorted(
|
final_rs = sorted(
|
||||||
rs,
|
rs,
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
|||||||
from app.core.workflow.variable.base_variable import VariableType
|
from app.core.workflow.variable.base_variable import VariableType
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models import ModelType
|
from app.models import ModelType
|
||||||
|
from app.schemas.model_schema import ModelInfo
|
||||||
from app.services.model_service import ModelConfigService
|
from app.services.model_service import ModelConfigService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -113,12 +114,15 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 在 Session 关闭前提取所有需要的数据
|
# 在 Session 关闭前提取所有需要的数据
|
||||||
api_config = self.model_balance(config)
|
api_config = self.model_balance(config)
|
||||||
model_name = api_config.model_name
|
model_info = ModelInfo(
|
||||||
provider = api_config.provider
|
model_name=api_config.model_name,
|
||||||
api_key = api_config.api_key
|
model_type=ModelType(config.type),
|
||||||
api_base = api_config.api_base
|
api_key=api_config.api_key,
|
||||||
is_omni = api_config.is_omni
|
api_base=api_config.api_base,
|
||||||
model_type = config.type
|
provider=api_config.provider,
|
||||||
|
is_omni=api_config.is_omni,
|
||||||
|
capability=api_config.capability
|
||||||
|
)
|
||||||
|
|
||||||
# 4. 创建 LLM 实例(使用已提取的数据)
|
# 4. 创建 LLM 实例(使用已提取的数据)
|
||||||
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
# 注意:对于流式输出,需要在模型初始化时设置 streaming=True
|
||||||
@@ -126,17 +130,18 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
llm = RedBearLLM(
|
llm = RedBearLLM(
|
||||||
RedBearModelConfig(
|
RedBearModelConfig(
|
||||||
model_name=model_name,
|
model_name=model_info.model_name,
|
||||||
provider=provider,
|
provider=model_info.provider,
|
||||||
api_key=api_key,
|
api_key=model_info.api_key,
|
||||||
base_url=api_base,
|
base_url=model_info.api_base,
|
||||||
extra_params=extra_params,
|
extra_params=extra_params,
|
||||||
is_omni=is_omni
|
is_omni=model_info.is_omni
|
||||||
),
|
),
|
||||||
type=ModelType(model_type)
|
type=model_info.model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"创建 LLM 实例: provider={provider}, model={model_name}, streaming={stream}")
|
logger.debug(
|
||||||
|
f"创建 LLM 实例: provider={model_info.provider}, model={model_info.model_name}, streaming={stream}")
|
||||||
|
|
||||||
messages_config = self.typed_config.messages
|
messages_config = self.typed_config.messages
|
||||||
|
|
||||||
@@ -148,35 +153,40 @@ class LLMNode(BaseNode):
|
|||||||
content_template = msg_config.content
|
content_template = msg_config.content
|
||||||
content_template = self._render_context(content_template, variable_pool)
|
content_template = self._render_context(content_template, variable_pool)
|
||||||
content = self._render_template(content_template, variable_pool)
|
content = self._render_template(content_template, variable_pool)
|
||||||
|
user_id = self.get_variable("sys.user_id", variable_pool)
|
||||||
# 根据角色创建对应的消息对象
|
# 根据角色创建对应的消息对象
|
||||||
if role == "system":
|
if role == "system":
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "system",
|
"role": "system",
|
||||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
"content": await self.process_message(
|
||||||
|
model_info,
|
||||||
|
content,
|
||||||
|
user_id,
|
||||||
|
self.typed_config.vision,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
elif role in ["user", "human"]:
|
elif role in ["user", "human"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
elif role in ["ai", "assistant"]:
|
elif role in ["ai", "assistant"]:
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
logger.warning(f"未知的消息角色: {role},默认使用 user")
|
||||||
messages.append({
|
messages.append({
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": await self.process_message(provider, is_omni, content, self.typed_config.vision)
|
"content": await self.process_message(model_info, content, user_id, self.typed_config.vision)
|
||||||
})
|
})
|
||||||
|
|
||||||
if self.typed_config.vision_input and self.typed_config.vision:
|
if self.typed_config.vision_input and self.typed_config.vision:
|
||||||
file_content = []
|
file_content = []
|
||||||
files = variable_pool.get_instance(self.typed_config.vision_input)
|
files = variable_pool.get_instance(self.typed_config.vision_input)
|
||||||
for file in files.value:
|
for file in files.value:
|
||||||
content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision)
|
content = await self.process_message(model_info, file.value, user_id, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
if messages and messages[-1]["role"] == 'user':
|
if messages and messages[-1]["role"] == 'user':
|
||||||
@@ -190,14 +200,19 @@ class LLMNode(BaseNode):
|
|||||||
if isinstance(message["content"], list):
|
if isinstance(message["content"], list):
|
||||||
file_content = []
|
file_content = []
|
||||||
for file in message["content"]:
|
for file in message["content"]:
|
||||||
content = await self.process_message(provider, is_omni, file, self.typed_config.vision)
|
content = await self.process_message(model_info, file, user_id, self.typed_config.vision)
|
||||||
if content:
|
if content:
|
||||||
file_content.extend(content)
|
file_content.extend(content)
|
||||||
history_message.append(
|
history_message.append(
|
||||||
{"role": message["role"], "content": file_content}
|
{"role": message["role"], "content": file_content}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision)
|
message["content"] = await self.process_message(
|
||||||
|
model_info,
|
||||||
|
message["content"],
|
||||||
|
user_id,
|
||||||
|
self.typed_config.vision
|
||||||
|
)
|
||||||
history_message.append(message)
|
history_message.append(message)
|
||||||
messages = messages[:-1] + history_message + messages[-1:]
|
messages = messages[:-1] + history_message + messages[-1:]
|
||||||
self.messages = messages
|
self.messages = messages
|
||||||
@@ -293,7 +308,7 @@ class LLMNode(BaseNode):
|
|||||||
|
|
||||||
# 调用 LLM(流式,支持字符串或消息列表)
|
# 调用 LLM(流式,支持字符串或消息列表)
|
||||||
last_meta_data = {}
|
last_meta_data = {}
|
||||||
async for chunk in llm.astream(self.messages, stream_usage=True):
|
async for chunk in llm.astream(self.messages):
|
||||||
# 提取内容
|
# 提取内容
|
||||||
if hasattr(chunk, 'content'):
|
if hasattr(chunk, 'content'):
|
||||||
content = self.process_model_output(chunk.content)
|
content = self.process_model_output(chunk.content)
|
||||||
|
|||||||
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
0
api/app/core/workflow/nodes/notes/__init__.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
12
api/app/core/workflow/nodes/notes/config.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from app.core.workflow.nodes.base_config import BaseNodeConfig
|
||||||
|
|
||||||
|
|
||||||
|
class NoteNodeConfig(BaseNodeConfig):
|
||||||
|
author: str = Field(default="", description="author")
|
||||||
|
text: str = Field(default="", description="note content")
|
||||||
|
width: int = Field(default=80)
|
||||||
|
height: int = Field(default=80)
|
||||||
|
theme: str = Field(default="blue")
|
||||||
|
show_author: bool = Field(default=True)
|
||||||
@@ -37,6 +37,14 @@ class ParameterExtractorNode(BaseNode):
|
|||||||
}
|
}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"text": self._render_template(self.typed_config.text, variable_pool),
|
||||||
|
"prompt": self._render_template(self.typed_config.prompt, variable_pool),
|
||||||
|
"params": [param.model_dump(mode="json") for param in self.typed_config.params],
|
||||||
|
"model_id": str(self.typed_config.model_id),
|
||||||
|
}
|
||||||
|
|
||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
outputs = {}
|
outputs = {}
|
||||||
for param in self.typed_config.params:
|
for param in self.typed_config.params:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ class ToolNode(BaseNode):
|
|||||||
def _output_types(self) -> dict[str, VariableType]:
|
def _output_types(self) -> dict[str, VariableType]:
|
||||||
return {
|
return {
|
||||||
"data": VariableType.STRING,
|
"data": VariableType.STRING,
|
||||||
"error_code": VariableType.STRING,
|
|
||||||
"execution_time": VariableType.NUMBER
|
"execution_time": VariableType.NUMBER
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -48,10 +47,7 @@ class ToolNode(BaseNode):
|
|||||||
|
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
logger.error(f"节点 {self.node_id} 缺少租户ID")
|
||||||
return {
|
raise ValueError("缺少租户ID")
|
||||||
"success": False,
|
|
||||||
"data": "缺少租户ID"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 渲染工具参数
|
# 渲染工具参数
|
||||||
rendered_parameters = {}
|
rendered_parameters = {}
|
||||||
@@ -83,13 +79,8 @@ class ToolNode(BaseNode):
|
|||||||
logger.info(f"节点 {self.node_id} 工具执行成功")
|
logger.info(f"节点 {self.node_id} 工具执行成功")
|
||||||
return {
|
return {
|
||||||
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
|
"data": result.data if isinstance(result.data, str) else json.dumps(result.data, ensure_ascii=False),
|
||||||
"error_code": "",
|
|
||||||
"execution_time": result.execution_time
|
"execution_time": result.execution_time
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
logger.error(f"节点 {self.node_id} 工具执行失败: {result.error}")
|
||||||
return {
|
raise ValueError(f"工具执行失败: {result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False)}")
|
||||||
"data": result.error if isinstance(result.error, str) else json.dumps(result.error, ensure_ascii=False),
|
|
||||||
"error_code": result.error_code,
|
|
||||||
"execution_time": result.execution_time
|
|
||||||
}
|
|
||||||
|
|||||||
56
api/app/core/workflow/utils/file_processer.py
Normal file
56
api/app/core/workflow/utils/file_processer.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# Author: Eternity
|
||||||
|
# @Email: 1533512157@qq.com
|
||||||
|
# @Time : 2026/3/10 13:36
|
||||||
|
TRANSFORM_FILE_TYPE = {
|
||||||
|
'text/plain': 'document/text',
|
||||||
|
'text/markdown': 'document/markdown',
|
||||||
|
'text/x-markdown': 'document/x-markdown',
|
||||||
|
|
||||||
|
'application/pdf': 'document/pdf',
|
||||||
|
|
||||||
|
'application/msword': 'document/doc',
|
||||||
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx',
|
||||||
|
|
||||||
|
'application/vnd.ms-powerpoint': 'document/ppt',
|
||||||
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx',
|
||||||
|
}
|
||||||
|
ALLOWED_FILE_TYPES = [
|
||||||
|
'text/plain',
|
||||||
|
'text/markdown',
|
||||||
|
'text/x-markdown',
|
||||||
|
'application/pdf',
|
||||||
|
'application/msword',
|
||||||
|
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
|
||||||
|
'application/vnd.ms-powerpoint',
|
||||||
|
'application/vnd.openxmlformats-officedocument.presentationml.presentation',
|
||||||
|
'image/jpg',
|
||||||
|
'image/jpeg',
|
||||||
|
'image/png',
|
||||||
|
'image/gif',
|
||||||
|
'image/bmp',
|
||||||
|
'image/webp',
|
||||||
|
'image/svg+xml',
|
||||||
|
'video/mp4',
|
||||||
|
'video/quicktime',
|
||||||
|
'video/x-msvideo',
|
||||||
|
'video/x-matroska',
|
||||||
|
'video/webm',
|
||||||
|
'video/x-flv',
|
||||||
|
'video/x-ms-wmv',
|
||||||
|
'audio/mpeg',
|
||||||
|
'audio/wav',
|
||||||
|
'audio/ogg',
|
||||||
|
'audio/aac',
|
||||||
|
'audio/flac',
|
||||||
|
'audio/mp4',
|
||||||
|
'audio/x-ms-wma',
|
||||||
|
'audio/x-m4a',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def mime_to_file_type(mime_type):
|
||||||
|
if mime_type not in ALLOWED_FILE_TYPES:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return TRANSFORM_FILE_TYPE.get(mime_type, mime_type)
|
||||||
@@ -114,9 +114,16 @@ class FileObject(BaseModel):
|
|||||||
file_id: str | None
|
file_id: str | None
|
||||||
|
|
||||||
content_cache: dict = Field(default_factory=dict)
|
content_cache: dict = Field(default_factory=dict)
|
||||||
|
|
||||||
is_file: bool
|
is_file: bool
|
||||||
|
|
||||||
|
_byte_content: bytes | None = None
|
||||||
|
|
||||||
|
def get_content(self):
|
||||||
|
return self._byte_content
|
||||||
|
|
||||||
|
def set_content(self, byte_content):
|
||||||
|
self._byte_content = byte_content
|
||||||
|
|
||||||
|
|
||||||
class BaseVariable(ABC):
|
class BaseVariable(ABC):
|
||||||
"""Abstract base class for all workflow variables.
|
"""Abstract base class for all workflow variables.
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ engine = create_engine(
|
|||||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||||
connect_args={
|
connect_args={
|
||||||
"options": "-c timezone=Asia/Shanghai -c statement_timeout=60000"
|
"options": "-c timezone=UTC -c statement_timeout=60000"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|||||||
61
api/app/i18n/README.md
Normal file
61
api/app/i18n/README.md
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
# Internationalization (i18n) Module
|
||||||
|
|
||||||
|
This module provides internationalization support for the MemoryBear API.
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
- `service.py` - Translation service and core translation logic
|
||||||
|
- `middleware.py` - Language detection middleware
|
||||||
|
- `dependencies.py` - FastAPI dependency injection functions
|
||||||
|
- `exceptions.py` - Internationalized exception classes
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Translation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.i18n import t
|
||||||
|
|
||||||
|
# Simple translation
|
||||||
|
message = t("common.success.created")
|
||||||
|
|
||||||
|
# Parameterized translation
|
||||||
|
message = t("common.validation.required", field="Name")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Enum Translation
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.i18n import t_enum
|
||||||
|
|
||||||
|
# Translate enum value
|
||||||
|
role_display = t_enum("workspace_role", "manager")
|
||||||
|
```
|
||||||
|
|
||||||
|
### In FastAPI Endpoints
|
||||||
|
|
||||||
|
```python
|
||||||
|
from fastapi import Depends
|
||||||
|
from app.i18n.dependencies import get_translator
|
||||||
|
|
||||||
|
@router.post("/workspaces")
|
||||||
|
async def create_workspace(
|
||||||
|
data: WorkspaceCreate,
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
workspace = await workspace_service.create(data)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": t("workspace.created_successfully"),
|
||||||
|
"data": workspace
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
See `app/core/config.py` for i18n configuration options:
|
||||||
|
|
||||||
|
- `I18N_DEFAULT_LANGUAGE` - Default language (default: "zh")
|
||||||
|
- `I18N_SUPPORTED_LANGUAGES` - Supported languages (default: "zh,en")
|
||||||
|
- `I18N_ENABLE_TRANSLATION_CACHE` - Enable caching (default: true)
|
||||||
|
- `I18N_LOG_MISSING_TRANSLATIONS` - Log missing translations (default: true)
|
||||||
124
api/app/i18n/__init__.py
Normal file
124
api/app/i18n/__init__.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
"""
|
||||||
|
Internationalization (i18n) module for MemoryBear Enterprise.
|
||||||
|
|
||||||
|
This module provides complete i18n support for the backend API including:
|
||||||
|
- Translation loading from multiple directories (community + enterprise)
|
||||||
|
- Translation service with caching and fallback
|
||||||
|
- Language detection middleware
|
||||||
|
- Dependency injection for FastAPI
|
||||||
|
- Convenience functions for easy usage
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from app.i18n import t, t_enum
|
||||||
|
|
||||||
|
# Simple translation
|
||||||
|
message = t("common.success.created")
|
||||||
|
|
||||||
|
# Parameterized translation
|
||||||
|
error = t("common.validation.required", field="名称")
|
||||||
|
|
||||||
|
# Enum translation
|
||||||
|
role_display = t_enum("workspace_role", "manager")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.i18n.dependencies import (
|
||||||
|
get_current_language,
|
||||||
|
get_enum_translator,
|
||||||
|
get_translator,
|
||||||
|
)
|
||||||
|
from app.i18n.exceptions import (
|
||||||
|
BadRequestError,
|
||||||
|
ConflictError,
|
||||||
|
FileNotFoundError,
|
||||||
|
FileTooLargeError,
|
||||||
|
ForbiddenError,
|
||||||
|
I18nException,
|
||||||
|
InternalServerError,
|
||||||
|
InvalidCredentialsError,
|
||||||
|
InvalidFileTypeError,
|
||||||
|
NotFoundError,
|
||||||
|
QuotaExceededError,
|
||||||
|
RateLimitExceededError,
|
||||||
|
ServiceUnavailableError,
|
||||||
|
TenantNotFoundError,
|
||||||
|
TenantSuspendedError,
|
||||||
|
TokenExpiredError,
|
||||||
|
TokenInvalidError,
|
||||||
|
UnauthorizedError,
|
||||||
|
UserAlreadyExistsError,
|
||||||
|
UserNotFoundError,
|
||||||
|
ValidationError,
|
||||||
|
WorkspaceNotFoundError,
|
||||||
|
WorkspacePermissionDeniedError,
|
||||||
|
get_current_locale,
|
||||||
|
set_current_locale,
|
||||||
|
)
|
||||||
|
from app.i18n.loader import TranslationLoader
|
||||||
|
from app.i18n.logger import (
|
||||||
|
TranslationLogger,
|
||||||
|
get_translation_logger,
|
||||||
|
log_missing_translation,
|
||||||
|
log_translation_error,
|
||||||
|
)
|
||||||
|
from app.i18n.middleware import LanguageMiddleware
|
||||||
|
from app.i18n.serializers import (
|
||||||
|
I18nResponseMixin,
|
||||||
|
WorkspaceSerializer,
|
||||||
|
WorkspaceMemberSerializer,
|
||||||
|
WorkspaceInviteSerializer,
|
||||||
|
)
|
||||||
|
from app.i18n.service import (
|
||||||
|
TranslationService,
|
||||||
|
get_translation_service,
|
||||||
|
t,
|
||||||
|
t_enum,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TranslationLoader",
|
||||||
|
"LanguageMiddleware",
|
||||||
|
"TranslationService",
|
||||||
|
"get_translation_service",
|
||||||
|
"t",
|
||||||
|
"t_enum",
|
||||||
|
"get_current_language",
|
||||||
|
"get_translator",
|
||||||
|
"get_enum_translator",
|
||||||
|
# Context management
|
||||||
|
"get_current_locale",
|
||||||
|
"set_current_locale",
|
||||||
|
# Logging
|
||||||
|
"TranslationLogger",
|
||||||
|
"get_translation_logger",
|
||||||
|
"log_missing_translation",
|
||||||
|
"log_translation_error",
|
||||||
|
# Serializers
|
||||||
|
"I18nResponseMixin",
|
||||||
|
"WorkspaceSerializer",
|
||||||
|
"WorkspaceMemberSerializer",
|
||||||
|
"WorkspaceInviteSerializer",
|
||||||
|
# Exception classes
|
||||||
|
"I18nException",
|
||||||
|
"BadRequestError",
|
||||||
|
"UnauthorizedError",
|
||||||
|
"ForbiddenError",
|
||||||
|
"NotFoundError",
|
||||||
|
"ConflictError",
|
||||||
|
"ValidationError",
|
||||||
|
"InternalServerError",
|
||||||
|
"ServiceUnavailableError",
|
||||||
|
"WorkspaceNotFoundError",
|
||||||
|
"WorkspacePermissionDeniedError",
|
||||||
|
"UserNotFoundError",
|
||||||
|
"UserAlreadyExistsError",
|
||||||
|
"TenantNotFoundError",
|
||||||
|
"TenantSuspendedError",
|
||||||
|
"InvalidCredentialsError",
|
||||||
|
"TokenExpiredError",
|
||||||
|
"TokenInvalidError",
|
||||||
|
"FileNotFoundError",
|
||||||
|
"FileTooLargeError",
|
||||||
|
"InvalidFileTypeError",
|
||||||
|
"RateLimitExceededError",
|
||||||
|
"QuotaExceededError",
|
||||||
|
]
|
||||||
291
api/app/i18n/cache.py
Normal file
291
api/app/i18n/cache.py
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
"""
|
||||||
|
Advanced caching system for i18n translations.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- LRU cache for hot translations
|
||||||
|
- Lazy loading mechanism
|
||||||
|
- Memory optimization
|
||||||
|
- Cache statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
from collections import OrderedDict
|
||||||
|
import time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationCache:
|
||||||
|
"""
|
||||||
|
Advanced translation cache with LRU eviction and lazy loading.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- LRU cache for frequently accessed translations
|
||||||
|
- Lazy loading to reduce startup time
|
||||||
|
- Memory-efficient storage
|
||||||
|
- Cache hit/miss statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, max_lru_size: int = 1000, enable_lazy_load: bool = True):
|
||||||
|
"""
|
||||||
|
Initialize the translation cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_lru_size: Maximum size of LRU cache for hot translations
|
||||||
|
enable_lazy_load: Enable lazy loading of locales
|
||||||
|
"""
|
||||||
|
self.max_lru_size = max_lru_size
|
||||||
|
self.enable_lazy_load = enable_lazy_load
|
||||||
|
|
||||||
|
# Main cache: {locale: {namespace: {key: value}}}
|
||||||
|
self._main_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# LRU cache for hot translations
|
||||||
|
self._lru_cache: OrderedDict = OrderedDict()
|
||||||
|
|
||||||
|
# Loaded locales tracker
|
||||||
|
self._loaded_locales: set = set()
|
||||||
|
|
||||||
|
# Statistics
|
||||||
|
self._stats = {
|
||||||
|
"hits": 0,
|
||||||
|
"misses": 0,
|
||||||
|
"lru_hits": 0,
|
||||||
|
"lru_misses": 0,
|
||||||
|
"lazy_loads": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"TranslationCache initialized with LRU size: {max_lru_size}, "
|
||||||
|
f"lazy loading: {enable_lazy_load}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_locale_data(self, locale: str, data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
Set translation data for a locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
data: Translation data dictionary
|
||||||
|
"""
|
||||||
|
self._main_cache[locale] = data
|
||||||
|
self._loaded_locales.add(locale)
|
||||||
|
logger.debug(f"Loaded locale '{locale}' into cache")
|
||||||
|
|
||||||
|
def get_translation(
|
||||||
|
self,
|
||||||
|
locale: str,
|
||||||
|
namespace: str,
|
||||||
|
key_path: list
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get translation from cache with LRU optimization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
namespace: Translation namespace
|
||||||
|
key_path: List of nested keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translation string or None if not found
|
||||||
|
"""
|
||||||
|
# Build cache key for LRU
|
||||||
|
cache_key = f"{locale}:{namespace}:{'.'.join(key_path)}"
|
||||||
|
|
||||||
|
# Check LRU cache first (hot translations)
|
||||||
|
if cache_key in self._lru_cache:
|
||||||
|
self._stats["lru_hits"] += 1
|
||||||
|
self._stats["hits"] += 1
|
||||||
|
# Move to end (most recently used)
|
||||||
|
self._lru_cache.move_to_end(cache_key)
|
||||||
|
return self._lru_cache[cache_key]
|
||||||
|
|
||||||
|
self._stats["lru_misses"] += 1
|
||||||
|
|
||||||
|
# Check main cache
|
||||||
|
if locale not in self._main_cache:
|
||||||
|
self._stats["misses"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
if namespace not in self._main_cache[locale]:
|
||||||
|
self._stats["misses"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Navigate through nested keys
|
||||||
|
current = self._main_cache[locale][namespace]
|
||||||
|
for key in key_path:
|
||||||
|
if isinstance(current, dict) and key in current:
|
||||||
|
current = current[key]
|
||||||
|
else:
|
||||||
|
self._stats["misses"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Return only if it's a string value
|
||||||
|
if not isinstance(current, str):
|
||||||
|
self._stats["misses"] += 1
|
||||||
|
return None
|
||||||
|
|
||||||
|
self._stats["hits"] += 1
|
||||||
|
|
||||||
|
# Add to LRU cache
|
||||||
|
self._add_to_lru(cache_key, current)
|
||||||
|
|
||||||
|
return current
|
||||||
|
|
||||||
|
def _add_to_lru(self, key: str, value: str):
|
||||||
|
"""
|
||||||
|
Add translation to LRU cache.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Cache key
|
||||||
|
value: Translation value
|
||||||
|
"""
|
||||||
|
# Remove oldest if cache is full
|
||||||
|
if len(self._lru_cache) >= self.max_lru_size:
|
||||||
|
self._lru_cache.popitem(last=False)
|
||||||
|
|
||||||
|
self._lru_cache[key] = value
|
||||||
|
|
||||||
|
def is_locale_loaded(self, locale: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a locale is loaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if locale is loaded
|
||||||
|
"""
|
||||||
|
return locale in self._loaded_locales
|
||||||
|
|
||||||
|
def get_loaded_locales(self) -> list:
|
||||||
|
"""
|
||||||
|
Get list of loaded locales.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of locale codes
|
||||||
|
"""
|
||||||
|
return list(self._loaded_locales)
|
||||||
|
|
||||||
|
def clear_lru(self):
|
||||||
|
"""Clear the LRU cache."""
|
||||||
|
self._lru_cache.clear()
|
||||||
|
logger.info("LRU cache cleared")
|
||||||
|
|
||||||
|
def clear_locale(self, locale: str):
|
||||||
|
"""
|
||||||
|
Clear cache for a specific locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
"""
|
||||||
|
if locale in self._main_cache:
|
||||||
|
del self._main_cache[locale]
|
||||||
|
self._loaded_locales.discard(locale)
|
||||||
|
|
||||||
|
# Clear related LRU entries
|
||||||
|
keys_to_remove = [k for k in self._lru_cache if k.startswith(f"{locale}:")]
|
||||||
|
for key in keys_to_remove:
|
||||||
|
del self._lru_cache[key]
|
||||||
|
|
||||||
|
logger.info(f"Cleared cache for locale '{locale}'")
|
||||||
|
|
||||||
|
def clear_all(self):
|
||||||
|
"""Clear all caches."""
|
||||||
|
self._main_cache.clear()
|
||||||
|
self._lru_cache.clear()
|
||||||
|
self._loaded_locales.clear()
|
||||||
|
logger.info("All caches cleared")
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache statistics
|
||||||
|
"""
|
||||||
|
total_requests = self._stats["hits"] + self._stats["misses"]
|
||||||
|
hit_rate = (
|
||||||
|
self._stats["hits"] / total_requests * 100
|
||||||
|
if total_requests > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
lru_total = self._stats["lru_hits"] + self._stats["lru_misses"]
|
||||||
|
lru_hit_rate = (
|
||||||
|
self._stats["lru_hits"] / lru_total * 100
|
||||||
|
if lru_total > 0
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"hits": self._stats["hits"],
|
||||||
|
"misses": self._stats["misses"],
|
||||||
|
"hit_rate": round(hit_rate, 2),
|
||||||
|
"lru_hits": self._stats["lru_hits"],
|
||||||
|
"lru_misses": self._stats["lru_misses"],
|
||||||
|
"lru_hit_rate": round(lru_hit_rate, 2),
|
||||||
|
"lru_size": len(self._lru_cache),
|
||||||
|
"lru_max_size": self.max_lru_size,
|
||||||
|
"loaded_locales": len(self._loaded_locales),
|
||||||
|
"lazy_loads": self._stats["lazy_loads"]
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset_stats(self):
|
||||||
|
"""Reset cache statistics."""
|
||||||
|
self._stats = {
|
||||||
|
"hits": 0,
|
||||||
|
"misses": 0,
|
||||||
|
"lru_hits": 0,
|
||||||
|
"lru_misses": 0,
|
||||||
|
"lazy_loads": 0
|
||||||
|
}
|
||||||
|
logger.info("Cache statistics reset")
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Estimate memory usage of the cache.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with memory usage information
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
|
||||||
|
main_cache_size = sys.getsizeof(self._main_cache)
|
||||||
|
lru_cache_size = sys.getsizeof(self._lru_cache)
|
||||||
|
|
||||||
|
# Rough estimate of nested data
|
||||||
|
for locale_data in self._main_cache.values():
|
||||||
|
main_cache_size += sys.getsizeof(locale_data)
|
||||||
|
for namespace_data in locale_data.values():
|
||||||
|
main_cache_size += sys.getsizeof(namespace_data)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"main_cache_bytes": main_cache_size,
|
||||||
|
"lru_cache_bytes": lru_cache_size,
|
||||||
|
"total_bytes": main_cache_size + lru_cache_size,
|
||||||
|
"main_cache_mb": round(main_cache_size / 1024 / 1024, 2),
|
||||||
|
"lru_cache_mb": round(lru_cache_size / 1024 / 1024, 2),
|
||||||
|
"total_mb": round((main_cache_size + lru_cache_size) / 1024 / 1024, 2)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=128)
|
||||||
|
def get_cached_translation_key(locale: str, namespace: str, key: str) -> str:
|
||||||
|
"""
|
||||||
|
LRU cached function for building translation cache keys.
|
||||||
|
|
||||||
|
This reduces string concatenation overhead for frequently accessed keys.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
namespace: Translation namespace
|
||||||
|
key: Translation key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Cache key string
|
||||||
|
"""
|
||||||
|
return f"{locale}:{namespace}:{key}"
|
||||||
158
api/app/i18n/dependencies.py
Normal file
158
api/app/i18n/dependencies.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""
|
||||||
|
FastAPI dependency injection functions for i18n.
|
||||||
|
|
||||||
|
This module provides dependency injection functions that can be used
|
||||||
|
in FastAPI route handlers to access the current language and translator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_language(request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Get the current language from the request context.
|
||||||
|
|
||||||
|
This dependency extracts the language that was determined by the
|
||||||
|
LanguageMiddleware and stored in request.state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language code (e.g., "zh", "en")
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/example")
|
||||||
|
async def example(language: str = Depends(get_current_language)):
|
||||||
|
return {"language": language}
|
||||||
|
"""
|
||||||
|
# Get language from request state (set by LanguageMiddleware)
|
||||||
|
language = getattr(request.state, "language", None)
|
||||||
|
|
||||||
|
if language is None:
|
||||||
|
# Fallback to default language if not set
|
||||||
|
from app.core.config import settings
|
||||||
|
language = settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
logger.warning(
|
||||||
|
"Language not found in request.state, using default: "
|
||||||
|
f"{language}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return language
|
||||||
|
|
||||||
|
|
||||||
|
async def get_translator(request: Request) -> Callable:
|
||||||
|
"""
|
||||||
|
Get a translator function bound to the current request's language.
|
||||||
|
|
||||||
|
This dependency returns a translation function that automatically
|
||||||
|
uses the current request's language, making it easy to translate
|
||||||
|
strings in route handlers.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translation function with signature: t(key: str, **params) -> str
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.post("/workspaces")
|
||||||
|
async def create_workspace(
|
||||||
|
data: WorkspaceCreate,
|
||||||
|
t: Callable = Depends(get_translator)
|
||||||
|
):
|
||||||
|
workspace = await workspace_service.create(data)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": t("workspace.created_successfully"),
|
||||||
|
"data": workspace
|
||||||
|
}
|
||||||
|
|
||||||
|
# With parameters
|
||||||
|
@router.get("/items")
|
||||||
|
async def get_items(t: Callable = Depends(get_translator)):
|
||||||
|
count = 5
|
||||||
|
return {
|
||||||
|
"message": t("items.found", count=count)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Get current language
|
||||||
|
language = await get_current_language(request)
|
||||||
|
|
||||||
|
# Get translation service
|
||||||
|
service = get_translation_service()
|
||||||
|
|
||||||
|
# Return a bound translation function
|
||||||
|
def translate(key: str, **params) -> str:
|
||||||
|
"""
|
||||||
|
Translate a key using the current request's language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key (e.g., "common.success.created")
|
||||||
|
**params: Parameters for parameterized messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated string
|
||||||
|
"""
|
||||||
|
return service.translate(key, language, **params)
|
||||||
|
|
||||||
|
return translate
|
||||||
|
|
||||||
|
|
||||||
|
async def get_enum_translator(request: Request) -> Callable:
|
||||||
|
"""
|
||||||
|
Get an enum translator function bound to the current request's language.
|
||||||
|
|
||||||
|
This dependency returns a function for translating enum values
|
||||||
|
that automatically uses the current request's language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Enum translation function with signature:
|
||||||
|
t_enum(enum_type: str, value: str) -> str
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/workspace/{id}")
|
||||||
|
async def get_workspace(
|
||||||
|
id: str,
|
||||||
|
t_enum: Callable = Depends(get_enum_translator)
|
||||||
|
):
|
||||||
|
workspace = await workspace_service.get(id)
|
||||||
|
return {
|
||||||
|
"id": workspace.id,
|
||||||
|
"role": workspace.role,
|
||||||
|
"role_display": t_enum("workspace_role", workspace.role),
|
||||||
|
"status": workspace.status,
|
||||||
|
"status_display": t_enum("workspace_status", workspace.status)
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Get current language
|
||||||
|
language = await get_current_language(request)
|
||||||
|
|
||||||
|
# Get translation service
|
||||||
|
service = get_translation_service()
|
||||||
|
|
||||||
|
# Return a bound enum translation function
|
||||||
|
def translate_enum(enum_type: str, value: str) -> str:
|
||||||
|
"""
|
||||||
|
Translate an enum value using the current request's language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enum_type: Enum type name (e.g., "workspace_role")
|
||||||
|
value: Enum value (e.g., "manager")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated enum display name
|
||||||
|
"""
|
||||||
|
return service.translate_enum(enum_type, value, language)
|
||||||
|
|
||||||
|
return translate_enum
|
||||||
495
api/app/i18n/exceptions.py
Normal file
495
api/app/i18n/exceptions.py
Normal file
@@ -0,0 +1,495 @@
|
|||||||
|
"""
|
||||||
|
Internationalized exception classes for i18n system.
|
||||||
|
|
||||||
|
This module provides exception classes that automatically translate
|
||||||
|
error messages based on the current request's language.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Context variable to store current locale
|
||||||
|
_current_locale: ContextVar[Optional[str]] = ContextVar("current_locale", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_locale(locale: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the current locale in the context variable.
|
||||||
|
|
||||||
|
This should be called by the LanguageMiddleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code (e.g., "zh", "en")
|
||||||
|
"""
|
||||||
|
_current_locale.set(locale)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_locale() -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get the current locale from the context variable.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Locale code or None if not set
|
||||||
|
"""
|
||||||
|
return _current_locale.get()
|
||||||
|
|
||||||
|
|
||||||
|
class I18nException(HTTPException):
|
||||||
|
"""
|
||||||
|
Base exception class with automatic i18n support.
|
||||||
|
|
||||||
|
This exception automatically translates error messages based on:
|
||||||
|
1. The current request's language (from request.state.language)
|
||||||
|
2. The fallback language if request language is not available
|
||||||
|
3. The error key itself if no translation is found
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Automatic error message translation
|
||||||
|
- Parameterized error messages support
|
||||||
|
- Consistent error response format
|
||||||
|
- Language-aware error handling
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Simple error
|
||||||
|
raise I18nException(
|
||||||
|
error_key="errors.workspace.not_found",
|
||||||
|
status_code=404
|
||||||
|
)
|
||||||
|
|
||||||
|
# Error with parameters
|
||||||
|
raise I18nException(
|
||||||
|
error_key="errors.validation.missing_field",
|
||||||
|
status_code=400,
|
||||||
|
field="name"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Custom error code
|
||||||
|
raise I18nException(
|
||||||
|
error_key="errors.workspace.not_found",
|
||||||
|
error_code="WORKSPACE_NOT_FOUND",
|
||||||
|
status_code=404,
|
||||||
|
workspace_id="123"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str,
|
||||||
|
status_code: int = 400,
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
headers: Optional[Dict[str, str]] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the i18n exception.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_key: Translation key for the error message
|
||||||
|
(e.g., "errors.workspace.not_found")
|
||||||
|
status_code: HTTP status code (default: 400)
|
||||||
|
error_code: Custom error code for API clients
|
||||||
|
(default: derived from error_key)
|
||||||
|
locale: Target locale for translation (optional)
|
||||||
|
If not provided, uses current request's language
|
||||||
|
headers: Additional HTTP headers
|
||||||
|
**params: Parameters for parameterized error messages
|
||||||
|
"""
|
||||||
|
self.error_key = error_key
|
||||||
|
self.error_code = error_code or self._generate_error_code(error_key)
|
||||||
|
self.params = params
|
||||||
|
|
||||||
|
# Get locale from request context if not provided
|
||||||
|
if locale is None:
|
||||||
|
locale = self._get_current_locale()
|
||||||
|
|
||||||
|
# Translate error message
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
message = translation_service.translate(
|
||||||
|
error_key,
|
||||||
|
locale,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build error detail
|
||||||
|
detail = {
|
||||||
|
"error_code": self.error_code,
|
||||||
|
"message": message,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add parameters to detail if provided
|
||||||
|
if params:
|
||||||
|
detail["params"] = params
|
||||||
|
|
||||||
|
# Initialize HTTPException
|
||||||
|
super().__init__(
|
||||||
|
status_code=status_code,
|
||||||
|
detail=detail,
|
||||||
|
headers=headers
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"I18nException raised: {self.error_code} "
|
||||||
|
f"(key: {error_key}, locale: {locale})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_current_locale(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the current locale from request context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Locale code (e.g., "zh", "en")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to get locale from context variable
|
||||||
|
locale = _current_locale.get()
|
||||||
|
if locale:
|
||||||
|
return locale
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not get locale from context: {e}")
|
||||||
|
|
||||||
|
# Fallback to default locale
|
||||||
|
from app.core.config import settings
|
||||||
|
return settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
|
||||||
|
def _generate_error_code(self, error_key: str) -> str:
|
||||||
|
"""
|
||||||
|
Generate error code from error key.
|
||||||
|
|
||||||
|
Converts "errors.workspace.not_found" to "WORKSPACE_NOT_FOUND"
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_key: Translation key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Error code in UPPER_SNAKE_CASE
|
||||||
|
"""
|
||||||
|
# Remove "errors." prefix if present
|
||||||
|
if error_key.startswith("errors."):
|
||||||
|
error_key = error_key[7:]
|
||||||
|
|
||||||
|
# Convert to UPPER_SNAKE_CASE
|
||||||
|
parts = error_key.split(".")
|
||||||
|
return "_".join(parts).upper()
|
||||||
|
|
||||||
|
|
||||||
|
# Specific exception classes for common errors
|
||||||
|
|
||||||
|
class BadRequestError(I18nException):
|
||||||
|
"""Bad request error (400)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.bad_request",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=400,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UnauthorizedError(I18nException):
|
||||||
|
"""Unauthorized error (401)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.auth.unauthorized",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=401,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ForbiddenError(I18nException):
|
||||||
|
"""Forbidden error (403)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.auth.forbidden",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=403,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NotFoundError(I18nException):
|
||||||
|
"""Not found error (404)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.not_found",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=404,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConflictError(I18nException):
|
||||||
|
"""Conflict error (409)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.conflict",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=409,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationError(I18nException):
|
||||||
|
"""Validation error (422)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.validation_failed",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=422,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InternalServerError(I18nException):
|
||||||
|
"""Internal server error (500)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.internal_error",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=500,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceUnavailableError(I18nException):
|
||||||
|
"""Service unavailable error (503)."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
error_key: str = "errors.common.service_unavailable",
|
||||||
|
error_code: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
error_key=error_key,
|
||||||
|
status_code=503,
|
||||||
|
error_code=error_code,
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Domain-specific exception classes
|
||||||
|
|
||||||
|
class WorkspaceNotFoundError(NotFoundError):
|
||||||
|
"""Workspace not found error."""
|
||||||
|
|
||||||
|
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||||
|
if workspace_id:
|
||||||
|
params["workspace_id"] = workspace_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.workspace.not_found",
|
||||||
|
error_code="WORKSPACE_NOT_FOUND",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspacePermissionDeniedError(ForbiddenError):
|
||||||
|
"""Workspace permission denied error."""
|
||||||
|
|
||||||
|
def __init__(self, workspace_id: Optional[str] = None, **params):
|
||||||
|
if workspace_id:
|
||||||
|
params["workspace_id"] = workspace_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.workspace.permission_denied",
|
||||||
|
error_code="WORKSPACE_PERMISSION_DENIED",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotFoundError(NotFoundError):
|
||||||
|
"""User not found error."""
|
||||||
|
|
||||||
|
def __init__(self, user_id: Optional[str] = None, **params):
|
||||||
|
if user_id:
|
||||||
|
params["user_id"] = user_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.user.not_found",
|
||||||
|
error_code="USER_NOT_FOUND",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserAlreadyExistsError(ConflictError):
|
||||||
|
"""User already exists error."""
|
||||||
|
|
||||||
|
def __init__(self, identifier: Optional[str] = None, **params):
|
||||||
|
if identifier:
|
||||||
|
params["identifier"] = identifier
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.user.already_exists",
|
||||||
|
error_code="USER_ALREADY_EXISTS",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantNotFoundError(NotFoundError):
|
||||||
|
"""Tenant not found error."""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||||
|
if tenant_id:
|
||||||
|
params["tenant_id"] = tenant_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.tenant.not_found",
|
||||||
|
error_code="TENANT_NOT_FOUND",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantSuspendedError(ForbiddenError):
|
||||||
|
"""Tenant suspended error."""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: Optional[str] = None, **params):
|
||||||
|
if tenant_id:
|
||||||
|
params["tenant_id"] = tenant_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.tenant.suspended",
|
||||||
|
error_code="TENANT_SUSPENDED",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidCredentialsError(UnauthorizedError):
|
||||||
|
"""Invalid credentials error."""
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.auth.invalid_credentials",
|
||||||
|
error_code="INVALID_CREDENTIALS",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenExpiredError(UnauthorizedError):
|
||||||
|
"""Token expired error."""
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.auth.token_expired",
|
||||||
|
error_code="TOKEN_EXPIRED",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenInvalidError(UnauthorizedError):
|
||||||
|
"""Token invalid error."""
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.auth.token_invalid",
|
||||||
|
error_code="TOKEN_INVALID",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileNotFoundError(NotFoundError):
|
||||||
|
"""File not found error."""
|
||||||
|
|
||||||
|
def __init__(self, file_id: Optional[str] = None, **params):
|
||||||
|
if file_id:
|
||||||
|
params["file_id"] = file_id
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.file.not_found",
|
||||||
|
error_code="FILE_NOT_FOUND",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FileTooLargeError(BadRequestError):
|
||||||
|
"""File too large error."""
|
||||||
|
|
||||||
|
def __init__(self, max_size: Optional[str] = None, **params):
|
||||||
|
if max_size:
|
||||||
|
params["max_size"] = max_size
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.file.too_large",
|
||||||
|
error_code="FILE_TOO_LARGE",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidFileTypeError(BadRequestError):
|
||||||
|
"""Invalid file type error."""
|
||||||
|
|
||||||
|
def __init__(self, file_type: Optional[str] = None, **params):
|
||||||
|
if file_type:
|
||||||
|
params["file_type"] = file_type
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.file.invalid_type",
|
||||||
|
error_code="INVALID_FILE_TYPE",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitExceededError(I18nException):
|
||||||
|
"""Rate limit exceeded error (429)."""
|
||||||
|
|
||||||
|
def __init__(self, **params):
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.api.rate_limit_exceeded",
|
||||||
|
status_code=429,
|
||||||
|
error_code="RATE_LIMIT_EXCEEDED",
|
||||||
|
**params
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuotaExceededError(ForbiddenError):
|
||||||
|
"""Quota exceeded error."""
|
||||||
|
|
||||||
|
def __init__(self, resource: Optional[str] = None, **params):
|
||||||
|
if resource:
|
||||||
|
params["resource"] = resource
|
||||||
|
super().__init__(
|
||||||
|
error_key="errors.api.quota_exceeded",
|
||||||
|
error_code="QUOTA_EXCEEDED",
|
||||||
|
**params
|
||||||
|
)
|
||||||
199
api/app/i18n/loader.py
Normal file
199
api/app/i18n/loader.py
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
Translation file loader for i18n system.
|
||||||
|
|
||||||
|
This module handles loading translation files from multiple directories
|
||||||
|
(community edition + enterprise edition) and provides hot reload support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationLoader:
|
||||||
|
"""
|
||||||
|
Translation file loader that supports:
|
||||||
|
- Loading from multiple directories (community + enterprise)
|
||||||
|
- Hot reload of translation files
|
||||||
|
- Automatic locale detection
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, locales_dirs: Optional[List[str]] = None):
|
||||||
|
"""
|
||||||
|
Initialize the translation loader.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locales_dirs: List of directories containing translation files.
|
||||||
|
If None, will auto-detect from settings.
|
||||||
|
"""
|
||||||
|
if locales_dirs is None:
|
||||||
|
locales_dirs = self._detect_locales_dirs()
|
||||||
|
|
||||||
|
self.locales_dirs = [Path(d) for d in locales_dirs]
|
||||||
|
logger.info(f"TranslationLoader initialized with directories: {self.locales_dirs}")
|
||||||
|
|
||||||
|
def _detect_locales_dirs(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Auto-detect translation directories from settings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of translation directory paths
|
||||||
|
"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
dirs = []
|
||||||
|
|
||||||
|
# 1. Core locales directory (community edition, required)
|
||||||
|
core_dir = Path(settings.I18N_CORE_LOCALES_DIR)
|
||||||
|
if core_dir.exists():
|
||||||
|
dirs.append(str(core_dir))
|
||||||
|
logger.debug(f"Found core locales directory: {core_dir}")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Core locales directory not found: {core_dir}")
|
||||||
|
|
||||||
|
# 2. Premium locales directory (enterprise edition, optional)
|
||||||
|
if settings.I18N_PREMIUM_LOCALES_DIR:
|
||||||
|
premium_dir = Path(settings.I18N_PREMIUM_LOCALES_DIR)
|
||||||
|
if premium_dir.exists():
|
||||||
|
dirs.append(str(premium_dir))
|
||||||
|
logger.debug(f"Found premium locales directory: {premium_dir}")
|
||||||
|
else:
|
||||||
|
# Auto-detect premium directory
|
||||||
|
premium_dir = Path("premium/locales")
|
||||||
|
if premium_dir.exists():
|
||||||
|
dirs.append(str(premium_dir))
|
||||||
|
logger.debug(f"Auto-detected premium locales directory: {premium_dir}")
|
||||||
|
|
||||||
|
if not dirs:
|
||||||
|
logger.error("No translation directories found!")
|
||||||
|
|
||||||
|
return dirs
|
||||||
|
|
||||||
|
def get_available_locales(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get list of all available locales across all directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of locale codes (e.g., ['zh', 'en'])
|
||||||
|
"""
|
||||||
|
locales = set()
|
||||||
|
|
||||||
|
for locales_dir in self.locales_dirs:
|
||||||
|
if not locales_dir.exists():
|
||||||
|
continue
|
||||||
|
|
||||||
|
for locale_dir in locales_dir.iterdir():
|
||||||
|
if locale_dir.is_dir() and not locale_dir.name.startswith('.'):
|
||||||
|
locales.add(locale_dir.name)
|
||||||
|
|
||||||
|
return sorted(list(locales))
|
||||||
|
|
||||||
|
def load_locale(self, locale: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load all translation files for a specific locale from all directories.
|
||||||
|
|
||||||
|
Translation files are merged with priority:
|
||||||
|
- Later directories override earlier directories
|
||||||
|
- Enterprise translations override community translations
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code (e.g., 'zh', 'en')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of translations organized by namespace
|
||||||
|
Format: {namespace: {key: value, ...}, ...}
|
||||||
|
"""
|
||||||
|
translations = {}
|
||||||
|
|
||||||
|
# Load from each directory in order (later directories override earlier)
|
||||||
|
for locales_dir in self.locales_dirs:
|
||||||
|
locale_dir = locales_dir / locale
|
||||||
|
if not locale_dir.exists():
|
||||||
|
logger.debug(f"Locale directory not found: {locale_dir}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load all JSON files in this locale directory
|
||||||
|
for json_file in locale_dir.glob("*.json"):
|
||||||
|
namespace = json_file.stem
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(json_file, "r", encoding="utf-8") as f:
|
||||||
|
new_translations = json.load(f)
|
||||||
|
|
||||||
|
# Merge translations (deep merge)
|
||||||
|
if namespace in translations:
|
||||||
|
translations[namespace] = self._deep_merge(
|
||||||
|
translations[namespace],
|
||||||
|
new_translations
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Merged translations: {locale}/{namespace} from {json_file}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
translations[namespace] = new_translations
|
||||||
|
logger.debug(
|
||||||
|
f"Loaded translations: {locale}/{namespace} from {json_file}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to parse JSON file {json_file}: {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to load translation file {json_file}: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not translations:
|
||||||
|
logger.warning(f"No translations found for locale: {locale}")
|
||||||
|
|
||||||
|
return translations
|
||||||
|
|
||||||
|
def reload(self, locale: Optional[str] = None) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Reload translation files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale to reload. If None, reloads all locales.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of reloaded translations
|
||||||
|
Format: {locale: {namespace: {key: value}}}
|
||||||
|
"""
|
||||||
|
if locale:
|
||||||
|
logger.info(f"Reloading translations for locale: {locale}")
|
||||||
|
return {locale: self.load_locale(locale)}
|
||||||
|
else:
|
||||||
|
logger.info("Reloading all translations")
|
||||||
|
all_translations = {}
|
||||||
|
for loc in self.get_available_locales():
|
||||||
|
all_translations[loc] = self.load_locale(loc)
|
||||||
|
return all_translations
|
||||||
|
|
||||||
|
def _deep_merge(self, base: Dict, override: Dict) -> Dict:
|
||||||
|
"""
|
||||||
|
Deep merge two dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base: Base dictionary
|
||||||
|
override: Dictionary with values to override
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Merged dictionary
|
||||||
|
"""
|
||||||
|
result = base.copy()
|
||||||
|
|
||||||
|
for key, value in override.items():
|
||||||
|
if (
|
||||||
|
key in result
|
||||||
|
and isinstance(result[key], dict)
|
||||||
|
and isinstance(value, dict)
|
||||||
|
):
|
||||||
|
result[key] = self._deep_merge(result[key], value)
|
||||||
|
else:
|
||||||
|
result[key] = value
|
||||||
|
|
||||||
|
return result
|
||||||
382
api/app/i18n/logger.py
Normal file
382
api/app/i18n/logger.py
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
"""
|
||||||
|
Translation logging for i18n system.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- TranslationLogger for recording missing translations
|
||||||
|
- Missing translation report generation
|
||||||
|
- Integration with existing logging system
|
||||||
|
- Structured logging for translation events
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Dict, List, Optional, Set
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
|
||||||
|
from app.core.logging_config import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationLogger:
|
||||||
|
"""
|
||||||
|
Logger for translation events and missing translations.
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- Records missing translations with context
|
||||||
|
- Generates missing translation reports
|
||||||
|
- Integrates with existing logging system
|
||||||
|
- Provides structured logging for analysis
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, log_file: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize translation logger.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
log_file: Optional custom log file path for missing translations
|
||||||
|
"""
|
||||||
|
self.log_file = log_file or "logs/i18n/missing_translations.log"
|
||||||
|
self._missing_translations: Dict[str, Set[str]] = defaultdict(set)
|
||||||
|
self._missing_with_context: List[Dict] = []
|
||||||
|
self._max_context_entries = 10000 # Keep last 10k entries
|
||||||
|
|
||||||
|
# Ensure log directory exists
|
||||||
|
log_path = Path(self.log_file)
|
||||||
|
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Create dedicated file handler for missing translations
|
||||||
|
self._file_handler = logging.FileHandler(
|
||||||
|
self.log_file,
|
||||||
|
encoding='utf-8'
|
||||||
|
)
|
||||||
|
self._file_handler.setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Create formatter
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
)
|
||||||
|
self._file_handler.setFormatter(formatter)
|
||||||
|
|
||||||
|
# Create dedicated logger for missing translations
|
||||||
|
self._logger = logging.getLogger("i18n.missing_translations")
|
||||||
|
self._logger.setLevel(logging.WARNING)
|
||||||
|
self._logger.addHandler(self._file_handler)
|
||||||
|
self._logger.propagate = False # Don't propagate to root logger
|
||||||
|
|
||||||
|
logger.info(f"TranslationLogger initialized with log file: {self.log_file}")
|
||||||
|
|
||||||
|
def log_missing_translation(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
locale: str,
|
||||||
|
context: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a missing translation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key that was not found
|
||||||
|
locale: Locale code
|
||||||
|
context: Optional context information (e.g., request path, user info)
|
||||||
|
"""
|
||||||
|
# Add to missing set
|
||||||
|
self._missing_translations[locale].add(key)
|
||||||
|
|
||||||
|
# Create context entry
|
||||||
|
entry = {
|
||||||
|
"timestamp": datetime.now().isoformat(),
|
||||||
|
"key": key,
|
||||||
|
"locale": locale,
|
||||||
|
"context": context or {}
|
||||||
|
}
|
||||||
|
|
||||||
|
# Keep only recent entries to avoid memory bloat
|
||||||
|
if len(self._missing_with_context) >= self._max_context_entries:
|
||||||
|
self._missing_with_context.pop(0)
|
||||||
|
|
||||||
|
self._missing_with_context.append(entry)
|
||||||
|
|
||||||
|
# Log to file
|
||||||
|
context_str = f" (context: {context})" if context else ""
|
||||||
|
self._logger.warning(
|
||||||
|
f"Missing translation: key='{key}', locale='{locale}'{context_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_translation_error(
|
||||||
|
self,
|
||||||
|
error_type: str,
|
||||||
|
message: str,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
context: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a translation error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error (e.g., "format_error", "parameter_missing")
|
||||||
|
message: Error message
|
||||||
|
key: Translation key (optional)
|
||||||
|
locale: Locale code (optional)
|
||||||
|
context: Optional context information
|
||||||
|
"""
|
||||||
|
error_data = {
|
||||||
|
"error_type": error_type,
|
||||||
|
"message": message,
|
||||||
|
"key": key,
|
||||||
|
"locale": locale,
|
||||||
|
"context": context or {},
|
||||||
|
"timestamp": datetime.now().isoformat()
|
||||||
|
}
|
||||||
|
|
||||||
|
self._logger.error(
|
||||||
|
f"Translation error: {error_type} - {message} "
|
||||||
|
f"(key: {key}, locale: {locale})"
|
||||||
|
)
|
||||||
|
|
||||||
|
def log_translation_success(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
locale: str,
|
||||||
|
duration_ms: Optional[float] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a successful translation (debug level).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Locale code
|
||||||
|
duration_ms: Optional duration in milliseconds
|
||||||
|
"""
|
||||||
|
duration_str = f" ({duration_ms:.3f}ms)" if duration_ms else ""
|
||||||
|
logger.debug(
|
||||||
|
f"Translation success: key='{key}', locale='{locale}'{duration_str}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_missing_translations(
|
||||||
|
self,
|
||||||
|
locale: Optional[str] = None
|
||||||
|
) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Get missing translations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale (optional, returns all if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of missing translations by locale
|
||||||
|
"""
|
||||||
|
if locale:
|
||||||
|
return {locale: sorted(list(self._missing_translations.get(locale, set())))}
|
||||||
|
|
||||||
|
return {
|
||||||
|
loc: sorted(list(keys))
|
||||||
|
for loc, keys in self._missing_translations.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_missing_with_context(
|
||||||
|
self,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
Get missing translations with context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Filter by locale (optional)
|
||||||
|
limit: Maximum number of entries to return (optional)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of missing translation entries with context
|
||||||
|
"""
|
||||||
|
entries = self._missing_with_context
|
||||||
|
|
||||||
|
# Filter by locale if specified
|
||||||
|
if locale:
|
||||||
|
entries = [e for e in entries if e["locale"] == locale]
|
||||||
|
|
||||||
|
# Apply limit if specified
|
||||||
|
if limit:
|
||||||
|
entries = entries[-limit:]
|
||||||
|
|
||||||
|
return entries
|
||||||
|
|
||||||
|
def generate_report(
|
||||||
|
self,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
output_file: Optional[str] = None
|
||||||
|
) -> Dict:
|
||||||
|
"""
|
||||||
|
Generate a missing translation report.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale (optional, generates for all if None)
|
||||||
|
output_file: Optional file path to save report as JSON
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Report dictionary
|
||||||
|
"""
|
||||||
|
missing = self.get_missing_translations(locale)
|
||||||
|
|
||||||
|
report = {
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"total_missing": sum(len(keys) for keys in missing.values()),
|
||||||
|
"missing_by_locale": {
|
||||||
|
loc: {
|
||||||
|
"count": len(keys),
|
||||||
|
"keys": keys
|
||||||
|
}
|
||||||
|
for loc, keys in missing.items()
|
||||||
|
},
|
||||||
|
"recent_context": self.get_missing_with_context(locale, limit=100)
|
||||||
|
}
|
||||||
|
|
||||||
|
# Save to file if specified
|
||||||
|
if output_file:
|
||||||
|
output_path = Path(output_file)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(report, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
logger.info(f"Missing translation report saved to: {output_file}")
|
||||||
|
|
||||||
|
return report
|
||||||
|
|
||||||
|
def get_statistics(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Get statistics about missing translations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with statistics
|
||||||
|
"""
|
||||||
|
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||||
|
|
||||||
|
# Count by namespace
|
||||||
|
namespace_counts = defaultdict(int)
|
||||||
|
for locale, keys in self._missing_translations.items():
|
||||||
|
for key in keys:
|
||||||
|
namespace = key.split('.')[0] if '.' in key else 'unknown'
|
||||||
|
namespace_counts[namespace] += 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_missing": total_missing,
|
||||||
|
"locales_affected": len(self._missing_translations),
|
||||||
|
"missing_by_locale": {
|
||||||
|
loc: len(keys)
|
||||||
|
for loc, keys in self._missing_translations.items()
|
||||||
|
},
|
||||||
|
"missing_by_namespace": dict(namespace_counts),
|
||||||
|
"total_context_entries": len(self._missing_with_context)
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear(self, locale: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Clear missing translation records.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale to clear (optional, clears all if None)
|
||||||
|
"""
|
||||||
|
if locale:
|
||||||
|
self._missing_translations.pop(locale, None)
|
||||||
|
self._missing_with_context = [
|
||||||
|
e for e in self._missing_with_context
|
||||||
|
if e["locale"] != locale
|
||||||
|
]
|
||||||
|
logger.info(f"Cleared missing translations for locale: {locale}")
|
||||||
|
else:
|
||||||
|
self._missing_translations.clear()
|
||||||
|
self._missing_with_context.clear()
|
||||||
|
logger.info("Cleared all missing translations")
|
||||||
|
|
||||||
|
def export_to_json(self, output_file: str):
|
||||||
|
"""
|
||||||
|
Export all missing translations to JSON file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
output_file: Output file path
|
||||||
|
"""
|
||||||
|
data = {
|
||||||
|
"exported_at": datetime.now().isoformat(),
|
||||||
|
"missing_translations": self.get_missing_translations(),
|
||||||
|
"statistics": self.get_statistics(),
|
||||||
|
"recent_context": self.get_missing_with_context(limit=1000)
|
||||||
|
}
|
||||||
|
|
||||||
|
output_path = Path(output_file)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
with open(output_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
logger.info(f"Missing translations exported to: {output_file}")
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""Cleanup file handler on deletion."""
|
||||||
|
try:
|
||||||
|
if hasattr(self, '_file_handler'):
|
||||||
|
self._file_handler.close()
|
||||||
|
self._logger.removeHandler(self._file_handler)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# Global translation logger instance
|
||||||
|
_translation_logger: Optional[TranslationLogger] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_translation_logger() -> TranslationLogger:
|
||||||
|
"""
|
||||||
|
Get the global translation logger instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TranslationLogger singleton
|
||||||
|
"""
|
||||||
|
global _translation_logger
|
||||||
|
if _translation_logger is None:
|
||||||
|
_translation_logger = TranslationLogger()
|
||||||
|
return _translation_logger
|
||||||
|
|
||||||
|
|
||||||
|
def log_missing_translation(
|
||||||
|
key: str,
|
||||||
|
locale: str,
|
||||||
|
context: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a missing translation (convenience function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Locale code
|
||||||
|
context: Optional context information
|
||||||
|
"""
|
||||||
|
translation_logger = get_translation_logger()
|
||||||
|
translation_logger.log_missing_translation(key, locale, context)
|
||||||
|
|
||||||
|
|
||||||
|
def log_translation_error(
|
||||||
|
error_type: str,
|
||||||
|
message: str,
|
||||||
|
key: Optional[str] = None,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
context: Optional[Dict] = None
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Log a translation error (convenience function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error
|
||||||
|
message: Error message
|
||||||
|
key: Translation key (optional)
|
||||||
|
locale: Locale code (optional)
|
||||||
|
context: Optional context information
|
||||||
|
"""
|
||||||
|
translation_logger = get_translation_logger()
|
||||||
|
translation_logger.log_translation_error(
|
||||||
|
error_type, message, key, locale, context
|
||||||
|
)
|
||||||
337
api/app/i18n/metrics.py
Normal file
337
api/app/i18n/metrics.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""
|
||||||
|
Performance monitoring and metrics for i18n system.
|
||||||
|
|
||||||
|
This module provides:
|
||||||
|
- Translation request counters
|
||||||
|
- Translation timing metrics
|
||||||
|
- Missing translation tracking
|
||||||
|
- Performance monitoring decorators
|
||||||
|
- Prometheus-compatible metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any, Callable, Dict, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationMetrics:
|
||||||
|
"""
|
||||||
|
Metrics collector for translation operations.
|
||||||
|
|
||||||
|
Tracks:
|
||||||
|
- Translation request counts
|
||||||
|
- Translation timing (latency)
|
||||||
|
- Missing translations
|
||||||
|
- Cache performance
|
||||||
|
- Locale usage
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize metrics collector."""
|
||||||
|
# Request counters by locale
|
||||||
|
self._request_counts: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# Missing translation tracker
|
||||||
|
self._missing_translations: Dict[str, set] = defaultdict(set)
|
||||||
|
|
||||||
|
# Timing metrics (in milliseconds)
|
||||||
|
self._timing_data: list = []
|
||||||
|
self._max_timing_samples = 10000 # Keep last 10k samples
|
||||||
|
|
||||||
|
# Locale usage
|
||||||
|
self._locale_usage: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# Namespace usage
|
||||||
|
self._namespace_usage: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# Error counts
|
||||||
|
self._error_counts: Dict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# Start time
|
||||||
|
self._start_time = datetime.now()
|
||||||
|
|
||||||
|
logger.info("TranslationMetrics initialized")
|
||||||
|
|
||||||
|
def record_request(self, locale: str, namespace: str = None):
|
||||||
|
"""
|
||||||
|
Record a translation request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
namespace: Translation namespace (optional)
|
||||||
|
"""
|
||||||
|
self._request_counts[locale] += 1
|
||||||
|
self._locale_usage[locale] += 1
|
||||||
|
|
||||||
|
if namespace:
|
||||||
|
self._namespace_usage[namespace] += 1
|
||||||
|
|
||||||
|
def record_missing(self, key: str, locale: str):
|
||||||
|
"""
|
||||||
|
Record a missing translation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Locale code
|
||||||
|
"""
|
||||||
|
self._missing_translations[locale].add(key)
|
||||||
|
logger.debug(f"Missing translation recorded: {key} (locale: {locale})")
|
||||||
|
|
||||||
|
def record_timing(self, duration_ms: float, locale: str, operation: str = "translate"):
|
||||||
|
"""
|
||||||
|
Record translation operation timing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
duration_ms: Duration in milliseconds
|
||||||
|
locale: Locale code
|
||||||
|
operation: Operation type
|
||||||
|
"""
|
||||||
|
# Keep only recent samples to avoid memory bloat
|
||||||
|
if len(self._timing_data) >= self._max_timing_samples:
|
||||||
|
self._timing_data.pop(0)
|
||||||
|
|
||||||
|
self._timing_data.append({
|
||||||
|
"duration_ms": duration_ms,
|
||||||
|
"locale": locale,
|
||||||
|
"operation": operation,
|
||||||
|
"timestamp": time.time()
|
||||||
|
})
|
||||||
|
|
||||||
|
def record_error(self, error_type: str):
|
||||||
|
"""
|
||||||
|
Record an error.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
error_type: Type of error
|
||||||
|
"""
|
||||||
|
self._error_counts[error_type] += 1
|
||||||
|
|
||||||
|
def get_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get metrics summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with metrics summary
|
||||||
|
"""
|
||||||
|
total_requests = sum(self._request_counts.values())
|
||||||
|
total_missing = sum(len(keys) for keys in self._missing_translations.values())
|
||||||
|
|
||||||
|
# Calculate timing statistics
|
||||||
|
timing_stats = self._calculate_timing_stats()
|
||||||
|
|
||||||
|
# Calculate uptime
|
||||||
|
uptime_seconds = (datetime.now() - self._start_time).total_seconds()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"uptime_seconds": round(uptime_seconds, 2),
|
||||||
|
"total_requests": total_requests,
|
||||||
|
"requests_per_locale": dict(self._request_counts),
|
||||||
|
"total_missing_translations": total_missing,
|
||||||
|
"missing_by_locale": {
|
||||||
|
locale: len(keys)
|
||||||
|
for locale, keys in self._missing_translations.items()
|
||||||
|
},
|
||||||
|
"timing": timing_stats,
|
||||||
|
"locale_usage": dict(self._locale_usage),
|
||||||
|
"namespace_usage": dict(self._namespace_usage),
|
||||||
|
"error_counts": dict(self._error_counts)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _calculate_timing_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Calculate timing statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with timing statistics
|
||||||
|
"""
|
||||||
|
if not self._timing_data:
|
||||||
|
return {
|
||||||
|
"count": 0,
|
||||||
|
"avg_ms": 0,
|
||||||
|
"min_ms": 0,
|
||||||
|
"max_ms": 0,
|
||||||
|
"p50_ms": 0,
|
||||||
|
"p95_ms": 0,
|
||||||
|
"p99_ms": 0
|
||||||
|
}
|
||||||
|
|
||||||
|
durations = [d["duration_ms"] for d in self._timing_data]
|
||||||
|
durations.sort()
|
||||||
|
|
||||||
|
count = len(durations)
|
||||||
|
avg = sum(durations) / count
|
||||||
|
|
||||||
|
# Calculate percentiles
|
||||||
|
p50_idx = int(count * 0.50)
|
||||||
|
p95_idx = int(count * 0.95)
|
||||||
|
p99_idx = int(count * 0.99)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"count": count,
|
||||||
|
"avg_ms": round(avg, 3),
|
||||||
|
"min_ms": round(durations[0], 3),
|
||||||
|
"max_ms": round(durations[-1], 3),
|
||||||
|
"p50_ms": round(durations[p50_idx], 3),
|
||||||
|
"p95_ms": round(durations[p95_idx], 3),
|
||||||
|
"p99_ms": round(durations[p99_idx], 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_missing_translations(self, locale: Optional[str] = None) -> Dict[str, list]:
|
||||||
|
"""
|
||||||
|
Get missing translations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale (optional, returns all if None)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of missing translations by locale
|
||||||
|
"""
|
||||||
|
if locale:
|
||||||
|
return {locale: list(self._missing_translations.get(locale, set()))}
|
||||||
|
|
||||||
|
return {
|
||||||
|
locale: list(keys)
|
||||||
|
for locale, keys in self._missing_translations.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""Reset all metrics."""
|
||||||
|
self._request_counts.clear()
|
||||||
|
self._missing_translations.clear()
|
||||||
|
self._timing_data.clear()
|
||||||
|
self._locale_usage.clear()
|
||||||
|
self._namespace_usage.clear()
|
||||||
|
self._error_counts.clear()
|
||||||
|
self._start_time = datetime.now()
|
||||||
|
logger.info("Metrics reset")
|
||||||
|
|
||||||
|
def export_prometheus(self) -> str:
|
||||||
|
"""
|
||||||
|
Export metrics in Prometheus format.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Prometheus-formatted metrics string
|
||||||
|
"""
|
||||||
|
lines = []
|
||||||
|
|
||||||
|
# Translation requests counter
|
||||||
|
lines.append("# HELP i18n_translation_requests_total Total number of translation requests")
|
||||||
|
lines.append("# TYPE i18n_translation_requests_total counter")
|
||||||
|
for locale, count in self._request_counts.items():
|
||||||
|
lines.append(f'i18n_translation_requests_total{{locale="{locale}"}} {count}')
|
||||||
|
|
||||||
|
# Missing translations counter
|
||||||
|
lines.append("# HELP i18n_missing_translations_total Total number of missing translations")
|
||||||
|
lines.append("# TYPE i18n_missing_translations_total counter")
|
||||||
|
for locale, keys in self._missing_translations.items():
|
||||||
|
lines.append(f'i18n_missing_translations_total{{locale="{locale}"}} {len(keys)}')
|
||||||
|
|
||||||
|
# Timing metrics
|
||||||
|
timing_stats = self._calculate_timing_stats()
|
||||||
|
lines.append("# HELP i18n_translation_duration_ms Translation operation duration in milliseconds")
|
||||||
|
lines.append("# TYPE i18n_translation_duration_ms summary")
|
||||||
|
lines.append(f'i18n_translation_duration_ms{{quantile="0.5"}} {timing_stats["p50_ms"]}')
|
||||||
|
lines.append(f'i18n_translation_duration_ms{{quantile="0.95"}} {timing_stats["p95_ms"]}')
|
||||||
|
lines.append(f'i18n_translation_duration_ms{{quantile="0.99"}} {timing_stats["p99_ms"]}')
|
||||||
|
lines.append(f'i18n_translation_duration_ms_sum {sum(d["duration_ms"] for d in self._timing_data)}')
|
||||||
|
lines.append(f'i18n_translation_duration_ms_count {timing_stats["count"]}')
|
||||||
|
|
||||||
|
# Error counter
|
||||||
|
lines.append("# HELP i18n_errors_total Total number of i18n errors")
|
||||||
|
lines.append("# TYPE i18n_errors_total counter")
|
||||||
|
for error_type, count in self._error_counts.items():
|
||||||
|
lines.append(f'i18n_errors_total{{type="{error_type}"}} {count}')
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# Global metrics instance
|
||||||
|
_metrics: Optional[TranslationMetrics] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_metrics() -> TranslationMetrics:
|
||||||
|
"""
|
||||||
|
Get the global metrics instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TranslationMetrics singleton
|
||||||
|
"""
|
||||||
|
global _metrics
|
||||||
|
if _metrics is None:
|
||||||
|
_metrics = TranslationMetrics()
|
||||||
|
return _metrics
|
||||||
|
|
||||||
|
|
||||||
|
def monitor_performance(operation: str = "translate"):
|
||||||
|
"""
|
||||||
|
Decorator to monitor translation operation performance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
operation: Operation name for metrics
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorated function
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@monitor_performance("translate")
|
||||||
|
def translate(key: str, locale: str) -> str:
|
||||||
|
...
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
|
# Record timing
|
||||||
|
duration_ms = (time.perf_counter() - start_time) * 1000
|
||||||
|
|
||||||
|
# Try to extract locale from args/kwargs
|
||||||
|
locale = kwargs.get("locale", "unknown")
|
||||||
|
if not locale and len(args) > 1:
|
||||||
|
locale = args[1] if isinstance(args[1], str) else "unknown"
|
||||||
|
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.record_timing(duration_ms, locale, operation)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Record error
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.record_error(type(e).__name__)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def track_missing_translation(key: str, locale: str):
|
||||||
|
"""
|
||||||
|
Track a missing translation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Locale code
|
||||||
|
"""
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.record_missing(key, locale)
|
||||||
|
|
||||||
|
|
||||||
|
def track_translation_request(locale: str, namespace: str = None):
|
||||||
|
"""
|
||||||
|
Track a translation request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
namespace: Translation namespace (optional)
|
||||||
|
"""
|
||||||
|
metrics = get_metrics()
|
||||||
|
metrics.record_request(locale, namespace)
|
||||||
202
api/app/i18n/middleware.py
Normal file
202
api/app/i18n/middleware.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""
|
||||||
|
Language detection middleware for i18n system.
|
||||||
|
|
||||||
|
This middleware determines the language to use for each request based on:
|
||||||
|
1. Query parameter (?lang=en)
|
||||||
|
2. Accept-Language HTTP header
|
||||||
|
3. User language preference (from database)
|
||||||
|
4. Tenant default language
|
||||||
|
5. System default language
|
||||||
|
|
||||||
|
The detected language is injected into request.state.language and
|
||||||
|
added to the response Content-Language header.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class LanguageMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""
|
||||||
|
Language detection middleware.
|
||||||
|
|
||||||
|
Determines the language for each request based on multiple sources
|
||||||
|
with a clear priority order, validates the language is supported,
|
||||||
|
and injects it into the request context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
"""
|
||||||
|
Process the request and determine the language.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request
|
||||||
|
call_next: The next middleware/handler in the chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response with Content-Language header added
|
||||||
|
"""
|
||||||
|
# Determine the language for this request
|
||||||
|
language = await self._determine_language(request)
|
||||||
|
|
||||||
|
# Validate language is supported
|
||||||
|
from app.core.config import settings
|
||||||
|
if language not in settings.I18N_SUPPORTED_LANGUAGES:
|
||||||
|
logger.warning(
|
||||||
|
f"Unsupported language '{language}' requested, "
|
||||||
|
f"falling back to default: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||||
|
)
|
||||||
|
language = settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
|
||||||
|
# Inject language into request state
|
||||||
|
request.state.language = language
|
||||||
|
|
||||||
|
# Also set in context variable for exception handling
|
||||||
|
from app.i18n.exceptions import set_current_locale
|
||||||
|
set_current_locale(language)
|
||||||
|
|
||||||
|
logger.debug(f"Request language set to: {language}")
|
||||||
|
|
||||||
|
# Process the request
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# Add Content-Language header to response
|
||||||
|
response.headers["Content-Language"] = language
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def _determine_language(self, request: Request) -> str:
|
||||||
|
"""
|
||||||
|
Determine the language to use based on priority order.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. Query parameter (?lang=en)
|
||||||
|
2. Accept-Language HTTP header
|
||||||
|
3. User language preference (from database)
|
||||||
|
4. Tenant default language
|
||||||
|
5. System default language
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming request
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language code (e.g., "zh", "en")
|
||||||
|
"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
# 1. Check query parameter (?lang=en)
|
||||||
|
if "lang" in request.query_params:
|
||||||
|
lang = request.query_params["lang"].strip().lower()
|
||||||
|
if lang:
|
||||||
|
logger.debug(f"Language from query parameter: {lang}")
|
||||||
|
return lang
|
||||||
|
|
||||||
|
# 2. Check Accept-Language HTTP header
|
||||||
|
if "Accept-Language" in request.headers:
|
||||||
|
lang = self._parse_accept_language(
|
||||||
|
request.headers["Accept-Language"]
|
||||||
|
)
|
||||||
|
if lang:
|
||||||
|
logger.debug(f"Language from Accept-Language header: {lang}")
|
||||||
|
return lang
|
||||||
|
|
||||||
|
# 3. Check user language preference (requires authentication)
|
||||||
|
# Note: This assumes user is already loaded into request.state by auth middleware
|
||||||
|
if hasattr(request.state, "user") and request.state.user:
|
||||||
|
user = request.state.user
|
||||||
|
if hasattr(user, "preferred_language") and user.preferred_language:
|
||||||
|
logger.debug(
|
||||||
|
f"Language from user preference: {user.preferred_language}"
|
||||||
|
)
|
||||||
|
return user.preferred_language
|
||||||
|
|
||||||
|
# 4. Check tenant default language
|
||||||
|
# Note: This assumes tenant is already loaded into request.state
|
||||||
|
if hasattr(request.state, "tenant") and request.state.tenant:
|
||||||
|
tenant = request.state.tenant
|
||||||
|
if hasattr(tenant, "default_language") and tenant.default_language:
|
||||||
|
logger.debug(
|
||||||
|
f"Language from tenant default: {tenant.default_language}"
|
||||||
|
)
|
||||||
|
return tenant.default_language
|
||||||
|
|
||||||
|
# 5. Fall back to system default language
|
||||||
|
logger.debug(
|
||||||
|
f"Using system default language: {settings.I18N_DEFAULT_LANGUAGE}"
|
||||||
|
)
|
||||||
|
return settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
|
||||||
|
def _parse_accept_language(self, header: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Parse the Accept-Language HTTP header.
|
||||||
|
|
||||||
|
The Accept-Language header format:
|
||||||
|
Accept-Language: zh-CN,zh;q=0.9,en;q=0.8,en-US;q=0.7
|
||||||
|
|
||||||
|
This method:
|
||||||
|
1. Parses all language codes and their quality values
|
||||||
|
2. Extracts the base language code (zh-CN -> zh)
|
||||||
|
3. Sorts by quality value (higher first)
|
||||||
|
4. Returns the first supported language
|
||||||
|
|
||||||
|
Args:
|
||||||
|
header: Accept-Language header value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Language code if found and supported, None otherwise
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
_parse_accept_language("zh-CN,zh;q=0.9,en;q=0.8")
|
||||||
|
# => "zh" (if zh is supported)
|
||||||
|
|
||||||
|
_parse_accept_language("en-US,en;q=0.9")
|
||||||
|
# => "en" (if en is supported)
|
||||||
|
"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
if not header:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parse language preferences with quality values
|
||||||
|
languages = []
|
||||||
|
|
||||||
|
for item in header.split(","):
|
||||||
|
item = item.strip()
|
||||||
|
if not item:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Split language code and quality value
|
||||||
|
parts = item.split(";")
|
||||||
|
lang_code = parts[0].strip()
|
||||||
|
|
||||||
|
# Extract base language code (zh-CN -> zh, en-US -> en)
|
||||||
|
base_lang = lang_code.split("-")[0].lower()
|
||||||
|
|
||||||
|
# Extract quality value (default: 1.0)
|
||||||
|
quality = 1.0
|
||||||
|
if len(parts) > 1:
|
||||||
|
# Look for q=0.9 pattern
|
||||||
|
q_match = re.search(r"q=([\d.]+)", parts[1])
|
||||||
|
if q_match:
|
||||||
|
try:
|
||||||
|
quality = float(q_match.group(1))
|
||||||
|
except ValueError:
|
||||||
|
quality = 1.0
|
||||||
|
|
||||||
|
languages.append((base_lang, quality))
|
||||||
|
|
||||||
|
# Sort by quality value (descending)
|
||||||
|
languages.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Return the first supported language
|
||||||
|
for lang_code, _ in languages:
|
||||||
|
if lang_code in settings.I18N_SUPPORTED_LANGUAGES:
|
||||||
|
return lang_code
|
||||||
|
|
||||||
|
return None
|
||||||
221
api/app/i18n/serializers.py
Normal file
221
api/app/i18n/serializers.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
"""
|
||||||
|
国际化响应序列化器
|
||||||
|
|
||||||
|
提供基础的 I18nResponseMixin 类,用于为 API 响应添加国际化字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Union
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class I18nResponseMixin:
|
||||||
|
"""国际化响应混入类
|
||||||
|
|
||||||
|
为响应数据添加国际化字段,特别是为枚举值添加 _display 后缀的翻译字段。
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
1. 继承此类
|
||||||
|
2. 实现 _get_enum_fields() 方法定义需要翻译的枚举字段
|
||||||
|
3. 调用 serialize_with_i18n() 方法序列化数据
|
||||||
|
|
||||||
|
示例:
|
||||||
|
class WorkspaceSerializer(I18nResponseMixin):
|
||||||
|
def _get_enum_fields(self) -> Dict[str, str]:
|
||||||
|
return {
|
||||||
|
"role": "workspace_role",
|
||||||
|
"status": "workspace_status"
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize(self, workspace: Workspace, locale: str = "zh") -> Dict:
|
||||||
|
data = {
|
||||||
|
"id": str(workspace.id),
|
||||||
|
"name": workspace.name,
|
||||||
|
"role": workspace.role,
|
||||||
|
"status": workspace.status
|
||||||
|
}
|
||||||
|
return self.serialize_with_i18n(data, locale)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def serialize_with_i18n(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
locale: str = "zh"
|
||||||
|
) -> Union[Dict, List[Dict], Any]:
|
||||||
|
"""序列化数据并添加国际化字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 要序列化的数据(字典、列表或 Pydantic 模型)
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的数据,包含国际化字段
|
||||||
|
"""
|
||||||
|
# 如果是 Pydantic 模型,转换为字典
|
||||||
|
if isinstance(data, BaseModel):
|
||||||
|
data = data.model_dump()
|
||||||
|
|
||||||
|
# 处理不同类型的数据
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return self._serialize_dict(data, locale)
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [self._serialize_dict(item, locale) if isinstance(item, dict) else item for item in data]
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _serialize_dict(self, data: Dict, locale: str) -> Dict:
|
||||||
|
"""序列化字典并添加 _display 字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 字典数据
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
添加了 _display 字段的字典
|
||||||
|
"""
|
||||||
|
from app.i18n.service import get_translation_service
|
||||||
|
|
||||||
|
translation_service = get_translation_service()
|
||||||
|
|
||||||
|
result = data.copy()
|
||||||
|
|
||||||
|
# 获取需要翻译的枚举字段
|
||||||
|
enum_fields = self._get_enum_fields()
|
||||||
|
|
||||||
|
# 为每个枚举字段添加 _display 字段
|
||||||
|
for field, enum_type in enum_fields.items():
|
||||||
|
if field in result and result[field] is not None:
|
||||||
|
value = result[field]
|
||||||
|
# 翻译枚举值
|
||||||
|
display_value = translation_service.translate_enum(
|
||||||
|
enum_type=enum_type,
|
||||||
|
value=str(value),
|
||||||
|
locale=locale
|
||||||
|
)
|
||||||
|
# 添加 _display 字段
|
||||||
|
result[f"{field}_display"] = display_value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _get_enum_fields(self) -> Dict[str, str]:
|
||||||
|
"""获取需要翻译的枚举字段
|
||||||
|
|
||||||
|
子类必须实现此方法,返回字段名到枚举类型的映射。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段名到枚举类型的映射
|
||||||
|
例如: {"role": "workspace_role", "status": "workspace_status"}
|
||||||
|
"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceSerializer(I18nResponseMixin):
|
||||||
|
"""工作空间序列化器
|
||||||
|
|
||||||
|
为工作空间响应添加国际化字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_enum_fields(self) -> Dict[str, str]:
|
||||||
|
"""定义工作空间的枚举字段"""
|
||||||
|
return {
|
||||||
|
"role": "workspace_role",
|
||||||
|
"status": "workspace_status"
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize(self, workspace_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||||
|
"""序列化工作空间数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_data: 工作空间数据(字典或 Pydantic 模型)
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的工作空间数据,包含国际化字段
|
||||||
|
"""
|
||||||
|
return self.serialize_with_i18n(workspace_data, locale)
|
||||||
|
|
||||||
|
def serialize_list(self, workspaces: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||||
|
"""序列化工作空间列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspaces: 工作空间列表
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的工作空间列表
|
||||||
|
"""
|
||||||
|
return [self.serialize(ws, locale) for ws in workspaces]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceMemberSerializer(I18nResponseMixin):
|
||||||
|
"""工作空间成员序列化器
|
||||||
|
|
||||||
|
为工作空间成员响应添加国际化字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_enum_fields(self) -> Dict[str, str]:
|
||||||
|
"""定义工作空间成员的枚举字段"""
|
||||||
|
return {
|
||||||
|
"role": "workspace_role"
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize(self, member_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||||
|
"""序列化工作空间成员数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
member_data: 成员数据(字典或 Pydantic 模型)
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的成员数据,包含国际化字段
|
||||||
|
"""
|
||||||
|
return self.serialize_with_i18n(member_data, locale)
|
||||||
|
|
||||||
|
def serialize_list(self, members: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||||
|
"""序列化工作空间成员列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
members: 成员列表
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的成员列表
|
||||||
|
"""
|
||||||
|
return [self.serialize(member, locale) for member in members]
|
||||||
|
|
||||||
|
|
||||||
|
class WorkspaceInviteSerializer(I18nResponseMixin):
|
||||||
|
"""工作空间邀请序列化器
|
||||||
|
|
||||||
|
为工作空间邀请响应添加国际化字段。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _get_enum_fields(self) -> Dict[str, str]:
|
||||||
|
"""定义工作空间邀请的枚举字段"""
|
||||||
|
return {
|
||||||
|
"status": "invite_status",
|
||||||
|
"role": "workspace_role"
|
||||||
|
}
|
||||||
|
|
||||||
|
def serialize(self, invite_data: Union[Dict, BaseModel], locale: str = "zh") -> Dict:
|
||||||
|
"""序列化工作空间邀请数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invite_data: 邀请数据(字典或 Pydantic 模型)
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的邀请数据,包含国际化字段
|
||||||
|
"""
|
||||||
|
return self.serialize_with_i18n(invite_data, locale)
|
||||||
|
|
||||||
|
def serialize_list(self, invites: List[Union[Dict, BaseModel]], locale: str = "zh") -> List[Dict]:
|
||||||
|
"""序列化工作空间邀请列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
invites: 邀请列表
|
||||||
|
locale: 语言代码
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
序列化后的邀请列表
|
||||||
|
"""
|
||||||
|
return [self.serialize(invite, locale) for invite in invites]
|
||||||
370
api/app/i18n/service.py
Normal file
370
api/app/i18n/service.py
Normal file
@@ -0,0 +1,370 @@
|
|||||||
|
"""
|
||||||
|
Translation service for i18n system.
|
||||||
|
|
||||||
|
This module provides the core translation functionality including:
|
||||||
|
- Translation lookup with fallback mechanism
|
||||||
|
- Parameterized message support
|
||||||
|
- Enum value translation
|
||||||
|
- Memory caching for performance
|
||||||
|
- Performance monitoring and metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from app.i18n.loader import TranslationLoader
|
||||||
|
from app.i18n.cache import TranslationCache
|
||||||
|
from app.i18n.metrics import get_metrics, monitor_performance, track_missing_translation, track_translation_request
|
||||||
|
from app.i18n.logger import get_translation_logger
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslationService:
|
||||||
|
"""
|
||||||
|
Translation service that provides:
|
||||||
|
- Fast translation lookup with memory cache
|
||||||
|
- Parameterized message support ({param} syntax)
|
||||||
|
- Fallback mechanism (current locale → default locale → key)
|
||||||
|
- Enum value translation
|
||||||
|
- Deep merge of multi-directory translations
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, locales_dirs: Optional[list] = None):
|
||||||
|
"""
|
||||||
|
Initialize the translation service.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locales_dirs: List of directories containing translation files.
|
||||||
|
If None, will auto-detect from settings.
|
||||||
|
"""
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
self.loader = TranslationLoader(locales_dirs)
|
||||||
|
self.default_locale = settings.I18N_DEFAULT_LANGUAGE
|
||||||
|
self.fallback_locale = settings.I18N_FALLBACK_LANGUAGE
|
||||||
|
self.log_missing = settings.I18N_LOG_MISSING_TRANSLATIONS
|
||||||
|
self.enable_cache = settings.I18N_ENABLE_TRANSLATION_CACHE
|
||||||
|
|
||||||
|
# Initialize advanced cache with LRU
|
||||||
|
lru_cache_size = getattr(settings, 'I18N_LRU_CACHE_SIZE', 1000)
|
||||||
|
self.cache = TranslationCache(
|
||||||
|
max_lru_size=lru_cache_size,
|
||||||
|
enable_lazy_load=False # Load all at startup for now
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load all translations into cache
|
||||||
|
self._load_all_locales()
|
||||||
|
|
||||||
|
# Initialize metrics
|
||||||
|
self.metrics = get_metrics()
|
||||||
|
|
||||||
|
# Initialize translation logger
|
||||||
|
self.translation_logger = get_translation_logger()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"TranslationService initialized with default locale: {self.default_locale}, "
|
||||||
|
f"LRU cache size: {lru_cache_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _load_all_locales(self):
|
||||||
|
"""Load all available locales into memory cache."""
|
||||||
|
available_locales = self.loader.get_available_locales()
|
||||||
|
logger.info(f"Loading translations for locales: {available_locales}")
|
||||||
|
|
||||||
|
for locale in available_locales:
|
||||||
|
locale_data = self.loader.load_locale(locale)
|
||||||
|
self.cache.set_locale_data(locale, locale_data)
|
||||||
|
|
||||||
|
logger.info(f"Loaded {len(available_locales)} locales into cache")
|
||||||
|
|
||||||
|
@monitor_performance("translate")
|
||||||
|
def translate(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
locale: Optional[str] = None,
|
||||||
|
**params
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Translate a key to the target locale.
|
||||||
|
|
||||||
|
Supports:
|
||||||
|
- Dot-separated keys (e.g., "common.success.created")
|
||||||
|
- Parameterized messages (e.g., "Hello {name}")
|
||||||
|
- Fallback mechanism
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key (format: "namespace.key.subkey")
|
||||||
|
locale: Target locale (defaults to default locale)
|
||||||
|
**params: Parameters for parameterized messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated string, or the key itself if translation not found
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
translate("common.success.created", "zh")
|
||||||
|
# => "创建成功"
|
||||||
|
|
||||||
|
translate("common.validation.required", "zh", field="名称")
|
||||||
|
# => "名称不能为空"
|
||||||
|
"""
|
||||||
|
if locale is None:
|
||||||
|
locale = self.default_locale
|
||||||
|
|
||||||
|
# Parse key (namespace.key.subkey)
|
||||||
|
parts = key.split(".", 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
if self.log_missing:
|
||||||
|
logger.warning(f"Invalid translation key format: {key}")
|
||||||
|
return key
|
||||||
|
|
||||||
|
namespace = parts[0]
|
||||||
|
key_path = parts[1].split(".")
|
||||||
|
|
||||||
|
# Track request
|
||||||
|
track_translation_request(locale, namespace)
|
||||||
|
|
||||||
|
# Get translation from cache
|
||||||
|
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||||
|
|
||||||
|
# Fallback to default locale if not found
|
||||||
|
if translation is None and locale != self.fallback_locale:
|
||||||
|
translation = self.cache.get_translation(
|
||||||
|
self.fallback_locale, namespace, key_path
|
||||||
|
)
|
||||||
|
|
||||||
|
# If still not found, return the key itself
|
||||||
|
if translation is None:
|
||||||
|
if self.log_missing:
|
||||||
|
logger.warning(
|
||||||
|
f"Missing translation: {key} (locale: {locale})"
|
||||||
|
)
|
||||||
|
track_missing_translation(key, locale)
|
||||||
|
|
||||||
|
# Log to translation logger with context
|
||||||
|
self.translation_logger.log_missing_translation(
|
||||||
|
key=key,
|
||||||
|
locale=locale,
|
||||||
|
context={"namespace": namespace}
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
# Apply parameters if provided
|
||||||
|
if params:
|
||||||
|
try:
|
||||||
|
translation = translation.format(**params)
|
||||||
|
except KeyError as e:
|
||||||
|
error_msg = f"Missing parameter in translation '{key}': {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
self.translation_logger.log_translation_error(
|
||||||
|
error_type="parameter_missing",
|
||||||
|
message=error_msg,
|
||||||
|
key=key,
|
||||||
|
locale=locale,
|
||||||
|
context={"params": list(params.keys())}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error formatting translation '{key}': {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
self.translation_logger.log_translation_error(
|
||||||
|
error_type="format_error",
|
||||||
|
message=error_msg,
|
||||||
|
key=key,
|
||||||
|
locale=locale
|
||||||
|
)
|
||||||
|
|
||||||
|
return translation
|
||||||
|
|
||||||
|
def _get_translation(
|
||||||
|
self,
|
||||||
|
locale: str,
|
||||||
|
namespace: str,
|
||||||
|
key_path: list
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Get translation from cache (deprecated, use cache.get_translation).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Locale code
|
||||||
|
namespace: Translation namespace
|
||||||
|
key_path: List of nested keys
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translation string or None if not found
|
||||||
|
"""
|
||||||
|
return self.cache.get_translation(locale, namespace, key_path)
|
||||||
|
|
||||||
|
@monitor_performance("translate_enum")
|
||||||
|
def translate_enum(
|
||||||
|
self,
|
||||||
|
enum_type: str,
|
||||||
|
value: str,
|
||||||
|
locale: Optional[str] = None
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Translate an enum value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enum_type: Enum type name (e.g., "workspace_role")
|
||||||
|
value: Enum value (e.g., "manager")
|
||||||
|
locale: Target locale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated enum display name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
translate_enum("workspace_role", "manager", "zh")
|
||||||
|
# => "管理员"
|
||||||
|
|
||||||
|
translate_enum("invite_status", "pending", "en")
|
||||||
|
# => "Pending"
|
||||||
|
"""
|
||||||
|
key = f"enums.{enum_type}.{value}"
|
||||||
|
return self.translate(key, locale)
|
||||||
|
|
||||||
|
def has_translation(self, key: str, locale: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a translation exists for the given key and locale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Locale code
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if translation exists, False otherwise
|
||||||
|
"""
|
||||||
|
parts = key.split(".", 1)
|
||||||
|
if len(parts) < 2:
|
||||||
|
return False
|
||||||
|
|
||||||
|
namespace = parts[0]
|
||||||
|
key_path = parts[1].split(".")
|
||||||
|
|
||||||
|
translation = self.cache.get_translation(locale, namespace, key_path)
|
||||||
|
return translation is not None
|
||||||
|
|
||||||
|
def reload(self, locale: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Reload translation files.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
locale: Specific locale to reload. If None, reloads all locales.
|
||||||
|
"""
|
||||||
|
logger.info(f"Reloading translations for locale: {locale or 'all'}")
|
||||||
|
|
||||||
|
if locale:
|
||||||
|
locale_data = self.loader.load_locale(locale)
|
||||||
|
self.cache.set_locale_data(locale, locale_data)
|
||||||
|
# Clear LRU cache for this locale
|
||||||
|
self.cache.clear_locale(locale)
|
||||||
|
else:
|
||||||
|
self._load_all_locales()
|
||||||
|
# Clear all LRU cache
|
||||||
|
self.cache.clear_lru()
|
||||||
|
|
||||||
|
logger.info("Translation reload completed")
|
||||||
|
|
||||||
|
def get_available_locales(self) -> list:
|
||||||
|
"""
|
||||||
|
Get list of all available locales.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of locale codes
|
||||||
|
"""
|
||||||
|
return self.cache.get_loaded_locales()
|
||||||
|
|
||||||
|
def get_cache_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get cache statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with cache statistics
|
||||||
|
"""
|
||||||
|
return self.cache.get_stats()
|
||||||
|
|
||||||
|
def get_metrics_summary(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get metrics summary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with metrics summary
|
||||||
|
"""
|
||||||
|
return self.metrics.get_summary()
|
||||||
|
|
||||||
|
def get_memory_usage(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get memory usage information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with memory usage information
|
||||||
|
"""
|
||||||
|
return self.cache.get_memory_usage()
|
||||||
|
|
||||||
|
def get_loaded_dirs(self) -> list:
|
||||||
|
"""
|
||||||
|
Get list of loaded translation directories.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of directory paths
|
||||||
|
"""
|
||||||
|
return self.loader.locales_dirs
|
||||||
|
|
||||||
|
|
||||||
|
# Global singleton instance
|
||||||
|
_translation_service: Optional[TranslationService] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_translation_service() -> TranslationService:
|
||||||
|
"""
|
||||||
|
Get the global translation service instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TranslationService singleton
|
||||||
|
"""
|
||||||
|
global _translation_service
|
||||||
|
if _translation_service is None:
|
||||||
|
_translation_service = TranslationService()
|
||||||
|
return _translation_service
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for easy access
|
||||||
|
def t(key: str, locale: Optional[str] = None, **params) -> str:
|
||||||
|
"""
|
||||||
|
Translate a key (convenience function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Translation key
|
||||||
|
locale: Target locale (optional, uses default if not provided)
|
||||||
|
**params: Parameters for parameterized messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated string
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
t("common.success.created")
|
||||||
|
t("common.validation.required", field="名称")
|
||||||
|
t("workspace.member_count", count=5)
|
||||||
|
"""
|
||||||
|
service = get_translation_service()
|
||||||
|
return service.translate(key, locale, **params)
|
||||||
|
|
||||||
|
|
||||||
|
def t_enum(enum_type: str, value: str, locale: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Translate an enum value (convenience function).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enum_type: Enum type name
|
||||||
|
value: Enum value
|
||||||
|
locale: Target locale
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Translated enum display name
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
t_enum("workspace_role", "manager")
|
||||||
|
t_enum("invite_status", "pending", "en")
|
||||||
|
"""
|
||||||
|
service = get_translation_service()
|
||||||
|
return service.translate_enum(enum_type, value, locale)
|
||||||
26
api/app/locales/en/README.md
Normal file
26
api/app/locales/en/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# English Translation Files
|
||||||
|
|
||||||
|
This directory contains English translation files.
|
||||||
|
|
||||||
|
## File Structure
|
||||||
|
|
||||||
|
- `common.json` - Common translations (success messages, actions, validation)
|
||||||
|
- `auth.json` - Authentication module translations
|
||||||
|
- `workspace.json` - Workspace module translations
|
||||||
|
- `tenant.json` - Tenant module translations
|
||||||
|
- `errors.json` - Error message translations
|
||||||
|
- `enums.json` - Enum value translations
|
||||||
|
|
||||||
|
## Translation File Format
|
||||||
|
|
||||||
|
All translation files use JSON format and support nested structures.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": {
|
||||||
|
"created": "Created successfully",
|
||||||
|
"updated": "Updated successfully"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
55
api/app/locales/en/auth.json
Normal file
55
api/app/locales/en/auth.json
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"login": {
|
||||||
|
"success": "Login successful",
|
||||||
|
"failed": "Login failed",
|
||||||
|
"invalid_credentials": "Invalid username or password",
|
||||||
|
"account_locked": "Account has been locked",
|
||||||
|
"account_disabled": "Account has been disabled"
|
||||||
|
},
|
||||||
|
"logout": {
|
||||||
|
"success": "Logout successful",
|
||||||
|
"failed": "Logout failed"
|
||||||
|
},
|
||||||
|
"token": {
|
||||||
|
"refresh_success": "Token refreshed successfully",
|
||||||
|
"invalid": "Invalid token",
|
||||||
|
"expired": "Token has expired",
|
||||||
|
"blacklisted": "Token has been invalidated",
|
||||||
|
"invalid_refresh_token": "Invalid refresh token",
|
||||||
|
"refresh_token_blacklisted": "Refresh token has been invalidated"
|
||||||
|
},
|
||||||
|
"registration": {
|
||||||
|
"success": "Registration successful",
|
||||||
|
"failed": "Registration failed",
|
||||||
|
"email_exists": "Email already in use",
|
||||||
|
"username_exists": "Username already taken"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"reset_success": "Password reset successful",
|
||||||
|
"reset_failed": "Password reset failed",
|
||||||
|
"change_success": "Password changed successfully",
|
||||||
|
"change_failed": "Password change failed",
|
||||||
|
"incorrect": "Incorrect password",
|
||||||
|
"too_weak": "Password is too weak",
|
||||||
|
"mismatch": "Passwords do not match"
|
||||||
|
},
|
||||||
|
"invite": {
|
||||||
|
"invalid": "Invalid or expired invite code",
|
||||||
|
"email_mismatch": "Invite email does not match login email",
|
||||||
|
"accept_success": "Invite accepted successfully",
|
||||||
|
"accept_failed": "Failed to accept invite",
|
||||||
|
"password_verification_failed": "Failed to accept invite, password verification error",
|
||||||
|
"bind_workspace_success": "Workspace bound successfully",
|
||||||
|
"bind_workspace_failed": "Failed to bind workspace"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"not_found": "User not found",
|
||||||
|
"already_exists": "User already exists",
|
||||||
|
"created_with_invite": "User created successfully and joined workspace"
|
||||||
|
},
|
||||||
|
"session": {
|
||||||
|
"expired": "Session expired, please login again",
|
||||||
|
"invalid": "Invalid session",
|
||||||
|
"single_session_enabled": "Single sign-on enabled, other device sessions will be logged out"
|
||||||
|
}
|
||||||
|
}
|
||||||
132
api/app/locales/en/common.json
Normal file
132
api/app/locales/en/common.json
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
{
|
||||||
|
"success": {
|
||||||
|
"created": "Created successfully",
|
||||||
|
"updated": "Updated successfully",
|
||||||
|
"deleted": "Deleted successfully",
|
||||||
|
"retrieved": "Retrieved successfully",
|
||||||
|
"saved": "Saved successfully",
|
||||||
|
"uploaded": "Uploaded successfully",
|
||||||
|
"downloaded": "Downloaded successfully",
|
||||||
|
"sent": "Sent successfully",
|
||||||
|
"completed": "Completed",
|
||||||
|
"confirmed": "Confirmed",
|
||||||
|
"cancelled": "Cancelled",
|
||||||
|
"archived": "Archived",
|
||||||
|
"restored": "Restored"
|
||||||
|
},
|
||||||
|
"actions": {
|
||||||
|
"create": "Create",
|
||||||
|
"update": "Update",
|
||||||
|
"delete": "Delete",
|
||||||
|
"view": "View",
|
||||||
|
"edit": "Edit",
|
||||||
|
"save": "Save",
|
||||||
|
"cancel": "Cancel",
|
||||||
|
"confirm": "Confirm",
|
||||||
|
"submit": "Submit",
|
||||||
|
"upload": "Upload",
|
||||||
|
"download": "Download",
|
||||||
|
"send": "Send",
|
||||||
|
"search": "Search",
|
||||||
|
"filter": "Filter",
|
||||||
|
"sort": "Sort",
|
||||||
|
"export": "Export",
|
||||||
|
"import": "Import",
|
||||||
|
"refresh": "Refresh",
|
||||||
|
"reset": "Reset",
|
||||||
|
"back": "Back",
|
||||||
|
"next": "Next",
|
||||||
|
"previous": "Previous",
|
||||||
|
"finish": "Finish",
|
||||||
|
"close": "Close",
|
||||||
|
"open": "Open",
|
||||||
|
"archive": "Archive",
|
||||||
|
"restore": "Restore",
|
||||||
|
"duplicate": "Duplicate",
|
||||||
|
"share": "Share",
|
||||||
|
"invite": "Invite",
|
||||||
|
"remove": "Remove",
|
||||||
|
"add": "Add",
|
||||||
|
"select": "Select",
|
||||||
|
"clear": "Clear"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"required": "{field} is required",
|
||||||
|
"invalid_format": "{field} format is invalid",
|
||||||
|
"too_long": "{field} cannot exceed {max} characters",
|
||||||
|
"too_short": "{field} must be at least {min} characters",
|
||||||
|
"invalid_email": "Invalid email format",
|
||||||
|
"invalid_url": "Invalid URL format",
|
||||||
|
"invalid_phone": "Invalid phone number format",
|
||||||
|
"invalid_date": "Invalid date format",
|
||||||
|
"invalid_number": "Must be a valid number",
|
||||||
|
"out_of_range": "{field} must be between {min} and {max}",
|
||||||
|
"already_exists": "{field} already exists",
|
||||||
|
"not_found": "{field} not found",
|
||||||
|
"invalid_value": "Invalid value for {field}",
|
||||||
|
"password_mismatch": "Passwords do not match",
|
||||||
|
"weak_password": "Password is too weak, please use a stronger password",
|
||||||
|
"invalid_credentials": "Invalid username or password",
|
||||||
|
"unauthorized": "Unauthorized access",
|
||||||
|
"forbidden": "Permission denied",
|
||||||
|
"expired": "{field} has expired",
|
||||||
|
"invalid_token": "Invalid token",
|
||||||
|
"file_too_large": "File size cannot exceed {max}",
|
||||||
|
"invalid_file_type": "Unsupported file type",
|
||||||
|
"duplicate": "Duplicate {field}"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"active": "Active",
|
||||||
|
"inactive": "Inactive",
|
||||||
|
"pending": "Pending",
|
||||||
|
"processing": "Processing",
|
||||||
|
"completed": "Completed",
|
||||||
|
"failed": "Failed",
|
||||||
|
"cancelled": "Cancelled",
|
||||||
|
"archived": "Archived",
|
||||||
|
"deleted": "Deleted",
|
||||||
|
"draft": "Draft",
|
||||||
|
"published": "Published",
|
||||||
|
"suspended": "Suspended",
|
||||||
|
"expired": "Expired"
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"loading": "Loading...",
|
||||||
|
"saving": "Saving...",
|
||||||
|
"processing": "Processing...",
|
||||||
|
"uploading": "Uploading...",
|
||||||
|
"downloading": "Downloading...",
|
||||||
|
"no_data": "No data available",
|
||||||
|
"no_results": "No results found",
|
||||||
|
"confirm_delete": "Are you sure you want to delete? This action cannot be undone.",
|
||||||
|
"confirm_action": "Are you sure you want to perform this action?",
|
||||||
|
"operation_success": "Operation successful",
|
||||||
|
"operation_failed": "Operation failed",
|
||||||
|
"please_wait": "Please wait...",
|
||||||
|
"try_again": "Please try again",
|
||||||
|
"contact_support": "If the problem persists, please contact support"
|
||||||
|
},
|
||||||
|
"pagination": {
|
||||||
|
"page": "Page {page}",
|
||||||
|
"of": "of {total}",
|
||||||
|
"items": "{total} items",
|
||||||
|
"per_page": "{count} per page",
|
||||||
|
"showing": "Showing {from} to {to} of {total}",
|
||||||
|
"first": "First",
|
||||||
|
"last": "Last",
|
||||||
|
"next": "Next",
|
||||||
|
"previous": "Previous"
|
||||||
|
},
|
||||||
|
"time": {
|
||||||
|
"just_now": "Just now",
|
||||||
|
"minutes_ago": "{count} minutes ago",
|
||||||
|
"hours_ago": "{count} hours ago",
|
||||||
|
"days_ago": "{count} days ago",
|
||||||
|
"weeks_ago": "{count} weeks ago",
|
||||||
|
"months_ago": "{count} months ago",
|
||||||
|
"years_ago": "{count} years ago",
|
||||||
|
"today": "Today",
|
||||||
|
"yesterday": "Yesterday",
|
||||||
|
"tomorrow": "Tomorrow"
|
||||||
|
}
|
||||||
|
}
|
||||||
132
api/app/locales/en/enums.json
Normal file
132
api/app/locales/en/enums.json
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
{
|
||||||
|
"workspace_role": {
|
||||||
|
"owner": "Owner",
|
||||||
|
"manager": "Manager",
|
||||||
|
"member": "Member",
|
||||||
|
"guest": "Guest"
|
||||||
|
},
|
||||||
|
"workspace_status": {
|
||||||
|
"active": "Active",
|
||||||
|
"inactive": "Inactive",
|
||||||
|
"archived": "Archived",
|
||||||
|
"suspended": "Suspended",
|
||||||
|
"deleted": "Deleted"
|
||||||
|
},
|
||||||
|
"invite_status": {
|
||||||
|
"pending": "Pending",
|
||||||
|
"accepted": "Accepted",
|
||||||
|
"rejected": "Rejected",
|
||||||
|
"revoked": "Revoked",
|
||||||
|
"expired": "Expired"
|
||||||
|
},
|
||||||
|
"user_status": {
|
||||||
|
"active": "Active",
|
||||||
|
"inactive": "Inactive",
|
||||||
|
"suspended": "Suspended",
|
||||||
|
"deleted": "Deleted",
|
||||||
|
"pending": "Pending"
|
||||||
|
},
|
||||||
|
"tenant_status": {
|
||||||
|
"active": "Active",
|
||||||
|
"inactive": "Inactive",
|
||||||
|
"suspended": "Suspended",
|
||||||
|
"expired": "Expired",
|
||||||
|
"trial": "Trial"
|
||||||
|
},
|
||||||
|
"file_status": {
|
||||||
|
"uploading": "Uploading",
|
||||||
|
"processing": "Processing",
|
||||||
|
"completed": "Completed",
|
||||||
|
"failed": "Failed",
|
||||||
|
"deleted": "Deleted"
|
||||||
|
},
|
||||||
|
"task_status": {
|
||||||
|
"pending": "Pending",
|
||||||
|
"running": "Running",
|
||||||
|
"completed": "Completed",
|
||||||
|
"failed": "Failed",
|
||||||
|
"cancelled": "Cancelled",
|
||||||
|
"paused": "Paused"
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"low": "Low",
|
||||||
|
"medium": "Medium",
|
||||||
|
"high": "High",
|
||||||
|
"urgent": "Urgent"
|
||||||
|
},
|
||||||
|
"visibility": {
|
||||||
|
"public": "Public",
|
||||||
|
"private": "Private",
|
||||||
|
"internal": "Internal",
|
||||||
|
"shared": "Shared"
|
||||||
|
},
|
||||||
|
"permission": {
|
||||||
|
"read": "Read",
|
||||||
|
"write": "Write",
|
||||||
|
"delete": "Delete",
|
||||||
|
"admin": "Admin",
|
||||||
|
"owner": "Owner"
|
||||||
|
},
|
||||||
|
"notification_type": {
|
||||||
|
"info": "Info",
|
||||||
|
"warning": "Warning",
|
||||||
|
"error": "Error",
|
||||||
|
"success": "Success"
|
||||||
|
},
|
||||||
|
"language": {
|
||||||
|
"zh": "Chinese (Simplified)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "Japanese",
|
||||||
|
"ko": "Korean",
|
||||||
|
"fr": "French",
|
||||||
|
"de": "German",
|
||||||
|
"es": "Spanish"
|
||||||
|
},
|
||||||
|
"timezone": {
|
||||||
|
"utc": "UTC",
|
||||||
|
"asia_shanghai": "Asia/Shanghai",
|
||||||
|
"asia_tokyo": "Asia/Tokyo",
|
||||||
|
"america_new_york": "America/New_York",
|
||||||
|
"europe_london": "Europe/London"
|
||||||
|
},
|
||||||
|
"date_format": {
|
||||||
|
"short": "Short",
|
||||||
|
"medium": "Medium",
|
||||||
|
"long": "Long",
|
||||||
|
"full": "Full"
|
||||||
|
},
|
||||||
|
"sort_order": {
|
||||||
|
"asc": "Ascending",
|
||||||
|
"desc": "Descending"
|
||||||
|
},
|
||||||
|
"filter_operator": {
|
||||||
|
"equals": "Equals",
|
||||||
|
"not_equals": "Not Equals",
|
||||||
|
"contains": "Contains",
|
||||||
|
"not_contains": "Not Contains",
|
||||||
|
"starts_with": "Starts With",
|
||||||
|
"ends_with": "Ends With",
|
||||||
|
"greater_than": "Greater Than",
|
||||||
|
"less_than": "Less Than",
|
||||||
|
"greater_or_equal": "Greater or Equal",
|
||||||
|
"less_or_equal": "Less or Equal",
|
||||||
|
"in": "In",
|
||||||
|
"not_in": "Not In",
|
||||||
|
"is_null": "Is Null",
|
||||||
|
"is_not_null": "Is Not Null"
|
||||||
|
},
|
||||||
|
"log_level": {
|
||||||
|
"debug": "Debug",
|
||||||
|
"info": "Info",
|
||||||
|
"warning": "Warning",
|
||||||
|
"error": "Error",
|
||||||
|
"critical": "Critical"
|
||||||
|
},
|
||||||
|
"api_method": {
|
||||||
|
"get": "GET",
|
||||||
|
"post": "POST",
|
||||||
|
"put": "PUT",
|
||||||
|
"patch": "PATCH",
|
||||||
|
"delete": "DELETE"
|
||||||
|
}
|
||||||
|
}
|
||||||
138
api/app/locales/en/errors.json
Normal file
138
api/app/locales/en/errors.json
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
{
|
||||||
|
"common": {
|
||||||
|
"internal_error": "Internal server error",
|
||||||
|
"network_error": "Network connection error",
|
||||||
|
"timeout": "Request timeout",
|
||||||
|
"service_unavailable": "Service temporarily unavailable",
|
||||||
|
"bad_request": "Bad request parameters",
|
||||||
|
"unauthorized": "Unauthorized access",
|
||||||
|
"forbidden": "Access forbidden",
|
||||||
|
"not_found": "Resource not found",
|
||||||
|
"method_not_allowed": "Method not allowed",
|
||||||
|
"conflict": "Resource conflict",
|
||||||
|
"too_many_requests": "Too many requests, please try again later",
|
||||||
|
"validation_failed": "Validation failed",
|
||||||
|
"database_error": "Database operation failed",
|
||||||
|
"file_operation_error": "File operation failed"
|
||||||
|
},
|
||||||
|
"auth": {
|
||||||
|
"invalid_credentials": "Invalid username or password",
|
||||||
|
"token_expired": "Session expired, please login again",
|
||||||
|
"token_invalid": "Invalid authentication token",
|
||||||
|
"token_missing": "Authentication token missing",
|
||||||
|
"unauthorized": "Unauthorized access",
|
||||||
|
"forbidden": "Permission denied",
|
||||||
|
"account_locked": "Account has been locked",
|
||||||
|
"account_disabled": "Account has been disabled",
|
||||||
|
"account_not_verified": "Account not verified",
|
||||||
|
"password_incorrect": "Incorrect password",
|
||||||
|
"password_too_weak": "Password is too weak",
|
||||||
|
"password_expired": "Password expired, please change it",
|
||||||
|
"email_not_verified": "Email not verified",
|
||||||
|
"phone_not_verified": "Phone number not verified",
|
||||||
|
"verification_code_invalid": "Invalid verification code",
|
||||||
|
"verification_code_expired": "Verification code expired",
|
||||||
|
"login_failed": "Login failed",
|
||||||
|
"logout_failed": "Logout failed",
|
||||||
|
"session_expired": "Session expired",
|
||||||
|
"already_logged_in": "Already logged in",
|
||||||
|
"not_logged_in": "Not logged in"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"not_found": "User not found",
|
||||||
|
"already_exists": "User already exists",
|
||||||
|
"email_already_exists": "Email already in use",
|
||||||
|
"phone_already_exists": "Phone number already in use",
|
||||||
|
"username_already_exists": "Username already taken",
|
||||||
|
"invalid_email": "Invalid email format",
|
||||||
|
"invalid_phone": "Invalid phone number format",
|
||||||
|
"invalid_username": "Invalid username format",
|
||||||
|
"create_failed": "Failed to create user",
|
||||||
|
"update_failed": "Failed to update user",
|
||||||
|
"delete_failed": "Failed to delete user",
|
||||||
|
"cannot_delete_self": "Cannot delete yourself",
|
||||||
|
"cannot_update_self_role": "Cannot update your own role",
|
||||||
|
"profile_update_failed": "Failed to update profile",
|
||||||
|
"avatar_upload_failed": "Failed to upload avatar",
|
||||||
|
"password_change_failed": "Failed to change password",
|
||||||
|
"old_password_incorrect": "Old password is incorrect"
|
||||||
|
},
|
||||||
|
"workspace": {
|
||||||
|
"not_found": "Workspace not found",
|
||||||
|
"already_exists": "Workspace already exists",
|
||||||
|
"name_required": "Workspace name is required",
|
||||||
|
"name_too_long": "Workspace name is too long",
|
||||||
|
"create_failed": "Failed to create workspace",
|
||||||
|
"update_failed": "Failed to update workspace",
|
||||||
|
"delete_failed": "Failed to delete workspace",
|
||||||
|
"permission_denied": "Permission denied to access this workspace",
|
||||||
|
"not_member": "Not a workspace member",
|
||||||
|
"already_member": "Already a workspace member",
|
||||||
|
"member_limit_reached": "Member limit reached",
|
||||||
|
"cannot_leave_last_manager": "Cannot leave, you are the last manager",
|
||||||
|
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||||
|
"cannot_remove_self": "Cannot remove yourself",
|
||||||
|
"invite_not_found": "Invite not found",
|
||||||
|
"invite_expired": "Invite has expired",
|
||||||
|
"invite_already_accepted": "Invite already accepted",
|
||||||
|
"invite_already_revoked": "Invite already revoked",
|
||||||
|
"invite_send_failed": "Failed to send invite",
|
||||||
|
"archived": "Workspace is archived",
|
||||||
|
"suspended": "Workspace is suspended"
|
||||||
|
},
|
||||||
|
"tenant": {
|
||||||
|
"not_found": "Tenant not found",
|
||||||
|
"already_exists": "Tenant already exists",
|
||||||
|
"create_failed": "Failed to create tenant",
|
||||||
|
"update_failed": "Failed to update tenant",
|
||||||
|
"delete_failed": "Failed to delete tenant",
|
||||||
|
"suspended": "Tenant is suspended",
|
||||||
|
"expired": "Tenant has expired",
|
||||||
|
"license_invalid": "Invalid license",
|
||||||
|
"license_expired": "License has expired",
|
||||||
|
"quota_exceeded": "Quota exceeded"
|
||||||
|
},
|
||||||
|
"file": {
|
||||||
|
"not_found": "File not found",
|
||||||
|
"upload_failed": "File upload failed",
|
||||||
|
"download_failed": "File download failed",
|
||||||
|
"delete_failed": "File deletion failed",
|
||||||
|
"too_large": "File size exceeds limit",
|
||||||
|
"invalid_type": "Unsupported file type",
|
||||||
|
"invalid_format": "Invalid file format",
|
||||||
|
"corrupted": "File is corrupted",
|
||||||
|
"storage_full": "Storage is full",
|
||||||
|
"access_denied": "Access denied to this file"
|
||||||
|
},
|
||||||
|
"api": {
|
||||||
|
"rate_limit_exceeded": "API rate limit exceeded",
|
||||||
|
"quota_exceeded": "API quota exceeded",
|
||||||
|
"invalid_api_key": "Invalid API key",
|
||||||
|
"api_key_expired": "API key has expired",
|
||||||
|
"api_key_revoked": "API key has been revoked",
|
||||||
|
"endpoint_not_found": "API endpoint not found",
|
||||||
|
"method_not_allowed": "Method not allowed",
|
||||||
|
"invalid_request": "Invalid request",
|
||||||
|
"missing_parameter": "Missing required parameter: {param}",
|
||||||
|
"invalid_parameter": "Invalid parameter: {param}"
|
||||||
|
},
|
||||||
|
"database": {
|
||||||
|
"connection_failed": "Database connection failed",
|
||||||
|
"query_failed": "Database query failed",
|
||||||
|
"transaction_failed": "Database transaction failed",
|
||||||
|
"constraint_violation": "Data constraint violation",
|
||||||
|
"duplicate_key": "Duplicate data",
|
||||||
|
"foreign_key_violation": "Foreign key constraint violation",
|
||||||
|
"deadlock": "Database deadlock"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"invalid_input": "Invalid input data",
|
||||||
|
"missing_field": "Missing required field: {field}",
|
||||||
|
"invalid_field": "Invalid field: {field}",
|
||||||
|
"field_too_long": "Field too long: {field}",
|
||||||
|
"field_too_short": "Field too short: {field}",
|
||||||
|
"invalid_format": "Invalid format: {field}",
|
||||||
|
"invalid_value": "Invalid value: {field}",
|
||||||
|
"out_of_range": "Value out of range: {field}"
|
||||||
|
}
|
||||||
|
}
|
||||||
27
api/app/locales/en/i18n.json
Normal file
27
api/app/locales/en/i18n.json
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"language": {
|
||||||
|
"not_found": "Language {locale} not found",
|
||||||
|
"already_exists": "Language {locale} already exists",
|
||||||
|
"add_instructions": "Language {locale} validated successfully. Please create translation files in {dir} directory to complete the addition.",
|
||||||
|
"update_instructions": "Language {locale} update validated successfully. Please update I18N_SUPPORTED_LANGUAGES environment variable to apply configuration changes."
|
||||||
|
},
|
||||||
|
"namespace": {
|
||||||
|
"not_found": "Namespace {namespace} not found in language {locale}"
|
||||||
|
},
|
||||||
|
"translation": {
|
||||||
|
"invalid_key_format": "Invalid translation key format: {key}. Should use format: namespace.key.subkey",
|
||||||
|
"update_instructions": "Translation {locale}/{key} update validated successfully. Please modify the corresponding JSON translation file to apply changes."
|
||||||
|
},
|
||||||
|
"reload": {
|
||||||
|
"disabled": "Translation hot reload is disabled. Please enable I18N_ENABLE_HOT_RELOAD in configuration.",
|
||||||
|
"success": "Translations reloaded successfully",
|
||||||
|
"failed": "Translation reload failed: {error}"
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"reset_success": "Performance metrics reset successfully"
|
||||||
|
},
|
||||||
|
"logs": {
|
||||||
|
"export_success": "Missing translations exported to: {file}",
|
||||||
|
"clear_success": "Missing translation logs cleared successfully"
|
||||||
|
}
|
||||||
|
}
|
||||||
63
api/app/locales/en/tenant.json
Normal file
63
api/app/locales/en/tenant.json
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
{
|
||||||
|
"info": {
|
||||||
|
"get_success": "Tenant information retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve tenant information",
|
||||||
|
"update_success": "Tenant information updated successfully",
|
||||||
|
"update_failed": "Failed to update tenant information"
|
||||||
|
},
|
||||||
|
"create": {
|
||||||
|
"success": "Tenant created successfully",
|
||||||
|
"failed": "Failed to create tenant"
|
||||||
|
},
|
||||||
|
"delete": {
|
||||||
|
"success": "Tenant deleted successfully",
|
||||||
|
"failed": "Failed to delete tenant"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"activate_success": "Tenant activated successfully",
|
||||||
|
"activate_failed": "Failed to activate tenant",
|
||||||
|
"deactivate_success": "Tenant deactivated successfully",
|
||||||
|
"deactivate_failed": "Failed to deactivate tenant"
|
||||||
|
},
|
||||||
|
"language": {
|
||||||
|
"get_success": "Tenant language configuration retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve tenant language configuration",
|
||||||
|
"update_success": "Tenant language configuration updated successfully",
|
||||||
|
"update_failed": "Failed to update tenant language configuration",
|
||||||
|
"invalid_language": "Unsupported language code",
|
||||||
|
"default_not_in_supported": "Default language must be in the supported languages list"
|
||||||
|
},
|
||||||
|
"list": {
|
||||||
|
"get_success": "Tenant list retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve tenant list"
|
||||||
|
},
|
||||||
|
"users": {
|
||||||
|
"list_success": "Tenant user list retrieved successfully",
|
||||||
|
"list_failed": "Failed to retrieve tenant user list",
|
||||||
|
"assign_success": "User assigned to tenant successfully",
|
||||||
|
"assign_failed": "Failed to assign user to tenant",
|
||||||
|
"remove_success": "User removed from tenant successfully",
|
||||||
|
"remove_failed": "Failed to remove user from tenant"
|
||||||
|
},
|
||||||
|
"statistics": {
|
||||||
|
"get_success": "Tenant statistics retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve tenant statistics"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"name_required": "Tenant name is required",
|
||||||
|
"name_invalid": "Invalid tenant name format",
|
||||||
|
"name_too_long": "Tenant name cannot exceed {max} characters",
|
||||||
|
"description_too_long": "Tenant description cannot exceed {max} characters",
|
||||||
|
"language_code_invalid": "Invalid language code format",
|
||||||
|
"supported_languages_empty": "Supported languages list cannot be empty"
|
||||||
|
},
|
||||||
|
"errors": {
|
||||||
|
"not_found": "Tenant not found",
|
||||||
|
"already_exists": "Tenant name already exists",
|
||||||
|
"permission_denied": "Permission denied to access this tenant",
|
||||||
|
"has_users": "Cannot delete tenant, associated users exist",
|
||||||
|
"has_workspaces": "Cannot delete tenant, associated workspaces exist",
|
||||||
|
"already_active": "Tenant is already active",
|
||||||
|
"already_inactive": "Tenant is already inactive"
|
||||||
|
}
|
||||||
|
}
|
||||||
72
api/app/locales/en/users.json
Normal file
72
api/app/locales/en/users.json
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
{
|
||||||
|
"info": {
|
||||||
|
"get_success": "User information retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve user information",
|
||||||
|
"update_success": "User information updated successfully",
|
||||||
|
"update_failed": "Failed to update user information"
|
||||||
|
},
|
||||||
|
"create": {
|
||||||
|
"success": "User created successfully",
|
||||||
|
"failed": "Failed to create user",
|
||||||
|
"superuser_success": "Superuser created successfully",
|
||||||
|
"superuser_failed": "Failed to create superuser"
|
||||||
|
},
|
||||||
|
"delete": {
|
||||||
|
"success": "User deleted successfully",
|
||||||
|
"failed": "Failed to delete user",
|
||||||
|
"deactivate_success": "User deactivated successfully",
|
||||||
|
"deactivate_failed": "Failed to deactivate user"
|
||||||
|
},
|
||||||
|
"activate": {
|
||||||
|
"success": "User activated successfully",
|
||||||
|
"failed": "Failed to activate user"
|
||||||
|
},
|
||||||
|
"language": {
|
||||||
|
"get_success": "Language preference retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve language preference",
|
||||||
|
"update_success": "Language preference updated successfully",
|
||||||
|
"update_failed": "Failed to update language preference",
|
||||||
|
"invalid_language": "Unsupported language code",
|
||||||
|
"current": "Current language preference"
|
||||||
|
},
|
||||||
|
"email": {
|
||||||
|
"change_success": "Email changed successfully",
|
||||||
|
"change_failed": "Failed to change email",
|
||||||
|
"code_sent": "Verification code has been sent to your email",
|
||||||
|
"code_send_failed": "Failed to send verification code",
|
||||||
|
"code_invalid": "Invalid or expired verification code",
|
||||||
|
"already_exists": "Email already in use"
|
||||||
|
},
|
||||||
|
"list": {
|
||||||
|
"get_success": "User list retrieved successfully",
|
||||||
|
"get_failed": "Failed to retrieve user list",
|
||||||
|
"superusers_success": "Tenant superuser list retrieved successfully",
|
||||||
|
"superusers_failed": "Failed to retrieve tenant superuser list"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"username_required": "Username is required",
|
||||||
|
"username_invalid": "Invalid username format",
|
||||||
|
"username_too_long": "Username cannot exceed {max} characters",
|
||||||
|
"email_required": "Email is required",
|
||||||
|
"email_invalid": "Invalid email format",
|
||||||
|
"password_required": "Password is required",
|
||||||
|
"password_too_short": "Password must be at least {min} characters",
|
||||||
|
"password_too_long": "Password cannot exceed {max} characters",
|
||||||
|
"old_password_required": "Old password is required",
|
||||||
|
"new_password_required": "New password is required",
|
||||||
|
"verification_code_required": "Verification code is required",
|
||||||
|
"verification_code_invalid": "Invalid verification code format"
|
||||||
|
},
|
||||||
|
"errors": {
|
||||||
|
"not_found": "User not found",
|
||||||
|
"already_exists": "User already exists",
|
||||||
|
"permission_denied": "Permission denied to access this user",
|
||||||
|
"cannot_delete_self": "Cannot delete yourself",
|
||||||
|
"cannot_deactivate_self": "Cannot deactivate yourself",
|
||||||
|
"already_deactivated": "User is already deactivated",
|
||||||
|
"already_activated": "User is already activated",
|
||||||
|
"password_verification_failed": "Password verification failed",
|
||||||
|
"old_password_incorrect": "Old password is incorrect",
|
||||||
|
"same_as_old_password": "New password cannot be the same as old password"
|
||||||
|
}
|
||||||
|
}
|
||||||
44
api/app/locales/en/workspace.json
Normal file
44
api/app/locales/en/workspace.json
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
{
|
||||||
|
"list_retrieved": "Workspace list retrieved successfully",
|
||||||
|
"created": "Workspace created successfully",
|
||||||
|
"updated": "Workspace updated successfully",
|
||||||
|
"deleted": "Workspace deleted successfully",
|
||||||
|
"switched": "Workspace switched successfully",
|
||||||
|
"not_found": "Workspace not found or access denied",
|
||||||
|
"already_exists": "Workspace already exists",
|
||||||
|
"permission_denied": "No permission to access this workspace",
|
||||||
|
"name_required": "Workspace name is required",
|
||||||
|
"invalid_name": "Invalid workspace name format",
|
||||||
|
"members": {
|
||||||
|
"list_retrieved": "Workspace members list retrieved successfully",
|
||||||
|
"role_updated": "Member role updated successfully",
|
||||||
|
"deleted": "Member deleted successfully",
|
||||||
|
"not_found": "Member not found",
|
||||||
|
"cannot_remove_self": "Cannot remove yourself",
|
||||||
|
"cannot_remove_last_manager": "Cannot remove the last manager",
|
||||||
|
"already_member": "User is already a workspace member"
|
||||||
|
},
|
||||||
|
"invites": {
|
||||||
|
"created": "Invite created successfully",
|
||||||
|
"list_retrieved": "Invite list retrieved successfully",
|
||||||
|
"validated": "Invite validated successfully",
|
||||||
|
"revoked": "Invite revoked successfully",
|
||||||
|
"accepted": "Invite accepted",
|
||||||
|
"not_found": "Invite not found",
|
||||||
|
"expired": "Invite has expired",
|
||||||
|
"already_used": "Invite has already been used",
|
||||||
|
"invalid_token": "Invalid invite token",
|
||||||
|
"email_required": "Email address is required",
|
||||||
|
"invalid_email": "Invalid email address format"
|
||||||
|
},
|
||||||
|
"storage": {
|
||||||
|
"type_retrieved": "Storage type retrieved successfully",
|
||||||
|
"type_updated": "Storage type updated successfully",
|
||||||
|
"invalid_type": "Invalid storage type"
|
||||||
|
},
|
||||||
|
"models": {
|
||||||
|
"config_retrieved": "Model configuration retrieved successfully",
|
||||||
|
"config_updated": "Model configuration updated successfully",
|
||||||
|
"invalid_config": "Invalid model configuration"
|
||||||
|
}
|
||||||
|
}
|
||||||
26
api/app/locales/zh/README.md
Normal file
26
api/app/locales/zh/README.md
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# 中文翻译文件
|
||||||
|
|
||||||
|
此目录包含中文(简体)的翻译文件。
|
||||||
|
|
||||||
|
## 文件结构
|
||||||
|
|
||||||
|
- `common.json` - 通用翻译(成功消息、操作、验证)
|
||||||
|
- `auth.json` - 认证模块翻译
|
||||||
|
- `workspace.json` - 工作空间模块翻译
|
||||||
|
- `tenant.json` - 租户模块翻译
|
||||||
|
- `errors.json` - 错误消息翻译
|
||||||
|
- `enums.json` - 枚举值翻译
|
||||||
|
|
||||||
|
## 翻译文件格式
|
||||||
|
|
||||||
|
所有翻译文件使用 JSON 格式,支持嵌套结构。
|
||||||
|
|
||||||
|
示例:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": {
|
||||||
|
"created": "创建成功",
|
||||||
|
"updated": "更新成功"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
55
api/app/locales/zh/auth.json
Normal file
55
api/app/locales/zh/auth.json
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
{
|
||||||
|
"login": {
|
||||||
|
"success": "登录成功",
|
||||||
|
"failed": "登录失败",
|
||||||
|
"invalid_credentials": "用户名或密码错误",
|
||||||
|
"account_locked": "账户已被锁定",
|
||||||
|
"account_disabled": "账户已被禁用"
|
||||||
|
},
|
||||||
|
"logout": {
|
||||||
|
"success": "登出成功",
|
||||||
|
"failed": "登出失败"
|
||||||
|
},
|
||||||
|
"token": {
|
||||||
|
"refresh_success": "token刷新成功",
|
||||||
|
"invalid": "无效的token",
|
||||||
|
"expired": "token已过期",
|
||||||
|
"blacklisted": "token已失效",
|
||||||
|
"invalid_refresh_token": "无效的refresh token",
|
||||||
|
"refresh_token_blacklisted": "Refresh token已失效"
|
||||||
|
},
|
||||||
|
"registration": {
|
||||||
|
"success": "注册成功",
|
||||||
|
"failed": "注册失败",
|
||||||
|
"email_exists": "邮箱已被使用",
|
||||||
|
"username_exists": "用户名已被使用"
|
||||||
|
},
|
||||||
|
"password": {
|
||||||
|
"reset_success": "密码重置成功",
|
||||||
|
"reset_failed": "密码重置失败",
|
||||||
|
"change_success": "密码修改成功",
|
||||||
|
"change_failed": "密码修改失败",
|
||||||
|
"incorrect": "密码错误",
|
||||||
|
"too_weak": "密码强度不够",
|
||||||
|
"mismatch": "两次输入的密码不一致"
|
||||||
|
},
|
||||||
|
"invite": {
|
||||||
|
"invalid": "邀请码无效或已过期",
|
||||||
|
"email_mismatch": "邀请邮箱与登录邮箱不匹配",
|
||||||
|
"accept_success": "接受邀请成功",
|
||||||
|
"accept_failed": "接受邀请失败",
|
||||||
|
"password_verification_failed": "接受邀请失败,密码验证错误",
|
||||||
|
"bind_workspace_success": "绑定工作空间成功",
|
||||||
|
"bind_workspace_failed": "绑定工作空间失败"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"not_found": "用户不存在",
|
||||||
|
"already_exists": "用户已存在",
|
||||||
|
"created_with_invite": "用户创建成功并已加入工作空间"
|
||||||
|
},
|
||||||
|
"session": {
|
||||||
|
"expired": "会话已过期,请重新登录",
|
||||||
|
"invalid": "无效的会话",
|
||||||
|
"single_session_enabled": "单点登录已启用,其他设备的登录将被注销"
|
||||||
|
}
|
||||||
|
}
|
||||||
132
api/app/locales/zh/common.json
Normal file
132
api/app/locales/zh/common.json
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
{
|
||||||
|
"success": {
|
||||||
|
"created": "创建成功",
|
||||||
|
"updated": "更新成功",
|
||||||
|
"deleted": "删除成功",
|
||||||
|
"retrieved": "获取成功",
|
||||||
|
"saved": "保存成功",
|
||||||
|
"uploaded": "上传成功",
|
||||||
|
"downloaded": "下载成功",
|
||||||
|
"sent": "发送成功",
|
||||||
|
"completed": "完成",
|
||||||
|
"confirmed": "已确认",
|
||||||
|
"cancelled": "已取消",
|
||||||
|
"archived": "已归档",
|
||||||
|
"restored": "已恢复"
|
||||||
|
},
|
||||||
|
"actions": {
|
||||||
|
"create": "创建",
|
||||||
|
"update": "更新",
|
||||||
|
"delete": "删除",
|
||||||
|
"view": "查看",
|
||||||
|
"edit": "编辑",
|
||||||
|
"save": "保存",
|
||||||
|
"cancel": "取消",
|
||||||
|
"confirm": "确认",
|
||||||
|
"submit": "提交",
|
||||||
|
"upload": "上传",
|
||||||
|
"download": "下载",
|
||||||
|
"send": "发送",
|
||||||
|
"search": "搜索",
|
||||||
|
"filter": "筛选",
|
||||||
|
"sort": "排序",
|
||||||
|
"export": "导出",
|
||||||
|
"import": "导入",
|
||||||
|
"refresh": "刷新",
|
||||||
|
"reset": "重置",
|
||||||
|
"back": "返回",
|
||||||
|
"next": "下一步",
|
||||||
|
"previous": "上一步",
|
||||||
|
"finish": "完成",
|
||||||
|
"close": "关闭",
|
||||||
|
"open": "打开",
|
||||||
|
"archive": "归档",
|
||||||
|
"restore": "恢复",
|
||||||
|
"duplicate": "复制",
|
||||||
|
"share": "分享",
|
||||||
|
"invite": "邀请",
|
||||||
|
"remove": "移除",
|
||||||
|
"add": "添加",
|
||||||
|
"select": "选择",
|
||||||
|
"clear": "清除"
|
||||||
|
},
|
||||||
|
"validation": {
|
||||||
|
"required": "{field}不能为空",
|
||||||
|
"invalid_format": "{field}格式不正确",
|
||||||
|
"too_long": "{field}长度不能超过{max}个字符",
|
||||||
|
"too_short": "{field}长度不能少于{min}个字符",
|
||||||
|
"invalid_email": "邮箱格式不正确",
|
||||||
|
"invalid_url": "URL格式不正确",
|
||||||
|
"invalid_phone": "手机号格式不正确",
|
||||||
|
"invalid_date": "日期格式不正确",
|
||||||
|
"invalid_number": "必须是有效的数字",
|
||||||
|
"out_of_range": "{field}必须在{min}和{max}之间",
|
||||||
|
"already_exists": "{field}已存在",
|
||||||
|
"not_found": "{field}不存在",
|
||||||
|
"invalid_value": "{field}的值无效",
|
||||||
|
"password_mismatch": "两次输入的密码不一致",
|
||||||
|
"weak_password": "密码强度不够,请使用更复杂的密码",
|
||||||
|
"invalid_credentials": "用户名或密码错误",
|
||||||
|
"unauthorized": "未授权访问",
|
||||||
|
"forbidden": "没有权限执行此操作",
|
||||||
|
"expired": "{field}已过期",
|
||||||
|
"invalid_token": "无效的令牌",
|
||||||
|
"file_too_large": "文件大小不能超过{max}",
|
||||||
|
"invalid_file_type": "不支持的文件类型",
|
||||||
|
"duplicate": "重复的{field}"
|
||||||
|
},
|
||||||
|
"status": {
|
||||||
|
"active": "活跃",
|
||||||
|
"inactive": "未激活",
|
||||||
|
"pending": "待处理",
|
||||||
|
"processing": "处理中",
|
||||||
|
"completed": "已完成",
|
||||||
|
"failed": "失败",
|
||||||
|
"cancelled": "已取消",
|
||||||
|
"archived": "已归档",
|
||||||
|
"deleted": "已删除",
|
||||||
|
"draft": "草稿",
|
||||||
|
"published": "已发布",
|
||||||
|
"suspended": "已暂停",
|
||||||
|
"expired": "已过期"
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"loading": "加载中...",
|
||||||
|
"saving": "保存中...",
|
||||||
|
"processing": "处理中...",
|
||||||
|
"uploading": "上传中...",
|
||||||
|
"downloading": "下载中...",
|
||||||
|
"no_data": "暂无数据",
|
||||||
|
"no_results": "没有找到结果",
|
||||||
|
"confirm_delete": "确定要删除吗?此操作不可恢复。",
|
||||||
|
"confirm_action": "确定要执行此操作吗?",
|
||||||
|
"operation_success": "操作成功",
|
||||||
|
"operation_failed": "操作失败",
|
||||||
|
"please_wait": "请稍候...",
|
||||||
|
"try_again": "请重试",
|
||||||
|
"contact_support": "如果问题持续,请联系技术支持"
|
||||||
|
},
|
||||||
|
"pagination": {
|
||||||
|
"page": "第{page}页",
|
||||||
|
"of": "共{total}页",
|
||||||
|
"items": "共{total}条",
|
||||||
|
"per_page": "每页{count}条",
|
||||||
|
"showing": "显示第{from}到第{to}条,共{total}条",
|
||||||
|
"first": "首页",
|
||||||
|
"last": "末页",
|
||||||
|
"next": "下一页",
|
||||||
|
"previous": "上一页"
|
||||||
|
},
|
||||||
|
"time": {
|
||||||
|
"just_now": "刚刚",
|
||||||
|
"minutes_ago": "{count}分钟前",
|
||||||
|
"hours_ago": "{count}小时前",
|
||||||
|
"days_ago": "{count}天前",
|
||||||
|
"weeks_ago": "{count}周前",
|
||||||
|
"months_ago": "{count}个月前",
|
||||||
|
"years_ago": "{count}年前",
|
||||||
|
"today": "今天",
|
||||||
|
"yesterday": "昨天",
|
||||||
|
"tomorrow": "明天"
|
||||||
|
}
|
||||||
|
}
|
||||||
132
api/app/locales/zh/enums.json
Normal file
132
api/app/locales/zh/enums.json
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
{
|
||||||
|
"workspace_role": {
|
||||||
|
"owner": "所有者",
|
||||||
|
"manager": "管理员",
|
||||||
|
"member": "成员",
|
||||||
|
"guest": "访客"
|
||||||
|
},
|
||||||
|
"workspace_status": {
|
||||||
|
"active": "活跃",
|
||||||
|
"inactive": "未激活",
|
||||||
|
"archived": "已归档",
|
||||||
|
"suspended": "已暂停",
|
||||||
|
"deleted": "已删除"
|
||||||
|
},
|
||||||
|
"invite_status": {
|
||||||
|
"pending": "待处理",
|
||||||
|
"accepted": "已接受",
|
||||||
|
"rejected": "已拒绝",
|
||||||
|
"revoked": "已撤销",
|
||||||
|
"expired": "已过期"
|
||||||
|
},
|
||||||
|
"user_status": {
|
||||||
|
"active": "活跃",
|
||||||
|
"inactive": "未激活",
|
||||||
|
"suspended": "已暂停",
|
||||||
|
"deleted": "已删除",
|
||||||
|
"pending": "待激活"
|
||||||
|
},
|
||||||
|
"tenant_status": {
|
||||||
|
"active": "活跃",
|
||||||
|
"inactive": "未激活",
|
||||||
|
"suspended": "已暂停",
|
||||||
|
"expired": "已过期",
|
||||||
|
"trial": "试用中"
|
||||||
|
},
|
||||||
|
"file_status": {
|
||||||
|
"uploading": "上传中",
|
||||||
|
"processing": "处理中",
|
||||||
|
"completed": "已完成",
|
||||||
|
"failed": "失败",
|
||||||
|
"deleted": "已删除"
|
||||||
|
},
|
||||||
|
"task_status": {
|
||||||
|
"pending": "待处理",
|
||||||
|
"running": "运行中",
|
||||||
|
"completed": "已完成",
|
||||||
|
"failed": "失败",
|
||||||
|
"cancelled": "已取消",
|
||||||
|
"paused": "已暂停"
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"low": "低",
|
||||||
|
"medium": "中",
|
||||||
|
"high": "高",
|
||||||
|
"urgent": "紧急"
|
||||||
|
},
|
||||||
|
"visibility": {
|
||||||
|
"public": "公开",
|
||||||
|
"private": "私有",
|
||||||
|
"internal": "内部",
|
||||||
|
"shared": "共享"
|
||||||
|
},
|
||||||
|
"permission": {
|
||||||
|
"read": "读取",
|
||||||
|
"write": "写入",
|
||||||
|
"delete": "删除",
|
||||||
|
"admin": "管理",
|
||||||
|
"owner": "所有者"
|
||||||
|
},
|
||||||
|
"notification_type": {
|
||||||
|
"info": "信息",
|
||||||
|
"warning": "警告",
|
||||||
|
"error": "错误",
|
||||||
|
"success": "成功"
|
||||||
|
},
|
||||||
|
"language": {
|
||||||
|
"zh": "中文(简体)",
|
||||||
|
"en": "English",
|
||||||
|
"ja": "日本語",
|
||||||
|
"ko": "한국어",
|
||||||
|
"fr": "Français",
|
||||||
|
"de": "Deutsch",
|
||||||
|
"es": "Español"
|
||||||
|
},
|
||||||
|
"timezone": {
|
||||||
|
"utc": "UTC",
|
||||||
|
"asia_shanghai": "亚洲/上海",
|
||||||
|
"asia_tokyo": "亚洲/东京",
|
||||||
|
"america_new_york": "美洲/纽约",
|
||||||
|
"europe_london": "欧洲/伦敦"
|
||||||
|
},
|
||||||
|
"date_format": {
|
||||||
|
"short": "短日期",
|
||||||
|
"medium": "中等日期",
|
||||||
|
"long": "长日期",
|
||||||
|
"full": "完整日期"
|
||||||
|
},
|
||||||
|
"sort_order": {
|
||||||
|
"asc": "升序",
|
||||||
|
"desc": "降序"
|
||||||
|
},
|
||||||
|
"filter_operator": {
|
||||||
|
"equals": "等于",
|
||||||
|
"not_equals": "不等于",
|
||||||
|
"contains": "包含",
|
||||||
|
"not_contains": "不包含",
|
||||||
|
"starts_with": "开始于",
|
||||||
|
"ends_with": "结束于",
|
||||||
|
"greater_than": "大于",
|
||||||
|
"less_than": "小于",
|
||||||
|
"greater_or_equal": "大于等于",
|
||||||
|
"less_or_equal": "小于等于",
|
||||||
|
"in": "在列表中",
|
||||||
|
"not_in": "不在列表中",
|
||||||
|
"is_null": "为空",
|
||||||
|
"is_not_null": "不为空"
|
||||||
|
},
|
||||||
|
"log_level": {
|
||||||
|
"debug": "调试",
|
||||||
|
"info": "信息",
|
||||||
|
"warning": "警告",
|
||||||
|
"error": "错误",
|
||||||
|
"critical": "严重"
|
||||||
|
},
|
||||||
|
"api_method": {
|
||||||
|
"get": "GET",
|
||||||
|
"post": "POST",
|
||||||
|
"put": "PUT",
|
||||||
|
"patch": "PATCH",
|
||||||
|
"delete": "DELETE"
|
||||||
|
}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user