Compare commits
341 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3f87c64e83 | ||
|
|
1795364f5f | ||
|
|
e69fbb2f97 | ||
|
|
32b40fc6bf | ||
|
|
f039ea7f56 | ||
|
|
41334f5f1e | ||
|
|
2103410694 | ||
|
|
2143d94e83 | ||
|
|
9ae2612945 | ||
|
|
e381449aec | ||
|
|
bacffc94d9 | ||
|
|
7044f705e7 | ||
|
|
6db4fe28a7 | ||
|
|
f966176694 | ||
|
|
bd24de4577 | ||
|
|
f6ad0aab94 | ||
|
|
371fdeb948 | ||
|
|
f7a0af75c4 | ||
|
|
26abf7b586 | ||
|
|
3ca3e8e023 | ||
|
|
3bd374495b | ||
|
|
b26f60ee8d | ||
|
|
df681eaf22 | ||
|
|
01458ac111 | ||
|
|
e3074b833f | ||
|
|
1097d699f8 | ||
|
|
55b4e0ebd3 | ||
|
|
0011a8ce9f | ||
|
|
100bf4fa49 | ||
|
|
6da5b81311 | ||
|
|
787adf5423 | ||
|
|
01b500e7d1 | ||
|
|
e64603ea27 | ||
|
|
4219e12cc0 | ||
|
|
c86ccf0931 | ||
|
|
d4571fb75b | ||
|
|
ec2369c397 | ||
|
|
6ebd48408b | ||
|
|
7e7b54593c | ||
|
|
f93c9f5cd2 | ||
|
|
a810fbe008 | ||
|
|
600a914bd9 | ||
|
|
b1688950c4 | ||
|
|
d8e3f9b7b8 | ||
|
|
08d55e4463 | ||
|
|
55e2baa865 | ||
|
|
55174dc707 | ||
|
|
d57e3b3f64 | ||
|
|
aa42cd0aec | ||
|
|
ac6d9a39ec | ||
|
|
9b07775395 | ||
|
|
936fb8b8a1 | ||
|
|
6c8318b696 | ||
|
|
d554079e2b | ||
|
|
37464a101e | ||
|
|
8326db1143 | ||
|
|
992e41e0a0 | ||
|
|
076e95d5c2 | ||
|
|
dfd79e5972 | ||
|
|
b16c9d53ef | ||
|
|
5fe85fb457 | ||
|
|
b45f470310 | ||
|
|
0ecda33ab8 | ||
|
|
7fcfca455a | ||
|
|
6a32154b8f | ||
|
|
132206677f | ||
|
|
30a8775548 | ||
|
|
045bc9aefc | ||
|
|
d5c46574cc | ||
|
|
37fea09403 | ||
|
|
063e8fae43 | ||
|
|
184c4fbf7f | ||
|
|
ea96830758 | ||
|
|
d2edbc738d | ||
|
|
03bc8c8280 | ||
|
|
68908213da | ||
|
|
b3d5add89a | ||
|
|
7fe2d8fbe1 | ||
|
|
bca03f1365 | ||
|
|
c89f55f0bd | ||
|
|
dcdc899528 | ||
|
|
b57aa55001 | ||
|
|
af596a09cf | ||
|
|
6849c620b8 | ||
|
|
12598f0dca | ||
|
|
3f4ce4f16f | ||
|
|
4aaf0d8d5c | ||
|
|
65db056e09 | ||
|
|
232cef7cb9 | ||
|
|
73a432879a | ||
|
|
09afec17f9 | ||
|
|
ac47ab3deb | ||
|
|
8b3d7c168a | ||
|
|
60e8eb63ac | ||
|
|
4f29cd24b8 | ||
|
|
ba73ade2a0 | ||
|
|
7559305fc9 | ||
|
|
6985f553f9 | ||
|
|
8fc15df6d0 | ||
|
|
eb8160a5af | ||
|
|
16cf6eee9b | ||
|
|
320f684354 | ||
|
|
12062a5440 | ||
|
|
4423a9d979 | ||
|
|
1eb44defb6 | ||
|
|
e253fba2e9 | ||
|
|
c05d95924f | ||
|
|
2db583d62d | ||
|
|
59d8e1bf9f | ||
|
|
1001344c27 | ||
|
|
8a0e2da03f | ||
|
|
f58886be6f | ||
|
|
3c1d3b4d6a | ||
|
|
bbba995ff7 | ||
|
|
0033b5be80 | ||
|
|
87d53fb9b7 | ||
|
|
157031f23e | ||
|
|
8a37869489 | ||
|
|
5c10f11681 | ||
|
|
7b72bf0cd0 | ||
|
|
be29666916 | ||
|
|
8d4c5b5b33 | ||
|
|
52260f469a | ||
|
|
c566d22836 | ||
|
|
75f59a86c8 | ||
|
|
1eaf12446f | ||
|
|
efdd42426e | ||
|
|
62c557deae | ||
|
|
db1da4a61a | ||
|
|
db46c186aa | ||
|
|
677a603835 | ||
|
|
447d8790ad | ||
|
|
7a78f15a90 | ||
|
|
c1941809e9 | ||
|
|
623aaf8a0e | ||
|
|
7b3bf41120 | ||
|
|
0c3960eb0b | ||
|
|
fe3c31c08c | ||
|
|
94600cdbfc | ||
|
|
4e7ab3d7e3 | ||
|
|
47b25d7a26 | ||
|
|
0249666fa4 | ||
|
|
2e8504ce2f | ||
|
|
aca7d25001 | ||
|
|
2444309bc2 | ||
|
|
97c5a78d48 | ||
|
|
effdb88455 | ||
|
|
2f0ce3852e | ||
|
|
5475496399 | ||
|
|
b569d77a23 | ||
|
|
dfa7a2d4cf | ||
|
|
169e01276d | ||
|
|
07e698265e | ||
|
|
0632d7611f | ||
|
|
b3f39eedac | ||
|
|
46ed7e38bf | ||
|
|
8c5199d32d | ||
|
|
36ed833d64 | ||
|
|
47969ce61e | ||
|
|
06731e2026 | ||
|
|
123347169d | ||
|
|
f9101a744c | ||
|
|
97eb33000f | ||
|
|
60231ec88d | ||
|
|
3364374dc6 | ||
|
|
a3cf773e75 | ||
|
|
4092d5fbaf | ||
|
|
07e9fde9e8 | ||
|
|
9b4613630b | ||
|
|
f125d11b6d | ||
|
|
657d48a5f9 | ||
|
|
3735bdde19 | ||
|
|
3f906d81cb | ||
|
|
7c1f622797 | ||
|
|
cfe696ae8d | ||
|
|
021c50a8f2 | ||
|
|
95745ba869 | ||
|
|
adfae54816 | ||
|
|
10ed093eb8 | ||
|
|
c96df6bfa5 | ||
|
|
0126d18525 | ||
|
|
9e6e8f50f8 | ||
|
|
7e0b31626f | ||
|
|
1d9e249a77 | ||
|
|
88b89ef315 | ||
|
|
62b7925cb0 | ||
|
|
cc1528f550 | ||
|
|
1c8a83140b | ||
|
|
34276e2066 | ||
|
|
71abd16ae7 | ||
|
|
918e7285c4 | ||
|
|
056d422c71 | ||
|
|
5ee54f4e0e | ||
|
|
260c75e70c | ||
|
|
2d7401922f | ||
|
|
8c7a1348cf | ||
|
|
24fbdbd716 | ||
|
|
aad8f0e36b | ||
|
|
15cad44f08 | ||
|
|
0271454671 | ||
|
|
d0ddf288ca | ||
|
|
bc250ac377 | ||
|
|
7922fc3b0e | ||
|
|
161da723b9 | ||
|
|
514c19a247 | ||
|
|
41550d4a41 | ||
|
|
33cc3c1c3f | ||
|
|
7d15182202 | ||
|
|
8f0a1d9c6e | ||
|
|
72b5e5cf8e | ||
|
|
62aba2dd38 | ||
|
|
cdd6b80089 | ||
|
|
333836f5e7 | ||
|
|
a2dfda3471 | ||
|
|
2d28b4b05c | ||
|
|
87f9bcc6a3 | ||
|
|
48aca996ff | ||
|
|
c8c7e9b304 | ||
|
|
97ff023995 | ||
|
|
e273a336f8 | ||
|
|
34f0c3b90c | ||
|
|
7c2902d2b8 | ||
|
|
8e41afdffc | ||
|
|
7268886294 | ||
|
|
cbae900866 | ||
|
|
ffff138a6f | ||
|
|
88c95db8d0 | ||
|
|
56e657a0bb | ||
|
|
bc36b79105 | ||
|
|
5694bc0230 | ||
|
|
36130031f9 | ||
|
|
b8f1095f53 | ||
|
|
442fa09533 | ||
|
|
42ef2efbc8 | ||
|
|
ead3080b2b | ||
|
|
c6ea31c296 | ||
|
|
21eae29bb7 | ||
|
|
406740b524 | ||
|
|
9d30bc4062 | ||
|
|
fad91b64ab | ||
|
|
2132e71a81 | ||
|
|
bd8a451879 | ||
|
|
24dafa7359 | ||
|
|
3b5df793fb | ||
|
|
da835b6138 | ||
|
|
7e650d86a5 | ||
|
|
308e28cecc | ||
|
|
9a3c74fb64 | ||
|
|
f571f0688a | ||
|
|
1e9c32a102 | ||
|
|
8c69199689 | ||
|
|
3efb3e8a35 | ||
|
|
cfcb278406 | ||
|
|
9e195ea63b | ||
|
|
dc0d34c281 | ||
|
|
72076c218f | ||
|
|
151fd3b950 | ||
|
|
2d484fcb30 | ||
|
|
6e0407f404 | ||
|
|
8670aaba1e | ||
|
|
f27de7df35 | ||
|
|
63fa4dc8ec | ||
|
|
a191e32f71 | ||
|
|
9a38e8a4a0 | ||
|
|
6194222289 | ||
|
|
0d077eaeb7 | ||
|
|
b2c7a9a005 | ||
|
|
be01f1869e | ||
|
|
9f2b6390b0 | ||
|
|
e196f86e30 | ||
|
|
ec41d45234 | ||
|
|
567d1ba18b | ||
|
|
df8706983b | ||
|
|
8697498b32 | ||
|
|
af917c538a | ||
|
|
034e97dfa6 | ||
|
|
5e1e5f68e1 | ||
|
|
fb76f765cc | ||
|
|
7a3f57261d | ||
|
|
a1a460625d | ||
|
|
3f42ea2c61 | ||
|
|
940c594066 | ||
|
|
5e47fc45ab | ||
|
|
b471d56a86 | ||
|
|
61f8029205 | ||
|
|
e2f047d035 | ||
|
|
1aff4eda67 | ||
|
|
a6c5c44ed8 | ||
|
|
3f389d685a | ||
|
|
5d5351f0bc | ||
|
|
1224802ac6 | ||
|
|
e919f89caf | ||
|
|
bb8e7a68ea | ||
|
|
48f95e0ea4 | ||
|
|
931e9bcf0d | ||
|
|
67a3351c4c | ||
|
|
dfe5eeed7b | ||
|
|
3464573f17 | ||
|
|
9cf49c9c75 | ||
|
|
4e837cb90c | ||
|
|
e4fb58496b | ||
|
|
15a254c0cd | ||
|
|
d62746fc8c | ||
|
|
4b8b6fe407 | ||
|
|
6754834eb3 | ||
|
|
be98db561d | ||
|
|
574d0afc72 | ||
|
|
31c8ad611c | ||
|
|
b23730388d | ||
|
|
36cb0a12ad | ||
|
|
5439eacf2d | ||
|
|
2687c3b80e | ||
|
|
fa009327ad | ||
|
|
838bd46e83 | ||
|
|
ccc2009aa8 | ||
|
|
d9aba92314 | ||
|
|
696b0475a8 | ||
|
|
e7370489e8 | ||
|
|
f1503b2238 | ||
|
|
cd4661e878 | ||
|
|
22151eb49b | ||
|
|
d0354345f6 | ||
|
|
b1e61eb1e4 | ||
|
|
36e0ed15b6 | ||
|
|
504d87b0b0 | ||
|
|
cfb7a40841 | ||
|
|
8267761890 | ||
|
|
a01911ba5f | ||
|
|
7347f9104c | ||
|
|
9206c7642a | ||
|
|
d1b4f2b6c2 | ||
|
|
cca3900678 | ||
|
|
4fe32b7dbc | ||
|
|
42b59a644d | ||
|
|
d9fa9039bb | ||
|
|
f3da8956d9 | ||
|
|
b1147d77af | ||
|
|
66bc2fb41f | ||
|
|
4e538a6df8 | ||
|
|
9c3e0b5541 | ||
|
|
33bfe33eb3 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,6 +21,7 @@ examples/
|
||||
|
||||
# Temporary outputs
|
||||
.DS_Store
|
||||
.hypothesis/
|
||||
time.log
|
||||
celerybeat-schedule.db
|
||||
search_results.json
|
||||
|
||||
28618
api/General_purpose_entity.ttl
Normal file
28618
api/General_purpose_entity.ttl
Normal file
File diff suppressed because it is too large
Load Diff
@@ -3,9 +3,14 @@ import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.config import settings
|
||||
from celery import Celery
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# macOS fork() safety - must be set before any Celery initialization
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 创建 Celery 应用实例
|
||||
# broker: 任务队列(使用 Redis DB 0)
|
||||
# backend: 结果存储(使用 Redis DB 10)
|
||||
@@ -63,15 +68,21 @@ celery_app.conf.update(
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Long-term storage tasks → memory_tasks queue (batched write strategies)
|
||||
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'periodic_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -81,10 +92,11 @@ celery_app.autodiscover_tasks(['app'])
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# 这个30秒的设计不合理
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
|
||||
# 构建定时任务配置
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
@@ -105,7 +117,7 @@ beat_schedule_config = {
|
||||
},
|
||||
}
|
||||
|
||||
# 如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
#如果配置了默认工作空间ID,则添加记忆总量统计任务
|
||||
if settings.DEFAULT_WORKSPACE_ID:
|
||||
beat_schedule_config["write-total-memory"] = {
|
||||
"task": "app.controllers.memory_storage_controller.search_all",
|
||||
|
||||
@@ -24,9 +24,11 @@ from . import (
|
||||
memory_episodic_controller,
|
||||
memory_explicit_controller,
|
||||
memory_forget_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_reflection_controller,
|
||||
memory_short_term_controller,
|
||||
memory_storage_controller,
|
||||
memory_working_controller,
|
||||
model_controller,
|
||||
multi_agent_controller,
|
||||
prompt_optimizer_controller,
|
||||
@@ -39,12 +41,9 @@ from . import (
|
||||
upload_controller,
|
||||
user_controller,
|
||||
user_memory_controllers,
|
||||
workflow_controller,
|
||||
workspace_controller,
|
||||
memory_forget_controller,
|
||||
home_page_controller,
|
||||
memory_perceptual_controller,
|
||||
memory_working_controller,
|
||||
ontology_controller,
|
||||
skill_controller
|
||||
)
|
||||
|
||||
# 创建管理端 API 路由器
|
||||
@@ -77,7 +76,6 @@ manager_router.include_router(release_share_controller.router)
|
||||
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
|
||||
manager_router.include_router(memory_dashboard_controller.router)
|
||||
manager_router.include_router(multi_agent_controller.router)
|
||||
manager_router.include_router(workflow_controller.router)
|
||||
manager_router.include_router(emotion_controller.router)
|
||||
manager_router.include_router(emotion_config_controller.router)
|
||||
manager_router.include_router(prompt_optimizer_controller.router)
|
||||
@@ -90,5 +88,7 @@ manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
manager_router.include_router(ontology_controller.router)
|
||||
manager_router.include_router(skill_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -22,6 +22,7 @@ from app.services import app_service, workspace_service
|
||||
from app.services.agent_config_helper import enrich_agent_config
|
||||
from app.services.app_service import AppService
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["Apps"])
|
||||
logger = get_business_logger()
|
||||
@@ -454,7 +455,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,7 +477,8 @@ async def draft_run(
|
||||
"app_id": str(app_id),
|
||||
"message_length": len(payload.message),
|
||||
"has_conversation_id": bool(payload.conversation_id),
|
||||
"has_variables": bool(payload.variables)
|
||||
"has_variables": bool(payload.variables),
|
||||
"has_files": bool(payload.files)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -490,7 +493,8 @@ async def draft_run(
|
||||
user_id=payload.user_id or str(current_user.id),
|
||||
variables=payload.variables,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -798,7 +802,8 @@ async def draft_run_compare(
|
||||
web_search=True,
|
||||
memory=True,
|
||||
parallel=payload.parallel,
|
||||
timeout=payload.timeout or 60
|
||||
timeout=payload.timeout or 60,
|
||||
files=payload.files
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -901,15 +906,46 @@ def get_app_statistics(
|
||||
- total_tokens: 总token消耗
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
|
||||
result = stats_service.get_app_statistics(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
|
||||
return success(data=result)
|
||||
|
||||
|
||||
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
|
||||
@cur_workspace_access_guard()
|
||||
def get_workspace_api_statistics(
|
||||
start_date: int,
|
||||
end_date: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""获取工作空间API调用统计
|
||||
|
||||
Args:
|
||||
start_date: 开始时间戳(毫秒)
|
||||
end_date: 结束时间戳(毫秒)
|
||||
|
||||
Returns:
|
||||
每日统计数据列表,每项包含:
|
||||
- date: 日期
|
||||
- total_calls: 当日总调用次数
|
||||
- app_calls: 当日应用调用次数
|
||||
- service_calls: 当日服务调用次数
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
stats_service = AppStatisticsService(db)
|
||||
|
||||
result = stats_service.get_workspace_api_statistics(
|
||||
workspace_id=workspace_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return success(data=result)
|
||||
|
||||
@@ -11,6 +11,7 @@ Routes:
|
||||
"""
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
@@ -45,11 +46,14 @@ emotion_service = EmotionAnalyticsService()
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪标签统计",
|
||||
extra={
|
||||
@@ -57,7 +61,8 @@ async def get_emotion_tags(
|
||||
"emotion_type": request.emotion_type,
|
||||
"start_date": request.start_date,
|
||||
"end_date": request.end_date,
|
||||
"limit": request.limit
|
||||
"limit": request.limit,
|
||||
"language_type": language
|
||||
}
|
||||
)
|
||||
|
||||
@@ -67,7 +72,8 @@ async def get_emotion_tags(
|
||||
emotion_type=request.emotion_type,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
limit=request.limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
@@ -97,11 +103,14 @@ async def get_emotion_tags(
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪词云数据",
|
||||
extra={
|
||||
@@ -144,11 +153,14 @@ async def get_emotion_wordcloud(
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 验证时间范围参数
|
||||
if request.time_range not in ["7d", "30d", "90d"]:
|
||||
raise HTTPException(
|
||||
@@ -174,7 +186,7 @@ async def get_emotion_health(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
"end_user_id": request.end_user_id,
|
||||
"health_score": data.get("health_score", 0),
|
||||
"health_score": data.get("health_score") or 0,
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
@@ -199,7 +211,7 @@ async def get_emotion_health(
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -214,6 +226,9 @@ async def get_emotion_suggestions(
|
||||
缓存的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
|
||||
extra={
|
||||
@@ -229,16 +244,46 @@ async def get_emotion_suggestions(
|
||||
)
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
# 缓存不存在或已过期,自动触发生成
|
||||
api_logger.info(
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期",
|
||||
f"用户 {request.end_user_id} 的建议缓存不存在或已过期,自动生成新建议",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请右上角刷新生成新建议",
|
||||
""
|
||||
)
|
||||
try:
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
except (ValueError, KeyError) as gen_e:
|
||||
# 预期内的业务异常:配置缺失、数据格式问题等
|
||||
api_logger.warning(
|
||||
f"自动生成建议失败(业务异常): {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.NOT_FOUND,
|
||||
f"自动生成建议失败: {str(gen_e)}",
|
||||
""
|
||||
)
|
||||
except Exception as gen_e:
|
||||
# 非预期异常:记录完整 traceback 便于排查
|
||||
api_logger.error(
|
||||
f"自动生成建议时发生未预期异常: {str(gen_e)}",
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"生成建议时发生内部错误: {str(gen_e)}"
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
@@ -265,7 +310,7 @@ async def get_emotion_suggestions(
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
@@ -280,6 +325,9 @@ async def generate_emotion_suggestions(
|
||||
新生成的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
extra={
|
||||
@@ -290,7 +338,8 @@ async def generate_emotion_suggestions(
|
||||
# 调用服务层生成建议
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
db=db,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 保存到缓存
|
||||
|
||||
@@ -29,7 +29,7 @@ from app.core.storage_exceptions import (
|
||||
StorageUploadError,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
@@ -143,6 +143,141 @@ async def upload_file(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/share/files", response_model=ApiResponse)
|
||||
async def upload_file_with_share_token(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend using share_token authentication.
|
||||
"""
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.models.app_model import App
|
||||
from app.models.workspace_model import Workspace
|
||||
|
||||
# Get share and release info from share_token
|
||||
service = ReleaseShareService(db)
|
||||
share_info = service.get_shared_release_info(share_token=share_data.share_token)
|
||||
|
||||
# Get share object to access app_id
|
||||
share = service.repo.get_by_share_token(share_data.share_token)
|
||||
if not share:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Shared app not found"
|
||||
)
|
||||
|
||||
# Get app to access workspace_id
|
||||
app = db.query(App).filter(
|
||||
App.id == share.app_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="App not found"
|
||||
)
|
||||
|
||||
# Get workspace to access tenant_id
|
||||
workspace = db.query(Workspace).filter(
|
||||
Workspace.id == app.workspace_id
|
||||
).first()
|
||||
|
||||
if not workspace:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Workspace not found"
|
||||
)
|
||||
|
||||
tenant_id = workspace.tenant_id
|
||||
workspace_id = app.workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, share_token={share_data.share_token}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed (share): {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
file_id: uuid.UUID,
|
||||
|
||||
@@ -9,13 +9,16 @@ from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.common import settings
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.yuque.client import YuqueAPIClient
|
||||
from app.core.rag.llm.chat_model import Base
|
||||
from app.core.rag.nlp import rag_tokenizer, search
|
||||
from app.core.rag.prompts.generator import graph_entity_types
|
||||
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import knowledge_model
|
||||
@@ -484,3 +487,99 @@ async def rebuild_knowledge_graph(
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = YuqueAPIClient(
|
||||
user_id=yuque_user_id,
|
||||
token=yuque_token
|
||||
)
|
||||
async with api_client as client:
|
||||
repos = await client.get_user_repos()
|
||||
if repos:
|
||||
return success(msg="Successfully auth yuque info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth yuque info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
api_client = FeishuAPIClient(
|
||||
app_id=feishu_app_id,
|
||||
app_secret=feishu_app_secret
|
||||
)
|
||||
async with api_client as client:
|
||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||
if files:
|
||||
return success(msg="Successfully auth feishu info")
|
||||
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"auth feishu info failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
|
||||
|
||||
try:
|
||||
# 1. Query knowledge base information from the database
|
||||
api_logger.debug(f"Query knowledge base: {knowledge_id}")
|
||||
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
|
||||
if not db_knowledge:
|
||||
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The knowledge base does not exist or access is denied"
|
||||
)
|
||||
|
||||
# 2. sync knowledge
|
||||
# from app.tasks import sync_knowledge_for_kb
|
||||
# sync_knowledge_for_kb(kb_id)
|
||||
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
|
||||
result = {
|
||||
"task_id": task.id
|
||||
}
|
||||
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.rag.llm.cv_model import QWenCV
|
||||
from app.core.response_utils import fail, success
|
||||
@@ -118,6 +119,7 @@ async def download_log(
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -126,13 +128,17 @@ async def write_server(
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Response with write operation status
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -169,7 +175,8 @@ async def write_server(
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
user_rag_memory_id,
|
||||
language
|
||||
)
|
||||
|
||||
return success(data=result, msg="写入成功")
|
||||
@@ -188,6 +195,7 @@ async def write_server(
|
||||
@cur_workspace_access_guard()
|
||||
async def write_server_async(
|
||||
user_input: Write_UserInput,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
@@ -196,14 +204,18 @@ async def write_server_async(
|
||||
|
||||
Args:
|
||||
user_input: Write request containing message and end_user_id
|
||||
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
|
||||
|
||||
Returns:
|
||||
Task ID for tracking async operation
|
||||
Use GET /memory/write_result/{task_id} to check task status and get result
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
config_id = user_input.config_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
|
||||
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
|
||||
|
||||
# 获取 storage_type,如果为 None 则使用默认值
|
||||
storage_type = workspace_service.get_workspace_storage_type(
|
||||
@@ -228,7 +240,7 @@ async def write_server_async(
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -653,7 +665,6 @@ async def get_knowledge_type_stats_api(
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
@@ -661,28 +672,18 @@ async def get_hot_memory_tags_by_user_api(
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
model_id=model_id,
|
||||
limit=limit
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
|
||||
@@ -3,9 +3,10 @@
|
||||
包含情景记忆总览和详情查询接口
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.dependencies import get_current_user
|
||||
@@ -14,6 +15,7 @@ from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_episodic_schema import (
|
||||
EpisodicMemoryOverviewRequest,
|
||||
EpisodicMemoryDetailsRequest,
|
||||
translate_episodic_type,
|
||||
)
|
||||
from app.services.memory_episodic_service import memory_episodic_service
|
||||
|
||||
@@ -84,6 +86,7 @@ async def get_episodic_memory_overview_api(
|
||||
@router.post("/details", response_model=ApiResponse)
|
||||
async def get_episodic_memory_details_api(
|
||||
request: EpisodicMemoryDetailsRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
@@ -111,6 +114,11 @@ async def get_episodic_memory_details_api(
|
||||
summary_id=request.summary_id
|
||||
)
|
||||
|
||||
# 根据语言参数翻译 episodic_type
|
||||
language = get_language_from_header(language_type)
|
||||
if "episodic_type" in result:
|
||||
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
|
||||
|
||||
api_logger.info(
|
||||
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
@@ -51,7 +52,6 @@ async def save_reflection_config(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="缺少必需参数: config_id"
|
||||
)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
memory_config = MemoryConfigRepository.update_reflection_config(
|
||||
@@ -102,51 +102,71 @@ async def start_workspace_reflection(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
"""启动工作空间中所有匹配应用的反思功能"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
reflection_service = MemoryReflectionService(db)
|
||||
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 启动workspace反思,workspace_id: {workspace_id}")
|
||||
|
||||
service = WorkspaceAppService(db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as query_db:
|
||||
service = WorkspaceAppService(query_db)
|
||||
result = service.get_workspace_apps_detailed(workspace_id)
|
||||
|
||||
reflection_results = []
|
||||
|
||||
for data in result['apps_detailed_info']:
|
||||
if data['memory_configs'] == []:
|
||||
# 跳过没有配置的应用
|
||||
if not data['memory_configs']:
|
||||
api_logger.debug(f"应用 {data['id']} 没有memory_configs,跳过")
|
||||
continue
|
||||
|
||||
|
||||
releases = data['releases']
|
||||
memory_configs = data['memory_configs']
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, memory_configs, end_users):
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
|
||||
# 为每个配置和用户组合执行反思
|
||||
for config in memory_configs:
|
||||
config_id_str = str(config['config_id'])
|
||||
|
||||
# 找到匹配此配置的所有release
|
||||
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
|
||||
|
||||
if not matching_releases:
|
||||
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": base['app_id'],
|
||||
"config_id": config['config_id'],
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
|
||||
# 为每个用户执行反思 - 使用独立的数据库会话
|
||||
for user in end_users:
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config_id_str}")
|
||||
|
||||
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
|
||||
with get_db_context() as user_db:
|
||||
try:
|
||||
reflection_service = MemoryReflectionService(user_db)
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": reflection_result
|
||||
})
|
||||
except Exception as e:
|
||||
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
|
||||
reflection_results.append({
|
||||
"app_id": data['id'],
|
||||
"config_id": config_id_str,
|
||||
"end_user_id": user['id'],
|
||||
"reflection_result": {
|
||||
"status": "错误",
|
||||
"message": f"反思失败: {str(e)}"
|
||||
}
|
||||
})
|
||||
|
||||
return success(data=reflection_results, msg="反思配置成功")
|
||||
|
||||
@@ -199,11 +219,13 @@ async def start_reflection_configs(
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: UUID|int,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Activate the reflection function for all matching applications in the workspace"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
@@ -20,10 +21,13 @@ router = APIRouter(
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type:str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
# 获取短期记忆数据
|
||||
short_term=ShortService(end_user_id)
|
||||
short_result=short_term.get_short_databasets()
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
@@ -11,7 +15,6 @@ from app.models.user_model import User
|
||||
from app.schemas.memory_storage_schema import (
|
||||
ConfigKey,
|
||||
ConfigParamsCreate,
|
||||
ConfigParamsDelete,
|
||||
ConfigPilotRun,
|
||||
ConfigUpdate,
|
||||
ConfigUpdateExtracted,
|
||||
@@ -31,7 +34,7 @@ from app.services.memory_storage_service import (
|
||||
search_entity,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Header
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -72,68 +75,9 @@ async def get_storage_info(
|
||||
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
|
||||
|
||||
|
||||
# --- DB connection dependency ---
|
||||
_CONN: Optional[object] = None
|
||||
|
||||
|
||||
"""PostgreSQL 连接生成与管理(使用 psycopg2)。"""
|
||||
# 这个可以转移,可能是已经有的
|
||||
# PostgreSQL 数据库连接
|
||||
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
|
||||
host = os.getenv("DB_HOST")
|
||||
user = os.getenv("DB_USER")
|
||||
password = os.getenv("DB_PASSWORD")
|
||||
database = os.getenv("DB_NAME")
|
||||
port_str = os.getenv("DB_PORT")
|
||||
try:
|
||||
import psycopg2 # type: ignore
|
||||
port = int(port_str) if port_str else 5432
|
||||
conn = psycopg2.connect(
|
||||
host=host or "localhost",
|
||||
port=port,
|
||||
user=user,
|
||||
password=password,
|
||||
dbname=database,
|
||||
)
|
||||
# 设置自动提交,避免显式事务管理
|
||||
conn.autocommit = True
|
||||
# 设置会话时区为中国标准时间(Asia/Shanghai),便于直接以本地时区展示
|
||||
try:
|
||||
cur = conn.cursor()
|
||||
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
|
||||
cur.close()
|
||||
except Exception:
|
||||
# 时区设置失败不影响连接,仅记录但不抛出
|
||||
pass
|
||||
return conn
|
||||
except Exception as e:
|
||||
try:
|
||||
print(f"[PostgreSQL] 连接失败: {e}")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
|
||||
global _CONN
|
||||
if _CONN is None:
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN
|
||||
|
||||
|
||||
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
|
||||
"""Close and recreate the global DB connection."""
|
||||
global _CONN
|
||||
try:
|
||||
if _CONN:
|
||||
try:
|
||||
_CONN.close()
|
||||
except Exception:
|
||||
pass
|
||||
_CONN = _make_pgsql_conn()
|
||||
return _CONN is not None
|
||||
except Exception:
|
||||
_CONN = None
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
|
||||
@@ -141,7 +85,7 @@ def create_config(
|
||||
payload: ConfigParamsCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
@@ -163,9 +107,20 @@ def create_config(
|
||||
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
|
||||
def delete_config(
|
||||
config_id: UUID|int,
|
||||
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
"""删除记忆配置(带终端用户保护)
|
||||
|
||||
- 检查是否为默认配置,默认配置不允许删除
|
||||
- 检查是否有终端用户连接到该配置
|
||||
- 如果有连接且 force=False,返回警告
|
||||
- 如果 force=True,清除终端用户引用后删除配置
|
||||
|
||||
Query Parameters:
|
||||
force: 设置为 true 可强制删除(即使有终端用户正在使用)
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id=resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -173,21 +128,62 @@ def delete_config(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
|
||||
f"config_id={config_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
result = svc.delete(ConfigParamsDelete(config_id=config_id))
|
||||
return success(data=result, msg="删除成功")
|
||||
# 使用带保护的删除服务
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
result = config_service.delete_config(config_id=config_id, force=force)
|
||||
|
||||
if result["status"] == "error":
|
||||
api_logger.warning(
|
||||
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.FORBIDDEN,
|
||||
msg=result["message"],
|
||||
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
|
||||
)
|
||||
|
||||
if result["status"] == "warning":
|
||||
api_logger.warning(
|
||||
f"记忆配置正在使用,无法删除: config_id={config_id}, "
|
||||
f"connected_count={result['connected_count']}"
|
||||
)
|
||||
return fail(
|
||||
code=BizCode.RESOURCE_IN_USE,
|
||||
msg=result["message"],
|
||||
data={
|
||||
"connected_count": result["connected_count"],
|
||||
"force_required": result["force_required"]
|
||||
}
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"记忆配置删除成功: config_id={config_id}, "
|
||||
f"affected_users={result['affected_users']}"
|
||||
)
|
||||
return success(
|
||||
msg=result["message"],
|
||||
data={"affected_users": result["affected_users"]}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Delete config failed: {str(e)}")
|
||||
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
|
||||
|
||||
|
||||
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
|
||||
def update_config(
|
||||
payload: ConfigUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -195,6 +191,11 @@ def update_config(
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
|
||||
# 校验至少有一个字段需要更新
|
||||
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
|
||||
try:
|
||||
svc = DataConfigService(db)
|
||||
@@ -210,7 +211,7 @@ def update_config_extracted(
|
||||
payload: ConfigUpdateExtracted,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -237,7 +238,7 @@ def read_config_extracted(
|
||||
config_id: UUID | int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
config_id = resolve_config_id(config_id, db)
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -258,7 +259,7 @@ def read_config_extracted(
|
||||
def read_all_config(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -280,17 +281,22 @@ def read_all_config(
|
||||
@router.post("/pilot_run", response_model=None)
|
||||
async def pilot_run(
|
||||
payload: ConfigPilotRun,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
api_logger.info(
|
||||
f"Pilot run requested: config_id={payload.config_id}, "
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}"
|
||||
f"dialogue_text_length={len(payload.dialogue_text)}, "
|
||||
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
|
||||
)
|
||||
payload.config_id = resolve_config_id(payload.config_id, db)
|
||||
svc = DataConfigService(db)
|
||||
return StreamingResponse(
|
||||
svc.pilot_run_stream(payload),
|
||||
svc.pilot_run_stream(payload, language=language),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
@@ -299,9 +305,8 @@ async def pilot_run(
|
||||
},
|
||||
)
|
||||
|
||||
"""
|
||||
以下为搜索与分析接口,直接挂载到同一 router,统一响应为 ApiResponse。
|
||||
"""
|
||||
|
||||
# ==================== Search & Analytics ====================
|
||||
|
||||
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
|
||||
async def get_kb_type_distribution(
|
||||
@@ -441,8 +446,9 @@ async def get_hot_memory_tags_api(
|
||||
|
||||
try:
|
||||
# 尝试从Redis缓存获取
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
import json
|
||||
|
||||
from app.aioRedis import aio_redis_get, aio_redis_set
|
||||
|
||||
cached_result = await aio_redis_get(cache_key)
|
||||
if cached_result:
|
||||
|
||||
1132
api/app/controllers/ontology_controller.py
Normal file
1132
api/app/controllers/ontology_controller.py
Normal file
File diff suppressed because it is too large
Load Diff
611
api/app/controllers/ontology_secondary_routes.py
Normal file
611
api/app/controllers/ontology_secondary_routes.py
Normal file
@@ -0,0 +1,611 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体场景和类型路由(续)
|
||||
|
||||
由于主Controller文件较大,将剩余路由放在此文件中。
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.ontology_schemas import (
|
||||
SceneResponse,
|
||||
SceneListResponse,
|
||||
PaginationInfo,
|
||||
ClassCreateRequest,
|
||||
ClassUpdateRequest,
|
||||
ClassResponse,
|
||||
ClassListResponse,
|
||||
ClassBatchCreateResponse,
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.ontology_service import OntologyService
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
|
||||
def _get_dummy_ontology_service(db: Session) -> OntologyService:
|
||||
"""获取OntologyService实例(不需要LLM)
|
||||
|
||||
场景和类型管理不需要LLM,创建一个dummy配置。
|
||||
"""
|
||||
dummy_config = RedBearModelConfig(
|
||||
model_name="dummy",
|
||||
provider="openai",
|
||||
api_key="dummy",
|
||||
base_url="https://api.openai.com/v1"
|
||||
)
|
||||
llm_client = OpenAIClient(model_config=dummy_config)
|
||||
return OntologyService(llm_client=llm_client, db=db)
|
||||
|
||||
|
||||
# 这些函数将被导入到主Controller中
|
||||
|
||||
async def scenes_handler(
|
||||
workspace_id: Optional[str] = None,
|
||||
scene_name: Optional[str] = None,
|
||||
page: Optional[int] = None,
|
||||
page_size: Optional[int] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
|
||||
|
||||
当提供 scene_name 参数时,进行模糊搜索(不分页);
|
||||
当不提供 scene_name 参数时,返回所有场景(支持分页)。
|
||||
|
||||
Args:
|
||||
workspace_id: 工作空间ID(可选,默认当前用户工作空间)
|
||||
scene_name: 场景名称关键词(可选,支持模糊匹配)
|
||||
page: 页码(可选,从1开始,仅在全量查询时有效)
|
||||
page_size: 每页数量(可选,仅在全量查询时有效)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if scene_name else "list"
|
||||
api_logger.info(
|
||||
f"Scene {operation} requested by user {current_user.id}, "
|
||||
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 确定工作空间ID
|
||||
if workspace_id:
|
||||
try:
|
||||
ws_uuid = UUID(workspace_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
|
||||
else:
|
||||
ws_uuid = current_user.current_workspace_id
|
||||
if not ws_uuid:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 根据是否提供 scene_name 决定查询方式
|
||||
if scene_name and scene_name.strip():
|
||||
# 验证分页参数(模糊搜索也支持分页)
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
# 模糊搜索场景(支持分页)
|
||||
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
|
||||
total = len(scenes)
|
||||
|
||||
# 如果提供了分页参数,进行分页处理
|
||||
if page is not None and page_size is not None:
|
||||
start_idx = (page - 1) * page_size
|
||||
end_idx = start_idx + page_size
|
||||
scenes = scenes[start_idx:end_idx]
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(
|
||||
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
|
||||
f"in workspace {ws_uuid}, total={total}"
|
||||
)
|
||||
else:
|
||||
# 获取所有场景(支持分页)
|
||||
# 验证分页参数
|
||||
if page is not None and page < 1:
|
||||
api_logger.warning(f"Invalid page number: {page}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
|
||||
|
||||
if page_size is not None and page_size < 1:
|
||||
api_logger.warning(f"Invalid page_size: {page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
|
||||
|
||||
# 如果只提供了page或page_size中的一个,返回错误
|
||||
if (page is not None and page_size is None) or (page is None and page_size is not None):
|
||||
api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
|
||||
|
||||
scenes, total = service.list_scenes(ws_uuid, page, page_size)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for scene in scenes:
|
||||
# 获取前3个class_name作为entity_type
|
||||
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
|
||||
# 动态计算 type_num
|
||||
type_num = len(scene.classes) if scene.classes else 0
|
||||
|
||||
items.append(SceneResponse(
|
||||
scene_id=scene.scene_id,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
type_num=type_num,
|
||||
entity_type=entity_type,
|
||||
workspace_id=scene.workspace_id,
|
||||
created_at=scene.created_at,
|
||||
updated_at=scene.updated_at,
|
||||
classes_count=type_num
|
||||
))
|
||||
|
||||
# 构建响应(包含分页信息)
|
||||
if page is not None and page_size is not None:
|
||||
# 计算是否有下一页
|
||||
hasnext = (page * page_size) < total
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
pagesize=page_size,
|
||||
total=total,
|
||||
hasnext=hasnext
|
||||
)
|
||||
response = SceneListResponse(items=items, page=pagination_info)
|
||||
else:
|
||||
response = SceneListResponse(items=items)
|
||||
|
||||
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
# ==================== 本体类型管理接口 ====================
|
||||
|
||||
async def create_class_handler(
|
||||
request: ClassCreateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
|
||||
|
||||
# 根据列表长度判断是单个还是批量
|
||||
count = len(request.classes)
|
||||
mode = "single" if count == 1 else "batch"
|
||||
|
||||
api_logger.info(
|
||||
f"Class creation ({mode}) requested by user {current_user.id}, "
|
||||
f"scene_id={request.scene_id}, count={count}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 准备类型数据
|
||||
classes_data = [
|
||||
{
|
||||
"class_name": item.class_name,
|
||||
"class_description": item.class_description
|
||||
}
|
||||
for item in request.classes
|
||||
]
|
||||
|
||||
if count == 1:
|
||||
# 单个创建
|
||||
class_data = classes_data[0]
|
||||
ontology_class = service.create_class(
|
||||
scene_id=request.scene_id,
|
||||
class_name=class_data["class_name"],
|
||||
class_description=class_data["class_description"],
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建单个响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
|
||||
|
||||
else:
|
||||
# 批量创建
|
||||
created_classes, errors = service.create_classes_batch(
|
||||
scene_id=request.scene_id,
|
||||
classes=classes_data,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建批量响应
|
||||
items = []
|
||||
for ontology_class in created_classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassBatchCreateResponse(
|
||||
total=len(classes_data),
|
||||
success_count=len(created_classes),
|
||||
failed_count=len(errors),
|
||||
items=items,
|
||||
errors=errors if errors else None
|
||||
)
|
||||
|
||||
api_logger.info(
|
||||
f"Batch class creation completed: "
|
||||
f"success={len(created_classes)}, failed={len(errors)}"
|
||||
)
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class creation: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
|
||||
|
||||
|
||||
async def update_class_handler(
|
||||
class_id: str,
|
||||
request: ClassUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新本体类型"""
|
||||
api_logger.info(
|
||||
f"Class update requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 更新类型
|
||||
ontology_class = service.update_class(
|
||||
class_id=class_uuid,
|
||||
class_name=request.class_name,
|
||||
class_description=request.class_description,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class updated successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class update: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
|
||||
|
||||
|
||||
async def delete_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除本体类型"""
|
||||
api_logger.info(
|
||||
f"Class deletion requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 删除类型
|
||||
success_flag = service.delete_class(
|
||||
class_id=class_uuid,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
api_logger.info(f"Class deleted successfully: {class_id}")
|
||||
|
||||
return success(data={"deleted": success_flag}, msg="类型删除成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class deletion: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
|
||||
|
||||
|
||||
async def get_class_handler(
|
||||
class_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取单个本体类型"""
|
||||
api_logger.info(
|
||||
f"Get class requested by user {current_user.id}, "
|
||||
f"class_id={class_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
class_uuid = UUID(class_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid class_id format: {class_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取类型(会抛出ValueError如果不存在)
|
||||
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
response = ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
)
|
||||
|
||||
api_logger.info(f"Class retrieved successfully: {class_id}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
# 类型不存在或无权限访问
|
||||
api_logger.warning(f"Validation error in get class: {str(e)}")
|
||||
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
|
||||
async def classes_handler(
|
||||
scene_id: str,
|
||||
class_name: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取类型列表(支持模糊搜索和全量查询)
|
||||
|
||||
当提供 class_name 参数时,进行模糊搜索;
|
||||
当不提供 class_name 参数时,返回场景下的所有类型。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID(必填)
|
||||
class_name: 类型名称关键词(可选,支持模糊匹配)
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
"""
|
||||
operation = "search" if class_name else "list"
|
||||
api_logger.info(
|
||||
f"Class {operation} requested by user {current_user.id}, "
|
||||
f"keyword={class_name}, scene_id={scene_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证UUID格式
|
||||
try:
|
||||
scene_uuid = UUID(scene_id)
|
||||
except ValueError:
|
||||
api_logger.warning(f"Invalid scene_id format: {scene_id}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
|
||||
|
||||
# 获取当前工作空间ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
if not workspace_id:
|
||||
api_logger.warning(f"User {current_user.id} has no current workspace")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
|
||||
|
||||
# 创建Service
|
||||
service = _get_dummy_ontology_service(db)
|
||||
|
||||
# 获取场景信息
|
||||
scene = service.get_scene_by_id(scene_uuid, workspace_id)
|
||||
if not scene:
|
||||
api_logger.warning(f"Scene not found: {scene_id}")
|
||||
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
|
||||
|
||||
# 根据是否提供 class_name 决定查询方式
|
||||
if class_name and class_name.strip():
|
||||
# 模糊搜索类型
|
||||
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
|
||||
else:
|
||||
# 获取所有类型
|
||||
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
|
||||
|
||||
# 构建响应
|
||||
items = []
|
||||
for ontology_class in classes:
|
||||
items.append(ClassResponse(
|
||||
class_id=ontology_class.class_id,
|
||||
class_name=ontology_class.class_name,
|
||||
class_description=ontology_class.class_description,
|
||||
scene_id=ontology_class.scene_id,
|
||||
created_at=ontology_class.created_at,
|
||||
updated_at=ontology_class.updated_at
|
||||
))
|
||||
|
||||
response = ClassListResponse(
|
||||
total=len(items),
|
||||
scene_id=scene_uuid,
|
||||
scene_name=scene.scene_name,
|
||||
scene_description=scene.scene_description,
|
||||
items=items
|
||||
)
|
||||
|
||||
if class_name:
|
||||
api_logger.info(
|
||||
f"Class search completed: found {len(items)} classes matching '{class_name}' "
|
||||
f"in scene {scene_id}"
|
||||
)
|
||||
else:
|
||||
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
|
||||
|
||||
return success(data=response.model_dump(mode='json'), msg="查询成功")
|
||||
|
||||
except ValueError as e:
|
||||
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
|
||||
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
|
||||
|
||||
except RuntimeError as e:
|
||||
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
|
||||
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
|
||||
@@ -1,5 +1,5 @@
|
||||
import uuid
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -8,9 +8,13 @@ from starlette.responses import StreamingResponse
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.dependencies import get_current_user, get_db
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
|
||||
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
|
||||
from app.schemas.prompt_optimizer_schema import (
|
||||
PromptOptMessage,
|
||||
CreateSessionResponse,
|
||||
SessionHistoryResponse,
|
||||
SessionMessage,
|
||||
PromptSaveRequest
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.prompt_optimizer_service import PromptOptimizerService
|
||||
|
||||
@@ -116,7 +120,8 @@ async def get_prompt_opt(
|
||||
session_id=session_id,
|
||||
user_id=current_user.id,
|
||||
current_prompt=data.current_prompt,
|
||||
user_require=data.message
|
||||
user_require=data.message,
|
||||
skill=data.skill
|
||||
):
|
||||
# chunk 是 prompt 的增量内容
|
||||
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
|
||||
@@ -135,3 +140,109 @@ async def get_prompt_opt(
|
||||
"X-Accel-Buffering": "no"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/releases",
|
||||
summary="Get prompt optimization",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def save_prompt(
|
||||
data: PromptSaveRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Save a prompt release for the current tenant.
|
||||
|
||||
Args:
|
||||
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
|
||||
db (Session): SQLAlchemy database session, injected via dependency.
|
||||
current_user: Currently authenticated user object, injected via dependency.
|
||||
|
||||
Returns:
|
||||
ApiResponse: Standard API response containing the saved prompt release info:
|
||||
- id: UUID of the prompt release
|
||||
- session_id: associated session
|
||||
- title: prompt title
|
||||
- prompt: prompt content
|
||||
- created_at: timestamp of creation
|
||||
|
||||
Raises:
|
||||
Any database or service exceptions are propagated to the global exception handler.
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
prompt_info = service.save_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
session_id=data.session_id,
|
||||
title=data.title,
|
||||
prompt=data.prompt
|
||||
)
|
||||
return success(data=prompt_info)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/releases/{prompt_id}",
|
||||
summary="Delete prompt (soft delete)",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def delete_prompt(
|
||||
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Soft delete a prompt release.
|
||||
|
||||
Args:
|
||||
prompt_id
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Success message confirming deletion
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
service.delete_prompt(
|
||||
tenant_id=current_user.tenant_id,
|
||||
prompt_id=prompt_id
|
||||
)
|
||||
return success(msg="Prompt deleted successfully")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/releases/list",
|
||||
summary="Get paginated list of released prompts with optional filter",
|
||||
response_model=ApiResponse
|
||||
)
|
||||
def get_release_list(
|
||||
page: int = 1,
|
||||
page_size: int = 20,
|
||||
keyword: str | None = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user=Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Retrieve paginated list of released prompts for the current tenant.
|
||||
Optionally filter by keyword in title.
|
||||
|
||||
Args:
|
||||
page (int): Page number (starting from 1)
|
||||
page_size (int): Number of items per page (max 100)
|
||||
keyword (str | None): Optional keyword to filter prompt titles
|
||||
db (Session): Database session
|
||||
current_user: Current logged-in user
|
||||
|
||||
Returns:
|
||||
ApiResponse: Contains paginated list of prompt releases with metadata
|
||||
"""
|
||||
service = PromptOptimizerService(db)
|
||||
result = service.get_release_list(
|
||||
tenant_id=current_user.tenant_id,
|
||||
page=max(1, page),
|
||||
page_size=min(max(1, page_size), 100),
|
||||
filter_keyword=keyword
|
||||
)
|
||||
return success(data=result)
|
||||
|
||||
|
||||
|
||||
@@ -438,7 +438,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -475,7 +476,8 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -578,6 +580,7 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
@@ -585,7 +588,8 @@ async def chat(
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
release_id=release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.core.exceptions import BusinessException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_app_or_workspace
|
||||
from app.models.app_model import App
|
||||
from app.models.app_model import AppType
|
||||
from app.repositories import knowledge_repository
|
||||
@@ -21,9 +20,10 @@ from app.schemas import AppChatRequest, conversation_schema
|
||||
from app.schemas.api_key_schema import ApiKeyAuth
|
||||
from app.services import workspace_service
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.services.app_service import get_app_service, AppService
|
||||
from app.services.conversation_service import ConversationService, get_conversation_service
|
||||
from app.utils.app_config_utils import workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/app", tags=["V1 - App API"])
|
||||
logger = get_business_logger()
|
||||
@@ -34,6 +34,7 @@ async def list_apps():
|
||||
"""列出可访问的应用(占位)"""
|
||||
return success(data=[], msg="App API - Coming Soon")
|
||||
|
||||
|
||||
# /v1/app/chat
|
||||
|
||||
# @router.post("/chat")
|
||||
@@ -73,16 +74,17 @@ def _checkAppConfig(app: App):
|
||||
else:
|
||||
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
|
||||
|
||||
|
||||
@router.post("/chat")
|
||||
@require_api_key(scopes=["app"])
|
||||
async def chat(
|
||||
request:Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
app_service: Annotated[AppService, Depends(get_app_service)] = None,
|
||||
message: str = Body(..., description="聊天消息内容"),
|
||||
):
|
||||
body = await request.json()
|
||||
payload = AppChatRequest(**body)
|
||||
@@ -98,8 +100,8 @@ async def chat(
|
||||
original_user_id=other_id # Save original user_id to other_id
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
web_search=True
|
||||
memory=True
|
||||
web_search = True
|
||||
memory = True
|
||||
# 提前验证和准备(在流式响应开始前完成)
|
||||
storage_type = workspace_service.get_workspace_storage_type_without_auth(
|
||||
db=db,
|
||||
@@ -146,16 +148,17 @@ async def chat(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=web_search,
|
||||
config=agent_config,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -175,12 +178,13 @@ async def chat(
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config= agent_config,
|
||||
config=agent_config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
files=payload.files # 传递多模态文件
|
||||
)
|
||||
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
|
||||
elif app_type == AppType.MULTI_AGENT:
|
||||
@@ -190,15 +194,15 @@ async def chat(
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -232,19 +236,19 @@ async def chat(
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
files=payload.files,
|
||||
config=config,
|
||||
web_search=web_search,
|
||||
memory=memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -294,4 +298,3 @@ async def chat(
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
|
||||
|
||||
|
||||
@@ -246,3 +246,73 @@ async def rebuild_knowledge_graph(
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/yuque/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_yuque_auth(
|
||||
yuque_user_id: str,
|
||||
yuque_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check yuque auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check yuque auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
|
||||
yuque_token=yuque_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.get("/check/feishu/auth", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def check_feishu_auth(
|
||||
feishu_app_id: str,
|
||||
feishu_app_secret: str,
|
||||
feishu_folder_token: str,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
check feishu auth info
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
api_logger.info(f"check feishu auth info, username: {current_user.username}")
|
||||
|
||||
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
|
||||
feishu_app_secret=feishu_app_secret,
|
||||
feishu_folder_token=feishu_folder_token,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
|
||||
@require_api_key(scopes=["rag"])
|
||||
async def sync_knowledge(
|
||||
knowledge_id: uuid.UUID,
|
||||
request: Request,
|
||||
api_key_auth: ApiKeyAuth = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
sync knowledge base information based on knowledge_id
|
||||
"""
|
||||
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
|
||||
current_user = api_key.creator
|
||||
current_user.current_workspace_id = api_key_auth.workspace_id
|
||||
|
||||
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
|
||||
db=db,
|
||||
current_user=current_user)
|
||||
|
||||
|
||||
85
api/app/controllers/skill_controller.py
Normal file
85
api/app/controllers/skill_controller.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""Skill Controller - 技能市场管理"""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
import uuid
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models import User
|
||||
from app.schemas import skill_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services.skill_service import SkillService
|
||||
from app.core.response_utils import success
|
||||
|
||||
router = APIRouter(prefix="/skills", tags=["Skills"])
|
||||
|
||||
|
||||
@router.post("", summary="创建技能")
|
||||
def create_skill(
|
||||
data: skill_schema.SkillCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""创建技能 - 可以关联现有工具(内置、MCP、自定义)"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.create_skill(db, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
|
||||
|
||||
|
||||
@router.get("", summary="技能列表")
|
||||
def list_skills(
|
||||
search: Optional[str] = Query(None, description="搜索关键词"),
|
||||
is_active: Optional[bool] = Query(None, description="是否激活"),
|
||||
is_public: Optional[bool] = Query(None, description="是否公开"),
|
||||
page: int = Query(1, ge=1, description="页码"),
|
||||
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""技能市场列表 - 包含本工作空间和公开的技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skills, total = SkillService.list_skills(
|
||||
db, tenant_id, search, is_active, is_public, page, pagesize
|
||||
)
|
||||
|
||||
items = [skill_schema.Skill.model_validate(s) for s in skills]
|
||||
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
|
||||
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
|
||||
|
||||
|
||||
@router.get("/{skill_id}", summary="获取技能详情")
|
||||
def get_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""获取技能详情"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.get_skill(db, skill_id, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
|
||||
|
||||
|
||||
@router.put("/{skill_id}", summary="更新技能")
|
||||
def update_skill(
|
||||
skill_id: uuid.UUID,
|
||||
data: skill_schema.SkillUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""更新技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
|
||||
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
|
||||
|
||||
|
||||
@router.delete("/{skill_id}", summary="删除技能")
|
||||
def delete_skill(
|
||||
skill_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""删除技能"""
|
||||
tenant_id = current_user.tenant_id
|
||||
SkillService.delete_skill(db, skill_id, tenant_id)
|
||||
return success(msg="技能删除成功")
|
||||
@@ -8,11 +8,11 @@ from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.language_utils import get_language_from_header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
@@ -45,7 +45,6 @@ router = APIRouter(
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -55,18 +54,10 @@ async def get_memory_insight_report_api(
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
@@ -82,7 +73,7 @@ async def get_memory_insight_report_api(
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -91,7 +82,14 @@ async def get_user_summary_api(
|
||||
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -103,7 +101,7 @@ async def get_user_summary_api(
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -119,6 +117,7 @@ async def get_user_summary_api(
|
||||
@router.post("/analytics/generate_cache", response_model=ApiResponse)
|
||||
async def generate_cache_api(
|
||||
request: GenerateCacheRequest,
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -127,7 +126,14 @@ async def generate_cache_api(
|
||||
|
||||
- 如果提供 end_user_id,只为该用户生成
|
||||
- 如果不提供,为当前工作空间的所有用户生成
|
||||
|
||||
语言控制:
|
||||
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
|
||||
- 如果未传 Header,默认使用中文 (zh)
|
||||
"""
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
@@ -139,7 +145,7 @@ async def generate_cache_api(
|
||||
|
||||
api_logger.info(
|
||||
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}"
|
||||
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -148,10 +154,10 @@ async def generate_cache_api(
|
||||
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
|
||||
|
||||
# 生成记忆洞察
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id)
|
||||
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 生成用户摘要
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id)
|
||||
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
@@ -185,7 +191,7 @@ async def generate_cache_api(
|
||||
# 为整个工作空间生成
|
||||
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
|
||||
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
|
||||
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
|
||||
|
||||
# 记录统计信息
|
||||
api_logger.info(
|
||||
@@ -385,10 +391,13 @@ async def update_end_user_profile(
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
# 使用集中化的语言校验
|
||||
language = get_language_from_header(language_type)
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
@@ -398,7 +407,7 @@ async def memory_space_timeline_of_shared_memories(id: str, label: str,language_
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
|
||||
@@ -1,610 +0,0 @@
|
||||
"""
|
||||
工作流 API 控制器
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Path, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
|
||||
from app.models.user_model import User
|
||||
from app.models.app_model import App
|
||||
from app.services.workflow_service import WorkflowService, get_workflow_service
|
||||
from app.schemas.workflow_schema import (
|
||||
WorkflowConfigCreate,
|
||||
WorkflowConfigUpdate,
|
||||
WorkflowConfig,
|
||||
WorkflowValidationResponse,
|
||||
WorkflowExecution,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowExecutionRequest,
|
||||
WorkflowExecutionResponse
|
||||
)
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.error_codes import BizCode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/apps", tags=["workflow"])
|
||||
|
||||
|
||||
# ==================== 工作流配置管理 ====================
|
||||
|
||||
@router.post("/{app_id}/workflow")
|
||||
@cur_workspace_access_guard()
|
||||
async def create_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
config: WorkflowConfigCreate,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""创建工作流配置
|
||||
|
||||
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 创建工作流配置
|
||||
workflow_config = service.create_workflow_config(
|
||||
app_id=app_id,
|
||||
nodes=[node.model_dump() for node in config.nodes],
|
||||
edges=[edge.model_dump() for edge in config.edges],
|
||||
variables=[var.model_dump() for var in config.variables],
|
||||
execution_config=config.execution_config.model_dump(),
|
||||
triggers=[trigger.model_dump() for trigger in config.triggers],
|
||||
validate=True # 进行基础验证
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowConfig.model_validate(workflow_config),
|
||||
msg="工作流配置创建成功"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"创建工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"创建工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# @router.get("/{app_id}/workflow")
|
||||
# async def get_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)]
|
||||
#
|
||||
# ):
|
||||
# """获取工作流配置
|
||||
#
|
||||
# 获取应用的工作流配置详情。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
#
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
#
|
||||
# # 获取工作流配置
|
||||
# service = WorkflowService(db)
|
||||
# workflow_config = service.get_workflow_config(app_id)
|
||||
#
|
||||
# if not workflow_config:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="工作流配置不存在"
|
||||
# )
|
||||
#
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config)
|
||||
# )
|
||||
#
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"获取工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
# @router.put("/{app_id}/workflow")
|
||||
# async def update_workflow_config(
|
||||
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
# config: WorkflowConfigUpdate,
|
||||
# db: Annotated[Session, Depends(get_db)],
|
||||
# current_user: Annotated[User, Depends(get_current_user)],
|
||||
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
# ):
|
||||
# """更新工作流配置
|
||||
|
||||
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
|
||||
# """
|
||||
# try:
|
||||
# # 验证应用是否存在且属于当前工作空间
|
||||
# app = db.query(App).filter(
|
||||
# App.id == app_id,
|
||||
# App.workspace_id == current_user.current_workspace_id,
|
||||
# App.is_active == True
|
||||
# ).first()
|
||||
|
||||
# if not app:
|
||||
# return fail(
|
||||
# code=BizCode.NOT_FOUND,
|
||||
# msg="应用不存在或无权访问"
|
||||
# )
|
||||
|
||||
# # 更新工作流配置
|
||||
# workflow_config = service.update_workflow_config(
|
||||
# app_id=app_id,
|
||||
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
|
||||
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
|
||||
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
|
||||
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
|
||||
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
|
||||
# validate=True
|
||||
# )
|
||||
|
||||
# return success(
|
||||
# data=WorkflowConfig.model_validate(workflow_config),
|
||||
# msg="工作流配置更新成功"
|
||||
# )
|
||||
|
||||
# except BusinessException as e:
|
||||
# logger.warning(f"更新工作流配置失败: {e.message}")
|
||||
# return fail(code=e.error_code, msg=e.message)
|
||||
# except Exception as e:
|
||||
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
|
||||
# return fail(
|
||||
# code=BizCode.INTERNAL_ERROR,
|
||||
# msg=f"更新工作流配置失败: {str(e)}"
|
||||
# )
|
||||
|
||||
|
||||
@router.delete("/{app_id}/workflow")
|
||||
async def delete_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""删除工作流配置
|
||||
|
||||
删除应用的工作流配置。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 删除工作流配置
|
||||
deleted = service.delete_workflow_config(app_id)
|
||||
|
||||
if not deleted:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
return success(msg="工作流配置删除成功")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"删除工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/workflow/validate")
|
||||
async def validate_workflow_config(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
|
||||
):
|
||||
"""验证工作流配置
|
||||
|
||||
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证工作流配置
|
||||
|
||||
if for_publish:
|
||||
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
|
||||
else:
|
||||
workflow_config = service.get_workflow_config(app_id)
|
||||
if not workflow_config:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="工作流配置不存在"
|
||||
)
|
||||
|
||||
from app.core.workflow.validator import validate_workflow_config as validate_config
|
||||
config_dict = {
|
||||
"nodes": workflow_config.nodes,
|
||||
"edges": workflow_config.edges,
|
||||
"variables": workflow_config.variables,
|
||||
"execution_config": workflow_config.execution_config,
|
||||
"triggers": workflow_config.triggers
|
||||
}
|
||||
is_valid, errors = validate_config(config_dict, for_publish=False)
|
||||
|
||||
return success(
|
||||
data=WorkflowValidationResponse(
|
||||
is_valid=is_valid,
|
||||
errors=errors,
|
||||
warnings=[]
|
||||
)
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"验证工作流配置失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"验证工作流配置失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行管理 ====================
|
||||
|
||||
@router.get("/{app_id}/workflow/executions")
|
||||
async def get_workflow_executions(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)],
|
||||
limit: Annotated[int, Query(ge=1, le=100)] = 50,
|
||||
offset: Annotated[int, Query(ge=0)] = 0
|
||||
):
|
||||
"""获取工作流执行记录列表
|
||||
|
||||
获取应用的工作流执行历史记录。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 获取执行记录
|
||||
executions = service.get_executions_by_app(app_id, limit, offset)
|
||||
|
||||
# 获取统计信息
|
||||
statistics = service.get_execution_statistics(app_id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"executions": [WorkflowExecution.model_validate(e) for e in executions],
|
||||
"statistics": statistics,
|
||||
"pagination": {
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"total": statistics["total"]
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行记录失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/workflow/executions/{execution_id}")
|
||||
async def get_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""获取工作流执行详情
|
||||
|
||||
获取单个工作流执行的详细信息,包括所有节点的执行记录。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 获取节点执行记录
|
||||
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
|
||||
|
||||
return success(
|
||||
data={
|
||||
"execution": WorkflowExecution.model_validate(execution),
|
||||
"node_executions": [
|
||||
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"获取工作流执行详情失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
@router.post("/{app_id}/workflow/run")
|
||||
async def run_workflow(
|
||||
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
|
||||
request: WorkflowExecutionRequest,
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""执行工作流
|
||||
|
||||
执行工作流并返回结果。支持流式和非流式两种模式。
|
||||
|
||||
**非流式模式**:等待工作流执行完成后返回完整结果。
|
||||
|
||||
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
|
||||
"""
|
||||
try:
|
||||
# 验证应用是否存在且属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="应用不存在或无权访问"
|
||||
)
|
||||
|
||||
# 验证应用类型
|
||||
if app.type != "workflow":
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"应用类型必须为 workflow,当前为 {app.type}"
|
||||
)
|
||||
|
||||
# 准备输入数据
|
||||
input_data = {
|
||||
"message": request.message or "",
|
||||
"variables": request.variables
|
||||
}
|
||||
|
||||
# 执行工作流
|
||||
|
||||
if request.stream:
|
||||
# 流式执行
|
||||
from fastapi.responses import StreamingResponse
|
||||
import json
|
||||
|
||||
async def event_generator():
|
||||
"""生成 SSE 事件
|
||||
|
||||
SSE 格式:
|
||||
event: <event_type>
|
||||
data: <json_data>
|
||||
|
||||
支持的事件类型:
|
||||
- workflow_start: 工作流开始
|
||||
- workflow_end: 工作流结束
|
||||
- node_start: 节点开始执行
|
||||
- node_end: 节点执行完成
|
||||
- node_chunk: 中间节点的流式输出
|
||||
- message: 最终消息的流式输出(End 节点及其相邻节点)
|
||||
"""
|
||||
try:
|
||||
async for event in await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=True
|
||||
):
|
||||
# 提取事件类型和数据
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
# 转换为标准 SSE 格式(字符串)
|
||||
# event: <type>
|
||||
# data: <json>
|
||||
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
|
||||
yield sse_message
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式执行异常: {e}", exc_info=True)
|
||||
# 发送错误事件
|
||||
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
||||
yield sse_error
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 非流式执行
|
||||
result = await service.run_workflow(
|
||||
app_id=app_id,
|
||||
input_data=input_data,
|
||||
triggered_by=current_user.id,
|
||||
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
|
||||
stream=False
|
||||
)
|
||||
|
||||
return success(
|
||||
data=WorkflowExecutionResponse(
|
||||
execution_id=result["execution_id"],
|
||||
status=result["status"],
|
||||
output=result.get("output"),
|
||||
output_data=result.get("output_data"),
|
||||
error_message=result.get("error_message"),
|
||||
elapsed_time=result.get("elapsed_time"),
|
||||
token_usage=result.get("token_usage")
|
||||
),
|
||||
msg="工作流执行完成"
|
||||
)
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"执行工作流失败: {e.message}")
|
||||
return fail(code=e.error_code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"执行工作流失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/workflow/executions/{execution_id}/cancel")
|
||||
async def cancel_workflow_execution(
|
||||
execution_id: Annotated[str, Path(description="执行 ID")],
|
||||
db: Annotated[Session, Depends(get_db)],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
service: Annotated[WorkflowService, Depends(get_workflow_service)]
|
||||
):
|
||||
"""取消工作流执行
|
||||
|
||||
取消正在运行的工作流执行。
|
||||
|
||||
**注意**:当前版本仅更新状态为 cancelled,实际的执行取消功能待实现。
|
||||
"""
|
||||
try:
|
||||
# 获取执行记录
|
||||
execution = service.get_execution(execution_id)
|
||||
|
||||
if not execution:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="执行记录不存在"
|
||||
)
|
||||
|
||||
# 验证应用是否属于当前工作空间
|
||||
app = db.query(App).filter(
|
||||
App.id == execution.app_id,
|
||||
App.workspace_id == current_user.current_workspace_id,
|
||||
App.is_active.is_(True)
|
||||
).first()
|
||||
|
||||
if not app:
|
||||
return fail(
|
||||
code=BizCode.NOT_FOUND,
|
||||
msg="无权访问该执行记录"
|
||||
)
|
||||
|
||||
# 检查执行状态
|
||||
if execution.status not in ["pending", "running"]:
|
||||
return fail(
|
||||
code=BizCode.INVALID_PARAMETER,
|
||||
msg=f"无法取消状态为 {execution.status} 的执行"
|
||||
)
|
||||
|
||||
# 更新状态为 cancelled
|
||||
service.update_execution_status(execution_id, "cancelled")
|
||||
|
||||
return success(msg="工作流执行已取消")
|
||||
|
||||
except BusinessException as e:
|
||||
logger.warning(f"取消工作流执行失败: {e.message}")
|
||||
return fail(code=e.code, msg=e.message)
|
||||
except Exception as e:
|
||||
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
|
||||
return fail(
|
||||
code=BizCode.INTERNAL_ERROR,
|
||||
msg=f"取消工作流执行失败: {str(e)}"
|
||||
)
|
||||
162
api/app/core/agent/agent_middleware.py
Normal file
162
api/app/core/agent/agent_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Agent Middleware - 动态技能过滤"""
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional
|
||||
from langchain_core.runnables import RunnablePassthrough
|
||||
|
||||
from app.services.skill_service import SkillService
|
||||
from app.repositories.skill_repository import SkillRepository
|
||||
|
||||
|
||||
class AgentMiddleware:
|
||||
"""Agent 中间件 - 用于动态过滤和加载技能"""
|
||||
|
||||
def __init__(self, skills: Optional[dict] = None):
|
||||
"""
|
||||
初始化中间件
|
||||
|
||||
Args:
|
||||
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
|
||||
"""
|
||||
self.skills = skills or {}
|
||||
self.enabled = self.skills.get('enabled', False)
|
||||
self.all_skills = self.skills.get('all_skills', False)
|
||||
self.skill_ids = self.skills.get('skill_ids', [])
|
||||
|
||||
@staticmethod
|
||||
def filter_tools(
|
||||
tools: List,
|
||||
message: str = "",
|
||||
skill_configs: Dict[str, Any] = None,
|
||||
tool_to_skill_map: Dict[str, str] = None
|
||||
) -> tuple[List, List[str]]:
|
||||
"""
|
||||
根据消息内容和技能配置动态过滤工具
|
||||
|
||||
Args:
|
||||
tools: 所有可用工具列表
|
||||
message: 用户消息(可用于智能过滤)
|
||||
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
|
||||
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
|
||||
|
||||
Returns:
|
||||
(过滤后的工具列表, 激活的技能ID列表)
|
||||
"""
|
||||
if not tools:
|
||||
return [], []
|
||||
|
||||
# 如果没有技能配置,返回所有工具
|
||||
if not skill_configs:
|
||||
return tools, []
|
||||
|
||||
# 基于关键词匹配激活技能
|
||||
activated_skill_ids = []
|
||||
message_lower = message.lower()
|
||||
|
||||
for skill_id, config in skill_configs.items():
|
||||
if not config.get('enabled', True):
|
||||
continue
|
||||
|
||||
keywords = config.get('keywords', [])
|
||||
# 如果没有关键词限制,或消息包含关键词,则激活该技能
|
||||
if not keywords or any(kw.lower() in message_lower for kw in keywords):
|
||||
activated_skill_ids.append(skill_id)
|
||||
|
||||
# 如果没有工具映射关系,返回所有工具
|
||||
if not tool_to_skill_map:
|
||||
return tools, activated_skill_ids
|
||||
|
||||
# 根据激活的技能过滤工具
|
||||
filtered_tools = []
|
||||
for tool in tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
# 如果工具不属于任何skill(base_tools),或者工具所属的skill被激活,则保留
|
||||
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
|
||||
filtered_tools.append(tool)
|
||||
|
||||
return filtered_tools, activated_skill_ids
|
||||
|
||||
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
|
||||
"""
|
||||
加载技能关联的工具
|
||||
|
||||
Args:
|
||||
db: 数据库会话
|
||||
tenant_id: 租户id
|
||||
base_tools: 基础工具列表
|
||||
|
||||
Returns:
|
||||
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
|
||||
"""
|
||||
|
||||
tools_dict = {}
|
||||
tool_to_skill_map = {} # 工具名称到技能ID的映射
|
||||
|
||||
if base_tools:
|
||||
for tool in base_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
tools_dict[tool_name] = tool
|
||||
# base_tools 不属于任何 skill,不加入映射
|
||||
|
||||
skill_configs = {}
|
||||
skill_ids_to_load = []
|
||||
|
||||
# 如果启用技能且 all_skills 为 True,加载租户下所有激活的技能
|
||||
if self.enabled and self.all_skills:
|
||||
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
|
||||
skill_ids_to_load = [str(skill.id) for skill in skills]
|
||||
elif self.enabled and self.skill_ids:
|
||||
skill_ids_to_load = self.skill_ids
|
||||
|
||||
if skill_ids_to_load:
|
||||
for skill_id in skill_ids_to_load:
|
||||
try:
|
||||
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
|
||||
if skill and skill.is_active:
|
||||
# 保存技能配置(包含prompt)
|
||||
config = skill.config or {}
|
||||
config['prompt'] = skill.prompt
|
||||
config['name'] = skill.name
|
||||
skill_configs[skill_id] = config
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 加载技能工具并获取映射关系
|
||||
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
|
||||
|
||||
# 只添加不冲突的 skill_tools
|
||||
for tool in skill_tools:
|
||||
tool_name = getattr(tool, 'name', str(id(tool)))
|
||||
if tool_name not in tools_dict:
|
||||
tools_dict[tool_name] = tool
|
||||
# 复制映射关系
|
||||
if tool_name in skill_tool_map:
|
||||
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
|
||||
|
||||
return list(tools_dict.values()), skill_configs, tool_to_skill_map
|
||||
|
||||
@staticmethod
|
||||
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
|
||||
"""
|
||||
根据激活的技能ID获取对应的提示词
|
||||
|
||||
Args:
|
||||
activated_skill_ids: 被激活的技能ID列表
|
||||
skill_configs: 技能配置字典
|
||||
|
||||
Returns:
|
||||
合并后的提示词
|
||||
"""
|
||||
prompts = []
|
||||
for skill_id in activated_skill_ids:
|
||||
config = skill_configs.get(skill_id, {})
|
||||
prompt = config.get('prompt')
|
||||
name = config.get('name', 'Skill')
|
||||
if prompt:
|
||||
prompts.append(f"# {name}\n{prompt}")
|
||||
|
||||
return "\n\n".join(prompts) if prompts else ""
|
||||
|
||||
@staticmethod
|
||||
def create_runnable():
|
||||
"""创建可运行的中间件"""
|
||||
return RunnablePassthrough()
|
||||
@@ -7,29 +7,21 @@ LangChain Agent 封装
|
||||
- 支持流式输出
|
||||
- 使用 RedBearLLM 支持多提供商
|
||||
"""
|
||||
import os
|
||||
|
||||
import time
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
|
||||
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.models.models_model import ModelType
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
@@ -45,7 +37,9 @@ class LangChainAgent:
|
||||
max_tokens: int = 2000,
|
||||
system_prompt: Optional[str] = None,
|
||||
tools: Optional[Sequence[BaseTool]] = None,
|
||||
streaming: bool = False
|
||||
streaming: bool = False,
|
||||
max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算)
|
||||
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
|
||||
):
|
||||
"""初始化 LangChain Agent
|
||||
|
||||
@@ -58,13 +52,36 @@ class LangChainAgent:
|
||||
max_tokens: 最大 token 数
|
||||
system_prompt: 系统提示词
|
||||
tools: 工具列表(可选,框架自动走 ReAct 循环)
|
||||
streaming: 是否启用流式输出(默认 True)
|
||||
streaming: 是否启用流式输出
|
||||
max_iterations: 最大迭代次数(None 表示自动计算:基础 5 次 + 每个工具 2 次)
|
||||
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.provider = provider
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
self.tools = tools or []
|
||||
self.streaming = streaming
|
||||
self.max_tool_consecutive_calls = max_tool_consecutive_calls
|
||||
|
||||
# 工具调用计数器:记录每个工具的连续调用次数
|
||||
self.tool_call_counter: Dict[str, int] = {}
|
||||
self.last_tool_called: Optional[str] = None
|
||||
|
||||
# 根据工具数量动态调整最大迭代次数
|
||||
# 基础值 + 每个工具额外的调用机会
|
||||
if max_iterations is None:
|
||||
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
|
||||
self.max_iterations = 5 + len(self.tools) * 2
|
||||
else:
|
||||
self.max_iterations = max_iterations
|
||||
|
||||
self.system_prompt = system_prompt or "你是一个专业的AI助手"
|
||||
|
||||
logger.debug(
|
||||
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
|
||||
f"tool_count={len(self.tools)}, "
|
||||
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
|
||||
f"auto_calculated={max_iterations is None}"
|
||||
)
|
||||
|
||||
# 创建 RedBearLLM(支持多提供商)
|
||||
model_config = RedBearModelConfig(
|
||||
@@ -88,11 +105,14 @@ class LangChainAgent:
|
||||
if streaming and hasattr(self._underlying_llm, 'streaming'):
|
||||
self._underlying_llm.streaming = True
|
||||
|
||||
# 包装工具以跟踪连续调用次数
|
||||
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
|
||||
|
||||
# 使用 create_agent 创建 agent graph(LangChain 1.x 标准方式)
|
||||
# 无论是否有工具,都使用 agent 统一处理
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools if self.tools else None,
|
||||
tools=wrapped_tools,
|
||||
system_prompt=self.system_prompt
|
||||
)
|
||||
|
||||
@@ -104,17 +124,91 @@ class LangChainAgent:
|
||||
"has_api_base": bool(api_base),
|
||||
"temperature": temperature,
|
||||
"streaming": streaming,
|
||||
"max_iterations": self.max_iterations,
|
||||
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
|
||||
"tool_count": len(self.tools),
|
||||
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
|
||||
"tool_count": len(self.tools)
|
||||
# "tool_count": len(self.tools)
|
||||
}
|
||||
)
|
||||
|
||||
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
|
||||
"""包装工具以跟踪连续调用次数
|
||||
|
||||
Args:
|
||||
tools: 原始工具列表
|
||||
|
||||
Returns:
|
||||
List[BaseTool]: 包装后的工具列表
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from functools import wraps
|
||||
|
||||
wrapped_tools = []
|
||||
|
||||
for original_tool in tools:
|
||||
tool_name = original_tool.name
|
||||
original_func = original_tool.func if hasattr(original_tool, 'func') else None
|
||||
|
||||
if not original_func:
|
||||
# 如果无法获取原始函数,直接使用原工具
|
||||
wrapped_tools.append(original_tool)
|
||||
continue
|
||||
|
||||
# 创建包装函数
|
||||
def make_wrapped_func(tool_name, original_func):
|
||||
"""创建包装函数的工厂函数,避免闭包问题"""
|
||||
@wraps(original_func)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
"""包装后的工具函数,跟踪连续调用次数"""
|
||||
# 检查是否是连续调用同一个工具
|
||||
if self.last_tool_called == tool_name:
|
||||
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
|
||||
else:
|
||||
# 切换到新工具,重置计数器
|
||||
self.tool_call_counter[tool_name] = 1
|
||||
self.last_tool_called = tool_name
|
||||
|
||||
current_count = self.tool_call_counter[tool_name]
|
||||
|
||||
logger.debug(
|
||||
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
|
||||
)
|
||||
|
||||
# 检查是否超过最大连续调用次数
|
||||
if current_count > self.max_tool_consecutive_calls:
|
||||
logger.warning(
|
||||
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls}),"
|
||||
f"返回提示信息"
|
||||
)
|
||||
return (
|
||||
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
|
||||
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 调用原始工具函数
|
||||
return original_func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
# 使用 StructuredTool 创建新工具
|
||||
wrapped_tool = StructuredTool(
|
||||
name=original_tool.name,
|
||||
description=original_tool.description,
|
||||
func=make_wrapped_func(tool_name, original_func),
|
||||
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
|
||||
)
|
||||
|
||||
wrapped_tools.append(wrapped_tool)
|
||||
|
||||
return wrapped_tools
|
||||
|
||||
def _prepare_messages(
|
||||
self,
|
||||
message: str,
|
||||
history: Optional[List[Dict[str, str]]] = None,
|
||||
context: Optional[str] = None
|
||||
context: Optional[str] = None,
|
||||
files: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[BaseMessage]:
|
||||
"""准备消息列表
|
||||
|
||||
@@ -122,6 +216,7 @@ class LangChainAgent:
|
||||
message: 用户消息
|
||||
history: 历史消息列表
|
||||
context: 上下文信息
|
||||
files: 多模态文件内容列表(已处理)
|
||||
|
||||
Returns:
|
||||
List[BaseMessage]: 消息列表
|
||||
@@ -144,107 +239,47 @@ class LangChainAgent:
|
||||
if context:
|
||||
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
|
||||
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
# 构建用户消息(支持多模态)
|
||||
if files and len(files) > 0:
|
||||
content_parts = self._build_multimodal_content(user_content, files)
|
||||
messages.append(HumanMessage(content=content_parts))
|
||||
else:
|
||||
# 纯文本消息
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# end_user_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
|
||||
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
构建多模态消息内容
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
text: 文本内容
|
||||
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 消息内容列表
|
||||
"""
|
||||
# 根据 provider 使用不同的文本格式
|
||||
if self.provider.lower() in ["bedrock", "anthropic"]:
|
||||
# Anthropic/Bedrock: {"type": "text", "text": "..."}
|
||||
content_parts = [{"type": "text", "text": text}]
|
||||
else:
|
||||
# 通义千问等: {"text": "..."}
|
||||
content_parts = [{"text": text}]
|
||||
|
||||
# 添加文件内容
|
||||
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
|
||||
content_parts.extend(files)
|
||||
|
||||
logger.debug(
|
||||
f"构建多模态消息: provider={self.provider}, "
|
||||
f"parts={len(content_parts)}, "
|
||||
f"files={len(files)}"
|
||||
)
|
||||
|
||||
return content_parts
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id=resolve_config_id(actual_config_id, db)
|
||||
|
||||
if storage_type == "rag":
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
async def chat(
|
||||
self,
|
||||
message: str,
|
||||
@@ -254,7 +289,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None, # 添加这个参数
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> Dict[str, Any]:
|
||||
"""执行对话
|
||||
|
||||
@@ -288,33 +324,9 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
"准备调用 LangChain Agent",
|
||||
@@ -322,27 +334,85 @@ class LangChainAgent:
|
||||
"has_context": bool(context),
|
||||
"has_history": bool(history),
|
||||
"has_tools": bool(self.tools),
|
||||
"message_count": len(messages)
|
||||
"has_files": bool(files),
|
||||
"message_count": len(messages),
|
||||
"max_iterations": self.max_iterations
|
||||
}
|
||||
)
|
||||
|
||||
# 统一使用 agent.invoke 调用
|
||||
result = await self.agent.ainvoke({"messages": messages})
|
||||
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
|
||||
try:
|
||||
result = await self.agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
)
|
||||
except RecursionError as e:
|
||||
logger.warning(
|
||||
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
|
||||
extra={"error": str(e)}
|
||||
)
|
||||
# 返回一个友好的错误提示
|
||||
return {
|
||||
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
|
||||
"model": self.model_name,
|
||||
"elapsed_time": time.time() - start_time,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
}
|
||||
|
||||
# 获取最后的 AI 消息
|
||||
output_messages = result.get("messages", [])
|
||||
content = ""
|
||||
|
||||
logger.debug(f"输出消息数量: {len(output_messages)}")
|
||||
total_tokens = 0
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
content = msg.content
|
||||
logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}")
|
||||
logger.debug(f"AI 消息内容: {msg.content}")
|
||||
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
if isinstance(msg.content, str):
|
||||
content = msg.content
|
||||
logger.debug(f"提取字符串内容,长度: {len(content)}")
|
||||
elif isinstance(msg.content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
|
||||
text_parts = []
|
||||
for item in msg.content:
|
||||
logger.debug(f"处理项: {item}")
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
text_parts.append(text)
|
||||
logger.debug(f"提取文本: {text[:100]}...")
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
logger.debug(f"提取字符串: {item[:100]}...")
|
||||
content = "".join(text_parts)
|
||||
logger.debug(f"合并后内容长度: {len(content)}")
|
||||
else:
|
||||
content = str(msg.content)
|
||||
logger.debug(f"转换为字符串: {content[:100]}...")
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
|
||||
break
|
||||
|
||||
logger.info(f"最终提取的内容长度: {len(content)}")
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -350,7 +420,7 @@ class LangChainAgent:
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
"total_tokens": total_tokens
|
||||
}
|
||||
}
|
||||
|
||||
@@ -377,7 +447,8 @@ class LangChainAgent:
|
||||
config_id: Optional[str] = None,
|
||||
storage_type:Optional[str] = None,
|
||||
user_rag_memory_id:Optional[str] = None,
|
||||
memory_flag: Optional[bool] = True
|
||||
memory_flag: Optional[bool] = True,
|
||||
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""执行流式对话
|
||||
|
||||
@@ -410,33 +481,15 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
# 准备消息列表(支持多模态)
|
||||
messages = self._prepare_messages(message, history, context, files)
|
||||
|
||||
logger.debug(
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, message_count={len(messages)}"
|
||||
f"准备流式调用,has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
|
||||
)
|
||||
|
||||
chunk_count = 0
|
||||
@@ -444,11 +497,12 @@ class LangChainAgent:
|
||||
|
||||
# 统一使用 agent 的 astream_events 实现流式输出
|
||||
logger.debug("使用 Agent astream_events 实现流式输出")
|
||||
full_content=''
|
||||
full_content = ''
|
||||
try:
|
||||
async for event in self.agent.astream_events(
|
||||
{"messages": messages},
|
||||
version="v2"
|
||||
version="v2",
|
||||
config={"recursion_limit": self.max_iterations}
|
||||
):
|
||||
chunk_count += 1
|
||||
kind = event.get("event")
|
||||
@@ -457,20 +511,70 @@ class LangChainAgent:
|
||||
if kind == "on_chat_model_stream":
|
||||
# LLM 流式输出
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
full_content+=chunk.content
|
||||
if chunk and hasattr(chunk, "content") and chunk.content:
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if chunk and hasattr(chunk, "content"):
|
||||
# 处理多模态响应:content 可能是字符串或列表
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
|
||||
elif kind == "on_llm_stream":
|
||||
# 另一种 LLM 流式事件
|
||||
chunk = event.get("data", {}).get("chunk")
|
||||
if chunk:
|
||||
if hasattr(chunk, "content") and chunk.content:
|
||||
full_content+=chunk.content
|
||||
yield chunk.content
|
||||
yielded_content = True
|
||||
if hasattr(chunk, "content"):
|
||||
chunk_content = chunk.content
|
||||
if isinstance(chunk_content, str) and chunk_content:
|
||||
full_content += chunk_content
|
||||
yield chunk_content
|
||||
yielded_content = True
|
||||
elif isinstance(chunk_content, list):
|
||||
# 多模态响应:提取文本部分
|
||||
for item in chunk_content:
|
||||
if isinstance(item, dict):
|
||||
# 通义千问格式: {"text": "..."}
|
||||
if "text" in item:
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
# OpenAI 格式: {"type": "text", "text": "..."}
|
||||
elif item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
if text:
|
||||
full_content += text
|
||||
yield text
|
||||
yielded_content = True
|
||||
elif isinstance(item, str):
|
||||
full_content += item
|
||||
yield item
|
||||
yielded_content = True
|
||||
elif isinstance(chunk, str):
|
||||
full_content += chunk
|
||||
yield chunk
|
||||
yielded_content = True
|
||||
|
||||
@@ -481,12 +585,17 @@ class LangChainAgent:
|
||||
logger.debug(f"工具调用结束: {event.get('name')}")
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
# 统计token消耗
|
||||
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
|
||||
for msg in reversed(output_messages):
|
||||
if isinstance(msg, AIMessage):
|
||||
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
|
||||
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
|
||||
0) if response_meta else 0
|
||||
yield total_tokens
|
||||
break
|
||||
if memory_flag:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
@@ -157,6 +157,11 @@ class Settings:
|
||||
if origin.strip()
|
||||
]
|
||||
|
||||
# Language Configuration
|
||||
# Supported values: "zh" (Chinese), "en" (English)
|
||||
# This controls the language used for memory summary titles and other generated content
|
||||
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
|
||||
|
||||
# Logging settings
|
||||
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
|
||||
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
@@ -210,9 +215,34 @@ class Settings:
|
||||
# official environment system version
|
||||
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
|
||||
|
||||
# model square loading
|
||||
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
|
||||
|
||||
# workflow config
|
||||
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
|
||||
|
||||
# ========================================================================
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
# Prompt 中最大类型数量
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
|
||||
|
||||
# 核心通用类型列表(逗号分隔)
|
||||
CORE_GENERAL_TYPES: str = os.getenv(
|
||||
"CORE_GENERAL_TYPES",
|
||||
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
|
||||
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
|
||||
)
|
||||
|
||||
# 实验模式开关(允许通过 API 动态切换本体配置)
|
||||
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
|
||||
|
||||
def get_memory_output_path(self, filename: str = "") -> str:
|
||||
"""
|
||||
Get the full path for memory module output files.
|
||||
|
||||
@@ -46,6 +46,7 @@ class BizCode(IntEnum):
|
||||
RESOURCE_ALREADY_EXISTS = 5002
|
||||
VERSION_ALREADY_EXISTS = 5003
|
||||
STATE_CONFLICT = 5004
|
||||
RESOURCE_IN_USE = 5005
|
||||
|
||||
# 应用发布(6xxx)
|
||||
PUBLISH_FAILED = 6001
|
||||
@@ -125,6 +126,7 @@ HTTP_MAPPING = {
|
||||
BizCode.RESOURCE_ALREADY_EXISTS: 409,
|
||||
BizCode.VERSION_ALREADY_EXISTS: 409,
|
||||
BizCode.STATE_CONFLICT: 409,
|
||||
BizCode.RESOURCE_IN_USE: 409,
|
||||
BizCode.PUBLISH_FAILED: 500,
|
||||
BizCode.NO_DRAFT_TO_PUBLISH: 400,
|
||||
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,
|
||||
|
||||
82
api/app/core/language_utils.py
Normal file
82
api/app/core/language_utils.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""语言处理工具模块
|
||||
|
||||
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
|
||||
|
||||
Functions:
|
||||
validate_language: 校验语言参数,确保其为有效值
|
||||
get_language_from_header: 从请求头获取并校验语言参数
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from app.core.logging_config import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 支持的语言列表
|
||||
SUPPORTED_LANGUAGES = {"zh", "en"}
|
||||
|
||||
# 默认回退语言
|
||||
DEFAULT_LANGUAGE = "zh"
|
||||
|
||||
|
||||
def validate_language(language: Optional[str]) -> str:
|
||||
"""
|
||||
校验语言参数,确保其为有效值。
|
||||
|
||||
Args:
|
||||
language: 待校验的语言代码,可以是 None、"zh"、"en" 或其他值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> validate_language("zh")
|
||||
'zh'
|
||||
>>> validate_language("en")
|
||||
'en'
|
||||
>>> validate_language("EN") # 大小写不敏感
|
||||
'en'
|
||||
>>> validate_language(None) # None 回退到默认值
|
||||
'zh'
|
||||
>>> validate_language("fr") # 不支持的语言回退到默认值
|
||||
'zh'
|
||||
"""
|
||||
if language is None:
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
# 标准化:转小写并去除空白
|
||||
lang = str(language).lower().strip()
|
||||
|
||||
if lang in SUPPORTED_LANGUAGES:
|
||||
return lang
|
||||
|
||||
logger.warning(
|
||||
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'。"
|
||||
f"支持的语言: {SUPPORTED_LANGUAGES}"
|
||||
)
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
|
||||
def get_language_from_header(language_type: Optional[str]) -> str:
|
||||
"""
|
||||
从请求头获取并校验语言参数。
|
||||
|
||||
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
|
||||
|
||||
Args:
|
||||
language_type: 从 X-Language-Type Header 获取的语言值
|
||||
|
||||
Returns:
|
||||
有效的语言代码("zh" 或 "en")
|
||||
|
||||
Examples:
|
||||
>>> get_language_from_header(None) # Header 未传递
|
||||
'zh'
|
||||
>>> get_language_from_header("en")
|
||||
'en'
|
||||
>>> get_language_from_header("invalid") # 无效值回退
|
||||
'zh'
|
||||
"""
|
||||
return validate_language(language_type)
|
||||
@@ -38,6 +38,56 @@ class SensitiveDataLoggingFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
class Neo4jSuccessNotificationFilter(logging.Filter):
|
||||
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
|
||||
|
||||
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
|
||||
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
|
||||
- 00000: 成功完成 (successful completion)
|
||||
- 00N00: 无数据 (no data)
|
||||
- 00NA0: 无数据,信息性通知 (no data, informational notification)
|
||||
|
||||
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
# 编译正则表达式以提高性能
|
||||
# 匹配所有"成功/信息性"的 GQL 状态码:
|
||||
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
|
||||
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
|
||||
|
||||
# 匹配 status_description 中的成功完成或信息性通知消息
|
||||
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
"""
|
||||
过滤 Neo4j 成功通知
|
||||
|
||||
Args:
|
||||
record: 日志记录
|
||||
|
||||
Returns:
|
||||
True表示允许记录,False表示拒绝(过滤掉)
|
||||
"""
|
||||
# 只处理 INFO 和 WARNING 级别的日志
|
||||
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
|
||||
# 对 severity='WARNING' 的通知使用 WARNING 级别
|
||||
if record.levelno not in (logging.INFO, logging.WARNING):
|
||||
return True
|
||||
|
||||
# 检查是否是 Neo4j 的成功通知
|
||||
message = str(record.msg)
|
||||
|
||||
# 使用正则表达式进行更严格的匹配
|
||||
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
|
||||
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
|
||||
return False # 过滤掉这条日志
|
||||
|
||||
# 保留其他所有日志(包括真正的警告和错误)
|
||||
return True
|
||||
|
||||
|
||||
class LoggingConfig:
|
||||
"""全局日志配置类"""
|
||||
|
||||
@@ -65,6 +115,22 @@ class LoggingConfig:
|
||||
# 清除现有处理器
|
||||
root_logger.handlers.clear()
|
||||
|
||||
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
|
||||
neo4j_filter = Neo4jSuccessNotificationFilter()
|
||||
|
||||
# 抑制 Neo4j 通知日志
|
||||
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler,
|
||||
# 导致日志绕过根 logger 的 filter 直接输出。
|
||||
# 多管齐下确保过滤生效:
|
||||
# 1. 设置 neo4j.notifications 级别为 WARNING(过滤 INFO 级别的 00NA0 通知)
|
||||
# 2. 在所有 neo4j logger 上添加 filter(过滤 WARNING 级别的成功通知)
|
||||
# 3. 在根 handler 上也添加 filter(兜底)
|
||||
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
|
||||
neo4j_notifications_logger.setLevel(logging.WARNING)
|
||||
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
|
||||
neo4j_logger = logging.getLogger(neo4j_logger_name)
|
||||
neo4j_logger.addFilter(neo4j_filter)
|
||||
|
||||
# 创建格式化器
|
||||
formatter = logging.Formatter(
|
||||
fmt=settings.LOG_FORMAT,
|
||||
@@ -80,6 +146,7 @@ class LoggingConfig:
|
||||
console_handler.setFormatter(formatter)
|
||||
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
console_handler.addFilter(sensitive_filter)
|
||||
console_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# 文件处理器(带轮转)
|
||||
@@ -93,6 +160,7 @@ class LoggingConfig:
|
||||
file_handler.setFormatter(formatter)
|
||||
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
|
||||
file_handler.addFilter(sensitive_filter)
|
||||
file_handler.addFilter(neo4j_filter)
|
||||
root_logger.addHandler(file_handler)
|
||||
|
||||
cls._initialized = True
|
||||
|
||||
@@ -10,7 +10,7 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, end_user_id, and memory_config
|
||||
state: WriteState containing messages, end_user_id, memory_config, and language
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
@@ -18,6 +18,7 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
messages = state.get('messages', [])
|
||||
end_user_id = state.get('end_user_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
language = state.get('language', 'zh') # 默认中文
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
@@ -35,6 +36,7 @@ async def write_node(state: WriteState) -> WriteState:
|
||||
messages=structured_messages,
|
||||
end_user_id=end_user_id,
|
||||
memory_config=memory_config,
|
||||
language=language,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
|
||||
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
|
||||
|
||||
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
from app.core.memory.agent.utils.redis_tool import count_store
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context, get_db
|
||||
from app.repositories.memory_short_repository import LongTermMemoryRepository
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_konwledges_server import write_rag
|
||||
from app.services.task_service import get_task_memory_write_result
|
||||
from app.tasks import write_message_task
|
||||
from app.utils.config_utils import resolve_config_id
|
||||
logger = get_agent_logger(__name__)
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
|
||||
|
||||
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
|
||||
actual_config_id, long_term_messages=[]):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
actual_config_id = resolve_config_id(actual_config_id, db)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if isinstance(user_message, str) and user_message.strip() != "":
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if isinstance(ai_message, str) and ai_message.strip() != "":
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果提供了 long_term_messages,使用它替代 structured_messages
|
||||
if long_term_messages and isinstance(long_term_messages, list):
|
||||
structured_messages = long_term_messages
|
||||
elif long_term_messages and isinstance(long_term_messages, str):
|
||||
# 如果是 JSON 字符串,先解析
|
||||
try:
|
||||
structured_messages = json.loads(long_term_messages)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # end_user_id: 用户ID
|
||||
structured_messages, # message: JSON 字符串格式的消息列表
|
||||
str(actual_config_id), # config_id: 配置ID字符串
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
|
||||
with get_db_context() as db_session:
|
||||
repo = LongTermMemoryRepository(db_session)
|
||||
|
||||
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
result = write_store.get_session_by_userid(end_user_id)
|
||||
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
|
||||
data = await format_parsing(result, "dict")
|
||||
chunk_data = data[:scope]
|
||||
if len(chunk_data)==scope:
|
||||
repo.upsert(end_user_id, chunk_data)
|
||||
logger.info(f'---------写入短长期-----------')
|
||||
else:
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
|
||||
long_messages = await messages_parse(long_time_data)
|
||||
repo.upsert(end_user_id, long_messages)
|
||||
logger.info(f'写入短长期:')
|
||||
|
||||
|
||||
|
||||
'''根据窗口'''
|
||||
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
|
||||
'''
|
||||
根据窗口获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
langchain_messages:原始数据LIST
|
||||
scope:窗口大小
|
||||
'''
|
||||
scope=scope
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)
|
||||
if is_end_user_id is not False:
|
||||
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
|
||||
redis_messages = count_store.get_sessions_count(end_user_id)[1]
|
||||
if is_end_user_id and int(is_end_user_id) != int(scope):
|
||||
is_end_user_id += 1
|
||||
langchain_messages += redis_messages
|
||||
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
|
||||
elif int(is_end_user_id) == int(scope):
|
||||
logger.info('写入长期记忆NEO4J')
|
||||
formatted_messages = (redis_messages)
|
||||
# 获取 config_id(如果 memory_config 是对象,提取 config_id;否则直接使用)
|
||||
if hasattr(memory_config, 'config_id'):
|
||||
config_id = memory_config.config_id
|
||||
else:
|
||||
config_id = memory_config
|
||||
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
config_id, formatted_messages)
|
||||
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
|
||||
else:
|
||||
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
|
||||
|
||||
|
||||
"""根据时间"""
|
||||
async def memory_long_term_storage(end_user_id,memory_config,time):
|
||||
'''
|
||||
根据时间获取redis数据,写入neo4j:
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
memory_config: 内存配置对象
|
||||
'''
|
||||
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
|
||||
format_messages = (long_time_data)
|
||||
messages=[]
|
||||
memory_config=memory_config.config_id
|
||||
for i in format_messages:
|
||||
message=json.loads(i['Query'])
|
||||
messages+= message
|
||||
if format_messages!=[]:
|
||||
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
|
||||
memory_config, messages)
|
||||
'''聚合判断'''
|
||||
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
|
||||
"""
|
||||
聚合判断函数:判断输入句子和历史消息是否描述同一事件
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
|
||||
try:
|
||||
# 1. 获取历史会话数据(使用新方法)
|
||||
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
|
||||
history = await format_parsing(result)
|
||||
if not result:
|
||||
history = []
|
||||
else:
|
||||
history = await format_parsing(result)
|
||||
json_schema = WriteAggregateModel.model_json_schema()
|
||||
template_service = TemplateService(template_root)
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='write_aggregate_judgment.jinja2',
|
||||
operation_name='aggregate_judgment',
|
||||
history=history,
|
||||
sentence=ori_messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
with get_db_context() as db_session:
|
||||
factory = MemoryClientFactory(db_session)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": system_prompt
|
||||
}
|
||||
]
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=WriteAggregateModel
|
||||
)
|
||||
output_value = structured.output
|
||||
if isinstance(output_value, list):
|
||||
output_value = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in output_value
|
||||
]
|
||||
|
||||
result_dict = {
|
||||
"is_same_event": structured.is_same_event,
|
||||
"output": output_value
|
||||
}
|
||||
if not structured.is_same_event:
|
||||
logger.info(result_dict)
|
||||
await write("neo4j", end_user_id, "", "", None, end_user_id,
|
||||
memory_config.config_id, output_value)
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
print(f"[aggregate_judgment] 发生错误: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return {
|
||||
"is_same_event": False,
|
||||
"output": ori_messages,
|
||||
"messages": ori_messages,
|
||||
"history": history if 'history' in locals() else [],
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -186,10 +186,11 @@ def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
import json
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage
|
||||
async def format_parsing(messages: list,type:str='string'):
|
||||
"""
|
||||
格式化解析消息列表
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
type: 返回类型 ('string' 或 'dict')
|
||||
|
||||
Returns:
|
||||
格式化后的消息列表
|
||||
"""
|
||||
result = []
|
||||
user=[]
|
||||
ai=[]
|
||||
|
||||
for message in messages:
|
||||
hstory_messages = message['messages']
|
||||
for history_messag in hstory_messages.strip().splitlines():
|
||||
history_messag = json.loads(history_messag)
|
||||
for content in history_messag:
|
||||
role = content['role']
|
||||
content = content['content']
|
||||
if type == "string":
|
||||
if role == 'human' or role=="user":
|
||||
content = '用户:' + content
|
||||
else:
|
||||
content = 'AI:' + content
|
||||
result.append(content)
|
||||
if type == "dict" :
|
||||
if role == 'human' or role=="user":
|
||||
user.append( content)
|
||||
else:
|
||||
ai.append(content)
|
||||
if type == "dict":
|
||||
for key,values in zip(user,ai):
|
||||
result.append({key:values})
|
||||
return result
|
||||
|
||||
async def messages_parse(messages: list | dict):
|
||||
user=[]
|
||||
ai=[]
|
||||
database=[]
|
||||
for message in messages:
|
||||
Query = message['Query']
|
||||
Query = json.loads(Query)
|
||||
for data in Query:
|
||||
role = data['role']
|
||||
if role == "human":
|
||||
user.append(data['content'])
|
||||
if role == "ai":
|
||||
ai.append(data['content'])
|
||||
for key, values in zip(user, ai):
|
||||
database.append({key, values})
|
||||
return database
|
||||
|
||||
|
||||
async def agent_chat_messages(user_content,ai_content):
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"{user_content}"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"{ai_content}"
|
||||
}
|
||||
|
||||
]
|
||||
return messages
|
||||
@@ -1,27 +1,26 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_context
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_write
|
||||
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph():
|
||||
"""
|
||||
@@ -34,14 +33,6 @@ async def make_write_graph():
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
# workflow = StateGraph(WriteState)
|
||||
# workflow.add_node("content_input", content_input_write)
|
||||
# workflow.add_node("save_neo4j", write_node)
|
||||
# workflow.add_edge(START, "content_input")
|
||||
# workflow.add_edge("content_input", "save_neo4j")
|
||||
# workflow.add_edge("save_neo4j", END)
|
||||
#
|
||||
# graph = workflow.compile()
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
@@ -51,43 +42,63 @@ async def make_write_graph():
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
end_user_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
|
||||
from app.core.memory.agent.utils.redis_tool import write_store
|
||||
write_store.save_session_write(end_user_id, (langchain_messages))
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": end_user_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "end_user_id": end_user_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
with get_db_context() as db_session:
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=memory_config, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
if long_term_type=='chunk':
|
||||
'''方案一:对话窗口6轮对话'''
|
||||
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
|
||||
if long_term_type=='time':
|
||||
"""时间"""
|
||||
await memory_long_term_storage(end_user_id, memory_config,5)
|
||||
if long_term_type=='aggregate':
|
||||
"""方案三:聚合判断"""
|
||||
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
|
||||
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
|
||||
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
|
||||
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
|
||||
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
|
||||
else:
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
|
||||
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
|
||||
long_term_messages = await agent_chat_messages(message_chat, aimessages)
|
||||
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
|
||||
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
|
||||
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
|
||||
|
||||
# async def main():
|
||||
# """主函数 - 运行工作流"""
|
||||
# langchain_messages = [
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": "今天周五去爬山"
|
||||
# },
|
||||
# {
|
||||
# "role": "assistant",
|
||||
# "content": "好耶"
|
||||
# }
|
||||
#
|
||||
# ]
|
||||
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
|
||||
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
|
||||
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
|
||||
#
|
||||
#
|
||||
#
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
28
api/app/core/memory/agent/models/write_aggregate_model.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""Pydantic models for write aggregate judgment operations."""
|
||||
|
||||
from typing import List, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MessageItem(BaseModel):
|
||||
"""Individual message item in conversation."""
|
||||
|
||||
role: str = Field(..., description="角色:user 或 assistant")
|
||||
content: str = Field(..., description="消息内容")
|
||||
|
||||
|
||||
class WriteAggregateResponse(BaseModel):
|
||||
"""Response model for aggregate judgment containing judgment result and output."""
|
||||
|
||||
is_same_event: bool = Field(
|
||||
...,
|
||||
description="是否是同一事件。True表示是同一事件,False表示不同事件"
|
||||
)
|
||||
output: Union[List[MessageItem], bool] = Field(
|
||||
...,
|
||||
description="如果is_same_event为True,返回False;如果is_same_event为False,返回消息列表"
|
||||
)
|
||||
|
||||
|
||||
# 为了保持向后兼容,保留旧的类名作为别名
|
||||
WriteAggregateModel = WriteAggregateResponse
|
||||
@@ -18,6 +18,7 @@ class WriteState(TypedDict):
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data: str
|
||||
language: str # 语言类型 ("zh" 中文, "en" 英文)
|
||||
|
||||
class ReadState(TypedDict):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
输入句子:{{sentence}}
|
||||
历史消息:{{history}}
|
||||
|
||||
# 你的角色
|
||||
你是一个擅长事件聚合与语义判断的专家。
|
||||
|
||||
# 你的任务
|
||||
结合历史消息和输入句子,判断它们是否在描述**同一件事件或同一事件链**。
|
||||
|
||||
以下情况视为"同一事件"(需要返回 is_same_event=True, output=False):
|
||||
- 描述的是同一个具体事件或事实
|
||||
- 存在明显的因果关系、前后发展关系
|
||||
- 是对同一事件的补充、解释、追问或延展
|
||||
- 逻辑上属于同一语境下的连续讨论
|
||||
|
||||
以下情况视为"不同事件"(需要返回 is_same_event=False, output=消息列表):
|
||||
- 话题不同,事件主体不同
|
||||
- 时间、地点、对象明显不同
|
||||
- 只是语义相似,但并非同一具体事件
|
||||
- 无直接事件、因果或逻辑关联
|
||||
|
||||
# 输出规则(非常重要)
|
||||
你必须按照以下JSON格式输出:
|
||||
|
||||
**如果是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": true,
|
||||
"output": false
|
||||
}
|
||||
```
|
||||
|
||||
**如果不是同一事件:**
|
||||
```json
|
||||
{
|
||||
"is_same_event": false,
|
||||
"output": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "输入句子的内容"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "对应的回复内容"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
# JSON Schema
|
||||
{{json_schema}}
|
||||
|
||||
# 注意事项
|
||||
- 必须严格按照上述格式输出
|
||||
- output 字段:如果是同一事件返回 false,如果不是同一事件返回完整的消息列表
|
||||
- 消息列表必须包含 role 和 content 字段
|
||||
- 不要输出任何解释、分析或多余内容
|
||||
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
186
api/app/core/memory/agent/utils/redis_base.py
Normal file
@@ -0,0 +1,186 @@
|
||||
import json
|
||||
from typing import Any, List, Dict, Optional
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def serialize_messages(messages: Any) -> str:
|
||||
"""
|
||||
将消息序列化为 JSON 字符串,支持 LangChain 消息对象
|
||||
|
||||
Args:
|
||||
messages: 可以是 list、dict、string 或 LangChain 消息对象列表
|
||||
|
||||
Returns:
|
||||
str: JSON 字符串
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
|
||||
if isinstance(messages, (list, tuple)):
|
||||
# 检查是否是 LangChain 消息对象列表
|
||||
serialized_list = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# LangChain 消息对象
|
||||
serialized_list.append({
|
||||
'type': msg.type,
|
||||
'content': msg.content,
|
||||
'role': getattr(msg, 'role', msg.type)
|
||||
})
|
||||
elif isinstance(msg, dict):
|
||||
serialized_list.append(msg)
|
||||
else:
|
||||
serialized_list.append(str(msg))
|
||||
return json.dumps(serialized_list, ensure_ascii=False)
|
||||
|
||||
if isinstance(messages, dict):
|
||||
return json.dumps(messages, ensure_ascii=False)
|
||||
|
||||
# 其他类型转为字符串
|
||||
return str(messages)
|
||||
|
||||
|
||||
def deserialize_messages(messages_str: str) -> Any:
|
||||
"""
|
||||
将 JSON 字符串反序列化为原始格式
|
||||
|
||||
Args:
|
||||
messages_str: JSON 字符串
|
||||
|
||||
Returns:
|
||||
反序列化后的对象(list、dict 或 string)
|
||||
"""
|
||||
if not messages_str:
|
||||
return []
|
||||
|
||||
try:
|
||||
return json.loads(messages_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return messages_str
|
||||
|
||||
|
||||
def fix_encoding(text: str) -> str:
|
||||
"""
|
||||
修复错误编码的文本
|
||||
|
||||
Args:
|
||||
text: 需要修复的文本
|
||||
|
||||
Returns:
|
||||
str: 修复后的文本
|
||||
"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
|
||||
def format_session_data(data: Dict[str, Any], include_time: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
格式化会话数据为统一的输出格式
|
||||
|
||||
Args:
|
||||
data: 原始会话数据
|
||||
include_time: 是否包含时间字段
|
||||
|
||||
Returns:
|
||||
Dict: 格式化后的数据 {"Query": "...", "Answer": "...", "starttime": "..."}
|
||||
"""
|
||||
result = {
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": fix_encoding(data.get('aimessages', ''))
|
||||
}
|
||||
|
||||
if include_time:
|
||||
result["starttime"] = data.get('starttime', '')
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def filter_by_time_range(items: List[Dict], minutes: int) -> List[Dict]:
|
||||
"""
|
||||
根据时间范围过滤数据
|
||||
|
||||
Args:
|
||||
items: 包含 starttime 字段的数据列表
|
||||
minutes: 时间范围(分钟)
|
||||
|
||||
Returns:
|
||||
List[Dict]: 过滤后的数据列表
|
||||
"""
|
||||
time_threshold = datetime.now() - timedelta(minutes=minutes)
|
||||
time_threshold_str = time_threshold.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
starttime = item.get('starttime', '')
|
||||
if starttime and starttime >= time_threshold_str:
|
||||
filtered_items.append(item)
|
||||
|
||||
return filtered_items
|
||||
|
||||
|
||||
def sort_and_limit_results(items: List[Dict], limit: int = 6,
|
||||
remove_time: bool = True) -> List[Dict]:
|
||||
"""
|
||||
对结果进行排序、限制数量并移除时间字段
|
||||
|
||||
Args:
|
||||
items: 数据列表
|
||||
limit: 最大返回数量
|
||||
remove_time: 是否移除 starttime 字段
|
||||
|
||||
Returns:
|
||||
List[Dict]: 处理后的数据列表
|
||||
"""
|
||||
# 按时间降序排序(最新的在前)
|
||||
items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
# 限制数量
|
||||
result_items = items[:limit]
|
||||
|
||||
# 移除 starttime 字段
|
||||
if remove_time:
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于1条,返回空列表
|
||||
if len(result_items) < 1:
|
||||
return []
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
def generate_session_key(session_id: str, key_type: str = "session") -> str:
|
||||
"""
|
||||
生成 Redis key
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
key_type: key 类型 ("session", "read", "write", "count")
|
||||
|
||||
Returns:
|
||||
str: Redis key
|
||||
"""
|
||||
if key_type == "count":
|
||||
return f"session:count:{session_id}"
|
||||
elif key_type == "write":
|
||||
return f"session:write:{session_id}"
|
||||
elif key_type == "session" or key_type == "read":
|
||||
return f"session:{session_id}"
|
||||
else:
|
||||
return f"session:{session_id}"
|
||||
|
||||
|
||||
def get_current_timestamp() -> str:
|
||||
"""
|
||||
获取当前时间戳字符串
|
||||
|
||||
Returns:
|
||||
str: 格式化的时间字符串 "YYYY-MM-DD HH:MM:SS"
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
@@ -1,11 +1,36 @@
|
||||
import redis
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from app.core.config import settings
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from app.core.memory.agent.utils.redis_base import (
|
||||
serialize_messages,
|
||||
deserialize_messages,
|
||||
fix_encoding,
|
||||
format_session_data,
|
||||
filter_by_time_range,
|
||||
sort_and_limit_results,
|
||||
generate_session_key,
|
||||
get_current_timestamp
|
||||
)
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
|
||||
|
||||
class RedisWriteStore:
|
||||
"""Redis Write 类型存储类,用于管理 save_session_write 相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
@@ -16,32 +41,437 @@ class RedisSessionStore:
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def _fix_encoding(self, text):
|
||||
"""修复错误编码的文本"""
|
||||
if not text or not isinstance(text, str):
|
||||
return text
|
||||
try:
|
||||
# 尝试修复 Latin-1 误编码为 UTF-8 的情况
|
||||
return text.encode('latin-1').decode('utf-8')
|
||||
except (UnicodeDecodeError, UnicodeEncodeError):
|
||||
# 如果修复失败,返回原文本
|
||||
return text
|
||||
|
||||
# 修改后的 save_session 方法
|
||||
def save_session(self, userid, messages, aimessages, apply_id, end_user_id):
|
||||
def save_session_write(self, userid: str, messages: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
优化版本:确保写入时间不超过1秒
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4()) # 为每次会话生成新的 ID
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}" # 使用新生成的 session_id 作为 key
|
||||
messages = serialize_messages(messages)
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="write")
|
||||
|
||||
# 使用 pipeline 批量写入,减少网络往返
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
"messages": messages,
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
result = pipe.execute()
|
||||
|
||||
# 直接写入数据,decode_responses=True 已经处理了编码
|
||||
print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"[save_session_write] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]:
|
||||
"""
|
||||
通过 save_session_write 的 userid 获取 sessionid 和 messages
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回 [{"sessionid": "...", "messages": "..."}, ...],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
results.append({
|
||||
"sessionid": session_id,
|
||||
"messages": fix_encoding(data.get('messages', ''))
|
||||
})
|
||||
|
||||
if not results:
|
||||
return False
|
||||
|
||||
print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_session_by_userid] 查询失败: {e}")
|
||||
return False
|
||||
|
||||
def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]:
|
||||
"""
|
||||
通过 end_user_id 获取所有 write 类型的会话数据
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID (对应 sessionid 字段)
|
||||
|
||||
Returns:
|
||||
List[Dict] 或 False: 如果找到数据返回完整的会话信息列表,否则返回 False
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{
|
||||
"session_id": "uuid",
|
||||
"id": "...",
|
||||
"sessionid": "end_user_id",
|
||||
"messages": "...",
|
||||
"starttime": "timestamp"
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
try:
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话")
|
||||
return False
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 end_user_id 的数据
|
||||
results = []
|
||||
for key, data in zip(keys, all_data):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == end_user_id:
|
||||
# 从 key 中提取 session_id: session:write:{session_id}
|
||||
session_id = key.split(':')[-1]
|
||||
|
||||
# 构建完整的会话信息
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"id": data.get('id', ''),
|
||||
"sessionid": data.get('sessionid', ''),
|
||||
"messages": fix_encoding(data.get('messages', '')),
|
||||
"starttime": data.get('starttime', '')
|
||||
}
|
||||
results.append(session_info)
|
||||
|
||||
if not results:
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据")
|
||||
return False
|
||||
|
||||
# 按时间排序(最新的在前)
|
||||
results.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
|
||||
print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据")
|
||||
return results
|
||||
except Exception as e:
|
||||
print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def find_user_recent_sessions(self, userid: str,
|
||||
minutes: int = 5) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据
|
||||
|
||||
Args:
|
||||
userid: 用户ID (对应 sessionid 字段)
|
||||
minutes: 查询最近几分钟的数据,默认5分钟
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 只查询 write 类型的 key
|
||||
keys = self.r.keys('session:write:*')
|
||||
if not keys:
|
||||
print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合 userid 的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 从 write 类型读取,匹配 sessionid 字段
|
||||
if data.get('sessionid') == userid and data.get('starttime'):
|
||||
# write 类型没有 aimessages,所以 Answer 为空
|
||||
matched_items.append({
|
||||
"Query": fix_encoding(data.get('messages', '')),
|
||||
"Answer": "",
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
|
||||
# 根据时间范围过滤
|
||||
filtered_items = filter_by_time_range(matched_items, minutes)
|
||||
# 排序并移除时间字段
|
||||
result_items = sort_and_limit_results(filtered_items, limit=None)
|
||||
print(result_items)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, "
|
||||
f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
def delete_all_write_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 write 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:write:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisCountStore:
|
||||
"""Redis Count 类型存储类,用于管理访问次数统计相关的数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str:
|
||||
"""
|
||||
保存用户访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
count: 访问次数
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
index_key = f'session:count:index:{end_user_id}' # 索引键
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"end_user_id": end_user_id,
|
||||
"count": int(count),
|
||||
"messages": serialize_messages(messages),
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期
|
||||
|
||||
# 创建索引:end_user_id -> session_id 映射
|
||||
pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60)
|
||||
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}")
|
||||
return session_id
|
||||
|
||||
def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]:
|
||||
"""
|
||||
通过 end_user_id 查询访问次数统计
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
list 或 False: 如果找到返回 [count, messages],否则返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[get_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
return False
|
||||
|
||||
# 直接获取数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
# 索引存在但数据不存在,清理索引
|
||||
self.r.delete(index_key)
|
||||
return False
|
||||
|
||||
count = data.get('count')
|
||||
messages_str = data.get('messages')
|
||||
|
||||
if count is not None:
|
||||
messages = deserialize_messages(messages_str)
|
||||
return [int(count), messages]
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"[get_sessions_count] 查询失败: {e}")
|
||||
return False
|
||||
def update_sessions_count(self, end_user_id: str, new_count: int,
|
||||
messages: Any) -> bool:
|
||||
"""
|
||||
通过 end_user_id 修改访问次数统计(优化版:使用索引)
|
||||
|
||||
Args:
|
||||
end_user_id: 终端用户ID
|
||||
new_count: 新的 count 值
|
||||
messages: 消息内容
|
||||
|
||||
Returns:
|
||||
bool: 更新成功返回 True,未找到记录返回 False
|
||||
"""
|
||||
try:
|
||||
# 使用索引键快速查找
|
||||
index_key = f'session:count:index:{end_user_id}'
|
||||
|
||||
# 检查索引键类型,避免 WRONGTYPE 错误
|
||||
try:
|
||||
key_type = self.r.type(index_key)
|
||||
if key_type != 'string' and key_type != 'none':
|
||||
# 索引键类型错误,删除并返回 False
|
||||
print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引")
|
||||
self.r.delete(index_key)
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
except Exception as type_error:
|
||||
print(f"[update_sessions_count] 检查键类型失败: {type_error}")
|
||||
|
||||
session_id = self.r.get(index_key)
|
||||
|
||||
if not session_id:
|
||||
print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}")
|
||||
return False
|
||||
|
||||
# 直接更新数据
|
||||
key = generate_session_key(session_id, key_type="count")
|
||||
messages_str = serialize_messages(messages)
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, 'count', int(new_count))
|
||||
pipe.hset(key, 'messages', messages_str)
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"[update_sessions_count] 更新失败: {e}")
|
||||
return False
|
||||
|
||||
def delete_all_count_sessions(self) -> int:
|
||||
"""
|
||||
删除所有 count 类型的会话
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:count:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
return 0
|
||||
|
||||
|
||||
class RedisSessionStore:
|
||||
"""Redis 会话存储类,用于管理会话数据"""
|
||||
|
||||
def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''):
|
||||
"""
|
||||
初始化 Redis 连接
|
||||
|
||||
Args:
|
||||
host: Redis 主机地址
|
||||
port: Redis 端口
|
||||
db: Redis 数据库编号
|
||||
password: Redis 密码
|
||||
session_id: 会话ID
|
||||
"""
|
||||
self.r = redis.Redis(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password,
|
||||
decode_responses=True,
|
||||
encoding='utf-8'
|
||||
)
|
||||
self.uudi = session_id
|
||||
|
||||
# ==================== 写入操作 ====================
|
||||
|
||||
def save_session(self, userid: str, messages: str, aimessages: str,
|
||||
apply_id: str, end_user_id: str) -> str:
|
||||
"""
|
||||
写入一条会话数据,返回 session_id
|
||||
|
||||
Args:
|
||||
userid: 用户ID
|
||||
messages: 用户消息
|
||||
aimessages: AI回复消息
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
str: 新生成的 session_id
|
||||
"""
|
||||
try:
|
||||
session_id = str(uuid.uuid4())
|
||||
key = generate_session_key(session_id, key_type="read")
|
||||
|
||||
pipe = self.r.pipeline()
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": userid,
|
||||
@@ -49,177 +479,195 @@ class RedisSessionStore:
|
||||
"end_user_id": end_user_id,
|
||||
"messages": messages,
|
||||
"aimessages": aimessages,
|
||||
"starttime": starttime
|
||||
"starttime": get_current_timestamp()
|
||||
})
|
||||
|
||||
# 可选:设置过期时间(例如30天),避免数据无限增长
|
||||
# pipe.expire(key, 30 * 24 * 60 * 60)
|
||||
|
||||
# 执行批量操作
|
||||
result = pipe.execute()
|
||||
|
||||
print(f"保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id # 返回新生成的 session_id
|
||||
print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}")
|
||||
return session_id
|
||||
except Exception as e:
|
||||
print(f"保存会话失败: {e}")
|
||||
print(f"[save_session] 保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
def save_sessions_batch(self, sessions_data):
|
||||
"""
|
||||
批量写入多条会话数据,返回 session_id 列表
|
||||
sessions_data: list of dict, 每个 dict 包含 userid, messages, aimessages, apply_id, end_user_id
|
||||
优化版本:批量操作,大幅提升性能
|
||||
"""
|
||||
try:
|
||||
session_ids = []
|
||||
pipe = self.r.pipeline()
|
||||
|
||||
for session in sessions_data:
|
||||
session_id = str(uuid.uuid4())
|
||||
starttime = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
key = f"session:{session_id}"
|
||||
|
||||
pipe.hset(key, mapping={
|
||||
"id": self.uudi,
|
||||
"sessionid": session.get('userid'),
|
||||
"apply_id": session.get('apply_id'),
|
||||
"end_user_id": session.get('end_user_id'),
|
||||
"messages": session.get('messages'),
|
||||
"aimessages": session.get('aimessages'),
|
||||
"starttime": starttime
|
||||
})
|
||||
|
||||
session_ids.append(session_id)
|
||||
|
||||
# 一次性执行所有写入操作
|
||||
results = pipe.execute()
|
||||
print(f"批量保存完成: {len(session_ids)} 条记录")
|
||||
return session_ids
|
||||
except Exception as e:
|
||||
print(f"批量保存会话失败: {e}")
|
||||
raise e
|
||||
|
||||
# ---------------- 读取 ----------------
|
||||
def get_session(self, session_id):
|
||||
# ==================== 读取操作 ====================
|
||||
|
||||
def get_session(self, session_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
读取一条会话数据
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
Dict 或 None: 会话数据
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
data = self.r.hgetall(key)
|
||||
return data if data else None
|
||||
|
||||
def get_session_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
def get_all_sessions(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据
|
||||
"""
|
||||
result_items = []
|
||||
|
||||
# 遍历所有会话数据
|
||||
for key in self.r.keys('session:*'):
|
||||
data = self.r.hgetall(key)
|
||||
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查三个条件是否都匹配
|
||||
if (data.get('sessionid') == sessionid and
|
||||
data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
result_items.append(data)
|
||||
|
||||
return result_items
|
||||
|
||||
def get_all_sessions(self):
|
||||
"""
|
||||
获取所有会话数据
|
||||
获取所有会话数据(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
Dict: 所有会话数据,key 为 session_id
|
||||
"""
|
||||
sessions = {}
|
||||
for key in self.r.keys('session:*'):
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
# 排除 count 和 write 类型的 key
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
sid = key.split(':')[1]
|
||||
sessions[sid] = self.get_session(sid)
|
||||
return sessions
|
||||
|
||||
# ---------------- 更新 ----------------
|
||||
def update_session(self, session_id, field, value):
|
||||
def find_user_apply_group(self, sessionid: str, apply_id: str,
|
||||
end_user_id: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条
|
||||
|
||||
Args:
|
||||
sessionid: 会话ID(支持模糊匹配)
|
||||
apply_id: 应用ID
|
||||
end_user_id: 终端用户ID
|
||||
|
||||
Returns:
|
||||
List[Dict]: 会话列表 [{"Query": "...", "Answer": "..."}, ...]
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
keys = self.r.keys('session:*')
|
||||
if not keys:
|
||||
print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 批量获取数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配或完全匹配 sessionid
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append(format_session_data(data, include_time=True))
|
||||
|
||||
# 排序、限制数量并移除时间字段
|
||||
result_items = sort_and_limit_results(matched_items, limit=6)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
# ==================== 更新操作 ====================
|
||||
|
||||
def update_session(self, session_id: str, field: str, value: Any) -> bool:
|
||||
"""
|
||||
更新单个字段
|
||||
优化版本:使用 pipeline 减少网络往返
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
field: 字段名
|
||||
value: 字段值
|
||||
|
||||
Returns:
|
||||
bool: 是否更新成功
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
pipe = self.r.pipeline()
|
||||
pipe.exists(key)
|
||||
pipe.hset(key, field, value)
|
||||
results = pipe.execute()
|
||||
return bool(results[0]) # 返回 key 是否存在
|
||||
return bool(results[0])
|
||||
|
||||
# ---------------- 删除 ----------------
|
||||
def delete_session(self, session_id):
|
||||
# ==================== 删除操作 ====================
|
||||
|
||||
def delete_session(self, session_id: str) -> int:
|
||||
"""
|
||||
删除单条会话
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
key = f"session:{session_id}"
|
||||
key = generate_session_key(session_id)
|
||||
return self.r.delete(key)
|
||||
|
||||
def delete_all_sessions(self):
|
||||
def delete_all_sessions(self) -> int:
|
||||
"""
|
||||
删除所有会话
|
||||
删除所有会话(不包括 count 和 write 类型)
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
keys = self.r.keys('session:*')
|
||||
if keys:
|
||||
return self.r.delete(*keys)
|
||||
# 过滤掉 count 和 write 类型
|
||||
keys_to_delete = [k for k in keys if ':count:' not in k and ':write:' not in k]
|
||||
if keys_to_delete:
|
||||
return self.r.delete(*keys_to_delete)
|
||||
return 0
|
||||
|
||||
def delete_duplicate_sessions(self):
|
||||
def delete_duplicate_sessions(self) -> int:
|
||||
"""
|
||||
删除重复会话数据,条件:
|
||||
"sessionid"、"user_id"、"end_user_id"、"messages"、"aimessages" 五个字段都相同的只保留一个,其他删除
|
||||
优化版本:使用 pipeline 批量操作,确保在1秒内完成
|
||||
删除重复会话数据(不包括 count 和 write 类型)
|
||||
条件:sessionid、user_id、end_user_id、messages、aimessages 五个字段都相同的只保留一个
|
||||
|
||||
Returns:
|
||||
int: 删除的数量
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 第一步:使用 pipeline 批量获取所有 key
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print("[delete_duplicate_sessions] 没有会话数据")
|
||||
return 0
|
||||
|
||||
# 第二步:使用 pipeline 批量获取所有数据
|
||||
# 批量获取所有数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
# 排除 count 和 write 类型
|
||||
if ':count:' not in key and ':write:' not in key:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 第三步:在内存中识别重复数据
|
||||
seen = {} # 用字典记录:identifier -> key(保留第一个出现的 key)
|
||||
keys_to_delete = [] # 需要删除的 key 列表
|
||||
# 识别重复数据
|
||||
seen = {}
|
||||
keys_to_delete = []
|
||||
|
||||
for key, data in zip(keys, all_data, strict=False):
|
||||
for key, data in zip([k for k in keys if ':count:' not in k and ':write:' not in k], all_data, strict=False):
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 获取五个字段的值
|
||||
sessionid = data.get('sessionid', '')
|
||||
user_id = data.get('id', '')
|
||||
end_user_id = data.get('end_user_id', '')
|
||||
messages = data.get('messages', '')
|
||||
aimessages = data.get('aimessages', '')
|
||||
|
||||
# 用五元组作为唯一标识
|
||||
identifier = (sessionid, user_id, end_user_id, messages, aimessages)
|
||||
identifier = (
|
||||
data.get('sessionid', ''),
|
||||
data.get('id', ''),
|
||||
data.get('end_user_id', ''),
|
||||
data.get('messages', ''),
|
||||
data.get('aimessages', '')
|
||||
)
|
||||
|
||||
if identifier in seen:
|
||||
# 重复,标记为待删除
|
||||
keys_to_delete.append(key)
|
||||
else:
|
||||
# 第一次出现,记录
|
||||
seen[identifier] = key
|
||||
|
||||
# 第四步:使用 pipeline 批量删除重复的 key
|
||||
# 批量删除重复的 key
|
||||
deleted_count = 0
|
||||
if keys_to_delete:
|
||||
# 分批删除,避免单次操作过大
|
||||
batch_size = 1000
|
||||
for i in range(0, len(keys_to_delete), batch_size):
|
||||
batch = keys_to_delete[i:i + batch_size]
|
||||
@@ -233,79 +681,28 @@ class RedisSessionStore:
|
||||
print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒")
|
||||
return deleted_count
|
||||
|
||||
def find_user_session(self, sessionid):
|
||||
user_id = sessionid
|
||||
|
||||
result_items = []
|
||||
for key, values in store.get_all_sessions().items():
|
||||
history = {}
|
||||
if user_id == str(values['sessionid']):
|
||||
history["Query"] = values['messages']
|
||||
history["Answer"] = values['aimessages']
|
||||
result_items.append(history)
|
||||
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
return (result_items)
|
||||
|
||||
def find_user_apply_group(self, sessionid, apply_id, end_user_id):
|
||||
"""
|
||||
根据 sessionid、apply_id 和 end_user_id 三个条件查询会话数据,返回最新的6条
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
# 使用 pipeline 批量获取数据,提高性能
|
||||
keys = self.r.keys('session:*')
|
||||
|
||||
if not keys:
|
||||
print(f"查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0")
|
||||
return []
|
||||
|
||||
# 使用 pipeline 批量获取所有 hash 数据
|
||||
pipe = self.r.pipeline()
|
||||
for key in keys:
|
||||
pipe.hgetall(key)
|
||||
all_data = pipe.execute()
|
||||
|
||||
# 解析并筛选符合条件的数据
|
||||
matched_items = []
|
||||
for data in all_data:
|
||||
if not data:
|
||||
continue
|
||||
|
||||
# 检查是否符合三个条件
|
||||
|
||||
if (data.get('apply_id') == apply_id and
|
||||
data.get('end_user_id') == end_user_id):
|
||||
# 支持模糊匹配 sessionid 或者完全匹配
|
||||
if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid:
|
||||
matched_items.append({
|
||||
"Query": self._fix_encoding(data.get('messages')),
|
||||
"Answer": self._fix_encoding(data.get('aimessages')),
|
||||
"starttime": data.get('starttime', '')
|
||||
})
|
||||
# 按时间降序排序(最新的在前)
|
||||
matched_items.sort(key=lambda x: x.get('starttime', ''), reverse=True)
|
||||
# 只保留最新的6条
|
||||
result_items = matched_items[:6]
|
||||
# # 移除 starttime 字段
|
||||
for item in result_items:
|
||||
item.pop('starttime', None)
|
||||
|
||||
# 如果结果少于等于1条,返回空列表
|
||||
if len(result_items) <= 1:
|
||||
result_items = []
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
print(f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}")
|
||||
|
||||
return result_items
|
||||
|
||||
|
||||
# 全局实例
|
||||
store = RedisSessionStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
)
|
||||
|
||||
write_store = RedisWriteStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
count_store = RedisCountStore(
|
||||
host=settings.REDIS_HOST,
|
||||
port=settings.REDIS_PORT,
|
||||
db=settings.REDIS_DB,
|
||||
password=settings.REDIS_PASSWORD if settings.REDIS_PASSWORD else None,
|
||||
session_id=str(uuid.uuid4())
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ Write Tools for Memory Knowledge Extraction Pipeline
|
||||
This module provides the main write function for executing the knowledge extraction
|
||||
pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
@@ -33,17 +34,17 @@ async def write(
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
language: str = "zh",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
end_user_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
# Extract config values
|
||||
embedding_model_id = str(memory_config.embedding_model_id)
|
||||
@@ -93,12 +94,39 @@ async def write(
|
||||
from app.core.memory.utils.config.config_utils import get_pipeline_config
|
||||
pipeline_config = get_pipeline_config(memory_config)
|
||||
|
||||
# Fetch ontology types if scene_id is configured
|
||||
ontology_types = None
|
||||
if memory_config.scene_id:
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import load_ontology_types_for_scene
|
||||
|
||||
with get_db_context() as db:
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=memory_config.scene_id,
|
||||
workspace_id=memory_config.workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
if ontology_types:
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {memory_config.scene_id}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No ontology classes found for scene_id: {memory_config.scene_id}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to fetch ontology types for scene_id {memory_config.scene_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
orchestrator = ExtractionOrchestrator(
|
||||
llm_client=llm_client,
|
||||
embedder_client=embedder_client,
|
||||
connector=neo4j_connector,
|
||||
config=pipeline_config,
|
||||
embedding_id=embedding_model_id,
|
||||
language=language,
|
||||
ontology_types=ontology_types,
|
||||
)
|
||||
|
||||
# Run the complete extraction pipeline
|
||||
@@ -123,23 +151,48 @@ async def write(
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating indexes: {e}", exc_info=True)
|
||||
|
||||
# 添加死锁重试机制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
break
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"Retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
# 检查是否是死锁错误
|
||||
if "DeadlockDetected" in error_msg or "deadlock" in error_msg.lower():
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(f"Deadlock detected, retrying... (attempt {attempt + 2}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay * (attempt + 1)) # 指数退避
|
||||
else:
|
||||
logger.error(f"Failed after {max_retries} attempts due to deadlock: {e}")
|
||||
raise
|
||||
else:
|
||||
# 非死锁错误,直接抛出
|
||||
raise
|
||||
|
||||
try:
|
||||
success = await save_dialog_and_statements_to_neo4j(
|
||||
dialogue_nodes=all_dialogue_nodes,
|
||||
chunk_nodes=all_chunk_nodes,
|
||||
statement_nodes=all_statement_nodes,
|
||||
entity_nodes=all_entity_nodes,
|
||||
statement_chunk_edges=all_statement_chunk_edges,
|
||||
statement_entity_edges=all_statement_entity_edges,
|
||||
entity_edges=all_entity_entity_edges,
|
||||
connector=neo4j_connector
|
||||
)
|
||||
if success:
|
||||
logger.info("Successfully saved all data to Neo4j")
|
||||
else:
|
||||
logger.warning("Failed to save some data to Neo4j")
|
||||
finally:
|
||||
await neo4j_connector.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing Neo4j connector: {e}")
|
||||
|
||||
log_time("Neo4j Database Save", time.time() - step_start, log_file)
|
||||
|
||||
@@ -147,7 +200,7 @@ async def write(
|
||||
step_start = time.time()
|
||||
try:
|
||||
summaries = await memory_summary_generation(
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client
|
||||
chunked_dialogs, llm_client=llm_client, embedder_client=embedder_client, language=language
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -173,4 +226,4 @@ async def write(
|
||||
f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n")
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
@@ -39,16 +39,20 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id:
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
|
||||
# Use the config_id to get the proper LLM client
|
||||
# Use the config_id to get the proper LLM client with workspace fallback
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
|
||||
@@ -108,7 +108,6 @@ class DimensionAnalyzer:
|
||||
|
||||
# Create dimension portrait
|
||||
portrait = DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=dimension_scores["creativity"],
|
||||
aesthetic=dimension_scores["aesthetic"],
|
||||
technology=dimension_scores["technology"],
|
||||
@@ -220,7 +219,7 @@ class DimensionAnalyzer:
|
||||
"""Create an empty dimension portrait when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty DimensionPortrait
|
||||
@@ -228,7 +227,6 @@ class DimensionAnalyzer:
|
||||
current_time = datetime.now()
|
||||
|
||||
return DimensionPortrait(
|
||||
user_id=user_id,
|
||||
creativity=self._create_default_dimension_score("creativity"),
|
||||
aesthetic=self._create_default_dimension_score("aesthetic"),
|
||||
technology=self._create_default_dimension_score("technology"),
|
||||
|
||||
@@ -7,7 +7,7 @@ providing percentage distribution that totals 100%.
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from app.core.memory.analytics.implicit_memory.llm_client import ImplicitMemoryLLMClient
|
||||
from app.core.memory.llm_tools.llm_client import LLMClientException
|
||||
@@ -133,7 +133,6 @@ class InterestAnalyzer:
|
||||
|
||||
# Create interest area distribution
|
||||
distribution = InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=interest_categories["tech"],
|
||||
lifestyle=interest_categories["lifestyle"],
|
||||
music=interest_categories["music"],
|
||||
@@ -251,7 +250,7 @@ class InterestAnalyzer:
|
||||
"""Create an empty interest distribution when no data is available.
|
||||
|
||||
Args:
|
||||
user_id: Target user ID
|
||||
user_id: Target user ID (used for logging only)
|
||||
|
||||
Returns:
|
||||
Empty InterestAreaDistribution with equal percentages
|
||||
@@ -259,15 +258,15 @@ class InterestAnalyzer:
|
||||
current_time = datetime.now()
|
||||
equal_percentage = 25.0 # 100% / 4 categories
|
||||
|
||||
default_category = lambda name: InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
def default_category(name: str) -> InterestCategory:
|
||||
return InterestCategory(
|
||||
category_name=name,
|
||||
percentage=equal_percentage,
|
||||
evidence=["Insufficient data for analysis"],
|
||||
trending_direction=None
|
||||
)
|
||||
|
||||
return InterestAreaDistribution(
|
||||
user_id=user_id,
|
||||
tech=default_category("tech"),
|
||||
lifestyle=default_category("lifestyle"),
|
||||
music=default_category("music"),
|
||||
|
||||
@@ -16,6 +16,7 @@ Summary {{ loop.index }}:
|
||||
3. DO NOT use long phrases - use short nouns or noun phrases
|
||||
4. Only include preferences with confidence_score >= 0.3
|
||||
5. **IMPORTANT: Output language MUST match the input language. If summaries are in Chinese, output in Chinese. If in English, output in English.**
|
||||
6. **CRITICAL: supporting_evidence must be DIRECT QUOTES or paraphrases from the user's actual statements. DO NOT reference summary numbers (e.g., "Summary 1", "摘要1"). DO NOT describe what the summary contains. Extract the actual user behavior or statement as evidence.**
|
||||
|
||||
## Output Format
|
||||
{
|
||||
@@ -38,6 +39,16 @@ Summary {{ loop.index }}:
|
||||
]
|
||||
}
|
||||
|
||||
## BAD supporting_evidence examples (DO NOT do this):
|
||||
- "Summary 1:西湖为核心景区" ❌
|
||||
- "摘要2中提到喜欢咖啡" ❌
|
||||
- "Based on Summary 3" ❌
|
||||
|
||||
## GOOD supporting_evidence examples:
|
||||
- "去过西湖断桥、苏堤" ✓
|
||||
- "每天早上喝咖啡" ✓
|
||||
- "mentioned visiting the lake twice" ✓
|
||||
|
||||
## Example (English input → English output)
|
||||
{
|
||||
"preferences": [
|
||||
|
||||
@@ -58,6 +58,25 @@ from app.core.memory.models.triplet_models import (
|
||||
TripletExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology scenario models (LLM extracted from scenarios)
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
# Ontology extraction models (for extraction flow)
|
||||
from app.core.memory.models.ontology_extraction_models import (
|
||||
OntologyTypeInfo,
|
||||
OntologyTypeList,
|
||||
)
|
||||
|
||||
# Ontology general models (loaded from external ontology files)
|
||||
from app.core.memory.models.ontology_general_models import (
|
||||
OntologyFileFormat,
|
||||
GeneralOntologyType,
|
||||
GeneralOntologyTypeRegistry,
|
||||
)
|
||||
|
||||
# Variable configuration models
|
||||
from app.core.memory.models.variate_config import (
|
||||
StatementExtractionConfig,
|
||||
@@ -105,6 +124,16 @@ __all__ = [
|
||||
"Entity",
|
||||
"Triplet",
|
||||
"TripletExtractionResponse",
|
||||
# Ontology models
|
||||
"OntologyClass",
|
||||
"OntologyExtractionResponse",
|
||||
# Ontology type models for extraction flow
|
||||
"OntologyTypeInfo",
|
||||
"OntologyTypeList",
|
||||
# General ontology type models
|
||||
"OntologyFileFormat",
|
||||
"GeneralOntologyType",
|
||||
"GeneralOntologyTypeRegistry",
|
||||
# Variable configuration
|
||||
"StatementExtractionConfig",
|
||||
"ForgettingEngineConfig",
|
||||
|
||||
@@ -413,7 +413,8 @@ class ExtractedEntityNode(Node):
|
||||
description="Entity aliases - alternative names for this entity"
|
||||
)
|
||||
name_embedding: Optional[List[float]] = Field(default_factory=list, description="Name embedding vector")
|
||||
fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary: str = Field(default="", description="Summary of the fact about this entity")
|
||||
connect_strength: str = Field(..., description="Strong VS Weak about this entity")
|
||||
config_id: Optional[int | str] = Field(None, description="Configuration ID used to process this entity (integer or string)")
|
||||
|
||||
|
||||
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
105
api/app/core/memory/models/ontology_extraction_models.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型数据结构模块
|
||||
|
||||
本模块定义用于在萃取流程中传递本体类型信息的轻量级数据类。
|
||||
|
||||
Classes:
|
||||
OntologyTypeInfo: 单个本体类型信息
|
||||
OntologyTypeList: 本体类型列表
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeInfo:
|
||||
"""本体类型信息,用于萃取流程中传递。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称
|
||||
class_description: 类型描述
|
||||
"""
|
||||
class_name: str
|
||||
class_description: str
|
||||
|
||||
def to_prompt_format(self) -> str:
|
||||
"""转换为提示词格式。
|
||||
|
||||
Returns:
|
||||
格式化的字符串,如 "- TypeName: Description"
|
||||
"""
|
||||
return f"- {self.class_name}: {self.class_description}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class OntologyTypeList:
|
||||
"""本体类型列表。
|
||||
|
||||
Attributes:
|
||||
types: 本体类型信息列表
|
||||
"""
|
||||
types: List[OntologyTypeInfo]
|
||||
|
||||
@classmethod
|
||||
def from_db_models(cls, ontology_classes: list) -> "OntologyTypeList":
|
||||
"""从数据库模型转换创建 OntologyTypeList。
|
||||
|
||||
Args:
|
||||
ontology_classes: OntologyClass 数据库模型列表,
|
||||
每个对象应包含 class_name 和 class_description 属性
|
||||
|
||||
Returns:
|
||||
包含转换后类型信息的 OntologyTypeList 实例
|
||||
"""
|
||||
types = [
|
||||
OntologyTypeInfo(
|
||||
class_name=oc.class_name,
|
||||
class_description=oc.class_description or ""
|
||||
)
|
||||
for oc in ontology_classes
|
||||
]
|
||||
return cls(types=types)
|
||||
|
||||
def to_prompt_section(self) -> str:
|
||||
"""转换为提示词中的类型列表部分。
|
||||
|
||||
Returns:
|
||||
格式化的类型列表字符串,每行一个类型;
|
||||
如果列表为空则返回空字符串
|
||||
"""
|
||||
if not self.types:
|
||||
return ""
|
||||
lines = [t.to_prompt_format() for t in self.types]
|
||||
return "\n".join(lines)
|
||||
|
||||
def get_type_names(self) -> List[str]:
|
||||
"""获取所有类型名称列表。
|
||||
|
||||
Returns:
|
||||
类型名称字符串列表
|
||||
"""
|
||||
return [t.class_name for t in self.types]
|
||||
|
||||
def get_type_hierarchy_hints(self) -> List[str]:
|
||||
"""获取类型层次结构提示列表。
|
||||
|
||||
尝试从通用本体注册表中获取每个类型的继承链信息。
|
||||
|
||||
Returns:
|
||||
层次提示字符串列表,格式为 "类型名 → 父类1 → 父类2"
|
||||
"""
|
||||
hints = []
|
||||
try:
|
||||
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
|
||||
|
||||
merger = OntologyTypeMerger()
|
||||
for type_info in self.types:
|
||||
hint = merger.get_type_hierarchy_hint(type_info.class_name)
|
||||
if hint:
|
||||
hints.append(hint)
|
||||
except Exception:
|
||||
# 如果无法获取层次信息,返回空列表
|
||||
pass
|
||||
|
||||
return hints
|
||||
223
api/app/core/memory/models/ontology_general_models.py
Normal file
223
api/app/core/memory/models/ontology_general_models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""通用本体类型数据模型模块
|
||||
|
||||
本模块定义用于通用本体类型管理的数据结构,包括:
|
||||
- OntologyFileFormat: 本体文件格式枚举
|
||||
- GeneralOntologyType: 通用本体类型数据类
|
||||
- GeneralOntologyTypeRegistry: 通用本体类型注册表
|
||||
|
||||
Classes:
|
||||
OntologyFileFormat: 本体文件格式枚举,支持 TTL、OWL/XML、RDF/XML、N-Triples、JSON-LD
|
||||
GeneralOntologyType: 通用本体类型,包含类名、URI、标签、描述、父类等信息
|
||||
GeneralOntologyTypeRegistry: 类型注册表,管理类型集合和层次结构
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyFileFormat(Enum):
|
||||
"""本体文件格式枚举
|
||||
|
||||
支持的格式:
|
||||
- TURTLE: Turtle 格式 (.ttl 文件)
|
||||
- RDF_XML: RDF/XML 格式 (.owl, .rdf 文件)
|
||||
- N_TRIPLES: N-Triples 格式 (.nt 文件)
|
||||
- JSON_LD: JSON-LD 格式 (.jsonld, .json 文件)
|
||||
"""
|
||||
TURTLE = "turtle" # .ttl 文件
|
||||
RDF_XML = "xml" # .owl, .rdf (RDF/XML 格式)
|
||||
N_TRIPLES = "nt" # .nt 文件
|
||||
JSON_LD = "json-ld" # .jsonld 文件
|
||||
|
||||
@classmethod
|
||||
def from_extension(cls, file_path: str) -> "OntologyFileFormat":
|
||||
"""根据文件扩展名推断格式
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
Returns:
|
||||
推断出的文件格式,默认返回 RDF_XML
|
||||
"""
|
||||
ext = file_path.lower().split('.')[-1]
|
||||
format_map = {
|
||||
'ttl': cls.TURTLE,
|
||||
'owl': cls.RDF_XML,
|
||||
'rdf': cls.RDF_XML,
|
||||
'nt': cls.N_TRIPLES,
|
||||
'jsonld': cls.JSON_LD,
|
||||
'json': cls.JSON_LD,
|
||||
}
|
||||
return format_map.get(ext, cls.RDF_XML)
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyType:
|
||||
"""通用本体类型
|
||||
|
||||
表示从本体文件中解析出的类型定义,包含类型的基本信息和层次关系。
|
||||
|
||||
Attributes:
|
||||
class_name: 类型名称,如 "Person"
|
||||
class_uri: 完整 URI,如 "http://dbpedia.org/ontology/Person"
|
||||
labels: 多语言标签字典,键为语言代码(如 "en", "zh"),值为标签文本
|
||||
description: 类型描述
|
||||
parent_class: 父类名称,用于构建类型层次
|
||||
source_file: 来源文件路径
|
||||
"""
|
||||
class_name: str # 类型名称,如 "Person"
|
||||
class_uri: str # 完整 URI
|
||||
labels: Dict[str, str] = field(default_factory=dict) # 多语言标签
|
||||
description: Optional[str] = None # 类型描述
|
||||
parent_class: Optional[str] = None # 父类名称
|
||||
source_file: Optional[str] = None # 来源文件
|
||||
|
||||
def get_label(self, lang: str = "en") -> str:
|
||||
"""获取指定语言的标签
|
||||
|
||||
优先返回指定语言的标签,如果不存在则尝试返回英文标签,
|
||||
最后返回类型名称作为默认值。
|
||||
|
||||
Args:
|
||||
lang: 语言代码,默认为 "en"
|
||||
|
||||
Returns:
|
||||
指定语言的标签,或默认值
|
||||
"""
|
||||
return self.labels.get(lang, self.labels.get("en", self.class_name))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneralOntologyTypeRegistry:
|
||||
"""通用本体类型注册表
|
||||
|
||||
管理解析后的本体类型集合,提供类型查询、层次遍历、注册表合并等功能。
|
||||
|
||||
Attributes:
|
||||
types: 类型字典,键为类型名称,值为 GeneralOntologyType 实例
|
||||
hierarchy: 层次结构字典,键为父类名称,值为子类名称集合
|
||||
source_files: 已加载的源文件路径列表
|
||||
"""
|
||||
types: Dict[str, GeneralOntologyType] = field(default_factory=dict)
|
||||
hierarchy: Dict[str, Set[str]] = field(default_factory=dict) # 父类 -> 子类集合
|
||||
source_files: List[str] = field(default_factory=list)
|
||||
|
||||
def get_type(self, name: str) -> Optional[GeneralOntologyType]:
|
||||
"""根据名称获取类型
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
对应的 GeneralOntologyType 实例,如果不存在则返回 None
|
||||
"""
|
||||
return self.types.get(name)
|
||||
|
||||
def get_ancestors(self, name: str) -> List[str]:
|
||||
"""获取类型的所有祖先类型(防循环)
|
||||
|
||||
从当前类型开始,沿着父类链向上遍历,返回所有祖先类型名称。
|
||||
使用 visited 集合防止循环引用导致的无限循环。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
祖先类型名称列表,按从近到远的顺序排列
|
||||
"""
|
||||
ancestors = []
|
||||
current = name
|
||||
visited = set()
|
||||
while current and current not in visited:
|
||||
visited.add(current)
|
||||
type_info = self.types.get(current)
|
||||
if type_info and type_info.parent_class:
|
||||
# 检测循环引用
|
||||
if type_info.parent_class in visited:
|
||||
logger.warning(
|
||||
f"检测到类型层次循环引用: {current} -> {type_info.parent_class},"
|
||||
f"已遍历路径: {' -> '.join([name] + ancestors)}"
|
||||
)
|
||||
break
|
||||
ancestors.append(type_info.parent_class)
|
||||
current = type_info.parent_class
|
||||
else:
|
||||
break
|
||||
return ancestors
|
||||
|
||||
def get_descendants(self, name: str) -> Set[str]:
|
||||
"""获取类型的所有后代类型
|
||||
|
||||
从当前类型开始,沿着子类关系向下遍历,返回所有后代类型名称。
|
||||
使用广度优先搜索,避免重复处理已访问的类型。
|
||||
|
||||
Args:
|
||||
name: 类型名称
|
||||
|
||||
Returns:
|
||||
后代类型名称集合
|
||||
"""
|
||||
descendants: Set[str] = set()
|
||||
to_process = [name]
|
||||
while to_process:
|
||||
current = to_process.pop()
|
||||
children = self.hierarchy.get(current, set())
|
||||
new_children = children - descendants
|
||||
descendants.update(new_children)
|
||||
to_process.extend(new_children)
|
||||
return descendants
|
||||
|
||||
def merge(self, other: "GeneralOntologyTypeRegistry") -> None:
|
||||
"""合并另一个注册表(先加载的优先)
|
||||
|
||||
将另一个注册表的类型和层次结构合并到当前注册表。
|
||||
对于同名类型,保留当前注册表中已存在的定义(先加载优先)。
|
||||
层次结构会合并所有子类关系。
|
||||
|
||||
Args:
|
||||
other: 要合并的另一个注册表
|
||||
"""
|
||||
for name, type_info in other.types.items():
|
||||
if name not in self.types:
|
||||
self.types[name] = type_info
|
||||
for parent, children in other.hierarchy.items():
|
||||
if parent not in self.hierarchy:
|
||||
self.hierarchy[parent] = set()
|
||||
self.hierarchy[parent].update(children)
|
||||
self.source_files.extend(other.source_files)
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
"""获取注册表统计信息
|
||||
|
||||
Returns:
|
||||
包含以下键的字典:
|
||||
- total_types: 总类型数
|
||||
- root_types: 根类型数(无父类的类型)
|
||||
- max_depth: 类型层次的最大深度
|
||||
- source_files: 源文件列表
|
||||
"""
|
||||
return {
|
||||
"total_types": len(self.types),
|
||||
"root_types": len([t for t in self.types.values() if not t.parent_class]),
|
||||
"max_depth": self._calculate_max_depth(),
|
||||
"source_files": self.source_files,
|
||||
}
|
||||
|
||||
def _calculate_max_depth(self) -> int:
|
||||
"""计算类型层次的最大深度
|
||||
|
||||
遍历所有类型,计算每个类型到根的深度,返回最大值。
|
||||
|
||||
Returns:
|
||||
类型层次的最大深度
|
||||
"""
|
||||
max_depth = 0
|
||||
for type_name in self.types:
|
||||
depth = len(self.get_ancestors(type_name))
|
||||
max_depth = max(max_depth, depth)
|
||||
return max_depth
|
||||
138
api/app/core/memory/models/ontology_scenario_models.py
Normal file
138
api/app/core/memory/models/ontology_scenario_models.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Models for ontology classes and extraction responses.
|
||||
|
||||
This module contains Pydantic models for representing extracted ontology classes
|
||||
from scenario descriptions, following OWL ontology engineering standards.
|
||||
|
||||
Classes:
|
||||
OntologyClass: Represents an extracted ontology class
|
||||
OntologyExtractionResponse: Response model containing extracted ontology classes
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class OntologyClass(BaseModel):
|
||||
"""Represents an extracted ontology class from scenario description.
|
||||
|
||||
An ontology class represents an abstract category or concept in a domain,
|
||||
following OWL ontology engineering standards and naming conventions.
|
||||
|
||||
Attributes:
|
||||
id: Unique string identifier for the ontology class
|
||||
name: Name of the class in PascalCase format (e.g., 'MedicalProcedure')
|
||||
name_chinese: Chinese translation of the class name (e.g., '医疗程序')
|
||||
description: Textual description of the class
|
||||
examples: List of concrete instance examples of this class
|
||||
parent_class: Optional name of the parent class in the hierarchy
|
||||
entity_type: Type/category of the entity (e.g., 'Person', 'Organization', 'Concept')
|
||||
domain: Domain this class belongs to (e.g., 'Healthcare', 'Education')
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
id: str = Field(
|
||||
default_factory=lambda: uuid4().hex,
|
||||
description="Unique identifier for the ontology class"
|
||||
)
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Name of the class in PascalCase format"
|
||||
)
|
||||
name_chinese: Optional[str] = Field(
|
||||
None,
|
||||
description="Chinese translation of the class name"
|
||||
)
|
||||
description: str = Field(
|
||||
...,
|
||||
description="Description of the class"
|
||||
)
|
||||
examples: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of concrete instance examples"
|
||||
)
|
||||
parent_class: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of the parent class in the hierarchy"
|
||||
)
|
||||
entity_type: str = Field(
|
||||
...,
|
||||
description="Type/category of the entity"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain this class belongs to"
|
||||
)
|
||||
|
||||
@field_validator('name')
|
||||
@classmethod
|
||||
def validate_pascal_case(cls, v: str) -> str:
|
||||
"""Validate that the class name follows PascalCase convention.
|
||||
|
||||
PascalCase rules:
|
||||
- Must start with an uppercase letter (for English) or any character (for Chinese/Unicode)
|
||||
- Cannot contain spaces
|
||||
- Should not contain special characters except underscores
|
||||
|
||||
Args:
|
||||
v: The class name to validate
|
||||
|
||||
Returns:
|
||||
The validated class name
|
||||
|
||||
Raises:
|
||||
ValueError: If the name doesn't follow PascalCase convention
|
||||
"""
|
||||
if not v:
|
||||
raise ValueError("Class name cannot be empty")
|
||||
|
||||
# For Chinese/Unicode characters, skip the uppercase check
|
||||
# Only check uppercase for ASCII letters
|
||||
first_char = v[0]
|
||||
if first_char.isascii() and first_char.isalpha() and not first_char.isupper():
|
||||
raise ValueError(
|
||||
f"Class name '{v}' must start with an uppercase letter (PascalCase)"
|
||||
)
|
||||
|
||||
if ' ' in v:
|
||||
raise ValueError(
|
||||
f"Class name '{v}' cannot contain spaces (PascalCase)"
|
||||
)
|
||||
|
||||
# Check for invalid characters (allow alphanumeric, underscore, and Unicode characters)
|
||||
if not all(c.isalnum() or c == '_' or ord(c) > 127 for c in v):
|
||||
raise ValueError(
|
||||
f"Class name '{v}' contains invalid characters. "
|
||||
"Only alphanumeric characters, underscores, and Unicode characters are allowed"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
|
||||
class OntologyExtractionResponse(BaseModel):
|
||||
"""Response model for ontology extraction from LLM.
|
||||
|
||||
This model represents the structured output from the LLM when
|
||||
extracting ontology classes from scenario descriptions.
|
||||
|
||||
Attributes:
|
||||
classes: List of extracted ontology classes
|
||||
domain: Domain/field the scenario belongs to
|
||||
|
||||
Config:
|
||||
extra: Ignore extra fields from LLM output
|
||||
"""
|
||||
model_config = ConfigDict(extra='ignore')
|
||||
|
||||
classes: List[OntologyClass] = Field(
|
||||
default_factory=list,
|
||||
description="List of extracted ontology classes"
|
||||
)
|
||||
domain: str = Field(
|
||||
...,
|
||||
description="Domain/field the scenario belongs to"
|
||||
)
|
||||
39
api/app/core/memory/ontology_services/__init__.py
Normal file
39
api/app/core/memory/ontology_services/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型服务模块
|
||||
|
||||
本模块提供本体类型相关的服务,包括:
|
||||
- OntologyTypeMerger: 本体类型合并服务
|
||||
- get_general_ontology_registry: 获取通用本体类型注册表(单例,懒加载)
|
||||
- get_ontology_type_merger: 获取类型合并服务实例
|
||||
- reload_ontology_registry: 重新加载本体注册表(实验模式)
|
||||
- clear_ontology_cache: 清除本体缓存
|
||||
- is_general_ontology_enabled: 检查通用本体类型功能是否启用
|
||||
- load_ontology_types_for_scene: 从数据库加载场景的本体类型
|
||||
- create_empty_ontology_type_list: 创建空的本体类型列表
|
||||
- load_ontology_types_with_fallback: 加载本体类型(带通用类型回退)
|
||||
"""
|
||||
|
||||
from .ontology_type_merger import OntologyTypeMerger, DEFAULT_CORE_GENERAL_TYPES
|
||||
from .ontology_type_loader import (
|
||||
get_general_ontology_registry,
|
||||
get_ontology_type_merger,
|
||||
reload_ontology_registry,
|
||||
clear_ontology_cache,
|
||||
is_general_ontology_enabled,
|
||||
load_ontology_types_for_scene,
|
||||
create_empty_ontology_type_list,
|
||||
load_ontology_types_with_fallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OntologyTypeMerger",
|
||||
"DEFAULT_CORE_GENERAL_TYPES",
|
||||
"get_general_ontology_registry",
|
||||
"get_ontology_type_merger",
|
||||
"reload_ontology_registry",
|
||||
"clear_ontology_cache",
|
||||
"is_general_ontology_enabled",
|
||||
"load_ontology_types_for_scene",
|
||||
"create_empty_ontology_type_list",
|
||||
"load_ontology_types_with_fallback",
|
||||
]
|
||||
270
api/app/core/memory/ontology_services/ontology_type_loader.py
Normal file
270
api/app/core/memory/ontology_services/ontology_type_loader.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""本体类型加载器
|
||||
|
||||
提供统一的本体类型加载逻辑,避免代码重复。
|
||||
|
||||
Functions:
|
||||
load_ontology_types_for_scene: 从数据库加载场景的本体类型
|
||||
is_general_ontology_enabled: 检查是否启用通用本体
|
||||
get_general_ontology_registry: 获取通用本体类型注册表(单例,懒加载)
|
||||
get_ontology_type_merger: 获取类型合并服务实例
|
||||
reload_ontology_registry: 重新加载本体注册表
|
||||
clear_ontology_cache: 清除本体缓存
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 模块级缓存(单例)
|
||||
_general_registry_cache = None
|
||||
_ontology_type_merger_cache = None
|
||||
|
||||
|
||||
def load_ontology_types_for_scene(
|
||||
scene_id: Optional[UUID],
|
||||
workspace_id: UUID,
|
||||
db: Session
|
||||
) -> Optional["OntologyTypeList"]:
|
||||
"""从数据库加载场景的本体类型
|
||||
|
||||
统一的本体类型加载逻辑,用于替代各处重复的加载代码。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID,如果为 None 则返回 None
|
||||
workspace_id: 工作空间ID
|
||||
db: 数据库会话
|
||||
|
||||
Returns:
|
||||
OntologyTypeList 如果场景有类型定义,否则返回 None
|
||||
|
||||
Examples:
|
||||
>>> ontology_types = load_ontology_types_for_scene(
|
||||
... scene_id=scene_uuid,
|
||||
... workspace_id=workspace_uuid,
|
||||
... db=db_session
|
||||
... )
|
||||
>>> if ontology_types:
|
||||
... print(f"Loaded {len(ontology_types.types)} types")
|
||||
"""
|
||||
if not scene_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.repositories.ontology_class_repository import OntologyClassRepository
|
||||
|
||||
# 查询场景的本体类型
|
||||
ontology_repo = OntologyClassRepository(db)
|
||||
ontology_classes = ontology_repo.get_classes_by_scene(
|
||||
scene_id=scene_id
|
||||
)
|
||||
|
||||
if not ontology_classes:
|
||||
logger.info(f"No ontology types found for scene_id: {scene_id}")
|
||||
return None
|
||||
|
||||
# 转换为 OntologyTypeList
|
||||
ontology_types = OntologyTypeList.from_db_models(ontology_classes)
|
||||
logger.info(
|
||||
f"Loaded {len(ontology_types.types)} ontology types for scene_id: {scene_id}"
|
||||
)
|
||||
|
||||
return ontology_types
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load ontology types for scene_id {scene_id}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def create_empty_ontology_type_list() -> Optional["OntologyTypeList"]:
|
||||
"""创建空的本体类型列表(用于仅使用通用类型的场景)
|
||||
|
||||
Returns:
|
||||
空的 OntologyTypeList 如果通用本体已启用,否则返回 None
|
||||
"""
|
||||
try:
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
|
||||
if is_general_ontology_enabled():
|
||||
logger.info("Creating empty OntologyTypeList for general types only")
|
||||
return OntologyTypeList(types=[])
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create empty OntologyTypeList: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def is_general_ontology_enabled() -> bool:
|
||||
"""检查是否启用了通用本体
|
||||
|
||||
通过配置开关和注册表是否可用来判断。
|
||||
|
||||
Returns:
|
||||
True 如果通用本体已启用,否则 False
|
||||
"""
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.ENABLE_GENERAL_ONTOLOGY_TYPES:
|
||||
return False
|
||||
|
||||
registry = get_general_ontology_registry()
|
||||
return registry is not None and len(registry.types) > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check general ontology status: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_general_ontology_registry():
|
||||
"""获取通用本体类型注册表(单例,懒加载)
|
||||
|
||||
从配置的本体文件中解析并缓存注册表。
|
||||
|
||||
Returns:
|
||||
GeneralOntologyTypeRegistry 实例,如果加载失败则返回 None
|
||||
"""
|
||||
global _general_registry_cache
|
||||
|
||||
if _general_registry_cache is not None:
|
||||
return _general_registry_cache
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
if not settings.ENABLE_GENERAL_ONTOLOGY_TYPES:
|
||||
logger.info("通用本体类型功能已禁用")
|
||||
return None
|
||||
|
||||
# 解析本体文件路径
|
||||
file_names = [f.strip() for f in settings.GENERAL_ONTOLOGY_FILES.split(",") if f.strip()]
|
||||
if not file_names:
|
||||
logger.warning("未配置通用本体文件")
|
||||
return None
|
||||
|
||||
# 构建完整路径(相对于项目根目录)
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
||||
file_paths = []
|
||||
for name in file_names:
|
||||
full_path = os.path.join(base_dir, name)
|
||||
if os.path.exists(full_path):
|
||||
file_paths.append(full_path)
|
||||
else:
|
||||
logger.warning(f"本体文件不存在: {full_path}")
|
||||
|
||||
if not file_paths:
|
||||
logger.warning("没有找到可用的通用本体文件")
|
||||
return None
|
||||
|
||||
# 解析本体文件
|
||||
from app.core.memory.utils.ontology.ontology_parser import MultiOntologyParser
|
||||
|
||||
parser = MultiOntologyParser(file_paths)
|
||||
_general_registry_cache = parser.parse_all()
|
||||
logger.info(f"通用本体注册表加载完成: {len(_general_registry_cache.types)} 个类型")
|
||||
|
||||
return _general_registry_cache
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"加载通用本体注册表失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def get_ontology_type_merger():
|
||||
"""获取类型合并服务实例(单例,懒加载)
|
||||
|
||||
Returns:
|
||||
OntologyTypeMerger 实例,如果通用本体未启用则返回 None
|
||||
"""
|
||||
global _ontology_type_merger_cache
|
||||
|
||||
if _ontology_type_merger_cache is not None:
|
||||
return _ontology_type_merger_cache
|
||||
|
||||
try:
|
||||
registry = get_general_ontology_registry()
|
||||
if registry is None:
|
||||
return None
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.ontology_services.ontology_type_merger import OntologyTypeMerger
|
||||
|
||||
# 从配置读取核心类型
|
||||
core_types_str = settings.CORE_GENERAL_TYPES
|
||||
core_types = [t.strip() for t in core_types_str.split(",") if t.strip()] if core_types_str else None
|
||||
|
||||
_ontology_type_merger_cache = OntologyTypeMerger(
|
||||
general_registry=registry,
|
||||
max_types_in_prompt=settings.MAX_ONTOLOGY_TYPES_IN_PROMPT,
|
||||
core_types=core_types,
|
||||
)
|
||||
logger.info("OntologyTypeMerger 实例创建完成")
|
||||
|
||||
return _ontology_type_merger_cache
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建 OntologyTypeMerger 失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
def reload_ontology_registry():
|
||||
"""重新加载本体注册表(清除缓存后重新加载)
|
||||
|
||||
用于实验模式下动态更新本体配置。
|
||||
"""
|
||||
clear_ontology_cache()
|
||||
registry = get_general_ontology_registry()
|
||||
if registry:
|
||||
get_ontology_type_merger()
|
||||
logger.info("本体注册表已重新加载")
|
||||
return registry
|
||||
|
||||
|
||||
def clear_ontology_cache():
|
||||
"""清除本体缓存"""
|
||||
global _general_registry_cache, _ontology_type_merger_cache
|
||||
_general_registry_cache = None
|
||||
_ontology_type_merger_cache = None
|
||||
logger.info("本体缓存已清除")
|
||||
|
||||
|
||||
def load_ontology_types_with_fallback(
|
||||
scene_id: Optional[UUID],
|
||||
workspace_id: UUID,
|
||||
db: Session,
|
||||
enable_general_fallback: bool = True
|
||||
) -> Optional["OntologyTypeList"]:
|
||||
"""加载本体类型,如果场景没有类型则回退到通用类型
|
||||
|
||||
这是一个便捷函数,组合了场景类型加载和通用类型回退逻辑。
|
||||
|
||||
Args:
|
||||
scene_id: 场景ID
|
||||
workspace_id: 工作空间ID
|
||||
db: 数据库会话
|
||||
enable_general_fallback: 是否在没有场景类型时启用通用类型回退
|
||||
|
||||
Returns:
|
||||
OntologyTypeList 或 None
|
||||
"""
|
||||
# 首先尝试加载场景类型
|
||||
ontology_types = load_ontology_types_for_scene(
|
||||
scene_id=scene_id,
|
||||
workspace_id=workspace_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
# 如果没有场景类型且启用了回退,创建空列表以使用通用类型
|
||||
if ontology_types is None and enable_general_fallback:
|
||||
ontology_types = create_empty_ontology_type_list()
|
||||
if ontology_types:
|
||||
logger.info("No scene ontology types, will use general ontology types only")
|
||||
|
||||
return ontology_types
|
||||
231
api/app/core/memory/ontology_services/ontology_type_merger.py
Normal file
231
api/app/core/memory/ontology_services/ontology_type_merger.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体类型合并服务模块
|
||||
|
||||
本模块实现本体类型合并服务,负责按优先级合并场景类型与通用类型。
|
||||
|
||||
合并优先级:
|
||||
1. 场景特定类型(最高优先级)
|
||||
2. 核心通用类型
|
||||
3. 相关父类类型(最低优先级)
|
||||
|
||||
Classes:
|
||||
OntologyTypeMerger: 本体类型合并服务类
|
||||
|
||||
Constants:
|
||||
DEFAULT_CORE_GENERAL_TYPES: 默认核心通用类型集合
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from app.core.memory.models.ontology_general_models import GeneralOntologyTypeRegistry
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeInfo, OntologyTypeList
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 默认核心通用类型
|
||||
DEFAULT_CORE_GENERAL_TYPES: Set[str] = {
|
||||
"Person", "Organization", "Company", "GovernmentAgency",
|
||||
"Place", "Location", "City", "Country", "Building",
|
||||
"Event", "SportsEvent", "MusicEvent", "SocialEvent",
|
||||
"Work", "Book", "Film", "Software", "Album",
|
||||
"Concept", "TopicalConcept", "AcademicSubject",
|
||||
"Device", "Food", "Drug", "ChemicalSubstance",
|
||||
"TimePeriod", "Year",
|
||||
}
|
||||
|
||||
|
||||
class OntologyTypeMerger:
|
||||
"""本体类型合并服务
|
||||
|
||||
负责按优先级合并场景类型与通用类型,生成用于三元组提取的类型列表。
|
||||
|
||||
合并优先级:
|
||||
1. 场景特定类型(最高优先级)- 标记为 [场景类型]
|
||||
2. 核心通用类型 - 标记为 [通用类型]
|
||||
3. 相关父类类型(最低优先级)- 标记为 [通用父类]
|
||||
|
||||
Attributes:
|
||||
general_registry: 通用本体类型注册表
|
||||
max_types_in_prompt: Prompt 中最大类型数量限制
|
||||
core_types: 核心通用类型集合
|
||||
|
||||
Example:
|
||||
>>> registry = GeneralOntologyTypeRegistry()
|
||||
>>> merger = OntologyTypeMerger(registry, max_types_in_prompt=50)
|
||||
>>> merged = merger.merge(scene_types)
|
||||
>>> print(len(merged.types))
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
general_registry: GeneralOntologyTypeRegistry,
|
||||
max_types_in_prompt: int = 50,
|
||||
core_types: Optional[List[str]] = None
|
||||
):
|
||||
"""初始化本体类型合并服务
|
||||
|
||||
Args:
|
||||
general_registry: 通用本体类型注册表
|
||||
max_types_in_prompt: Prompt 中最大类型数量,默认 50
|
||||
core_types: 自定义核心类型列表,如果为 None 则使用默认核心类型
|
||||
"""
|
||||
self.general_registry = general_registry
|
||||
self.max_types_in_prompt = max_types_in_prompt
|
||||
self.core_types: Set[str] = set(core_types) if core_types else DEFAULT_CORE_GENERAL_TYPES.copy()
|
||||
|
||||
def update_core_types(self, core_types: List[str]) -> None:
|
||||
"""动态更新核心类型列表
|
||||
|
||||
更新后立即生效,无需重启服务。
|
||||
|
||||
Args:
|
||||
core_types: 新的核心类型列表
|
||||
"""
|
||||
self.core_types = set(core_types)
|
||||
logger.info(f"核心类型已更新: {len(self.core_types)} 个类型")
|
||||
|
||||
def merge(
|
||||
self,
|
||||
scene_types: Optional[OntologyTypeList],
|
||||
include_related_types: bool = True
|
||||
) -> OntologyTypeList:
|
||||
"""合并场景类型与通用类型
|
||||
|
||||
按优先级合并类型:
|
||||
1. 场景特定类型(最高优先级)
|
||||
2. 核心通用类型
|
||||
3. 相关父类类型(可选)
|
||||
|
||||
合并后的类型总数不超过 max_types_in_prompt。
|
||||
|
||||
Args:
|
||||
scene_types: 场景特定类型列表,可以为 None
|
||||
include_related_types: 是否包含相关父类类型,默认 True
|
||||
|
||||
Returns:
|
||||
合并后的类型列表,每个类型带有来源标记
|
||||
"""
|
||||
merged_types: List[OntologyTypeInfo] = []
|
||||
seen_names: Set[str] = set()
|
||||
|
||||
# 1. 场景特定类型(最高优先级)
|
||||
scene_type_count = 0
|
||||
if scene_types and scene_types.types:
|
||||
for scene_type in scene_types.types:
|
||||
if scene_type.class_name not in seen_names:
|
||||
merged_types.append(OntologyTypeInfo(
|
||||
class_name=scene_type.class_name,
|
||||
class_description=f"[场景类型] {scene_type.class_description}"
|
||||
))
|
||||
seen_names.add(scene_type.class_name)
|
||||
scene_type_count += 1
|
||||
|
||||
# 2. 核心通用类型
|
||||
remaining_slots = self.max_types_in_prompt - len(merged_types)
|
||||
core_types_added: List[OntologyTypeInfo] = []
|
||||
|
||||
for type_name in self.core_types:
|
||||
if type_name not in seen_names and remaining_slots > 0:
|
||||
general_type = self.general_registry.get_type(type_name)
|
||||
if general_type:
|
||||
description = (
|
||||
general_type.labels.get("zh") or
|
||||
general_type.description or
|
||||
general_type.get_label("en") or
|
||||
type_name
|
||||
)
|
||||
core_types_added.append(OntologyTypeInfo(
|
||||
class_name=type_name,
|
||||
class_description=f"[通用类型] {description}"
|
||||
))
|
||||
seen_names.add(type_name)
|
||||
remaining_slots -= 1
|
||||
|
||||
merged_types.extend(core_types_added)
|
||||
|
||||
# 3. 相关父类类型
|
||||
related_types_added: List[OntologyTypeInfo] = []
|
||||
if include_related_types and scene_types and scene_types.types:
|
||||
for scene_type in scene_types.types:
|
||||
if remaining_slots <= 0:
|
||||
break
|
||||
general_type = self.general_registry.get_type(scene_type.class_name)
|
||||
if general_type and general_type.parent_class:
|
||||
parent_name = general_type.parent_class
|
||||
if parent_name not in seen_names:
|
||||
parent_type = self.general_registry.get_type(parent_name)
|
||||
if parent_type:
|
||||
description = (
|
||||
parent_type.labels.get("zh") or
|
||||
parent_type.description or
|
||||
parent_name
|
||||
)
|
||||
related_types_added.append(OntologyTypeInfo(
|
||||
class_name=parent_name,
|
||||
class_description=f"[通用父类] {description}"
|
||||
))
|
||||
seen_names.add(parent_name)
|
||||
remaining_slots -= 1
|
||||
|
||||
merged_types.extend(related_types_added)
|
||||
|
||||
logger.info(
|
||||
f"类型合并完成: 场景类型 {scene_type_count} 个, "
|
||||
f"核心通用类型 {len(core_types_added)} 个, "
|
||||
f"相关类型 {len(related_types_added)} 个, "
|
||||
f"总计 {len(merged_types)} 个"
|
||||
)
|
||||
|
||||
return OntologyTypeList(types=merged_types)
|
||||
|
||||
def get_type_hierarchy_hint(self, type_name: str) -> Optional[str]:
|
||||
"""获取类型的层次提示信息(最多 3 级)
|
||||
|
||||
返回类型的继承链信息,格式为 "类型名 → 父类1 → 父类2 → 父类3"。
|
||||
|
||||
Args:
|
||||
type_name: 类型名称
|
||||
|
||||
Returns:
|
||||
层次提示字符串,如果类型不存在或没有父类则返回 None
|
||||
"""
|
||||
general_type = self.general_registry.get_type(type_name)
|
||||
if not general_type:
|
||||
return None
|
||||
ancestors = self.general_registry.get_ancestors(type_name)
|
||||
if ancestors:
|
||||
# 限制最多 3 级祖先
|
||||
return f"{type_name} → {' → '.join(ancestors[:3])}"
|
||||
return None
|
||||
|
||||
def get_merge_statistics(self, scene_types: Optional[OntologyTypeList]) -> dict:
|
||||
"""获取合并统计信息
|
||||
|
||||
执行合并操作并返回各类型来源的数量统计。
|
||||
|
||||
Args:
|
||||
scene_types: 场景特定类型列表
|
||||
|
||||
Returns:
|
||||
包含以下键的统计字典:
|
||||
- total_types: 合并后总类型数
|
||||
- scene_types: 场景类型数量
|
||||
- general_types: 通用类型数量
|
||||
- parent_types: 父类类型数量
|
||||
- available_core_types: 可用核心类型数量
|
||||
- registry_total_types: 注册表中总类型数
|
||||
"""
|
||||
merged = self.merge(scene_types)
|
||||
scene_count = sum(1 for t in merged.types if "[场景类型]" in t.class_description)
|
||||
general_count = sum(1 for t in merged.types if "[通用类型]" in t.class_description)
|
||||
parent_count = sum(1 for t in merged.types if "[通用父类]" in t.class_description)
|
||||
|
||||
return {
|
||||
"total_types": len(merged.types),
|
||||
"scene_types": scene_count,
|
||||
"general_types": general_count,
|
||||
"parent_types": parent_count,
|
||||
"available_core_types": len(self.core_types),
|
||||
"registry_total_types": len(self.general_registry.types),
|
||||
}
|
||||
@@ -134,42 +134,45 @@ def _merge_attribute(canonical: ExtractedEntityNode, ent: ExtractedEntityNode):
|
||||
if len(desc_b) > len(desc_a):
|
||||
canonical.description = desc_b
|
||||
# 合并事实摘要:统一保留一个“实体: name”行,来源行去重保序
|
||||
fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
def _extract_sources(txt: str) -> List[str]:
|
||||
sources: List[str] = []
|
||||
if not txt:
|
||||
return sources
|
||||
for line in str(txt).splitlines():
|
||||
ln = line.strip()
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = getattr(canonical, "fact_summary", "") or ""
|
||||
# fact_b = getattr(ent, "fact_summary", "") or ""
|
||||
# def _extract_sources(txt: str) -> List[str]:
|
||||
# sources: List[str] = []
|
||||
# if not txt:
|
||||
# return sources
|
||||
# for line in str(txt).splitlines():
|
||||
# ln = line.strip()
|
||||
# 支持“来源:”或“来源:”前缀
|
||||
m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
if m:
|
||||
content = m.group(1).strip()
|
||||
if content:
|
||||
sources.append(content)
|
||||
# m = re.match(r"^来源[::]\s*(.+)$", ln)
|
||||
# if m:
|
||||
# content = m.group(1).strip()
|
||||
# if content:
|
||||
# sources.append(content)
|
||||
# 如果不存在“来源”前缀,则将整体文本视为一个来源片段,避免信息丢失
|
||||
if not sources and txt.strip():
|
||||
sources.append(txt.strip())
|
||||
return sources
|
||||
# if not sources and txt.strip():
|
||||
# sources.append(txt.strip())
|
||||
# return sources
|
||||
try:
|
||||
src_a = _extract_sources(fact_a)
|
||||
src_b = _extract_sources(fact_b)
|
||||
seen = set()
|
||||
merged_sources: List[str] = []
|
||||
for s in src_a + src_b:
|
||||
if s and s not in seen:
|
||||
seen.add(s)
|
||||
merged_sources.append(s)
|
||||
if merged_sources:
|
||||
name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
elif fact_b and not fact_a:
|
||||
canonical.fact_summary = fact_b
|
||||
# src_a = _extract_sources(fact_a)
|
||||
# src_b = _extract_sources(fact_b)
|
||||
# seen = set()
|
||||
# merged_sources: List[str] = []
|
||||
# for s in src_a + src_b:
|
||||
# if s and s not in seen:
|
||||
# seen.add(s)
|
||||
# merged_sources.append(s)
|
||||
# if merged_sources:
|
||||
# name_line = f"实体: {getattr(canonical, 'name', '')}".strip()
|
||||
# canonical.fact_summary = "\n".join([name_line] + [f"来源: {s}" for s in merged_sources])
|
||||
# elif fact_b and not fact_a:
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
# 兜底:若解析失败,保留较长文本
|
||||
if len(fact_b) > len(fact_a):
|
||||
canonical.fact_summary = fact_b
|
||||
# if len(fact_b) > len(fact_a):
|
||||
# canonical.fact_summary = fact_b
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@@ -145,10 +145,13 @@ def _choose_canonical(a: ExtractedEntityNode, b: ExtractedEntityNode) -> int: #
|
||||
# 2. 第二优先级:按“描述+事实摘要”的总长度排序(内容越长,信息越完整)
|
||||
desc_a = (getattr(a, "description", "") or "")
|
||||
desc_b = (getattr(b, "description", "") or "")
|
||||
fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
score_a = len(desc_a) + len(fact_a)
|
||||
score_b = len(desc_b) + len(fact_b)
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_a = (getattr(a, "fact_summary", "") or "")
|
||||
# fact_b = (getattr(b, "fact_summary", "") or "")
|
||||
# score_a = len(desc_a) + len(fact_a)
|
||||
# score_b = len(desc_b) + len(fact_b)
|
||||
score_a = len(desc_a)
|
||||
score_b = len(desc_b)
|
||||
if score_a != score_b:
|
||||
return 0 if score_a >= score_b else 1
|
||||
return 0
|
||||
@@ -189,7 +192,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -197,7 +201,8 @@ async def _judge_pair(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
# 5. 渲染LLM提示词(用工具函数填充模板,包含实体信息、上下文、输出格式)
|
||||
@@ -248,7 +253,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(a, "entity_type", None),
|
||||
"description": getattr(a, "description", None),
|
||||
"aliases": getattr(a, "aliases", None) or [],
|
||||
"fact_summary": getattr(a, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(a, "fact_summary", None),
|
||||
"connect_strength": getattr(a, "connect_strength", None),
|
||||
}
|
||||
entity_b = {
|
||||
@@ -256,7 +262,8 @@ async def _judge_pair_disamb(
|
||||
"entity_type": getattr(b, "entity_type", None),
|
||||
"description": getattr(b, "description", None),
|
||||
"aliases": getattr(b, "aliases", None) or [],
|
||||
"fact_summary": getattr(b, "fact_summary", None),
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# "fact_summary": getattr(b, "fact_summary", None),
|
||||
"connect_strength": getattr(b, "connect_strength", None),
|
||||
}
|
||||
prompt = render_entity_dedup_prompt(
|
||||
|
||||
@@ -72,7 +72,8 @@ def _row_to_entity(row: Dict[str, Any]) -> ExtractedEntityNode:
|
||||
description=row.get("description") or "",
|
||||
aliases=row.get("aliases") or [],
|
||||
name_embedding=row.get("name_embedding") or [],
|
||||
fact_summary=row.get("fact_summary") or "",
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=row.get("fact_summary") or "",
|
||||
connect_strength=row.get("connect_strength") or "",
|
||||
)
|
||||
|
||||
|
||||
@@ -34,6 +34,8 @@ from app.core.memory.models.graph_models import (
|
||||
StatementNode,
|
||||
)
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.models.variate_config import (
|
||||
ExtractionPipelineConfig,
|
||||
)
|
||||
@@ -95,6 +97,9 @@ class ExtractionOrchestrator:
|
||||
config: Optional[ExtractionPipelineConfig] = None,
|
||||
progress_callback: Optional[Callable[[str, str, Optional[Dict[str, Any]]], Awaitable[None]]] = None,
|
||||
embedding_id: Optional[str] = None,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
enable_general_types: bool = True,
|
||||
language: str = "zh",
|
||||
):
|
||||
"""
|
||||
初始化流水线编排器
|
||||
@@ -108,6 +113,7 @@ class ExtractionOrchestrator:
|
||||
- 接受 (stage: str, message: str, data: Optional[Dict[str, Any]]) 并返回 Awaitable[None]
|
||||
- 在管线关键点调用以报告进度和结果数据
|
||||
embedding_id: 嵌入模型ID,如果为 None 则从全局配置获取(向后兼容)
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.embedder_client = embedder_client
|
||||
@@ -116,6 +122,30 @@ class ExtractionOrchestrator:
|
||||
self.is_pilot_run = False # 默认非试运行模式
|
||||
self.progress_callback = progress_callback # 保存进度回调函数
|
||||
self.embedding_id = embedding_id # 保存嵌入模型ID
|
||||
self.language = language # 保存语言配置
|
||||
|
||||
# 处理本体类型配置
|
||||
# 根据 enable_general_types 参数决定是否将通用本体类型与场景特定类型合并
|
||||
# 如果启用合并且配置中开启了通用本体功能,则使用 OntologyTypeMerger 进行融合
|
||||
if enable_general_types and ontology_types:
|
||||
from app.core.memory.ontology_services.ontology_type_loader import (
|
||||
get_ontology_type_merger,
|
||||
is_general_ontology_enabled,
|
||||
)
|
||||
if is_general_ontology_enabled():
|
||||
merger = get_ontology_type_merger()
|
||||
self.ontology_types = merger.merge(ontology_types)
|
||||
logger.info(
|
||||
f"已启用通用本体类型融合: 场景类型 {len(ontology_types.types) if ontology_types.types else 0} 个 -> "
|
||||
f"合并后 {len(self.ontology_types.types) if self.ontology_types.types else 0} 个"
|
||||
)
|
||||
else:
|
||||
self.ontology_types = ontology_types
|
||||
logger.info("通用本体类型功能已在配置中禁用,仅使用场景类型")
|
||||
else:
|
||||
self.ontology_types = ontology_types
|
||||
if not enable_general_types and ontology_types:
|
||||
logger.info("enable_general_types=False,仅使用场景类型")
|
||||
|
||||
# 保存去重消歧的详细记录(内存中的数据结构)
|
||||
self.dedup_merge_records: List[Dict[str, Any]] = [] # 实体合并记录
|
||||
@@ -127,7 +157,7 @@ class ExtractionOrchestrator:
|
||||
llm_client=llm_client,
|
||||
config=self.config.statement_extraction,
|
||||
)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client)
|
||||
self.triplet_extractor = TripletExtractor(llm_client=llm_client,ontology_types=self.ontology_types, language=language)
|
||||
self.temporal_extractor = TemporalExtractor(llm_client=llm_client)
|
||||
|
||||
logger.info("ExtractionOrchestrator 初始化完成")
|
||||
@@ -615,9 +645,25 @@ class ExtractionOrchestrator:
|
||||
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
# 如果 emotion_model_id 为空,回退到工作空间默认 LLM
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
|
||||
emotion_model_id = memory_config.emotion_model_id
|
||||
if not emotion_model_id and memory_config.workspace_id:
|
||||
from app.repositories.workspace_repository import get_workspace_models_configs
|
||||
from app.db import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
try:
|
||||
workspace_models = get_workspace_models_configs(db, memory_config.workspace_id)
|
||||
if workspace_models and workspace_models.get("llm"):
|
||||
emotion_model_id = workspace_models["llm"]
|
||||
logger.info(f"emotion_model_id 为空,使用工作空间默认 LLM: {emotion_model_id}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
emotion_service = EmotionExtractionService(
|
||||
llm_id=memory_config.emotion_model_id if memory_config.emotion_model_id else None
|
||||
llm_id=emotion_model_id if emotion_model_id else None
|
||||
)
|
||||
|
||||
# 全局并行处理所有陈述句
|
||||
@@ -1085,7 +1131,8 @@ class ExtractionOrchestrator:
|
||||
entity_type=getattr(entity, 'type', 'unknown'), # 使用 type 而不是 entity_type
|
||||
description=getattr(entity, 'description', ''), # 添加必需的 description 字段
|
||||
example=getattr(entity, 'example', ''), # 新增:传递示例字段
|
||||
fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# fact_summary=getattr(entity, 'fact_summary', ''), # 添加必需的 fact_summary 字段
|
||||
connect_strength=entity_connect_strength if entity_connect_strength is not None else 'Strong', # 添加必需的 connect_strength 字段
|
||||
aliases=getattr(entity, 'aliases', []) or [], # 传递从三元组提取阶段获取的aliases
|
||||
name_embedding=getattr(entity, 'name_embedding', None),
|
||||
|
||||
@@ -8,4 +8,5 @@
|
||||
- TemporalExtractor: 时间信息提取
|
||||
- EmbeddingGenerator: 嵌入向量生成
|
||||
- MemorySummaryGenerator: 记忆摘要生成
|
||||
- OntologyExtractor: 本体类提取
|
||||
"""
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.core.memory.models.base_response import RobustLLMResponse
|
||||
from app.core.memory.models.graph_models import MemorySummaryNode
|
||||
from app.core.memory.models.message_models import DialogData
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_memory_summary_prompt
|
||||
from app.core.language_utils import validate_language # 使用集中化的语言校验
|
||||
from pydantic import Field
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
@@ -31,7 +32,8 @@ class MemorySummaryResponse(RobustLLMResponse):
|
||||
|
||||
async def generate_title_and_type_for_summary(
|
||||
content: str,
|
||||
llm_client
|
||||
llm_client,
|
||||
language: str = "zh"
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
为MemorySummary生成标题和类型
|
||||
@@ -41,12 +43,16 @@ async def generate_title_and_type_for_summary(
|
||||
Args:
|
||||
content: Summary的内容文本
|
||||
llm_client: LLM客户端实例
|
||||
language: 生成标题使用的语言 ("zh" 中文, "en" 英文),默认中文
|
||||
|
||||
Returns:
|
||||
(标题, 类型)元组
|
||||
"""
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_episodic_title_and_type_prompt
|
||||
|
||||
# 验证语言参数
|
||||
language = validate_language(language)
|
||||
|
||||
# 定义有效的类型集合
|
||||
VALID_TYPES = {
|
||||
"conversation", # 对话
|
||||
@@ -57,13 +63,19 @@ async def generate_title_and_type_for_summary(
|
||||
}
|
||||
DEFAULT_TYPE = "conversation" # 默认类型
|
||||
|
||||
# 根据语言设置默认标题
|
||||
DEFAULT_TITLE = "空内容" if language == "zh" else "Empty Content"
|
||||
PARSE_ERROR_TITLE = "解析失败" if language == "zh" else "Parse Failed"
|
||||
ERROR_TITLE = "错误" if language == "zh" else "Error"
|
||||
UNKNOWN_TITLE = "未知标题" if language == "zh" else "Unknown Title"
|
||||
|
||||
try:
|
||||
if not content:
|
||||
logger.warning("content为空,无法生成标题和类型")
|
||||
return ("空内容", DEFAULT_TYPE)
|
||||
logger.warning(f"content为空,无法生成标题和类型 (language={language})")
|
||||
return (DEFAULT_TITLE, DEFAULT_TYPE)
|
||||
|
||||
# 1. 渲染Jinja2提示词模板
|
||||
prompt = await render_episodic_title_and_type_prompt(content)
|
||||
# 1. 渲染Jinja2提示词模板,传递语言参数
|
||||
prompt = await render_episodic_title_and_type_prompt(content, language=language)
|
||||
|
||||
# 2. 调用LLM生成标题和类型
|
||||
messages = [
|
||||
@@ -102,7 +114,7 @@ async def generate_title_and_type_for_summary(
|
||||
json_str = json_str.strip()
|
||||
|
||||
result_data = json.loads(json_str)
|
||||
title = result_data.get("title", "未知标题")
|
||||
title = result_data.get("title", UNKNOWN_TITLE)
|
||||
episodic_type_raw = result_data.get("type", DEFAULT_TYPE)
|
||||
|
||||
# 5. 校验和归一化类型
|
||||
@@ -130,22 +142,23 @@ async def generate_title_and_type_for_summary(
|
||||
f"已归一化为 '{episodic_type}'"
|
||||
)
|
||||
|
||||
logger.info(f"成功生成标题和类型: title={title}, type={episodic_type}")
|
||||
logger.info(f"成功生成标题和类型 (language={language}): title={title}, type={episodic_type}")
|
||||
return (title, episodic_type)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"无法解析LLM响应为JSON: {full_response}")
|
||||
return ("解析失败", DEFAULT_TYPE)
|
||||
logger.error(f"无法解析LLM响应为JSON (language={language}): {full_response}")
|
||||
return (PARSE_ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成标题和类型时出错: {str(e)}", exc_info=True)
|
||||
return ("错误", DEFAULT_TYPE)
|
||||
logger.error(f"生成标题和类型时出错 (language={language}): {str(e)}", exc_info=True)
|
||||
return (ERROR_TITLE, DEFAULT_TYPE)
|
||||
|
||||
async def _process_chunk_summary(
|
||||
dialog: DialogData,
|
||||
chunk,
|
||||
llm_client,
|
||||
embedder: OpenAIEmbedderClient,
|
||||
language: str = "zh",
|
||||
) -> Optional[MemorySummaryNode]:
|
||||
"""Process a single chunk to generate a memory summary node."""
|
||||
# Skip empty chunks
|
||||
@@ -153,11 +166,15 @@ async def _process_chunk_summary(
|
||||
return None
|
||||
|
||||
try:
|
||||
# 验证语言参数
|
||||
language = validate_language(language)
|
||||
|
||||
# Render prompt via Jinja2 for a single chunk
|
||||
prompt_content = await render_memory_summary_prompt(
|
||||
chunk_texts=chunk.content,
|
||||
json_schema=MemorySummaryResponse.model_json_schema(),
|
||||
max_words=200,
|
||||
language=language,
|
||||
)
|
||||
|
||||
messages = [
|
||||
@@ -178,9 +195,10 @@ async def _process_chunk_summary(
|
||||
try:
|
||||
title, episodic_type = await generate_title_and_type_for_summary(
|
||||
content=summary_text,
|
||||
llm_client=llm_client
|
||||
llm_client=llm_client,
|
||||
language=language
|
||||
)
|
||||
logger.info(f"Generated title and type for MemorySummary: title={title}, type={episodic_type}")
|
||||
logger.info(f"Generated title and type for MemorySummary (language={language}): title={title}, type={episodic_type}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate title and type for chunk {chunk.id}: {e}")
|
||||
# Continue without title and type
|
||||
@@ -219,13 +237,21 @@ async def memory_summary_generation(
|
||||
chunked_dialogs: List[DialogData],
|
||||
llm_client,
|
||||
embedder_client: OpenAIEmbedderClient,
|
||||
language: str = "zh",
|
||||
) -> List[MemorySummaryNode]:
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes."""
|
||||
"""Generate memory summaries per chunk, embed them, and return nodes.
|
||||
|
||||
Args:
|
||||
chunked_dialogs: 分块后的对话数据
|
||||
llm_client: LLM客户端
|
||||
embedder_client: 嵌入客户端
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
"""
|
||||
# Collect all tasks for parallel processing
|
||||
tasks = []
|
||||
for dialog in chunked_dialogs:
|
||||
for chunk in dialog.chunks:
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client))
|
||||
tasks.append(_process_chunk_summary(dialog, chunk, llm_client, embedder_client, language=language))
|
||||
|
||||
# Process all chunks in parallel
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
|
||||
@@ -0,0 +1,489 @@
|
||||
"""Ontology class extraction from scenario descriptions using LLM.
|
||||
|
||||
This module provides the OntologyExtractor class for extracting ontology classes
|
||||
from natural language scenario descriptions. It uses LLM-driven extraction combined
|
||||
with two-layer validation (string validation + OWL semantic validation).
|
||||
|
||||
Classes:
|
||||
OntologyExtractor: Extracts ontology classes from scenario descriptions
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.models.ontology_scenario_models import (
|
||||
OntologyClass,
|
||||
OntologyExtractionResponse,
|
||||
)
|
||||
from app.core.memory.utils.validation.ontology_validator import OntologyValidator
|
||||
from app.core.memory.utils.validation.owl_validator import OWLValidator
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_ontology_extraction_prompt
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyExtractor:
|
||||
"""Extractor for ontology classes from scenario descriptions.
|
||||
|
||||
This extractor uses LLM to identify abstract classes and concepts from
|
||||
natural language scenario descriptions, following OWL ontology engineering
|
||||
standards. It performs two-layer validation:
|
||||
1. String validation (naming conventions, reserved words, duplicates)
|
||||
2. OWL semantic validation (consistency checking, circular inheritance)
|
||||
|
||||
Attributes:
|
||||
llm_client: OpenAI client for LLM calls
|
||||
validator: String validator for class names and descriptions
|
||||
owl_validator: OWL validator for semantic validation
|
||||
"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
"""Initialize the OntologyExtractor.
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for LLM processing
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.validator = OntologyValidator()
|
||||
self.owl_validator = OWLValidator()
|
||||
|
||||
logger.info("OntologyExtractor initialized")
|
||||
|
||||
async def extract_ontology_classes(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str] = None,
|
||||
max_classes: int = 15,
|
||||
min_classes: int = 5,
|
||||
enable_owl_validation: bool = True,
|
||||
llm_temperature: float = 0.3,
|
||||
llm_max_tokens: int = 2000,
|
||||
max_description_length: int = 500,
|
||||
timeout: Optional[float] = None,
|
||||
language: str = "zh",
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Extract ontology classes from a scenario description.
|
||||
|
||||
This is the main extraction method that orchestrates the entire process:
|
||||
1. Call LLM to extract ontology classes
|
||||
2. Perform first-layer validation (string validation and cleaning)
|
||||
3. Perform second-layer validation (OWL semantic validation)
|
||||
4. Filter invalid classes based on validation errors
|
||||
5. Return validated ontology classes
|
||||
|
||||
Args:
|
||||
scenario: Natural language scenario description
|
||||
domain: Optional domain hint (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
min_classes: Minimum number of classes to extract (default: 5)
|
||||
enable_owl_validation: Whether to enable OWL validation (default: True)
|
||||
llm_temperature: LLM temperature parameter (default: 0.3)
|
||||
llm_max_tokens: LLM max tokens parameter (default: 2000)
|
||||
max_description_length: Maximum description length (default: 500)
|
||||
timeout: Optional timeout in seconds for LLM call (default: None, no timeout)
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse containing validated ontology classes
|
||||
|
||||
Raises:
|
||||
ValueError: If scenario is empty or invalid
|
||||
asyncio.TimeoutError: If extraction times out
|
||||
|
||||
Examples:
|
||||
>>> extractor = OntologyExtractor(llm_client)
|
||||
>>> response = await extractor.extract_ontology_classes(
|
||||
... scenario="A hospital manages patient records...",
|
||||
... domain="Healthcare",
|
||||
... max_classes=10,
|
||||
... timeout=30.0
|
||||
... )
|
||||
>>> len(response.classes)
|
||||
7
|
||||
"""
|
||||
# Start timing
|
||||
start_time = time.time()
|
||||
|
||||
# Validate input
|
||||
if not scenario or not scenario.strip():
|
||||
logger.error("Scenario description is empty")
|
||||
raise ValueError("Scenario description cannot be empty")
|
||||
|
||||
scenario = scenario.strip()
|
||||
|
||||
logger.info(
|
||||
f"Starting ontology extraction - scenario_length={len(scenario)}, "
|
||||
f"domain={domain}, max_classes={max_classes}, min_classes={min_classes}, "
|
||||
f"timeout={timeout}, language={language}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Call LLM for extraction with timeout
|
||||
logger.info("Step 1: Calling LLM for ontology extraction")
|
||||
llm_start_time = time.time()
|
||||
|
||||
if timeout is not None:
|
||||
# Wrap LLM call with timeout
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
language=language,
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.error(
|
||||
f"LLM extraction timed out after {timeout} seconds "
|
||||
f"(actual duration: {llm_duration:.2f}s)"
|
||||
)
|
||||
# Return empty response on timeout
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
else:
|
||||
# No timeout specified, call directly
|
||||
response = await self._call_llm_for_extraction(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
llm_temperature=llm_temperature,
|
||||
llm_max_tokens=llm_max_tokens,
|
||||
language=language,
|
||||
)
|
||||
|
||||
llm_duration = time.time() - llm_start_time
|
||||
logger.info(
|
||||
f"LLM returned {len(response.classes)} classes in {llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
# Step 2: First-layer validation (string validation and cleaning)
|
||||
logger.info("Step 2: Performing first-layer validation (string validation)")
|
||||
validation_start_time = time.time()
|
||||
|
||||
response = self._validate_and_clean(
|
||||
response=response,
|
||||
max_description_length=max_description_length,
|
||||
)
|
||||
|
||||
validation_duration = time.time() - validation_start_time
|
||||
logger.info(
|
||||
f"After first-layer validation: {len(response.classes)} classes remain "
|
||||
f"(validation took {validation_duration:.2f}s)"
|
||||
)
|
||||
|
||||
# Check if we have enough classes after first-layer validation
|
||||
if len(response.classes) < min_classes:
|
||||
logger.warning(
|
||||
f"Only {len(response.classes)} classes remain after validation, "
|
||||
f"which is below minimum of {min_classes}"
|
||||
)
|
||||
|
||||
# Step 3: Second-layer validation (OWL semantic validation)
|
||||
if enable_owl_validation and response.classes:
|
||||
logger.info("Step 3: Performing second-layer validation (OWL validation)")
|
||||
owl_start_time = time.time()
|
||||
|
||||
is_valid, errors, world = self.owl_validator.validate_ontology_classes(
|
||||
classes=response.classes,
|
||||
)
|
||||
|
||||
owl_duration = time.time() - owl_start_time
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"OWL validation found {len(errors)} issues in {owl_duration:.2f}s: {errors}"
|
||||
)
|
||||
|
||||
# Filter invalid classes based on errors
|
||||
response = self._filter_invalid_classes(
|
||||
response=response,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"After second-layer validation: {len(response.classes)} classes remain"
|
||||
)
|
||||
else:
|
||||
logger.info(f"OWL validation passed successfully in {owl_duration:.2f}s")
|
||||
else:
|
||||
if not enable_owl_validation:
|
||||
logger.info("Step 3: OWL validation disabled, skipping")
|
||||
else:
|
||||
logger.info("Step 3: No classes to validate, skipping OWL validation")
|
||||
|
||||
# Calculate total duration
|
||||
total_duration = time.time() - start_time
|
||||
|
||||
# Log extraction statistics
|
||||
logger.info(
|
||||
f"Ontology extraction completed - "
|
||||
f"final_class_count={len(response.classes)}, "
|
||||
f"domain={response.domain}, "
|
||||
f"total_duration={total_duration:.2f}s, "
|
||||
f"llm_duration={llm_duration:.2f}s"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# Re-raise timeout errors
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction timed out after {timeout} seconds "
|
||||
f"(total duration: {total_duration:.2f}s)",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
total_duration = time.time() - start_time
|
||||
logger.error(
|
||||
f"Ontology extraction failed after {total_duration:.2f}s: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty response on failure
|
||||
return OntologyExtractionResponse(
|
||||
classes=[],
|
||||
domain=domain or "Unknown",
|
||||
)
|
||||
|
||||
async def _call_llm_for_extraction(
|
||||
self,
|
||||
scenario: str,
|
||||
domain: Optional[str],
|
||||
max_classes: int,
|
||||
llm_temperature: float,
|
||||
llm_max_tokens: int,
|
||||
language: str = "zh",
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Call LLM to extract ontology classes from scenario.
|
||||
|
||||
This method renders the extraction prompt using the Jinja2 template
|
||||
and calls the LLM with structured output to get ontology classes.
|
||||
|
||||
Args:
|
||||
scenario: Scenario description text
|
||||
domain: Optional domain hint
|
||||
max_classes: Maximum number of classes to extract
|
||||
llm_temperature: LLM temperature parameter
|
||||
llm_max_tokens: LLM max tokens parameter
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
OntologyExtractionResponse from LLM
|
||||
|
||||
Raises:
|
||||
Exception: If LLM call fails
|
||||
"""
|
||||
try:
|
||||
# Render prompt using template
|
||||
prompt_content = await render_ontology_extraction_prompt(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=OntologyExtractionResponse.model_json_schema(),
|
||||
language=language,
|
||||
)
|
||||
|
||||
logger.debug(f"Rendered prompt length: {len(prompt_content)}")
|
||||
|
||||
# Create messages for LLM
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"You are an expert ontology engineer specializing in knowledge "
|
||||
"representation and OWL standards. Extract ontology classes from "
|
||||
"scenario descriptions following the provided instructions. "
|
||||
"Return valid JSON conforming to the schema."
|
||||
),
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt_content,
|
||||
},
|
||||
]
|
||||
|
||||
# Call LLM with structured output
|
||||
logger.debug(
|
||||
f"Calling LLM with temperature={llm_temperature}, "
|
||||
f"max_tokens={llm_max_tokens}"
|
||||
)
|
||||
|
||||
response = await self.llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=OntologyExtractionResponse,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"LLM extraction successful - extracted {len(response.classes)} classes"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM extraction failed: {str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
raise
|
||||
|
||||
def _validate_and_clean(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
max_description_length: int,
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Perform first-layer validation: string validation and cleaning.
|
||||
|
||||
This method validates and cleans the extracted ontology classes:
|
||||
1. Validate class names (PascalCase, no reserved words)
|
||||
2. Sanitize invalid class names
|
||||
3. Truncate long descriptions
|
||||
4. Remove duplicate classes
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse from LLM
|
||||
max_description_length: Maximum description length
|
||||
|
||||
Returns:
|
||||
Cleaned OntologyExtractionResponse
|
||||
"""
|
||||
if not response.classes:
|
||||
logger.debug("No classes to validate")
|
||||
return response
|
||||
|
||||
logger.debug(f"Validating {len(response.classes)} classes")
|
||||
|
||||
validated_classes = []
|
||||
|
||||
for ontology_class in response.classes:
|
||||
# Validate class name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(
|
||||
f"Invalid class name '{ontology_class.name}': {error_msg}"
|
||||
)
|
||||
|
||||
# Attempt to sanitize
|
||||
sanitized_name = self.validator.sanitize_class_name(
|
||||
ontology_class.name
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Sanitized class name: '{ontology_class.name}' -> '{sanitized_name}'"
|
||||
)
|
||||
|
||||
# Update class name
|
||||
ontology_class.name = sanitized_name
|
||||
|
||||
# Re-validate sanitized name
|
||||
is_valid, error_msg = self.validator.validate_class_name(
|
||||
sanitized_name
|
||||
)
|
||||
|
||||
if not is_valid:
|
||||
logger.error(
|
||||
f"Failed to sanitize class name '{ontology_class.name}': {error_msg}. "
|
||||
"Skipping this class."
|
||||
)
|
||||
continue
|
||||
|
||||
# Truncate description if too long
|
||||
if ontology_class.description:
|
||||
original_length = len(ontology_class.description)
|
||||
ontology_class.description = self.validator.truncate_description(
|
||||
ontology_class.description,
|
||||
max_length=max_description_length,
|
||||
)
|
||||
|
||||
if len(ontology_class.description) < original_length:
|
||||
logger.debug(
|
||||
f"Truncated description for '{ontology_class.name}': "
|
||||
f"{original_length} -> {len(ontology_class.description)} chars"
|
||||
)
|
||||
|
||||
validated_classes.append(ontology_class)
|
||||
|
||||
# Remove duplicates (case-insensitive)
|
||||
original_count = len(validated_classes)
|
||||
validated_classes = self.validator.remove_duplicates(validated_classes)
|
||||
|
||||
if len(validated_classes) < original_count:
|
||||
logger.info(
|
||||
f"Removed {original_count - len(validated_classes)} duplicate classes"
|
||||
)
|
||||
|
||||
# Return cleaned response
|
||||
return OntologyExtractionResponse(
|
||||
classes=validated_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
def _filter_invalid_classes(
|
||||
self,
|
||||
response: OntologyExtractionResponse,
|
||||
errors: List[str],
|
||||
) -> OntologyExtractionResponse:
|
||||
"""Filter invalid classes based on OWL validation errors.
|
||||
|
||||
This method analyzes OWL validation errors and removes classes
|
||||
that caused validation failures (e.g., circular inheritance,
|
||||
inconsistencies).
|
||||
|
||||
Args:
|
||||
response: OntologyExtractionResponse to filter
|
||||
errors: List of error messages from OWL validation
|
||||
|
||||
Returns:
|
||||
Filtered OntologyExtractionResponse
|
||||
"""
|
||||
if not errors:
|
||||
return response
|
||||
|
||||
logger.debug(f"Filtering classes based on {len(errors)} OWL validation errors")
|
||||
|
||||
# Extract class names mentioned in errors
|
||||
invalid_class_names = set()
|
||||
|
||||
for error in errors:
|
||||
# Look for class names in error messages
|
||||
for ontology_class in response.classes:
|
||||
if ontology_class.name in error:
|
||||
invalid_class_names.add(ontology_class.name)
|
||||
logger.debug(
|
||||
f"Class '{ontology_class.name}' marked as invalid due to error: {error}"
|
||||
)
|
||||
|
||||
# Filter out invalid classes
|
||||
if invalid_class_names:
|
||||
original_count = len(response.classes)
|
||||
|
||||
filtered_classes = [
|
||||
c for c in response.classes
|
||||
if c.name not in invalid_class_names
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Filtered out {original_count - len(filtered_classes)} invalid classes: "
|
||||
f"{invalid_class_names}"
|
||||
)
|
||||
|
||||
return OntologyExtractionResponse(
|
||||
classes=filtered_classes,
|
||||
domain=response.domain,
|
||||
)
|
||||
|
||||
return response
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_memory_logger
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
@@ -8,6 +8,7 @@ from app.core.memory.utils.prompt.prompt_utils import render_triplet_extraction_
|
||||
from app.core.memory.utils.data.ontology import PREDICATE_DEFINITIONS, Predicate # 引入枚举 Predicate 白名单过滤
|
||||
from app.core.memory.models.triplet_models import TripletExtractionResponse
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
from app.core.memory.models.ontology_extraction_models import OntologyTypeList
|
||||
from app.core.memory.utils.log.logging_utils import prompt_logger
|
||||
|
||||
logger = get_memory_logger(__name__)
|
||||
@@ -17,13 +18,30 @@ logger = get_memory_logger(__name__)
|
||||
class TripletExtractor:
|
||||
"""Extracts knowledge triplets and entities from statements using LLM"""
|
||||
|
||||
def __init__(self, llm_client: OpenAIClient):
|
||||
def __init__(
|
||||
self,
|
||||
llm_client: OpenAIClient,
|
||||
ontology_types: Optional[OntologyTypeList] = None,
|
||||
language: str = "zh"):
|
||||
"""Initialize the TripletExtractor with an LLM client
|
||||
|
||||
Args:
|
||||
llm_client: OpenAIClient instance for processing
|
||||
language: 语言类型 ("zh" 中文, "en" 英文),默认中文
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types
|
||||
for entity classification guidance
|
||||
"""
|
||||
self.llm_client = llm_client
|
||||
self.ontology_types = ontology_types
|
||||
self.language = language
|
||||
|
||||
def _get_language(self) -> str:
|
||||
"""Get the configured language for entity descriptions
|
||||
|
||||
Returns:
|
||||
Language code ("zh" or "en")
|
||||
"""
|
||||
return self.language
|
||||
|
||||
async def _extract_triplets(self, statement: Statement, chunk_content: str) -> TripletExtractionResponse:
|
||||
"""Process a single statement and return extracted triplets and entities"""
|
||||
@@ -40,7 +58,9 @@ class TripletExtractor:
|
||||
statement=statement.statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=TripletExtractionResponse.model_json_schema(),
|
||||
predicate_instructions=PREDICATE_DEFINITIONS
|
||||
predicate_instructions=PREDICATE_DEFINITIONS,
|
||||
language=self._get_language(),
|
||||
ontology_types=self.ontology_types,
|
||||
)
|
||||
|
||||
# Create messages for LLM
|
||||
|
||||
@@ -462,8 +462,8 @@ class ReflectionEngine:
|
||||
List[Any]: 反思数据列表
|
||||
"""
|
||||
|
||||
|
||||
|
||||
print("=== 获取反思数据 ===")
|
||||
print(f" 主机ID: {host_id}")
|
||||
if self.config.reflexion_range == ReflectionRange.PARTIAL:
|
||||
neo4j_query = neo4j_query_part.format(host_id)
|
||||
neo4j_statement = neo4j_statement_part.format(host_id)
|
||||
|
||||
@@ -296,7 +296,9 @@ def resolve_alias_cycles(entities: List[Any], cycles: Dict[str, Set[str]]) -> Li
|
||||
key=lambda eid: (
|
||||
_strength_rank(eid),
|
||||
len(getattr(entity_by_id.get(eid), 'description', '') or ''),
|
||||
len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
|
||||
# len(getattr(entity_by_id.get(eid), 'fact_summary', '') or '')
|
||||
0 # 临时占位
|
||||
),
|
||||
reverse=True
|
||||
)
|
||||
|
||||
12
api/app/core/memory/utils/ontology/__init__.py
Normal file
12
api/app/core/memory/utils/ontology/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体解析工具模块
|
||||
|
||||
本模块提供本体文件解析功能,支持多种 RDF 格式的本体文件解析。
|
||||
|
||||
Modules:
|
||||
ontology_parser: 本体文件解析器
|
||||
"""
|
||||
|
||||
from .ontology_parser import MultiOntologyParser, OntologyParser
|
||||
|
||||
__all__ = ["OntologyParser", "MultiOntologyParser"]
|
||||
366
api/app/core/memory/utils/ontology/ontology_parser.py
Normal file
366
api/app/core/memory/utils/ontology/ontology_parser.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""本体文件解析器模块
|
||||
|
||||
本模块提供统一的本体文件解析功能,支持多种 RDF 格式:
|
||||
- Turtle (.ttl)
|
||||
- OWL/XML (.owl)
|
||||
- RDF/XML (.rdf)
|
||||
- N-Triples (.nt)
|
||||
- JSON-LD (.jsonld)
|
||||
|
||||
解析器会自动根据文件扩展名推断格式,并在解析失败时尝试其他格式。
|
||||
解析结果包含类定义的名称、URI、多语言标签、描述和父类信息。
|
||||
|
||||
Classes:
|
||||
OntologyParser: 统一本体文件解析器
|
||||
MultiOntologyParser: 多本体文件解析器
|
||||
|
||||
Example:
|
||||
>>> parser = OntologyParser("ontology.ttl")
|
||||
>>> registry = parser.parse()
|
||||
>>> print(f"解析了 {len(registry.types)} 个类型")
|
||||
|
||||
>>> multi_parser = MultiOntologyParser(["ontology1.ttl", "ontology2.owl"])
|
||||
>>> merged_registry = multi_parser.parse_all()
|
||||
>>> print(f"合并后共 {len(merged_registry.types)} 个类型")
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from rdflib import OWL, RDF, RDFS, Graph, URIRef
|
||||
|
||||
from app.core.memory.models.ontology_general_models import (
|
||||
GeneralOntologyType,
|
||||
GeneralOntologyTypeRegistry,
|
||||
OntologyFileFormat,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyParser:
|
||||
"""统一本体文件解析器
|
||||
|
||||
解析本体文件并提取类定义,构建类型注册表。支持多种 RDF 格式,
|
||||
并提供格式自动推断和回退机制。
|
||||
|
||||
Attributes:
|
||||
file_path: 本体文件路径
|
||||
file_format: 文件格式,如果未指定则根据扩展名推断
|
||||
graph: rdflib Graph 实例,用于存储解析后的 RDF 数据
|
||||
|
||||
Example:
|
||||
>>> parser = OntologyParser("dbpedia.owl")
|
||||
>>> registry = parser.parse()
|
||||
>>> person_type = registry.get_type("Person")
|
||||
>>> if person_type:
|
||||
... print(f"Person URI: {person_type.class_uri}")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: str,
|
||||
file_format: Optional[OntologyFileFormat] = None,
|
||||
):
|
||||
"""初始化解析器
|
||||
|
||||
Args:
|
||||
file_path: 本体文件路径
|
||||
file_format: 文件格式,如果未指定则根据扩展名自动推断
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.file_format = file_format or OntologyFileFormat.from_extension(file_path)
|
||||
self.graph = Graph()
|
||||
|
||||
def parse(self) -> GeneralOntologyTypeRegistry:
|
||||
"""解析本体文件,返回类型注册表
|
||||
|
||||
首先尝试使用推断的格式解析文件,如果失败则尝试其他格式。
|
||||
解析成功后,遍历所有 owl:Class 和 rdfs:Class 定义,
|
||||
提取类信息并构建层次结构。
|
||||
|
||||
Returns:
|
||||
GeneralOntologyTypeRegistry: 包含所有解析出的类型和层次结构的注册表
|
||||
|
||||
Raises:
|
||||
ValueError: 当所有格式都无法解析文件时抛出
|
||||
"""
|
||||
logger.info(f"开始解析本体文件: {self.file_path}")
|
||||
|
||||
# 尝试解析,失败则尝试其他格式
|
||||
self._parse_with_fallback()
|
||||
|
||||
registry = GeneralOntologyTypeRegistry()
|
||||
registry.source_files.append(self.file_path)
|
||||
|
||||
# 遍历 owl:Class
|
||||
for class_uri in self.graph.subjects(RDF.type, OWL.Class):
|
||||
type_info = self._parse_class(class_uri)
|
||||
if type_info:
|
||||
registry.types[type_info.class_name] = type_info
|
||||
self._update_hierarchy(registry, type_info)
|
||||
|
||||
# 遍历 rdfs:Class(避免重复)
|
||||
for class_uri in self.graph.subjects(RDF.type, RDFS.Class):
|
||||
uri_str = str(class_uri)
|
||||
# 检查是否已经作为 owl:Class 解析过
|
||||
if uri_str not in [t.class_uri for t in registry.types.values()]:
|
||||
type_info = self._parse_class(class_uri)
|
||||
if type_info and type_info.class_name not in registry.types:
|
||||
registry.types[type_info.class_name] = type_info
|
||||
self._update_hierarchy(registry, type_info)
|
||||
|
||||
logger.info(f"本体解析完成: {len(registry.types)} 个类型")
|
||||
return registry
|
||||
|
||||
def _parse_with_fallback(self) -> None:
|
||||
"""尝试解析文件,失败时尝试其他格式
|
||||
|
||||
首先使用推断的格式解析,如果失败则依次尝试 RDF_XML 和 TURTLE 格式。
|
||||
|
||||
Raises:
|
||||
ValueError: 当所有格式都无法解析文件时抛出
|
||||
"""
|
||||
try:
|
||||
self.graph.parse(self.file_path, format=self.file_format.value)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"使用 {self.file_format.value} 格式解析失败: {e}")
|
||||
|
||||
# 尝试其他格式
|
||||
fallback_formats = [
|
||||
OntologyFileFormat.RDF_XML,
|
||||
OntologyFileFormat.TURTLE,
|
||||
OntologyFileFormat.N_TRIPLES,
|
||||
OntologyFileFormat.JSON_LD,
|
||||
]
|
||||
|
||||
for fmt in fallback_formats:
|
||||
if fmt != self.file_format:
|
||||
try:
|
||||
self.graph.parse(self.file_path, format=fmt.value)
|
||||
logger.info(f"使用回退格式 {fmt.value} 解析成功")
|
||||
return
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise ValueError(f"无法解析本体文件: {self.file_path}")
|
||||
|
||||
def _update_hierarchy(
|
||||
self,
|
||||
registry: GeneralOntologyTypeRegistry,
|
||||
type_info: GeneralOntologyType
|
||||
) -> None:
|
||||
"""更新层次结构
|
||||
|
||||
如果类型有父类,将其添加到层次结构中。
|
||||
|
||||
Args:
|
||||
registry: 类型注册表
|
||||
type_info: 类型信息
|
||||
"""
|
||||
if type_info.parent_class:
|
||||
if type_info.parent_class not in registry.hierarchy:
|
||||
registry.hierarchy[type_info.parent_class] = set()
|
||||
registry.hierarchy[type_info.parent_class].add(type_info.class_name)
|
||||
|
||||
def _parse_class(self, class_uri: URIRef) -> Optional[GeneralOntologyType]:
|
||||
"""解析单个类定义
|
||||
|
||||
从 RDF 图中提取类的名称、URI、标签、描述和父类信息。
|
||||
过滤空白节点和内置类型(Thing、Resource)。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
GeneralOntologyType 实例,如果应该跳过该类则返回 None
|
||||
"""
|
||||
uri_str = str(class_uri)
|
||||
class_name = self._extract_local_name(uri_str)
|
||||
|
||||
# 过滤空白节点和内置类型
|
||||
if not class_name:
|
||||
return None
|
||||
if class_name.startswith('_:'):
|
||||
return None
|
||||
if class_name in ('Thing', 'Resource'):
|
||||
return None
|
||||
# 过滤空白节点 URI(以 _: 开头或包含空白节点标识)
|
||||
if uri_str.startswith('_:'):
|
||||
return None
|
||||
|
||||
# 提取标签
|
||||
labels = self._extract_labels(class_uri)
|
||||
|
||||
# 提取描述
|
||||
description = self._extract_description(class_uri)
|
||||
|
||||
# 提取父类
|
||||
parent_class = self._extract_parent_class(class_uri)
|
||||
|
||||
return GeneralOntologyType(
|
||||
class_name=class_name,
|
||||
class_uri=uri_str,
|
||||
labels=labels,
|
||||
description=description,
|
||||
parent_class=parent_class,
|
||||
source_file=self.file_path
|
||||
)
|
||||
|
||||
def _extract_labels(self, class_uri: URIRef) -> dict:
|
||||
"""提取类的多语言标签
|
||||
|
||||
从 rdfs:label 属性中提取所有语言的标签。
|
||||
如果没有标签,使用类名作为英文标签。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
语言代码到标签文本的字典
|
||||
"""
|
||||
labels = {}
|
||||
for label in self.graph.objects(class_uri, RDFS.label):
|
||||
lang = getattr(label, 'language', None) or "en"
|
||||
labels[lang] = str(label)
|
||||
|
||||
# 如果没有标签,使用类名作为默认标签
|
||||
if not labels:
|
||||
class_name = self._extract_local_name(str(class_uri))
|
||||
if class_name:
|
||||
labels["en"] = class_name
|
||||
|
||||
return labels
|
||||
|
||||
def _extract_description(self, class_uri: URIRef) -> Optional[str]:
|
||||
"""提取类的描述
|
||||
|
||||
从 rdfs:comment 属性中提取描述,优先使用英文描述。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
类的描述文本,如果没有则返回 None
|
||||
"""
|
||||
description = None
|
||||
for comment in self.graph.objects(class_uri, RDFS.comment):
|
||||
lang = getattr(comment, 'language', None)
|
||||
# 优先使用英文描述
|
||||
if lang == "en":
|
||||
return str(comment)
|
||||
# 如果还没有描述,使用无语言标记或其他语言的描述
|
||||
if description is None:
|
||||
description = str(comment)
|
||||
return description
|
||||
|
||||
def _extract_parent_class(self, class_uri: URIRef) -> Optional[str]:
|
||||
"""提取类的父类
|
||||
|
||||
从 rdfs:subClassOf 属性中提取第一个有效的父类。
|
||||
过滤内置类型(Thing、Resource)和空白节点。
|
||||
|
||||
Args:
|
||||
class_uri: 类的 URI 引用
|
||||
|
||||
Returns:
|
||||
父类名称,如果没有有效父类则返回 None
|
||||
"""
|
||||
for parent_uri in self.graph.objects(class_uri, RDFS.subClassOf):
|
||||
parent_uri_str = str(parent_uri)
|
||||
# 跳过空白节点
|
||||
if parent_uri_str.startswith('_:'):
|
||||
continue
|
||||
|
||||
parent_name = self._extract_local_name(parent_uri_str)
|
||||
# 过滤内置类型
|
||||
if parent_name and parent_name not in ('Thing', 'Resource'):
|
||||
return parent_name
|
||||
|
||||
return None
|
||||
|
||||
def _extract_local_name(self, uri: str) -> Optional[str]:
|
||||
"""从 URI 中提取本地名称
|
||||
|
||||
支持两种常见的 URI 格式:
|
||||
1. 使用 # 分隔的 URI,如 http://example.org/ontology#Person
|
||||
2. 使用 / 分隔的 URI,如 http://dbpedia.org/ontology/Person
|
||||
|
||||
Args:
|
||||
uri: 完整的 URI 字符串
|
||||
|
||||
Returns:
|
||||
本地名称,如果无法提取则返回 None
|
||||
"""
|
||||
# 处理空白节点
|
||||
if uri.startswith('_:'):
|
||||
return None
|
||||
|
||||
# 尝试使用 # 分隔
|
||||
if '#' in uri:
|
||||
local_name = uri.rsplit('#', 1)[1]
|
||||
if local_name:
|
||||
return local_name
|
||||
|
||||
# 尝试使用 / 分隔
|
||||
if '/' in uri:
|
||||
local_name = uri.rsplit('/', 1)[1]
|
||||
if local_name:
|
||||
return local_name
|
||||
|
||||
# 使用正则表达式作为最后手段
|
||||
match = re.search(r'[#/]([^#/]+)$', uri)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
class MultiOntologyParser:
|
||||
"""多本体文件解析器
|
||||
|
||||
支持加载多个本体文件并将它们合并到一个统一的类型注册表中。
|
||||
先加载的文件中的类型定义优先保留(当存在同名类型时)。
|
||||
|
||||
Attributes:
|
||||
file_paths: 本体文件路径列表
|
||||
|
||||
Example:
|
||||
>>> parser = MultiOntologyParser([
|
||||
... "General_purpose_entity.ttl",
|
||||
... "domain_specific.owl"
|
||||
... ])
|
||||
>>> registry = parser.parse_all()
|
||||
>>> print(f"合并后共 {len(registry.types)} 个类型")
|
||||
"""
|
||||
|
||||
def __init__(self, file_paths: List[str]):
|
||||
"""初始化多文件解析器
|
||||
|
||||
Args:
|
||||
file_paths: 本体文件路径列表
|
||||
"""
|
||||
self.file_paths = file_paths
|
||||
|
||||
def parse_all(self) -> GeneralOntologyTypeRegistry:
|
||||
"""解析所有本体文件并合并
|
||||
|
||||
依次解析每个本体文件,并将结果合并到一个统一的注册表中。
|
||||
如果某个文件解析失败,会记录警告日志并跳过该文件继续处理。
|
||||
|
||||
Returns:
|
||||
GeneralOntologyTypeRegistry: 合并后的类型注册表
|
||||
"""
|
||||
merged_registry = GeneralOntologyTypeRegistry()
|
||||
|
||||
for file_path in self.file_paths:
|
||||
try:
|
||||
parser = OntologyParser(file_path)
|
||||
registry = parser.parse()
|
||||
merged_registry.merge(registry)
|
||||
logger.info(f"已合并本体文件: {file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"跳过无法解析的本体文件 {file_path}: {e}")
|
||||
|
||||
logger.info(f"多本体合并完成: 共 {len(merged_registry.types)} 个类型")
|
||||
return merged_registry
|
||||
@@ -9,22 +9,29 @@ current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
prompt_dir = os.path.join(current_dir, "prompts")
|
||||
prompt_env = Environment(loader=FileSystemLoader(prompt_dir))
|
||||
|
||||
async def get_prompts(message: str) -> list[dict]:
|
||||
async def get_prompts(message: str, language: str = "zh") -> list[dict]:
|
||||
"""
|
||||
Renders system and user prompts using Jinja2 templates.
|
||||
|
||||
Args:
|
||||
message: The message content
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
List of message dictionaries with role and content
|
||||
"""
|
||||
system_template = prompt_env.get_template("system.jinja2")
|
||||
user_template = prompt_env.get_template("user.jinja2")
|
||||
|
||||
system_prompt = system_template.render()
|
||||
user_prompt = user_template.render(message=message)
|
||||
system_prompt = system_template.render(language=language)
|
||||
user_prompt = user_template.render(message=message, language=language)
|
||||
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('system', system_prompt)
|
||||
log_prompt_rendering('user', user_prompt)
|
||||
# 可选:记录模板渲染信息(仅当 prompt_templates.log 存在时生效)
|
||||
log_template_rendering('system.jinja2', {})
|
||||
log_template_rendering('user.jinja2', {'message': message})
|
||||
log_template_rendering('system.jinja2', {'language': language})
|
||||
log_template_rendering('user.jinja2', {'message': message, 'language': language})
|
||||
return [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
@@ -38,6 +45,7 @@ async def render_statement_extraction_prompt(
|
||||
include_dialogue_context: bool = False,
|
||||
dialogue_content: str | None = None,
|
||||
max_dialogue_chars: int | None = None,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the statement extraction prompt using the extract_statement.jinja2 template.
|
||||
@@ -46,6 +54,11 @@ async def render_statement_extraction_prompt(
|
||||
chunk_content: The content of the chunk to process
|
||||
definitions: Label definitions for statement classification
|
||||
json_schema: JSON schema for the expected output format
|
||||
granularity: Extraction granularity level (1-3)
|
||||
include_dialogue_context: Whether to include full dialogue context
|
||||
dialogue_content: Full dialogue content for context
|
||||
max_dialogue_chars: Maximum characters for dialogue context
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -69,6 +82,7 @@ async def render_statement_extraction_prompt(
|
||||
granularity=granularity,
|
||||
include_dialogue_context=include_dialogue_context,
|
||||
dialogue_context=ctx,
|
||||
language=language,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('statement extraction', rendered_prompt)
|
||||
@@ -90,6 +104,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide: dict,
|
||||
statement_guide: dict,
|
||||
json_schema: dict,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the temporal extraction prompt using the extract_temporal.jinja2 template.
|
||||
@@ -100,6 +115,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide: Guidance on temporal types.
|
||||
statement_guide: Guidance on statement types.
|
||||
json_schema: JSON schema for the expected output format.
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as a string.
|
||||
@@ -111,6 +127,7 @@ async def render_temporal_extraction_prompt(
|
||||
temporal_guide=temporal_guide,
|
||||
statement_guide=statement_guide,
|
||||
json_schema=json_schema,
|
||||
language=language,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('temporal extraction', rendered_prompt)
|
||||
@@ -130,6 +147,7 @@ def render_entity_dedup_prompt(
|
||||
context: dict,
|
||||
json_schema: dict,
|
||||
disambiguation_mode: bool = False,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Render the entity deduplication prompt using the entity_dedup.jinja2 template.
|
||||
@@ -139,6 +157,8 @@ def render_entity_dedup_prompt(
|
||||
entity_b: Dict of entity B attributes
|
||||
context: Dict of computed signals (group/type gate, similarities, co-occurrence, relation statements)
|
||||
json_schema: JSON schema for the structured output (EntityDedupDecision)
|
||||
disambiguation_mode: Whether to use disambiguation mode
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -157,6 +177,7 @@ def render_entity_dedup_prompt(
|
||||
relation_statements=context.get("relation_statements", []),
|
||||
json_schema=json_schema,
|
||||
disambiguation_mode=disambiguation_mode,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# prompt_logger.info("\n=== RENDERED ENTITY DEDUP PROMPT ===")
|
||||
@@ -177,7 +198,14 @@ def render_entity_dedup_prompt(
|
||||
|
||||
# Args:
|
||||
# entity_a: Dict of entity A attributes
|
||||
async def render_triplet_extraction_prompt(statement: str, chunk_content: str, json_schema: dict, predicate_instructions: dict = None) -> str:
|
||||
async def render_triplet_extraction_prompt(
|
||||
statement: str,
|
||||
chunk_content: str,
|
||||
json_schema: dict,
|
||||
predicate_instructions: dict = None,
|
||||
language: str = "zh",
|
||||
ontology_types: "OntologyTypeList | None" = None,
|
||||
) -> str:
|
||||
"""
|
||||
Renders the triplet extraction prompt using the extract_triplet.jinja2 template.
|
||||
|
||||
@@ -186,16 +214,32 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
chunk_content: The content of the chunk to process
|
||||
json_schema: JSON schema for the expected output format
|
||||
predicate_instructions: Optional predicate instructions
|
||||
language: The language to use for entity descriptions ("zh" for Chinese, "en" for English)
|
||||
ontology_types: Optional OntologyTypeList containing predefined ontology types for entity classification
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_triplet.jinja2")
|
||||
|
||||
# 准备本体类型数据
|
||||
ontology_type_section = ""
|
||||
ontology_type_names = []
|
||||
type_hierarchy_hints = []
|
||||
if ontology_types and ontology_types.types:
|
||||
ontology_type_section = ontology_types.to_prompt_section()
|
||||
ontology_type_names = ontology_types.get_type_names()
|
||||
type_hierarchy_hints = ontology_types.get_type_hierarchy_hints()
|
||||
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
chunk_content=chunk_content,
|
||||
json_schema=json_schema,
|
||||
predicate_instructions=predicate_instructions
|
||||
predicate_instructions=predicate_instructions,
|
||||
language=language,
|
||||
ontology_types=ontology_type_section,
|
||||
ontology_type_names=ontology_type_names,
|
||||
type_hierarchy_hints=type_hierarchy_hints,
|
||||
)
|
||||
# 记录渲染结果到提示日志(与示例日志结构一致)
|
||||
log_prompt_rendering('triplet extraction', rendered_prompt)
|
||||
@@ -204,7 +248,11 @@ async def render_triplet_extraction_prompt(statement: str, chunk_content: str, j
|
||||
'statement': 'str',
|
||||
'chunk_content': 'str',
|
||||
'json_schema': 'TripletExtractionResponse.schema',
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS'
|
||||
'predicate_instructions': 'PREDICATE_DEFINITIONS',
|
||||
'language': language,
|
||||
'ontology_types': bool(ontology_type_section),
|
||||
'ontology_type_count': len(ontology_type_names),
|
||||
'type_hierarchy_hints_count': len(type_hierarchy_hints),
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -213,6 +261,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: str,
|
||||
json_schema: dict,
|
||||
max_words: int = 200,
|
||||
language: str = "zh",
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory summary prompt using the memory_summary.jinja2 template.
|
||||
@@ -221,6 +270,7 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts: Concatenated text of conversation chunks
|
||||
json_schema: JSON schema for the expected output format
|
||||
max_words: Maximum words for the summary
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string.
|
||||
@@ -230,19 +280,22 @@ async def render_memory_summary_prompt(
|
||||
chunk_texts=chunk_texts,
|
||||
json_schema=json_schema,
|
||||
max_words=max_words,
|
||||
language=language,
|
||||
)
|
||||
log_prompt_rendering('memory summary', rendered_prompt)
|
||||
log_template_rendering('memory_summary.jinja2', {
|
||||
'chunk_texts_len': len(chunk_texts or ""),
|
||||
'max_words': max_words,
|
||||
'json_schema': 'MemorySummaryResponse.schema'
|
||||
'json_schema': 'MemorySummaryResponse.schema',
|
||||
'language': language
|
||||
})
|
||||
return rendered_prompt
|
||||
|
||||
async def render_emotion_extraction_prompt(
|
||||
statement: str,
|
||||
extract_keywords: bool,
|
||||
enable_subject: bool
|
||||
enable_subject: bool,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion extraction prompt using the extract_emotion.jinja2 template.
|
||||
@@ -251,6 +304,7 @@ async def render_emotion_extraction_prompt(
|
||||
statement: The statement to analyze
|
||||
extract_keywords: Whether to extract emotion keywords
|
||||
enable_subject: Whether to enable subject classification
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -259,7 +313,8 @@ async def render_emotion_extraction_prompt(
|
||||
rendered_prompt = template.render(
|
||||
statement=statement,
|
||||
extract_keywords=extract_keywords,
|
||||
enable_subject=enable_subject
|
||||
enable_subject=enable_subject,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -276,7 +331,8 @@ async def render_emotion_extraction_prompt(
|
||||
async def render_emotion_suggestions_prompt(
|
||||
health_data: dict,
|
||||
patterns: dict,
|
||||
user_profile: dict
|
||||
user_profile: dict,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the emotion suggestions generation prompt using the generate_emotion_suggestions.jinja2 template.
|
||||
@@ -285,6 +341,7 @@ async def render_emotion_suggestions_prompt(
|
||||
health_data: 情绪健康数据
|
||||
patterns: 情绪模式分析结果
|
||||
user_profile: 用户画像数据
|
||||
language: 输出语言 ("zh" 中文, "en" 英文)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -292,18 +349,39 @@ async def render_emotion_suggestions_prompt(
|
||||
import json
|
||||
|
||||
# 预处理 emotion_distribution 为 JSON 字符串
|
||||
# 如果是中文,将 emotion_distribution 的 key 翻译为中文
|
||||
emotion_distribution = health_data.get('emotion_distribution', {})
|
||||
if language == "zh":
|
||||
emotion_type_zh = {
|
||||
'joy': '喜悦', 'sadness': '悲伤', 'anger': '愤怒',
|
||||
'fear': '恐惧', 'surprise': '惊讶', 'neutral': '中性'
|
||||
}
|
||||
emotion_distribution = {
|
||||
emotion_type_zh.get(k, k): v for k, v in emotion_distribution.items()
|
||||
}
|
||||
emotion_distribution_json = json.dumps(
|
||||
health_data.get('emotion_distribution', {}),
|
||||
emotion_distribution,
|
||||
ensure_ascii=False,
|
||||
indent=2
|
||||
)
|
||||
|
||||
# 翻译 dominant_negative_emotion
|
||||
dominant_negative_translated = None
|
||||
dominant_neg = patterns.get('dominant_negative_emotion')
|
||||
if dominant_neg and language == "zh":
|
||||
emotion_type_zh_map = {
|
||||
'sadness': '悲伤', 'anger': '愤怒', 'fear': '恐惧'
|
||||
}
|
||||
dominant_negative_translated = emotion_type_zh_map.get(dominant_neg, dominant_neg)
|
||||
|
||||
template = prompt_env.get_template("generate_emotion_suggestions.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
health_data=health_data,
|
||||
patterns=patterns,
|
||||
user_profile=user_profile,
|
||||
emotion_distribution_json=emotion_distribution_json
|
||||
emotion_distribution_json=emotion_distribution_json,
|
||||
language=language,
|
||||
dominant_negative_translated=dominant_negative_translated
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -321,7 +399,8 @@ async def render_emotion_suggestions_prompt(
|
||||
async def render_user_summary_prompt(
|
||||
user_id: str,
|
||||
entities: str,
|
||||
statements: str
|
||||
statements: str,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the user summary prompt using the user_summary.jinja2 template.
|
||||
@@ -330,6 +409,7 @@ async def render_user_summary_prompt(
|
||||
user_id: User identifier
|
||||
entities: Core entities with frequency information
|
||||
statements: Representative statement samples
|
||||
language: The language to use for summary generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -338,7 +418,8 @@ async def render_user_summary_prompt(
|
||||
rendered_prompt = template.render(
|
||||
user_id=user_id,
|
||||
entities=entities,
|
||||
statements=statements
|
||||
statements=statements,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -347,7 +428,8 @@ async def render_user_summary_prompt(
|
||||
log_template_rendering('user_summary.jinja2', {
|
||||
'user_id': user_id,
|
||||
'entities_len': len(entities),
|
||||
'statements_len': len(statements)
|
||||
'statements_len': len(statements),
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
@@ -356,7 +438,8 @@ async def render_user_summary_prompt(
|
||||
async def render_memory_insight_prompt(
|
||||
domain_distribution: str = None,
|
||||
active_periods: str = None,
|
||||
social_connections: str = None
|
||||
social_connections: str = None,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the memory insight prompt using the memory_insight.jinja2 template.
|
||||
@@ -365,6 +448,7 @@ async def render_memory_insight_prompt(
|
||||
domain_distribution: 核心领域分布信息
|
||||
active_periods: 活跃时段信息
|
||||
social_connections: 社交关联信息
|
||||
language: The language to use for report generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
@@ -373,7 +457,8 @@ async def render_memory_insight_prompt(
|
||||
rendered_prompt = template.render(
|
||||
domain_distribution=domain_distribution,
|
||||
active_periods=active_periods,
|
||||
social_connections=social_connections
|
||||
social_connections=social_connections,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
@@ -382,30 +467,76 @@ async def render_memory_insight_prompt(
|
||||
log_template_rendering('memory_insight.jinja2', {
|
||||
'has_domain_distribution': bool(domain_distribution),
|
||||
'has_active_periods': bool(active_periods),
|
||||
'has_social_connections': bool(social_connections)
|
||||
'has_social_connections': bool(social_connections),
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_episodic_title_and_type_prompt(content: str) -> str:
|
||||
async def render_episodic_title_and_type_prompt(content: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the episodic title and type classification prompt using the episodic_type_classification.jinja2 template.
|
||||
|
||||
Args:
|
||||
content: The content of the episodic memory summary to analyze
|
||||
language: The language to use for title generation ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("episodic_type_classification.jinja2")
|
||||
rendered_prompt = template.render(content=content)
|
||||
rendered_prompt = template.render(content=content, language=language)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('episodic title and type classification', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('episodic_type_classification.jinja2', {
|
||||
'content_len': len(content) if content else 0
|
||||
'content_len': len(content) if content else 0,
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
async def render_ontology_extraction_prompt(
|
||||
scenario: str,
|
||||
domain: str | None = None,
|
||||
max_classes: int = 15,
|
||||
json_schema: dict | None = None,
|
||||
language: str = "zh"
|
||||
) -> str:
|
||||
"""
|
||||
Renders the ontology extraction prompt using the extract_ontology.jinja2 template.
|
||||
|
||||
Args:
|
||||
scenario: The scenario description text to extract ontology classes from
|
||||
domain: Optional domain hint for the scenario (e.g., "Healthcare", "Education")
|
||||
max_classes: Maximum number of classes to extract (default: 15)
|
||||
json_schema: JSON schema for the expected output format
|
||||
language: Language for output ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("extract_ontology.jinja2")
|
||||
rendered_prompt = template.render(
|
||||
scenario=scenario,
|
||||
domain=domain,
|
||||
max_classes=max_classes,
|
||||
json_schema=json_schema,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 记录渲染结果到提示日志
|
||||
log_prompt_rendering('ontology extraction', rendered_prompt)
|
||||
# 可选:记录模板渲染信息
|
||||
log_template_rendering('extract_ontology.jinja2', {
|
||||
'scenario_len': len(scenario) if scenario else 0,
|
||||
'domain': domain,
|
||||
'max_classes': max_classes,
|
||||
'json_schema': 'OntologyExtractionResponse.schema',
|
||||
'language': language
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
@@ -1,15 +1,23 @@
|
||||
===任务===
|
||||
===Task===
|
||||
{% if language == "zh" %}
|
||||
你是一个实体去重/消歧判断助手。你将被提供两个实体的详细信息和上下文,请严格根据指引判断它们是否是同一真实世界实体,并在需要时进行类型消歧。
|
||||
|
||||
模式: {{ '消歧模式' if disambiguation_mode else '去重模式' }}
|
||||
{% else %}
|
||||
You are an entity deduplication/disambiguation assistant. You will be provided with detailed information and context for two entities. Please strictly follow the guidelines to determine whether they are the same real-world entity and perform type disambiguation when necessary.
|
||||
|
||||
===输入===
|
||||
Mode: {{ 'Disambiguation Mode' if disambiguation_mode else 'Deduplication Mode' }}
|
||||
{% endif %}
|
||||
|
||||
===Input===
|
||||
{% if language == "zh" %}
|
||||
实体A:
|
||||
- 名称: "{{ entity_a.name | default('') }}"
|
||||
- 类型: "{{ entity_a.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_a.description | default('') }}"
|
||||
- 别名: {{ entity_a.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_a.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_a.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
实体B:
|
||||
@@ -17,7 +25,8 @@
|
||||
- 类型: "{{ entity_b.entity_type | default('') }}"
|
||||
- 描述: "{{ entity_b.description | default('') }}"
|
||||
- 别名: {{ entity_b.aliases | default([]) }}
|
||||
- 摘要: "{{ entity_b.fact_summary | default('') }}"
|
||||
{# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用 #}
|
||||
{# - 摘要: "{{ entity_b.fact_summary | default('') }}" #}
|
||||
- 连接强弱: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
上下文:
|
||||
@@ -32,8 +41,41 @@
|
||||
{% for s in relation_statements %}
|
||||
- {{ s }}
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
Entity A:
|
||||
- Name: "{{ entity_a.name | default('') }}"
|
||||
- Type: "{{ entity_a.entity_type | default('') }}"
|
||||
- Description: "{{ entity_a.description | default('') }}"
|
||||
- Aliases: {{ entity_a.aliases | default([]) }}
|
||||
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
|
||||
{# - Summary: "{{ entity_a.fact_summary | default('') }}" #}
|
||||
- Connection Strength: "{{ entity_a.connect_strength | default('') }}"
|
||||
|
||||
===判定指引===
|
||||
Entity B:
|
||||
- Name: "{{ entity_b.name | default('') }}"
|
||||
- Type: "{{ entity_b.entity_type | default('') }}"
|
||||
- Description: "{{ entity_b.description | default('') }}"
|
||||
- Aliases: {{ entity_b.aliases | default([]) }}
|
||||
{# TODO: fact_summary feature temporarily disabled, to be enabled after future development #}
|
||||
{# - Summary: "{{ entity_b.fact_summary | default('') }}" #}
|
||||
- Connection Strength: "{{ entity_b.connect_strength | default('') }}"
|
||||
|
||||
Context:
|
||||
- Same Group: {{ same_group | default(false) }}
|
||||
- Type Consistent or Unknown: {{ type_ok | default(false) }}
|
||||
- Type Similarity (0-1): {{ type_similarity | default(0.0) }}
|
||||
- Name Text Similarity (0-1): {{ name_text_sim | default(0.0) }}
|
||||
- Name Embedding Similarity (0-1): {{ name_embed_sim | default(0.0) }}
|
||||
- Name Contains Relationship: {{ name_contains | default(false) }}
|
||||
- Context Co-occurrence (same statement refers to both): {{ co_occurrence | default(false) }}
|
||||
- Related Relationship Statements (from entity-entity edges):
|
||||
{% for s in relation_statements %}
|
||||
- {{ s }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
===Guidelines===
|
||||
{% if language == "zh" %}
|
||||
{% if disambiguation_mode %}
|
||||
- 这是"同名但类型不同"的消歧场景。请判断两者是否指向同一真实世界实体。
|
||||
- 综合名称文本/向量相似度、别名、描述、摘要与上下文关系(同源与关系陈述)进行判断。
|
||||
@@ -66,8 +108,43 @@
|
||||
- 优先保留连接强度更强(strong/both)者;其余相同则保留描述/摘要更丰富者;再相同时保留实体A(canonical_idx=0)。
|
||||
- **注意**:别名(aliases)已在三元组提取阶段获取,合并时会自动整合,无需在此阶段提取。
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if disambiguation_mode %}
|
||||
- This is a disambiguation scenario for "same name but different types". Please determine whether they refer to the same real-world entity.
|
||||
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships (co-occurrence and relationship statements).
|
||||
- **Alias Handling (High Priority)**:
|
||||
* If the alias lists of both entities have intersections, this is a strong signal of identity
|
||||
* If one entity's name appears in another entity's aliases, it should be considered a high-confidence match
|
||||
* If one entity's alias exactly matches another entity's name, it should be considered a high-confidence match
|
||||
* Alias matching weight should be higher than pure name text similarity
|
||||
- If unable to determine with sufficient confidence, handle conservatively: do not merge, and suggest blocking this pair in other fuzzy/heuristic merges (block_pair=true).
|
||||
- If merging is needed (should_merge=true), select the "canonical entity" (canonical_idx) and **must** provide a suggested unified type (suggested_type).
|
||||
- **Type Unification Principles (Important)**:
|
||||
* Prioritize more specific and accurate types (e.g., HistoricalPeriod over Organization, MilitaryCapability over Concept)
|
||||
* If both types are specific but different, choose the type that best matches the entity's core semantics
|
||||
* Generic types (Concept, Phenomenon, Condition, State, Attribute, Event) have lower priority than domain-specific types
|
||||
* Suggested type must be consistent with context and entity description
|
||||
- Canonical entity priority: higher connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
|
||||
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
|
||||
{% else %}
|
||||
- If entity types are the same or either is UNKNOWN/empty, can proceed as candidates; if types clearly conflict (e.g., person vs. item), unless aliases and descriptions are highly consistent, determine as different entities.
|
||||
- **Alias Matching Priority (Highest Priority)**:
|
||||
* If Entity A's name exactly matches any of Entity B's aliases, it should be considered a high-confidence match
|
||||
* If Entity B's name exactly matches any of Entity A's aliases, it should be considered a high-confidence match
|
||||
* If any alias of Entity A exactly matches any alias of Entity B, it should be considered a high-confidence match
|
||||
* When aliases match exactly, merging should be considered even if name text similarity is low
|
||||
* Alias matching confidence should be higher than pure name similarity matching
|
||||
- Make judgments based on name text/vector similarity, aliases, descriptions, summaries, and contextual relationships.
|
||||
- When context co-occurs or there are clear relationship statements supporting identity (e.g., the same object is repeatedly mentioned or aliases correspond), the judgment threshold can be moderately lowered.
|
||||
- Conservative decision: when unable to determine with sufficient confidence, do not merge (same_entity=false).
|
||||
- If merging is needed, select the "canonical entity to retain" (canonical_idx) as the more appropriate one:
|
||||
- Prioritize retaining the one with stronger connection strength (strong/both); if equal, retain the one with richer description/summary; if still equal, retain Entity A (canonical_idx=0).
|
||||
- **Note**: Aliases are already obtained during triplet extraction and will be automatically integrated during merging; no need to extract at this stage.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
{% if language == "zh" %}
|
||||
{% if disambiguation_mode %}
|
||||
返回JSON格式,必须包含以下字段:
|
||||
{
|
||||
@@ -101,6 +178,41 @@
|
||||
- confidence: 决策的置信度,范围0.0-1.0
|
||||
- reason: 决策理由的简短说明
|
||||
{% endif %}
|
||||
{% else %}
|
||||
{% if disambiguation_mode %}
|
||||
Return JSON format with the following required fields:
|
||||
{
|
||||
"should_merge": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"block_pair": boolean,
|
||||
"suggested_type": "string or null",
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**Field Descriptions**:
|
||||
- should_merge: Whether these two entities should be merged (true/false)
|
||||
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
|
||||
- confidence: Confidence level of the decision, range 0.0-1.0
|
||||
- block_pair: Whether to block this pair in other fuzzy/heuristic merges (true/false)
|
||||
- suggested_type: Suggested unified type (string or null)
|
||||
- reason: Brief explanation of the decision
|
||||
{% else %}
|
||||
Return JSON format with the following required fields:
|
||||
{
|
||||
"same_entity": boolean,
|
||||
"canonical_idx": 0 or 1,
|
||||
"confidence": float (0.0-1.0),
|
||||
"reason": "string"
|
||||
}
|
||||
|
||||
**Field Descriptions**:
|
||||
- same_entity: Whether the two entities refer to the same real-world entity (true/false)
|
||||
- canonical_idx: Index of the canonical entity, 0 for Entity A, 1 for Entity B
|
||||
- confidence: Confidence level of the decision, range 0.0-1.0
|
||||
- reason: Brief explanation of the decision
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
@@ -108,5 +220,9 @@
|
||||
3. Do not include line breaks within JSON string values
|
||||
4. Test your JSON output mentally to ensure it can be parsed correctly
|
||||
|
||||
{% if language == "zh" %}
|
||||
输出语言应始终与输入语言相同。
|
||||
{% else %}
|
||||
The output language should always be the same as the input language.
|
||||
{% endif %}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
=== Task ===
|
||||
Generate a concise title and classify the episodic memory into the most appropriate category.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成标题和分类。**
|
||||
{% else %}
|
||||
**Important: Please generate the title and classification in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Extract a clear, concise title (10-20 characters) that captures the core content
|
||||
{% if language == "zh" %}
|
||||
- 标题必须使用中文
|
||||
{% else %}
|
||||
- Title must be in English
|
||||
{% endif %}
|
||||
- Classify into exactly one category based on the primary theme
|
||||
- Be specific and avoid ambiguity
|
||||
- Output must be valid JSON conforming to the schema below
|
||||
|
||||
@@ -17,9 +17,18 @@
|
||||
#}
|
||||
|
||||
{% set scene_instructions = {
|
||||
'education': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'online_service': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'outbound': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。'
|
||||
'education': {
|
||||
'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。',
|
||||
'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.'
|
||||
},
|
||||
'online_service': {
|
||||
'zh': '在线客服场景:客户咨询、问题排查、服务工单、售后支持、订单/退款、工单升级等。',
|
||||
'en': 'Online Service Scenario: Customer inquiries, troubleshooting, service tickets, after-sales support, orders/refunds, ticket escalation, etc.'
|
||||
},
|
||||
'outbound': {
|
||||
'zh': '外呼场景:电话外呼、邀约、调研问卷、线索跟进、对话脚本、回访记录等。',
|
||||
'en': 'Outbound Scenario: Outbound calls, invitations, survey questionnaires, lead follow-up, call scripts, follow-up records, etc.'
|
||||
}
|
||||
} %}
|
||||
|
||||
{% set scene_key = pruning_scene %}
|
||||
@@ -27,8 +36,9 @@
|
||||
{% set scene_key = 'education' %}
|
||||
{% endif %}
|
||||
|
||||
{% set instruction = scene_instructions[scene_key] %}
|
||||
{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性:
|
||||
场景说明:{{ instruction }}
|
||||
|
||||
@@ -46,4 +56,24 @@
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
}
|
||||
}
|
||||
{% else %}
|
||||
Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario:
|
||||
Scenario Description: {{ instruction }}
|
||||
|
||||
Full Dialogue:
|
||||
"""
|
||||
{{ dialog_text }}
|
||||
"""
|
||||
|
||||
Output strict JSON only (fixed keys, order doesn't matter):
|
||||
{
|
||||
"is_related": <true or false>,
|
||||
"times": [<string>...],
|
||||
"ids": [<string>...],
|
||||
"amounts": [<string>...],
|
||||
"contacts": [<string>...],
|
||||
"addresses": [<string>...],
|
||||
"keywords": [<string>...]
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
{% if language == "zh" %}
|
||||
你是一个专业的情绪分析专家。请分析以下陈述句的情绪信息。
|
||||
|
||||
陈述句:{{ statement }}
|
||||
@@ -55,3 +56,62 @@
|
||||
- 主体分类要准确,优先识别用户本人(self)
|
||||
|
||||
请以 JSON 格式返回结果。
|
||||
{% else %}
|
||||
You are a professional emotion analysis expert. Please analyze the emotional information in the following statement.
|
||||
|
||||
Statement: {{ statement }}
|
||||
|
||||
Please extract the following information:
|
||||
|
||||
1. emotion_type (Emotion Type):
|
||||
- joy: happiness, delight, pleasure, satisfaction, cheerfulness
|
||||
- sadness: sorrow, grief, disappointment, depression, regret
|
||||
- anger: rage, irritation, dissatisfaction, annoyance, frustration
|
||||
- fear: anxiety, worry, concern, nervousness, apprehension
|
||||
- surprise: astonishment, amazement, shock, wonder
|
||||
- neutral: neutral, objective statement, no obvious emotion
|
||||
|
||||
2. emotion_intensity (Emotion Intensity):
|
||||
- 0.0-0.3: weak emotion
|
||||
- 0.3-0.7: moderate emotion
|
||||
- 0.7-1.0: strong emotion
|
||||
|
||||
{% if extract_keywords %}
|
||||
3. emotion_keywords (Emotion Keywords):
|
||||
- Words directly expressing emotions in the original sentence
|
||||
- Extract up to 3 keywords
|
||||
- Return empty list if no obvious emotion words
|
||||
{% else %}
|
||||
3. emotion_keywords (Emotion Keywords):
|
||||
- Return empty list
|
||||
{% endif %}
|
||||
|
||||
{% if enable_subject %}
|
||||
4. emotion_subject (Emotion Subject):
|
||||
- self: user's own emotions (includes "I", "we", "us" and other first-person pronouns)
|
||||
- other: others' emotions (includes names, "he/she" and other third-person pronouns)
|
||||
- object: evaluation of things (for products, places, events, etc.)
|
||||
|
||||
Note:
|
||||
- If multiple subjects are present, prioritize identifying the user (self)
|
||||
- If the subject cannot be clearly determined, default to self
|
||||
|
||||
5. emotion_target (Emotion Target):
|
||||
- If there is a clear emotion target, extract its name
|
||||
- If there is no clear target, return null
|
||||
{% else %}
|
||||
4. emotion_subject (Emotion Subject):
|
||||
- Default to self
|
||||
|
||||
5. emotion_target (Emotion Target):
|
||||
- Return null
|
||||
{% endif %}
|
||||
|
||||
Notes:
|
||||
- If the statement is an objective factual statement with no obvious emotion, mark as neutral
|
||||
- Emotion intensity should match the context, do not over-interpret
|
||||
- Emotion keywords should be accurate, do not add words not in the original sentence
|
||||
- Subject classification should be accurate, prioritize identifying the user (self)
|
||||
|
||||
Please return the result in JSON format.
|
||||
{% endif %}
|
||||
|
||||
445
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
445
api/app/core/memory/utils/prompt/prompts/extract_ontology.jinja2
Normal file
@@ -0,0 +1,445 @@
|
||||
===Task===
|
||||
{% if language == "zh" %}
|
||||
从给定的场景描述中提取本体类,遵循本体工程标准。
|
||||
{% else %}
|
||||
Extract ontology classes from the given scenario description following ontology engineering standards.
|
||||
{% endif %}
|
||||
|
||||
===Role===
|
||||
{% if language == "zh" %}
|
||||
你是一位专业的本体工程师,精通知识表示和OWL(Web本体语言)标准。你的任务是从场景描述中识别抽象类和概念,而不是具体实例。
|
||||
{% else %}
|
||||
You are a professional ontology engineer with expertise in knowledge representation and OWL (Web Ontology Language) standards. Your task is to identify abstract classes and concepts from scenario descriptions, not concrete instances.
|
||||
{% endif %}
|
||||
|
||||
===Scenario Description===
|
||||
{{ scenario }}
|
||||
|
||||
{% if domain -%}
|
||||
===Domain Hint===
|
||||
{% if language == "zh" %}
|
||||
此场景属于 **{{ domain }}** 领域。提取类时请考虑领域特定的概念和术语。
|
||||
{% else %}
|
||||
This scenario belongs to the **{{ domain }}** domain. Consider domain-specific concepts and terminology when extracting classes.
|
||||
{% endif %}
|
||||
{%- endif %}
|
||||
|
||||
===Output Language===
|
||||
{% if language == "en" -%}
|
||||
**IMPORTANT: All output content MUST be in English.**
|
||||
- Class names (name field): English in PascalCase format
|
||||
- Chinese name (name_chinese field): Provide Chinese translation
|
||||
- Descriptions: MUST be in English
|
||||
- Examples: MUST be in English
|
||||
- Domain: MUST be in English
|
||||
{%- else -%}
|
||||
**IMPORTANT: Output content language requirements:**
|
||||
- Class names (name field): English in PascalCase format
|
||||
- Chinese name (name_chinese field): Chinese translation
|
||||
- Descriptions: MUST be in Chinese (中文)
|
||||
- Examples: MUST be in Chinese (中文)
|
||||
- Domain: Can be in Chinese or English
|
||||
{%- endif %}
|
||||
|
||||
===Extraction Rules===
|
||||
|
||||
{% if language == "zh" %}
|
||||
**1. 抽象类,而非实例:**
|
||||
- 提取抽象类别和概念(如"医疗程序"、"患者"、"诊断")
|
||||
- 不要提取具体实例(如"张三"、"301房间"、"2024-01-15")
|
||||
- 以"事物的类型"而非"具体事物"的角度思考
|
||||
|
||||
**2. 命名规范:**
|
||||
- "name"字段使用中文名称
|
||||
- 使用清晰、描述性的中文名称
|
||||
- 示例:"医疗程序"、"医疗服务提供者"、"诊断测试"
|
||||
|
||||
**3. 领域相关性:**
|
||||
- 专注于场景领域的核心类
|
||||
- 优先提取代表关键概念、实体或关系的类
|
||||
- 避免过于通用的类(如"事物"、"对象"),除非它们在领域中有特定含义
|
||||
|
||||
**4. 类数量:**
|
||||
- 提取5到{{ max_classes }}个类
|
||||
- 目标是覆盖场景主要概念的平衡集合
|
||||
- 质量优于数量:优先选择定义明确的类
|
||||
|
||||
**5. 清晰的描述:**
|
||||
- 用中文提供简洁、信息丰富的描述(最多500字)
|
||||
- 描述类代表什么,而不是具体实例
|
||||
- 使用清晰、自然的中文解释类在领域中的作用
|
||||
|
||||
**6. 具体示例:**
|
||||
- 为每个类提供2-5个中文具体实例示例
|
||||
- 示例应该是该类的具体、现实的实例
|
||||
- 示例有助于阐明类的范围和含义
|
||||
- 示例格式:["示例1", "示例2", "示例3"]
|
||||
|
||||
**7. 类层次结构:**
|
||||
- 在适用的情况下识别父子关系
|
||||
- 使用parent_class字段指定继承关系
|
||||
- 父类必须是提取的类之一或标准OWL类
|
||||
- 顶级类的parent_class设为null
|
||||
|
||||
**8. 实体类型:**
|
||||
- 为每个类分配适当的entity_type
|
||||
- 常见类型:"人物"、"组织"、"地点"、"事件"、"概念"、"过程"、"对象"、"角色"
|
||||
- 选择最具体的适用类型
|
||||
|
||||
**9. 语言一致性:**
|
||||
- 所有字段内容必须使用中文
|
||||
- "name"字段使用中文名称
|
||||
- "description"字段使用中文描述
|
||||
- "examples"字段使用中文示例
|
||||
- "entity_type"字段使用中文类型名称
|
||||
- "domain"字段使用中文领域名称
|
||||
|
||||
{% else %}
|
||||
**1. Abstract Classes, Not Instances:**
|
||||
- Extract abstract categories and concepts (e.g., "MedicalProcedure", "Patient", "Diagnosis")
|
||||
- Do NOT extract concrete instances (e.g., "John Smith", "Room 301", "2024-01-15")
|
||||
- Think in terms of "types of things" rather than "specific things"
|
||||
|
||||
**2. Naming Convention (PascalCase):**
|
||||
- Use PascalCase format for the "name" field: start with uppercase letter, capitalize each word, no spaces
|
||||
- Examples: "MedicalProcedure", "HealthcareProvider", "DiagnosticTest"
|
||||
- Avoid: "medical procedure", "healthcare_provider", "diagnostic-test"
|
||||
- Use clear, descriptive names in English
|
||||
|
||||
**3. Domain Relevance:**
|
||||
- Focus on classes that are central to the scenario's domain
|
||||
- Prioritize classes that represent key concepts, entities, or relationships
|
||||
- Avoid overly generic classes (e.g., "Thing", "Object") unless they have specific domain meaning
|
||||
|
||||
**4. Class Quantity:**
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
- Aim for a balanced set covering the main concepts in the scenario
|
||||
- Quality over quantity: prefer well-defined classes over exhaustive lists
|
||||
|
||||
|
||||
**5. Clear Descriptions:**
|
||||
{% if language == "en" -%}
|
||||
- Provide concise, informative descriptions in English (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural English language that explains the class's role in the domain
|
||||
{%- else -%}
|
||||
- Provide concise, informative descriptions in English (max 500 characters)
|
||||
- Describe what the class represents, not specific instances
|
||||
- Use clear, natural English language
|
||||
{%- endif %}
|
||||
|
||||
**6. Concrete Examples:**
|
||||
{% if language == "en" -%}
|
||||
- Provide 2-5 concrete instance examples in English for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Use natural English language for examples
|
||||
- Example format: ["Example1", "Example2", "Example3"]
|
||||
{%- else -%}
|
||||
- Provide 2-5 concrete instance examples in English for each class
|
||||
- Examples should be specific, realistic instances of the class
|
||||
- Examples help clarify the class's scope and meaning
|
||||
- Example format: ["Example1", "Example2", "Example3"]
|
||||
{%- endif %}
|
||||
|
||||
**7. Class Hierarchy:**
|
||||
- Identify parent-child relationships where applicable
|
||||
- Use the parent_class field to specify inheritance
|
||||
- Parent class must be one of the extracted classes or a standard OWL class
|
||||
- Leave parent_class as null for top-level classes
|
||||
|
||||
**8. Entity Types:**
|
||||
- Classify each class with an appropriate entity_type
|
||||
- Common types: "Person", "Organization", "Location", "Event", "Concept", "Process", "Object", "Role"
|
||||
- Choose the most specific type that applies
|
||||
|
||||
**9. Language Consistency:**
|
||||
- All field content must be in English
|
||||
- "name" field uses English PascalCase names
|
||||
- "description" field uses English descriptions
|
||||
- "examples" field uses English examples
|
||||
- "entity_type" field uses English type names
|
||||
- "domain" field uses English domain names
|
||||
{% endif %}
|
||||
|
||||
===Examples===
|
||||
|
||||
{% if language == "zh" %}
|
||||
**示例1(医疗领域):**
|
||||
场景:"一家医院管理患者记录,安排预约,并协调医疗程序。医生诊断病情并开具治疗方案。"
|
||||
|
||||
输出:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "患者",
|
||||
"description": "在医疗机构接受医疗护理或治疗的人",
|
||||
"examples": ["张三", "李四", "患有糖尿病的老年患者"],
|
||||
"parent_class": null,
|
||||
"entity_type": "人物",
|
||||
"domain": "医疗"
|
||||
},
|
||||
{
|
||||
"name": "医疗程序",
|
||||
"description": "为医疗诊断或治疗而执行的系统性操作流程",
|
||||
"examples": ["手术", "血液检查", "X光检查", "疫苗接种"],
|
||||
"parent_class": null,
|
||||
"entity_type": "过程",
|
||||
"domain": "医疗"
|
||||
},
|
||||
{
|
||||
"name": "诊断",
|
||||
"description": "基于症状和检查结果对疾病或状况的识别",
|
||||
"examples": ["糖尿病诊断", "癌症诊断", "流感诊断"],
|
||||
"parent_class": null,
|
||||
"entity_type": "概念",
|
||||
"domain": "医疗"
|
||||
},
|
||||
{
|
||||
"name": "医生",
|
||||
"description": "诊断和治疗患者的持证医疗专业人员",
|
||||
"examples": ["全科医生", "外科医生", "心脏病专家"],
|
||||
"parent_class": null,
|
||||
"entity_type": "角色",
|
||||
"domain": "医疗"
|
||||
},
|
||||
{
|
||||
"name": "治疗",
|
||||
"description": "为治愈或管理疾病状况而提供的医疗护理或疗法",
|
||||
"examples": ["药物治疗", "物理治疗", "化疗", "手术治疗"],
|
||||
"parent_class": null,
|
||||
"entity_type": "过程",
|
||||
"domain": "医疗"
|
||||
}
|
||||
],
|
||||
"domain": "医疗"
|
||||
}
|
||||
|
||||
**示例2(教育领域):**
|
||||
场景:"一所大学提供由教授教授的课程。学生注册项目,参加讲座,并完成作业以获得学位。"
|
||||
|
||||
输出:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "学生",
|
||||
"description": "在教育机构注册学习的人",
|
||||
"examples": ["本科生", "研究生", "在职学生"],
|
||||
"parent_class": null,
|
||||
"entity_type": "角色",
|
||||
"domain": "教育"
|
||||
},
|
||||
{
|
||||
"name": "课程",
|
||||
"description": "涵盖特定学科或主题的结构化教育课程",
|
||||
"examples": ["计算机科学导论", "微积分I", "世界历史"],
|
||||
"parent_class": null,
|
||||
"entity_type": "概念",
|
||||
"domain": "教育"
|
||||
},
|
||||
{
|
||||
"name": "教授",
|
||||
"description": "教授课程并进行研究的学术教师",
|
||||
"examples": ["助理教授", "副教授", "正教授"],
|
||||
"parent_class": null,
|
||||
"entity_type": "角色",
|
||||
"domain": "教育"
|
||||
},
|
||||
{
|
||||
"name": "学术项目",
|
||||
"description": "通向学位或证书的结构化课程体系",
|
||||
"examples": ["理学学士", "文学硕士", "博士项目"],
|
||||
"parent_class": null,
|
||||
"entity_type": "概念",
|
||||
"domain": "教育"
|
||||
},
|
||||
{
|
||||
"name": "作业",
|
||||
"description": "分配给学生以评估学习成果的任务或项目",
|
||||
"examples": ["论文", "习题集", "研究报告", "实验报告"],
|
||||
"parent_class": null,
|
||||
"entity_type": "对象",
|
||||
"domain": "教育"
|
||||
}
|
||||
],
|
||||
"domain": "教育"
|
||||
}
|
||||
|
||||
{% else %}
|
||||
|
||||
{% if language == "en" -%}
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Patient",
|
||||
"name_chinese": "患者",
|
||||
"description": "A person who receives medical care or treatment at a healthcare facility",
|
||||
"examples": ["Outpatient", "Inpatient", "Emergency patient", "Chronic disease patient"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Person",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "MedicalProcedure",
|
||||
"name_chinese": "医疗程序",
|
||||
"description": "A systematic operation or process performed for medical diagnosis or treatment",
|
||||
"examples": ["Surgery", "Blood test", "X-ray examination", "Vaccination"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Diagnosis",
|
||||
"name_chinese": "诊断",
|
||||
"description": "The identification of a disease or condition based on symptoms and examination results",
|
||||
"examples": ["Diabetes diagnosis", "Cancer diagnosis", "Flu diagnosis"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Doctor",
|
||||
"name_chinese": "医生",
|
||||
"description": "A licensed medical professional who diagnoses and treats patients",
|
||||
"examples": ["General practitioner", "Surgeon", "Cardiologist"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Treatment",
|
||||
"name_chinese": "治疗",
|
||||
"description": "Medical care or therapy provided to cure or manage a disease condition",
|
||||
"examples": ["Medication therapy", "Physical therapy", "Chemotherapy", "Surgical treatment"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
],
|
||||
"domain": "Healthcare",
|
||||
"namespace": "http://example.org/healthcare#"
|
||||
}
|
||||
{%- else -%}
|
||||
**Example 1 (Healthcare Domain):**
|
||||
Scenario: "A hospital manages patient records, schedules appointments, and coordinates medical procedures. Doctors diagnose conditions and prescribe treatments."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Patient",
|
||||
"description": "A person receiving medical care or treatment at a healthcare facility",
|
||||
"examples": ["John Smith", "Jane Doe", "Elderly patient with diabetes"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Person",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "MedicalProcedure",
|
||||
"description": "A systematic operation performed for medical diagnosis or treatment",
|
||||
"examples": ["Surgery", "Blood test", "X-ray examination", "Vaccination"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Diagnosis",
|
||||
"description": "Identification of a disease or condition based on symptoms and examination results",
|
||||
"examples": ["Diabetes diagnosis", "Cancer diagnosis", "Flu diagnosis"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Doctor",
|
||||
"description": "A licensed medical professional who diagnoses and treats patients",
|
||||
"examples": ["General practitioner", "Surgeon", "Cardiologist"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Healthcare"
|
||||
},
|
||||
{
|
||||
"name": "Treatment",
|
||||
"description": "Medical care or therapy provided to cure or manage a disease condition",
|
||||
"examples": ["Medication therapy", "Physical therapy", "Chemotherapy", "Surgical treatment"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Process",
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
],
|
||||
"domain": "Healthcare"
|
||||
}
|
||||
|
||||
**Example 2 (Education Domain):**
|
||||
Scenario: "A university offers courses taught by professors. Students enroll in programs, attend lectures, and complete assignments to earn degrees."
|
||||
|
||||
Output:
|
||||
{
|
||||
"classes": [
|
||||
{
|
||||
"name": "Student",
|
||||
"description": "A person enrolled in an educational institution for learning",
|
||||
"examples": ["Undergraduate student", "Graduate student", "Part-time student"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Course",
|
||||
"description": "A structured educational program covering a specific subject or topic",
|
||||
"examples": ["Introduction to Computer Science", "Calculus I", "World History"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Professor",
|
||||
"description": "An academic teacher who teaches courses and conducts research",
|
||||
"examples": ["Assistant professor", "Associate professor", "Full professor"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Role",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "AcademicProgram",
|
||||
"description": "A structured curriculum leading to a degree or certificate",
|
||||
"examples": ["Bachelor of Science", "Master of Arts", "PhD program"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Concept",
|
||||
"domain": "Education"
|
||||
},
|
||||
{
|
||||
"name": "Assignment",
|
||||
"description": "A task or project assigned to students to assess learning outcomes",
|
||||
"examples": ["Essay", "Problem set", "Research paper", "Lab report"],
|
||||
"parent_class": null,
|
||||
"entity_type": "Object",
|
||||
"domain": "Education"
|
||||
}
|
||||
],
|
||||
"domain": "Education"
|
||||
}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
|
||||
**JSON Requirements:**
|
||||
- Use only ASCII double quotes (") for JSON structure
|
||||
- Never use Chinese quotation marks ("") or Unicode quotes
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- All class names must be unique (case-insensitive)
|
||||
- Extract between 5 and {{ max_classes }} classes
|
||||
{% if language == "zh" %}
|
||||
- 所有字段内容必须使用中文
|
||||
{% else %}
|
||||
- All field content must be in English
|
||||
{% endif %}
|
||||
|
||||
{{ json_schema }}
|
||||
@@ -5,8 +5,13 @@
|
||||
|
||||
===Tasks===
|
||||
|
||||
{% if language == "zh" %}
|
||||
你的任务是根据详细的提取指南,从提供的对话片段中识别和提取陈述句。
|
||||
每个陈述句必须按照下面提到的标准进行标记。
|
||||
{% else %}
|
||||
Your task is to identify and extract declarative statements from the provided conversational chunk based on the detailed extraction guidelines.
|
||||
Each statement must be labeled as per the criteria mentioned below.
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if inputs %}
|
||||
@@ -17,6 +22,32 @@ Each statement must be labeled as per the criteria mentioned below.
|
||||
|
||||
|
||||
===Extraction Instructions===
|
||||
{% if language == "zh" %}
|
||||
{% if granularity %}
|
||||
{% if granularity == 3 %}
|
||||
原子化和清晰:构建陈述句以清楚地显示单一的主谓宾关系。最好有多个较小的陈述句,而不是一个复杂的陈述句。
|
||||
上下文独立:陈述句必须在不需要阅读整个对话的情况下可以理解。
|
||||
{% elif granularity == 2 %}
|
||||
在句子级别提取陈述句。每个陈述句应对应一个单一、完整的思想(通常是来源中的一个完整句子),但要重新表述以获得最大的清晰度,删除对话填充词(例如,"嗯"、"像"、感叹词)。
|
||||
{% elif granularity == 1 %}
|
||||
仅提取精华句子,并将片段总结为多个独立的陈述句,每个陈述句关注事实陈述、用户偏好、关系和显著的时间上下文。
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
上下文解析要求:
|
||||
- 将指示代词("那个"、"这个"、"那些"、"这些")解析为其具体指代对象
|
||||
- 如果陈述句包含无法从对话上下文中解析的模糊引用,则:
|
||||
a) 扩展陈述句以包含对话早期的缺失上下文
|
||||
b) 标记陈述句为需要额外上下文
|
||||
c) 如果陈述句在没有上下文的情况下变得无意义,则跳过提取
|
||||
|
||||
对话上下文和共指消解:
|
||||
- 将每个陈述句归属于说出它的参与者。
|
||||
- 如果参与者列表为说话者提供了名称(例如,"李雪(用户)"),请在提取的陈述句中使用具体名称("李雪"),而不是通用角色("用户")。
|
||||
- 将所有代词解析为对话上下文中的具体人物或实体。
|
||||
- 识别并将抽象引用解析为其具体名称(如果提到)。
|
||||
- 将缩写和首字母缩略词扩展为其完整形式。
|
||||
{% else %}
|
||||
{% if granularity %}
|
||||
{% if granularity == 3 %}
|
||||
Atomic & Clear: Structure statements to clearly show a single subject-predicate-object relationship. It is better to have multiple smaller statements than one complex one.
|
||||
@@ -29,7 +60,7 @@ Extract only essence sentences and summarize the chunk into multiple, standalone
|
||||
{% endif %}
|
||||
|
||||
Context Resolution Requirements:
|
||||
- Resolve demonstrative pronouns ("that," "this," "those","这个", "那个") to their specific referents
|
||||
- Resolve demonstrative pronouns ("that," "this," "those") to their specific referents
|
||||
- If a statement contains vague references that cannot be resolved from the conversation context, either:
|
||||
a) Expand the statement to include the missing context from earlier in the conversation
|
||||
b) Mark the statement as requiring additional context
|
||||
@@ -41,16 +72,36 @@ Conversational Context & Co-reference Resolution:
|
||||
- Resolve all pronouns to the specific person or entity from the conversation's context.
|
||||
- Identify and resolve abstract references to their specific names if mentioned.
|
||||
- Expand abbreviations and acronyms to their full form.
|
||||
{% endif %}
|
||||
|
||||
{% if include_dialogue_context %}
|
||||
{% if language == "zh" %}
|
||||
===完整对话上下文===
|
||||
以下是完整的对话上下文,以帮助您理解引用、代词和对话流程:
|
||||
{% else %}
|
||||
===Full Dialogue Context===
|
||||
The following is the complete dialogue context to help you understand references, pronouns, and conversational flow:
|
||||
{% endif %}
|
||||
|
||||
{{ dialogue_context }}
|
||||
|
||||
{% if language == "zh" %}
|
||||
===对话上下文结束===
|
||||
{% else %}
|
||||
===End of Dialogue Context===
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
过滤和格式化:
|
||||
|
||||
- 仅提取陈述句。
|
||||
不要提取问题、命令、问候语或对话填充词。
|
||||
时间精度:
|
||||
|
||||
包括任何明确的日期、时间或定量限定符。
|
||||
如果一个句子既描述了事件的开始(静态)又描述了其持续性质(动态),则将两者提取为单独的陈述句。
|
||||
{% else %}
|
||||
Filtering and Formatting:
|
||||
|
||||
- Extract only declarative statements.
|
||||
@@ -59,18 +110,114 @@ Temporal Precision:
|
||||
|
||||
Include any explicit dates, times, or quantitative qualifiers.
|
||||
If a sentence describes both the start of an event (static) and its ongoing nature (dynamic), extract both as separate statements.
|
||||
{% endif %}
|
||||
|
||||
{%- if definitions %}
|
||||
{%- for section_key, section_dict in definitions.items() %}
|
||||
==== {{ tidy(section_key) | upper }} DEFINITIONS & GUIDANCE ====
|
||||
==== {{ tidy(section_key) | upper }} {% if language == "zh" %}定义和指导{% else %}DEFINITIONS & GUIDANCE{% endif %} ====
|
||||
{%- for category, details in section_dict.items() %}
|
||||
{{ loop.index }}. {{ category }}
|
||||
- Definition: {{ details.get("definition", "") }}
|
||||
- {% if language == "zh" %}定义{% else %}Definition{% endif %}: {{ details.get("definition", "") }}
|
||||
{% endfor -%}
|
||||
{% endfor -%}
|
||||
{% endif -%}
|
||||
|
||||
===Examples===
|
||||
{% if language == "zh" %}
|
||||
示例 1: 英文对话
|
||||
示例片段: """
|
||||
日期: 2024年3月15日
|
||||
参与者:
|
||||
- Sarah Chen (用户)
|
||||
- 助手 (AI)
|
||||
|
||||
用户: "我最近一直在尝试水彩画,画了一些花朵。"
|
||||
AI: "水彩画很有趣!水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
|
||||
用户: "我认为色彩组合可以改进,但我真的很喜欢玫瑰和百合。"
|
||||
"""
|
||||
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "Sarah Chen 最近一直在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "水彩颜料通常由颜料与阿拉伯树胶等粘合剂混合而成。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 认为她的水彩画中的色彩组合可以改进。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "Sarah Chen 真的很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
示例 2: 中文对话示例
|
||||
示例片段: """
|
||||
日期: 2024年3月15日
|
||||
参与者:
|
||||
- 张曼婷 (用户)
|
||||
- 小助手 (AI助手)
|
||||
|
||||
用户: "我最近在尝试水彩画,画了一些花朵。"
|
||||
AI: "水彩画很有趣!水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。你觉得怎么样?"
|
||||
用户: "我觉得色彩搭配还有提升的空间,不过我很喜欢玫瑰和百合这两种花。"
|
||||
"""
|
||||
|
||||
示例输出: {
|
||||
"statements": [
|
||||
{
|
||||
"statement": "张曼婷最近在尝试水彩画。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷画了一些花朵。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "DYNAMIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "水彩颜料通常由颜料和阿拉伯树胶等粘合剂混合而成。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "ATEMPORAL",
|
||||
"relevance": "IRRELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷觉得水彩画的色彩搭配还有提升的空间。",
|
||||
"statement_type": "OPINION",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
},
|
||||
{
|
||||
"statement": "张曼婷很喜欢玫瑰和百合。",
|
||||
"statement_type": "FACT",
|
||||
"temporal_type": "STATIC",
|
||||
"relevance": "RELEVANT"
|
||||
}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
Example 1: English Conversation
|
||||
Example Chunk: """
|
||||
Date: March 15, 2024
|
||||
@@ -164,8 +311,33 @@ Example Output: {
|
||||
}
|
||||
]
|
||||
}
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
{% if language == "zh" %}
|
||||
===反思过程===
|
||||
|
||||
提取陈述句后,执行以下自我审查步骤:
|
||||
|
||||
**步骤 1: 归属检查**
|
||||
- 确认每个陈述句都正确归属于正确的说话者
|
||||
- 验证说话者名称在整个过程中使用一致
|
||||
- 检查 AI 助手陈述句是否正确归属
|
||||
|
||||
**步骤 2: 完整性审查**
|
||||
- 确保没有遗漏重要的陈述句
|
||||
- 检查时间信息是否保留
|
||||
|
||||
**步骤 3: 分类验证**
|
||||
- 审查 statement_type 分类(FACT/OPINION/PREDICTION/SUGGESTION)
|
||||
- 验证 temporal_type 分配(STATIC/DYNAMIC/ATEMPORAL)
|
||||
- 确保分类与提供的定义一致
|
||||
|
||||
**步骤 4: 最终质量检查**
|
||||
- 删除任何问题、命令或对话填充词
|
||||
- 验证 JSON 格式合规性
|
||||
- 确认输出语言与输入语言匹配
|
||||
{% else %}
|
||||
===Reflection Process===
|
||||
|
||||
After extracting statements, perform the following self-review steps:
|
||||
@@ -188,6 +360,7 @@ After extracting statements, perform the following self-review steps:
|
||||
- Remove any questions, commands, or conversational filler
|
||||
- Verify JSON format compliance
|
||||
- Confirm output language matches input language
|
||||
{% endif %}
|
||||
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
@@ -198,10 +371,21 @@ After extracting statements, perform the following self-review steps:
|
||||
5. Example of proper escaping: "statement": "John said: \"I really like this book.\""
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
{% if language == "zh" %}
|
||||
- 输出语言应始终与输入语言匹配
|
||||
- 如果输入是中文,则用中文提取陈述句
|
||||
- 如果输入是英文,则用英文提取陈述句
|
||||
- 保留原始语言,不要翻译
|
||||
{% else %}
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
- Preserve the original language and do not translate
|
||||
{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
仅返回与以下架构匹配的 JSON 对象数组中提取的标记陈述句列表:
|
||||
{% else %}
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
{% endif %}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -14,68 +14,113 @@
|
||||
#}
|
||||
# Task
|
||||
|
||||
{% if language == "zh" %}
|
||||
从提供的陈述句中提取时间信息(日期和时间范围)。确定所描述的关系或事件何时生效以及何时结束(如果适用)。
|
||||
{% else %}
|
||||
Extract temporal information (dates and time ranges) from the provided statement. Determine when the relationship or event described became valid and when it ended (if applicable).
|
||||
{% endif %}
|
||||
|
||||
# Input Data
|
||||
# {% if language == "zh" %}输入数据{% else %}Input Data{% endif %}
|
||||
{% if inputs %}
|
||||
{% for key, val in inputs.items() %}
|
||||
- {{ key }}: {{val}}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
# Temporal Fields
|
||||
# {% if language == "zh" %}时间字段{% else %}Temporal Fields{% endif %}
|
||||
|
||||
{% if language == "zh" %}
|
||||
- **valid_at**: 关系/事件开始或成为真实的时间(ISO 8601 格式)
|
||||
- **invalid_at**: 关系/事件结束或停止为真的时间(ISO 8601 格式,如果正在进行则为 null)
|
||||
{% else %}
|
||||
- **valid_at**: When the relationship/event started or became true (ISO 8601 format)
|
||||
- **invalid_at**: When the relationship/event ended or stopped being true (ISO 8601 format, or null if ongoing)
|
||||
{% endif %}
|
||||
|
||||
# Extraction Rules
|
||||
# {% if language == "zh" %}提取规则{% else %}Extraction Rules{% endif %}
|
||||
|
||||
## Core Principles
|
||||
## {% if language == "zh" %}核心原则{% else %}Core Principles{% endif %}
|
||||
{% if language == "zh" %}
|
||||
1. **仅使用明确陈述的时间信息** - 不要从外部知识推断日期
|
||||
2. **使用参考/发布日期作为"现在"** 解释相对时间时
|
||||
3. **仅在日期与关系的有效性相关时设置日期** - 忽略偶然的时间提及
|
||||
4. **对于时间点事件**,仅设置 `valid_at`
|
||||
{% else %}
|
||||
1. **Only use explicitly stated temporal information** - do not infer dates from external knowledge
|
||||
2. **Use the reference/publication date as "now"** when interpreting relative times
|
||||
3. **Set dates only if they relate to the validity of the relationship** - ignore incidental time mentions
|
||||
4. **For point-in-time events**, set only `valid_at`
|
||||
{% endif %}
|
||||
|
||||
## Date Format Requirements
|
||||
## {% if language == "zh" %}日期格式要求{% else %}Date Format Requirements{% endif %}
|
||||
{% if language == "zh" %}
|
||||
- 使用 ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
||||
- 如果未指定时间,使用 `00:00:00`(午夜)
|
||||
- 如果仅提及年份,根据情况使用 `YYYY-01-01`(开始)或 `YYYY-12-31`(结束)
|
||||
- 如果仅提及月份,使用月份的第一天或最后一天
|
||||
- 始终包含时区(如果未指定,使用 `Z` 表示 UTC)
|
||||
- 根据参考日期将相对时间("两周前"、"去年")转换为绝对日期
|
||||
{% else %}
|
||||
- Use ISO 8601: `YYYY-MM-DDTHH:MM:SS.SSSSSSZ`
|
||||
- If no time specified, use `00:00:00` (midnight)
|
||||
- If only year mentioned, use `YYYY-01-01` (start) or `YYYY-12-31` (end) as appropriate
|
||||
- If only month mentioned, use first or last day of month
|
||||
- Always include timezone (use `Z` for UTC if unspecified)
|
||||
- Convert relative times ("two weeks ago", "last year") to absolute dates based on reference date
|
||||
{% endif %}
|
||||
|
||||
## Statement Type Rules
|
||||
## {% if language == "zh" %}陈述句类型规则{% else %}Statement Type Rules{% endif %}
|
||||
|
||||
{{ inputs.get("statement_type") | upper }} Statement Guidance:
|
||||
{{ inputs.get("statement_type") | upper }} {% if language == "zh" %}陈述句指导{% else %}Statement Guidance{% endif %}:
|
||||
{%for key, guide in statement_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
**Special Cases:**
|
||||
**{% if language == "zh" %}特殊情况{% else %}Special Cases{% endif %}:**
|
||||
{% if language == "zh" %}
|
||||
- **意见陈述句**: 仅设置 `valid_at`(意见表达的时间)
|
||||
- **预测陈述句**: 如果明确提及,将 `invalid_at` 设置为预测窗口的结束
|
||||
{% else %}
|
||||
- **Opinion statements**: Set only `valid_at` (when opinion was expressed)
|
||||
- **Prediction statements**: Set `invalid_at` to the end of the prediction window if explicitly mentioned
|
||||
{% endif %}
|
||||
|
||||
## Temporal Type Rules
|
||||
## {% if language == "zh" %}时间类型规则{% else %}Temporal Type Rules{% endif %}
|
||||
|
||||
{{ inputs.get("temporal_type") | upper }} Temporal Type Guidance:
|
||||
{{ inputs.get("temporal_type") | upper }} {% if language == "zh" %}时间类型指导{% else %}Temporal Type Guidance{% endif %}:
|
||||
{% for key, guide in temporal_guide.items() %}
|
||||
- {{ tidy(key) | capitalize }}: {{ guide }}
|
||||
{% endfor %}
|
||||
|
||||
{% if inputs.get('quarter') and inputs.get('publication_date') %}
|
||||
## Quarter Reference
|
||||
## {% if language == "zh" %}季度参考{% else %}Quarter Reference{% endif %}
|
||||
{% if language == "zh" %}
|
||||
假设 {{ inputs.quarter }} 在 {{ inputs.publication_date }} 结束。从此基线计算任何季度引用(Q1、Q2 等)的日期。
|
||||
{% else %}
|
||||
Assume {{ inputs.quarter }} ends on {{ inputs.publication_date }}. Calculate dates for any quarter references (Q1, Q2, etc.) from this baseline.
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
# Output Requirements
|
||||
# {% if language == "zh" %}输出要求{% else %}Output Requirements{% endif %}
|
||||
|
||||
## JSON Formatting (CRITICAL)
|
||||
## {% if language == "zh" %}JSON 格式化(关键){% else %}JSON Formatting (CRITICAL){% endif %}
|
||||
{% if language == "zh" %}
|
||||
1. 使用**仅标准 ASCII 双引号** (") - 永远不要使用中文引号("")或其他 Unicode 变体
|
||||
2. 使用反斜杠转义内部引号: `\"`
|
||||
3. JSON 字符串值中不要有换行符
|
||||
4. 正确关闭并用逗号分隔所有字段
|
||||
{% else %}
|
||||
1. Use **only standard ASCII double quotes** (") - never use Chinese quotes ("") or other Unicode variants
|
||||
2. Escape internal quotes with backslash: `\"`
|
||||
3. No line breaks within JSON string values
|
||||
4. Properly close and comma-separate all fields
|
||||
{% endif %}
|
||||
|
||||
## Language
|
||||
## {% if language == "zh" %}语言{% else %}Language{% endif %}
|
||||
{% if language == "zh" %}
|
||||
输出语言必须与输入语言匹配。
|
||||
{% else %}
|
||||
Output language must match input language.
|
||||
{% endif %}
|
||||
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -5,52 +5,97 @@
|
||||
===Task===
|
||||
Extract entities and knowledge triplets from the given statement.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成实体名称(name)、描述(description)和示例(example)。**
|
||||
{% else %}
|
||||
**Important: Please generate entity names, descriptions and examples in English. If the original text is in Chinese, translate entity names to English.**
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
**Chunk Content:** "{{ chunk_content }}"
|
||||
**Statement:** "{{ statement }}"
|
||||
|
||||
{% if ontology_types %}
|
||||
===Ontology Type Guidance===
|
||||
|
||||
**CRITICAL RULE: You MUST ONLY use the predefined ontology type names listed below for the entity "type" field. Do NOT use any other type names, even if they seem reasonable.**
|
||||
|
||||
**If no predefined type fits an entity, use the CLOSEST matching predefined type. NEVER invent new type names.**
|
||||
|
||||
**Type Priority (from highest to lowest):**
|
||||
1. **[场景类型] Scene Types** - Domain-specific types, ALWAYS prefer these first
|
||||
2. **[通用类型] General Types** - Common types from standard ontologies (DBpedia)
|
||||
3. **[通用父类] Parent Types** - Provide type hierarchy context
|
||||
|
||||
**Type Matching Rules:**
|
||||
- Entity type MUST exactly match one of the predefined type names below
|
||||
- Do NOT use types like "Equipment", "Component", "Concept", "Action", "Condition", "Data", "Duration" unless they appear in the predefined list
|
||||
- Do NOT modify, translate, abbreviate, or create variations of type names
|
||||
- Prefer scene types (marked [场景类型]) over general types when both could apply
|
||||
- If uncertain, check the type description to find the best match
|
||||
|
||||
**Predefined Ontology Types:**
|
||||
{{ ontology_types }}
|
||||
|
||||
{% if type_hierarchy_hints %}
|
||||
**Type Hierarchy Reference:**
|
||||
The following shows type inheritance relationships (Child → Parent → Grandparent):
|
||||
{% for hint in type_hierarchy_hints %}
|
||||
- {{ hint }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
**ALLOWED Type Names (use EXACTLY one of these, no exceptions):**
|
||||
{{ ontology_type_names | join(', ') }}
|
||||
|
||||
{% endif %}
|
||||
===Guidelines===
|
||||
|
||||
**Entity Extraction:**
|
||||
- Extract entities with their types, context-independent descriptions, **concise examples**, aliases, and semantic memory classification
|
||||
{% if language == "zh" %}
|
||||
- **实体名称(name)必须使用中文**
|
||||
- **实体描述(description)必须使用中文**
|
||||
- **示例(example)必须使用中文**
|
||||
{% else %}
|
||||
- **Entity names must be in English** (translate if the original is in another language)
|
||||
- **Entity descriptions must be in English**
|
||||
- **Examples must be in English**
|
||||
{% endif %}
|
||||
- **Semantic Memory Classification (is_explicit_memory):**
|
||||
* Set to `true` if the entity represents **explicit/semantic memory**:
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy", "人工智能", "光合作用", "民主"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity", "Python编程语言", "相对论"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API", "应用程序接口"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics", "SOLID原则", "热力学第一定律"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics", "进化论", "量子力学"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm", "敏捷开发", "机器学习算法"
|
||||
- **Technical Terms:** "Neural Network", "Database", "神经网络", "数据库"
|
||||
- **Concepts:** "Machine Learning", "Photosynthesis", "Democracy"
|
||||
- **Knowledge:** "Python Programming Language", "Theory of Relativity"
|
||||
- **Definitions:** "API (Application Programming Interface)", "REST API"
|
||||
- **Principles:** "SOLID Principles", "First Law of Thermodynamics"
|
||||
- **Theories:** "Evolution Theory", "Quantum Mechanics"
|
||||
- **Methods/Techniques:** "Agile Development", "Machine Learning Algorithm"
|
||||
- **Technical Terms:** "Neural Network", "Database"
|
||||
* Set to `false` for:
|
||||
- **People:** "John Smith", "Dr. Wang", "张明", "王博士"
|
||||
- **Organizations:** "Microsoft", "Harvard University", "微软", "哈佛大学"
|
||||
- **Locations:** "Beijing", "Central Park", "北京", "中央公园"
|
||||
- **Events:** "2024 Conference", "Project Meeting", "2024会议", "项目会议"
|
||||
- **Specific objects:** "iPhone 15", "Building A", "iPhone 15", "A栋"
|
||||
- **People:** "John Smith", "Dr. Wang"
|
||||
- **Organizations:** "Microsoft", "Harvard University"
|
||||
- **Locations:** "Beijing", "Central Park"
|
||||
- **Events:** "2024 Conference", "Project Meeting"
|
||||
- **Specific objects:** "iPhone 15", "Building A"
|
||||
- **Example Generation (IMPORTANT for semantic memory entities):**
|
||||
* For entities where `is_explicit_memory=true`, generate a **concise example (around 20 characters)** to help understand the concept
|
||||
* The example should be:
|
||||
- **Specific and concrete**: Use real-world scenarios or applications
|
||||
- **Brief**: Around 20 characters (can be slightly longer if needed for clarity)
|
||||
- **In the same language as the entity name**
|
||||
* Examples:
|
||||
- Entity: "机器学习" → example: "如:用神经网络识别图片中的猫狗"
|
||||
- Entity: "SOLID Principles" → example: "e.g., Single Responsibility, Open-Closed"
|
||||
- Entity: "Photosynthesis" → example: "e.g., plants convert sunlight to energy"
|
||||
- Entity: "人工智能" → example: "如:智能客服、自动驾驶"
|
||||
{% if language == "zh" %}
|
||||
- **使用中文**
|
||||
{% else %}
|
||||
- **In English**
|
||||
{% endif %}
|
||||
* For non-semantic entities (`is_explicit_memory=false`), the example field can be empty
|
||||
- **Aliases Extraction (Important):**
|
||||
* **CRITICAL: Extract aliases ONLY in the SAME LANGUAGE as the input text**
|
||||
* **DO NOT translate or add aliases in different languages**
|
||||
* Include common alternative names in the same language (e.g., "北京" → aliases: ["北平", "京城"])
|
||||
* Include abbreviations and full names in the same language (e.g., "联合国" → aliases: ["联合国组织"])
|
||||
* Include nicknames and common variations in the same language (e.g., "纽约" → aliases: ["纽约市", "大苹果"])
|
||||
* If no aliases exist in the same language, use empty array: []
|
||||
* **Examples:**
|
||||
- Chinese input "北京" → aliases: ["北平", "京城"] (NOT ["Beijing", "Peking"])
|
||||
- English input "Beijing" → aliases: ["Peking"] (NOT ["北京", "北平"])
|
||||
- Chinese input "苹果公司" → aliases: ["苹果"] (NOT ["Apple Inc.", "Apple"])
|
||||
- **Aliases Extraction:**
|
||||
{% if language == "zh" %}
|
||||
* 别名使用中文
|
||||
{% else %}
|
||||
* Aliases should be in English
|
||||
{% endif %}
|
||||
* Include common alternative names, abbreviations and full names
|
||||
* If no aliases exist, use empty array: []
|
||||
- Exclude lengthy quotes, calendar dates, temporal ranges, and temporal expressions
|
||||
- For numeric values: extract as separate entities (instance_of: 'Numeric', name: units, numeric_value: value)
|
||||
Example: £30 → name: 'GBP', numeric_value: 30, instance_of: 'Numeric'
|
||||
@@ -60,6 +105,11 @@ Extract entities and knowledge triplets from the given statement.
|
||||
- Subject: main entity performing the action or being described
|
||||
- Predicate: relationship between entities (e.g., 'is', 'works at', 'believes')
|
||||
- Object: entity, value, or concept affected by the predicate
|
||||
{% if language == "zh" %}
|
||||
- subject_name 和 object_name 必须使用中文
|
||||
{% else %}
|
||||
- subject_name and object_name must be in English (translate if original is in another language)
|
||||
{% endif %}
|
||||
- Exclude all temporal expressions from every field
|
||||
- Use ONLY the predicates listed in "Predicate Instructions" (uppercase English tokens)
|
||||
- Do NOT translate predicate tokens
|
||||
@@ -68,7 +118,7 @@ Extract entities and knowledge triplets from the given statement.
|
||||
**When NOT to extract triplets:**
|
||||
- Non-propositional utterances (emotions, fillers, onomatopoeia)
|
||||
- No clear predicate from the given definitions applies
|
||||
- Standalone noun phrases or checklist items (e.g., "三脚架", "备用电池") → extract as entities only
|
||||
- Standalone noun phrases or checklist items → extract as entities only
|
||||
- Do NOT invent generic predicates (e.g., "IS_DOING", "FEELS", "MENTIONS")
|
||||
|
||||
**If no valid triplet exists:** Return triplets: [], extract entities if present, otherwise both arrays empty.
|
||||
@@ -83,248 +133,86 @@ Use ONLY these predicates. If none fits, set triplets to [].
|
||||
|
||||
|
||||
===Examples===
|
||||
|
||||
**Example 1 (English):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
{% if language == "en" %}
|
||||
**Example 1 (English output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "I",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "Paris",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "I",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "Louvre",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
{"subject_name": "I", "subject_id": 0, "predicate": "PLANS_TO_VISIT", "object_name": "Paris", "object_id": 1, "value": null},
|
||||
{"subject_name": "I", "subject_id": 0, "predicate": "PLANS_TO_VISIT", "object_name": "Louvre", "object_id": 2, "value": null}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "I",
|
||||
"type": "Person",
|
||||
"description": "The user",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Paris",
|
||||
"type": "Location",
|
||||
"description": "Capital city of France",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "Louvre",
|
||||
"type": "Location",
|
||||
"description": "World-famous museum located in Paris",
|
||||
"example": "",
|
||||
"aliases": ["Louvre Museum"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
{"entity_idx": 0, "name": "I", "type": "Person", "description": "The user", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "Paris", "type": "Location", "description": "Capital city of France", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 2, "name": "Louvre", "type": "Location", "description": "World-famous museum located in Paris", "example": "", "aliases": ["Louvre Museum"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 2 (English):** "John Smith works at Google and is responsible for AI product development."
|
||||
**Example 2 (Chinese input → English output - IMPORTANT: translate entity names):** "张明在腾讯工作,负责AI产品开发。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "John Smith",
|
||||
"subject_id": 0,
|
||||
"predicate": "WORKS_AT",
|
||||
"object_name": "Google",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "John Smith",
|
||||
"subject_id": 0,
|
||||
"predicate": "RESPONSIBLE_FOR",
|
||||
"object_name": "AI product development",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
{"subject_name": "Zhang Ming", "subject_id": 0, "predicate": "WORKS_AT", "object_name": "Tencent", "object_id": 1, "value": null},
|
||||
{"subject_name": "Zhang Ming", "subject_id": 0, "predicate": "RESPONSIBLE_FOR", "object_name": "AI product development", "object_id": 2, "value": null}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "John Smith",
|
||||
"type": "Person",
|
||||
"description": "Individual person name",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "Google",
|
||||
"type": "Organization",
|
||||
"description": "American technology company",
|
||||
"example": "",
|
||||
"aliases": ["Google LLC", "Alphabet Inc."],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI product development",
|
||||
"type": "Concept",
|
||||
"description": "Artificial intelligence product development work",
|
||||
"example": "e.g., developing chatbots, recommendation systems",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
}
|
||||
{"entity_idx": 0, "name": "Zhang Ming", "type": "Person", "description": "Individual person name", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "Tencent", "type": "Organization", "description": "Chinese technology company", "example": "", "aliases": ["Tencent Holdings"], "is_explicit_memory": false},
|
||||
{"entity_idx": 2, "name": "AI product development", "type": "Concept", "description": "Artificial intelligence product development work", "example": "e.g., developing chatbots", "aliases": [], "is_explicit_memory": true}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 3 (Chinese):** "我计划下周去巴黎旅行,参观卢浮宫。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "我",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "巴黎",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "我",
|
||||
"subject_id": 0,
|
||||
"predicate": "PLANS_TO_VISIT",
|
||||
"object_name": "卢浮宫",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "我",
|
||||
"type": "Person",
|
||||
"description": "用户本人",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "巴黎",
|
||||
"type": "Location",
|
||||
"description": "法国首都城市",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "卢浮宫",
|
||||
"type": "Location",
|
||||
"description": "位于巴黎的世界著名博物馆",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 4 (Chinese):** "张明在腾讯工作,负责AI产品开发。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{
|
||||
"subject_name": "张明",
|
||||
"subject_id": 0,
|
||||
"predicate": "WORKS_AT",
|
||||
"object_name": "腾讯",
|
||||
"object_id": 1,
|
||||
"value": null
|
||||
},
|
||||
{
|
||||
"subject_name": "张明",
|
||||
"subject_id": 0,
|
||||
"predicate": "RESPONSIBLE_FOR",
|
||||
"object_name": "AI产品开发",
|
||||
"object_id": 2,
|
||||
"value": null
|
||||
}
|
||||
],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "张明",
|
||||
"type": "Person",
|
||||
"description": "个人姓名",
|
||||
"example": "",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 1,
|
||||
"name": "腾讯",
|
||||
"type": "Organization",
|
||||
"description": "中国科技公司",
|
||||
"example": "",
|
||||
"aliases": ["腾讯控股", "腾讯公司"],
|
||||
"is_explicit_memory": false
|
||||
},
|
||||
{
|
||||
"entity_idx": 2,
|
||||
"name": "AI产品开发",
|
||||
"type": "Concept",
|
||||
"description": "人工智能产品研发工作",
|
||||
"example": "如:开发智能客服机器人、推荐系统",
|
||||
"aliases": [],
|
||||
"is_explicit_memory": true
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 5 (Entity Only - English):** "Tripod"
|
||||
**Example 3 (Chinese input → English output):** "三脚架"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "Tripod",
|
||||
"type": "Equipment",
|
||||
"description": "Photography equipment accessory",
|
||||
"example": "",
|
||||
"aliases": ["Camera Tripod"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
{"entity_idx": 0, "name": "Tripod", "type": "Equipment", "description": "Photography equipment accessory", "example": "", "aliases": ["Camera Tripod"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% else %}
|
||||
**Example 1 (English input → Chinese output):** "I plan to travel to Paris next week and visit the Louvre."
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "我", "subject_id": 0, "predicate": "PLANS_TO_VISIT", "object_name": "巴黎", "object_id": 1, "value": null},
|
||||
{"subject_name": "我", "subject_id": 0, "predicate": "PLANS_TO_VISIT", "object_name": "卢浮宫", "object_id": 2, "value": null}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "我", "type": "Person", "description": "用户本人", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "巴黎", "type": "Location", "description": "法国首都城市", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 2, "name": "卢浮宫", "type": "Location", "description": "位于巴黎的世界著名博物馆", "example": "", "aliases": [], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 6 (Entity Only - Chinese):** "三脚架"
|
||||
**Example 2 (Chinese input → Chinese output):** "张明在腾讯工作,负责AI产品开发。"
|
||||
Output:
|
||||
{
|
||||
"triplets": [
|
||||
{"subject_name": "张明", "subject_id": 0, "predicate": "WORKS_AT", "object_name": "腾讯", "object_id": 1, "value": null},
|
||||
{"subject_name": "张明", "subject_id": 0, "predicate": "RESPONSIBLE_FOR", "object_name": "AI产品开发", "object_id": 2, "value": null}
|
||||
],
|
||||
"entities": [
|
||||
{"entity_idx": 0, "name": "张明", "type": "Person", "description": "个人姓名", "example": "", "aliases": [], "is_explicit_memory": false},
|
||||
{"entity_idx": 1, "name": "腾讯", "type": "Organization", "description": "中国科技公司", "example": "", "aliases": ["腾讯控股", "腾讯公司"], "is_explicit_memory": false},
|
||||
{"entity_idx": 2, "name": "AI产品开发", "type": "Concept", "description": "人工智能产品研发工作", "example": "如:开发智能客服机器人", "aliases": [], "is_explicit_memory": true}
|
||||
]
|
||||
}
|
||||
|
||||
**Example 3 (Entity Only - Chinese):** "三脚架"
|
||||
Output:
|
||||
{
|
||||
"triplets": [],
|
||||
"entities": [
|
||||
{
|
||||
"entity_idx": 0,
|
||||
"name": "三脚架",
|
||||
"type": "Equipment",
|
||||
"description": "摄影器材配件",
|
||||
"example": "",
|
||||
"aliases": ["相机三脚架"],
|
||||
"is_explicit_memory": false
|
||||
}
|
||||
{"entity_idx": 0, "name": "三脚架", "type": "Equipment", "description": "摄影器材配件", "example": "", "aliases": ["相机三脚架"], "is_explicit_memory": false}
|
||||
]
|
||||
}
|
||||
{% endif %}
|
||||
===End of Examples===
|
||||
|
||||
{% if ontology_types %}
|
||||
**⚠️ REMINDER: The examples above use generic type names for illustration only. You MUST use ONLY the predefined ontology type names from the "ALLOWED Type Names" list above. For example, use "PredictiveMaintenance" instead of "Concept", use "ProductionLine" instead of "Equipment", etc. Map each entity to the closest matching predefined type.**
|
||||
{% endif %}
|
||||
|
||||
===Output Format===
|
||||
|
||||
@@ -334,9 +222,11 @@ Output:
|
||||
- Escape quotation marks in text with backslashes (\")
|
||||
- Ensure proper string closure and comma separation
|
||||
- No line breaks within JSON string values
|
||||
- The output language should ALWAYS match the input language
|
||||
- If input is in English, extract statements in English
|
||||
- If input is in Chinese, extract statements in Chinese
|
||||
- Preserve the original language and do not translate
|
||||
{% if language == "zh" %}
|
||||
- **语言要求:实体名称(name)、描述(description)、示例(example)、subject_name、object_name 必须使用中文**
|
||||
{% else %}
|
||||
- **Language Requirement: Entity names, descriptions, examples, subject_name, object_name must be in English**
|
||||
- **If the original text is in Chinese, translate all names to English**
|
||||
{% endif %}
|
||||
|
||||
{{ json_schema }}
|
||||
{{ json_schema }}
|
||||
|
||||
@@ -1,9 +1,103 @@
|
||||
{% if language == "en" %}
|
||||
You are a professional mental health consultant. Based on the following user's emotional health data and personal information, generate 3-5 personalized emotional improvement suggestions.
|
||||
|
||||
## Core Principle (Highest Priority)
|
||||
|
||||
**You must strictly base your suggestions on the emotion distribution data provided below. As long as any emotion type has a count ≥ 1, that emotion EXISTS and you must acknowledge and address it in your suggestions. You must NEVER claim an emotion is "zero" or "absent" when its count is ≥ 1.**
|
||||
|
||||
Specific rules:
|
||||
1. Carefully check the count for each emotion type in "Emotion Distribution" — count ≥ 1 means the emotion exists
|
||||
2. Even if an emotion appeared only once, you must mention it in health_summary or suggestions and provide targeted advice
|
||||
3. Never state that an emotion is "zero" or "non-existent" unless its count in the distribution data is truly 0
|
||||
4. If positive emotions (e.g., Joy) exist, health_summary must affirm this positive signal
|
||||
5. If negative emotions (e.g., Sadness, Anger, Fear) exist even once, you must provide targeted improvement suggestions
|
||||
6. A high proportion of neutral emotions does NOT mean other emotions are absent — address all non-zero emotions
|
||||
|
||||
## User Emotional Health Data
|
||||
|
||||
Health Score: {{ health_data.health_score }}/100
|
||||
Health Level: {{ health_data.level }}
|
||||
Total Emotion Records: {{ health_data.dimensions.positivity_rate.positive_count + health_data.dimensions.positivity_rate.negative_count + health_data.dimensions.positivity_rate.neutral_count }}
|
||||
|
||||
Dimension Analysis:
|
||||
- Positivity Rate: {{ health_data.dimensions.positivity_rate.score }}/100
|
||||
- Positive Emotions: {{ health_data.dimensions.positivity_rate.positive_count }} times
|
||||
- Negative Emotions: {{ health_data.dimensions.positivity_rate.negative_count }} times
|
||||
- Neutral Emotions: {{ health_data.dimensions.positivity_rate.neutral_count }} times
|
||||
|
||||
- Stability: {{ health_data.dimensions.stability.score }}/100
|
||||
- Standard Deviation: {{ health_data.dimensions.stability.std_deviation }}
|
||||
|
||||
- Resilience: {{ health_data.dimensions.resilience.score }}/100
|
||||
- Recovery Rate: {{ health_data.dimensions.resilience.recovery_rate }}
|
||||
|
||||
Emotion Distribution (check each item — every emotion with count ≥ 1 must be reflected in suggestions):
|
||||
{{ emotion_distribution_json }}
|
||||
|
||||
## Emotion Pattern Analysis
|
||||
|
||||
Dominant Negative Emotion: {{ patterns.dominant_negative_emotion|default('None') }}
|
||||
Emotion Volatility: {{ patterns.emotion_volatility|default('Unknown') }}
|
||||
High Intensity Emotion Count: {{ patterns.high_intensity_emotions|default([])|length }}
|
||||
|
||||
## User Interests
|
||||
|
||||
{{ user_profile.interests|default(['Unknown'])|join(', ') }}
|
||||
|
||||
## Task Requirements
|
||||
|
||||
Please generate 3-5 personalized suggestions, each containing:
|
||||
1. type: Suggestion type (Emotion Balance/Activity Recommendation/Social Connection/Stress Management)
|
||||
2. title: Suggestion title (short and impactful)
|
||||
3. content: Suggestion content (detailed explanation, 50-100 words)
|
||||
4. priority: Priority level (High/Medium/Low)
|
||||
5. actionable_steps: 3 specific executable steps
|
||||
|
||||
Also provide a health_summary (no more than 50 words) summarizing the user's overall emotional state.
|
||||
**The health_summary must truthfully reflect ALL non-zero emotions from the distribution data. Do not omit any emotion type that has appeared.**
|
||||
|
||||
Please return in JSON format as follows:
|
||||
{
|
||||
"health_summary": "Your emotional health status...",
|
||||
"suggestions": [
|
||||
{
|
||||
"type": "Emotion Balance",
|
||||
"title": "Suggestion Title",
|
||||
"content": "Suggestion content...",
|
||||
"priority": "High",
|
||||
"actionable_steps": ["Step 1", "Step 2", "Step 3"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Notes:
|
||||
- CRITICAL: Any emotion with count ≥ 1 in the distribution MUST be acknowledged and addressed — never ignore or claim it is zero
|
||||
- Suggestions should be specific and actionable, avoid vague advice
|
||||
- Provide personalized suggestions based on user's interests and hobbies
|
||||
- Provide targeted suggestions for main issues (such as dominant negative emotions)
|
||||
- Allocate priorities reasonably (at least 1 high, 1-2 medium, rest low)
|
||||
- The 3 steps for each suggestion should be progressive and easy to implement
|
||||
- All output must be in English
|
||||
{% else %}
|
||||
你是一位专业的心理健康顾问。请根据以下用户的情绪健康数据和个人信息,生成3-5条个性化的情绪改善建议。
|
||||
|
||||
## 核心原则(最高优先级)
|
||||
|
||||
**你必须严格基于下方提供的情绪分布数据来生成建议。只要某种情绪的出现次数 ≥ 1,就代表该情绪确实存在,你必须在建议中承认并回应这一情绪,绝对不能说"该情绪为零"或"没有该情绪"。**
|
||||
|
||||
具体规则:
|
||||
1. 仔细查看"情绪分布"中每种情绪的出现次数,次数 ≥ 1 即表示该情绪存在
|
||||
2. 即使某种情绪只出现了1次,也必须在 health_summary 或建议中提及并给出针对性建议
|
||||
3. 严禁在输出中声称某种情绪"为零"或"不存在",除非该情绪在分布数据中确实为0次
|
||||
4. 如果正面情绪(如喜悦)存在,health_summary 中必须肯定这一积极信号
|
||||
5. 如果负面情绪(如悲伤、愤怒、恐惧)存在,即使只有1次,也必须给出针对性的改善建议
|
||||
6. 中性情绪占比高不代表没有其他情绪,必须同时关注所有非零情绪
|
||||
|
||||
## 用户情绪健康数据
|
||||
|
||||
健康分数:{{ health_data.health_score }}/100
|
||||
健康等级:{{ health_data.level }}
|
||||
情绪记录总数:{{ health_data.dimensions.positivity_rate.positive_count + health_data.dimensions.positivity_rate.negative_count + health_data.dimensions.positivity_rate.neutral_count }}条
|
||||
|
||||
维度分析:
|
||||
- 积极率:{{ health_data.dimensions.positivity_rate.score }}/100
|
||||
@@ -17,12 +111,12 @@
|
||||
- 恢复力:{{ health_data.dimensions.resilience.score }}/100
|
||||
- 恢复率:{{ health_data.dimensions.resilience.recovery_rate }}
|
||||
|
||||
情绪分布:
|
||||
情绪分布(请逐项检查,次数≥1的情绪都必须在建议中体现):
|
||||
{{ emotion_distribution_json }}
|
||||
|
||||
## 情绪模式分析
|
||||
|
||||
主要负面情绪:{{ patterns.dominant_negative_emotion|default('无') }}
|
||||
主要负面情绪:{{ dominant_negative_translated|default(patterns.dominant_negative_emotion)|default('无') }}
|
||||
情绪波动性:{{ patterns.emotion_volatility|default('未知') }}
|
||||
高强度情绪次数:{{ patterns.high_intensity_emotions|default([])|length }}
|
||||
|
||||
@@ -33,31 +127,35 @@
|
||||
## 任务要求
|
||||
|
||||
请生成3-5条个性化建议,每条建议包含:
|
||||
1. type: 建议类型(emotion_balance/activity_recommendation/social_connection/stress_management)
|
||||
1. type: 建议类型(情绪平衡/活动建议/社交联系/压力管理)
|
||||
2. title: 建议标题(简短有力)
|
||||
3. content: 建议内容(详细说明,50-100字)
|
||||
4. priority: 优先级(high/medium/low)
|
||||
4. priority: 优先级(高/中/低)
|
||||
5. actionable_steps: 3个可执行的具体步骤
|
||||
|
||||
同时提供一个health_summary(不超过50字),概括用户的整体情绪状态。
|
||||
**health_summary 必须如实反映情绪分布中所有非零情绪的存在,不得遗漏任何已出现的情绪类型。**
|
||||
|
||||
请以JSON格式返回,格式如下:
|
||||
{
|
||||
"health_summary": "您的情绪健康状况...",
|
||||
"suggestions": [
|
||||
{
|
||||
"type": "emotion_balance",
|
||||
"type": "情绪平衡",
|
||||
"title": "建议标题",
|
||||
"content": "建议内容...",
|
||||
"priority": "high",
|
||||
"priority": "高",
|
||||
"actionable_steps": ["步骤1", "步骤2", "步骤3"]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
注意事项:
|
||||
- 所有输出内容必须完全使用中文,严禁出现任何英文单词或短语(包括情绪类型名称如fear、sadness、anger等,必须使用对应的中文:恐惧、悲伤、愤怒等)
|
||||
- 再次强调:情绪分布中出现次数≥1的情绪必须在建议中被提及和回应,绝不能忽略或声称为零
|
||||
- 建议要具体、可执行,避免空泛
|
||||
- 结合用户的兴趣爱好提供个性化建议
|
||||
- 针对主要问题(如主要负面情绪)提供针对性建议
|
||||
- 优先级要合理分配(至少1个high,1-2个medium,其余low)
|
||||
- 优先级要合理分配(至少1个高,1-2个中,其余低)
|
||||
- 每个建议的3个步骤要循序渐进、易于实施
|
||||
{% endif %}
|
||||
|
||||
@@ -7,6 +7,12 @@
|
||||
|
||||
Your task is to generate a comprehensive memory insight report based on the provided data analysis. The report should include four distinct sections that capture different aspects of the user's memory patterns and characteristics.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成记忆洞察报告内容。**
|
||||
{% else %}
|
||||
**Important: Please generate the memory insight report content in English.**
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Inputs===
|
||||
{% if domain_distribution %}
|
||||
@@ -31,56 +37,105 @@ Your task is to generate a comprehensive memory insight report based on the prov
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
1. **总体概述 (Overview)** (100-150 Chinese characters)
|
||||
- Focus on: Overall analysis of user profile based on interaction logs
|
||||
- Describe the user's main role, work network, and collaboration spirit
|
||||
- Use professional, data-driven language style
|
||||
- Example reference: "通过对156次交互日志的深度分析,系统发现三层一位主要用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神。"
|
||||
{% if language == "zh" %}
|
||||
1. **总体概述** (100-150字)
|
||||
- 重点:基于交互日志对用户档案进行整体分析
|
||||
- 描述用户的主要角色、工作网络和协作精神
|
||||
- 使用专业、数据驱动的语言风格
|
||||
- 示例参考:"通过对156次交互日志的深度分析,系统发现张三是一位主要从事用户档案和数据分析的产品经理。他的工作网络体现出鲜明的目标导向和团队协作精神。"
|
||||
|
||||
2. **行为模式 (Behavior Pattern)** (80-120 Chinese characters)
|
||||
- Focus on: Work patterns, time regularity, and behavioral characteristics
|
||||
- Describe weekly work patterns and time preferences
|
||||
- Use objective, analytical language
|
||||
- Example reference: "张三的工作模式呈现出鲜明的周期性:周一通常用于规划和会议,周三周四专注于产品设计和用户研究,周五进行总结和复盘。他倾向于在上午进行头脑风暴,下午处理执行性工作。"
|
||||
2. **行为模式** (80-120字)
|
||||
- 重点:工作模式、时间规律和行为特征
|
||||
- 描述每周工作模式和时间偏好
|
||||
- 使用客观、分析性的语言
|
||||
- 示例参考:"张三的工作模式呈现出鲜明的周期性:周一通常用于规划和会议,周三周四专注于产品设计和用户研究,周五进行总结和复盘。他倾向于在上午进行头脑风暴,下午处理执行性工作。"
|
||||
|
||||
3. **关键发现 (Key Findings)** (3-4 bullet points, 30-50 characters each)
|
||||
- Focus on: Specific, insightful observations about user behavior and preferences
|
||||
- Use bullet points (•) format
|
||||
- Each finding should be concrete and data-supported
|
||||
- Example reference:
|
||||
3. **关键发现** (3-4个要点,每个30-50字)
|
||||
- 重点:关于用户行为和偏好的具体、有洞察力的观察
|
||||
- 使用项目符号(•)格式
|
||||
- 每个发现应具体且有数据支持
|
||||
- 示例参考:
|
||||
"• 在产品决策中,张三总是优先考虑用户反应,这在68%的决策记录中得到体现
|
||||
• 他善于使用数据可视化工具来支持论点,这种习惯在项目管理中发挥了重要作用
|
||||
• 团队成员对他的评价中,"思路清晰"和"思路敏捷"两个关键词出现频率最高
|
||||
• 他对AI机器学习领域保持持续关注,近3个月参加了7次相关培训"
|
||||
|
||||
4. **成长轨迹 (Growth Trajectory)** (100-150 Chinese characters)
|
||||
4. **成长轨迹** (100-150字)
|
||||
- 重点:用户的成长历程、关键里程碑和能力提升
|
||||
- 按时间顺序组织内容
|
||||
- 突出角色变化和成就
|
||||
- 使用积极、鼓励的语气
|
||||
- 示例参考:"从入职时的产品经理成长为高级产品经理,张三在产品规划、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他的成长轨迹显示出对新技术的持续学习和对产品思维的不断深化。"
|
||||
{% else %}
|
||||
1. **Overview** (100-150 words)
|
||||
- Focus on: Overall analysis of user profile based on interaction logs
|
||||
- Describe the user's main role, work network, and collaboration spirit
|
||||
- Use professional, data-driven language style
|
||||
- Example reference: "Through in-depth analysis of 156 interaction logs, the system identified Zhang San as a product manager primarily focused on user profiling and data analysis. His work network demonstrates a clear goal-oriented approach and team collaboration spirit."
|
||||
|
||||
2. **Behavior Pattern** (80-120 words)
|
||||
- Focus on: Work patterns, time regularity, and behavioral characteristics
|
||||
- Describe weekly work patterns and time preferences
|
||||
- Use objective, analytical language
|
||||
- Example reference: "Zhang San's work pattern shows distinct periodicity: Mondays are typically used for planning and meetings, Wednesdays and Thursdays focus on product design and user research, and Fridays are for summary and review. He tends to brainstorm in the morning and handle execution tasks in the afternoon."
|
||||
|
||||
3. **Key Findings** (3-4 bullet points, 30-50 words each)
|
||||
- Focus on: Specific, insightful observations about user behavior and preferences
|
||||
- Use bullet points (•) format
|
||||
- Each finding should be concrete and data-supported
|
||||
- Example reference:
|
||||
"• In product decisions, Zhang San always prioritizes user feedback, as evidenced in 68% of decision records
|
||||
• He excels at using data visualization tools to support arguments, a habit that plays an important role in project management
|
||||
• Among team member evaluations, 'clear thinking' and 'quick thinking' are the most frequently mentioned keywords
|
||||
• He maintains continuous attention to AI and machine learning, attending 7 related training sessions in the past 3 months"
|
||||
|
||||
4. **Growth Trajectory** (100-150 words)
|
||||
- Focus on: User's growth journey, key milestones, and capability improvements
|
||||
- Organize content chronologically
|
||||
- Highlight role changes and achievements
|
||||
- Use positive, encouraging tone
|
||||
- Example reference: "从入职时的产品经理成长为高级产品经理,张三在产品单独、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他的成长轨迹显示出对新技术的持续学习和对产品思维的不断深化。"
|
||||
- Example reference: "Growing from a product manager at entry to a senior product manager, Zhang San has shown significant improvement in product planning, team management, and technical understanding. Especially in the past year, he has begun to independently lead more complex projects, demonstrating stronger strategic thinking capabilities. His growth trajectory shows continuous learning of new technologies and deepening of product thinking."
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Output Format (MUST STRICTLY FOLLOW)===
|
||||
|
||||
{% if language == "zh" %}
|
||||
【总体概述】
|
||||
[100-150 characters describing overall user profile and work network based on interaction analysis]
|
||||
[100-150字,基于交互分析描述用户整体档案和工作网络]
|
||||
|
||||
【行为模式】
|
||||
[80-120 characters describing work patterns, time regularity, and behavioral characteristics]
|
||||
[80-120字,描述工作模式、时间规律和行为特征]
|
||||
|
||||
【关键发现】
|
||||
• [First key finding with data support, 30-50 characters]
|
||||
• [Second key finding with data support, 30-50 characters]
|
||||
• [Third key finding with data support, 30-50 characters]
|
||||
• [Fourth key finding with data support, 30-50 characters]
|
||||
• [第一个关键发现,有数据支持,30-50字]
|
||||
• [第二个关键发现,有数据支持,30-50字]
|
||||
• [第三个关键发现,有数据支持,30-50字]
|
||||
• [第四个关键发现,有数据支持,30-50字]
|
||||
|
||||
【成长轨迹】
|
||||
[100-150 characters describing growth journey, milestones, and capability improvements]
|
||||
[100-150字,描述成长历程、关键里程碑和能力提升]
|
||||
{% else %}
|
||||
【Overview】
|
||||
[100-150 words describing overall user profile and work network based on interaction analysis]
|
||||
|
||||
【Behavior Pattern】
|
||||
[80-120 words describing work patterns, time regularity, and behavioral characteristics]
|
||||
|
||||
【Key Findings】
|
||||
• [First key finding with data support, 30-50 words]
|
||||
• [Second key finding with data support, 30-50 words]
|
||||
• [Third key finding with data support, 30-50 words]
|
||||
• [Fourth key finding with data support, 30-50 words]
|
||||
|
||||
【Growth Trajectory】
|
||||
[100-150 words describing growth journey, milestones, and capability improvements]
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Example===
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- 核心领域分布: 产品管理(38%), 数据分析(24%), 团队协作(21%)
|
||||
- 活跃时段: 用户在每年的 4 和 10 月最为活跃
|
||||
@@ -101,6 +156,28 @@ Example Output:
|
||||
|
||||
【成长轨迹】
|
||||
从入职时的产品经理成长为高级产品经理,张三在产品规划、团队管理和技术理解三个方面都有显著提升。特别是在最近一年,他开始独立主导更复杂的项目,展现出更强的战略思维能力。他与李明的47条共同记忆见证了他的成长历程。
|
||||
{% else %}
|
||||
Example Input:
|
||||
- Core Domain Distribution: Product Management (38%), Data Analysis (24%), Team Collaboration (21%)
|
||||
- Active Periods: User is most active in April and October each year
|
||||
- Social Connections: Has the most shared memories (47 entries) with user "Li Ming", primarily during 2020-2023
|
||||
|
||||
Example Output:
|
||||
【Overview】
|
||||
Through in-depth analysis of 156 interaction logs, the system identified Zhang San as a product manager primarily focused on user profiling and data analysis. His work network demonstrates a clear goal-oriented approach and team collaboration spirit, with deep practical experience in product management, data analysis, and team collaboration.
|
||||
|
||||
【Behavior Pattern】
|
||||
Zhang San's work pattern shows distinct periodicity: Mondays are typically used for planning and meetings, Wednesdays and Thursdays focus on product design and user research, and Fridays are for summary and review. He tends to brainstorm in the morning and handle execution tasks in the afternoon. April and October are his most active periods each year.
|
||||
|
||||
【Key Findings】
|
||||
• In product decisions, Zhang San always prioritizes user feedback, as evidenced in 68% of decision records
|
||||
• He excels at using data visualization tools to support arguments, a habit that plays an important role in project management
|
||||
• Among team member evaluations, "clear thinking" and "quick thinking" are the most frequently mentioned keywords
|
||||
• He maintains continuous attention to AI and machine learning, attending 7 related training sessions in the past 3 months
|
||||
|
||||
【Growth Trajectory】
|
||||
Growing from a product manager at entry to a senior product manager, Zhang San has shown significant improvement in product planning, team management, and technical understanding. Especially in the past year, he has begun to independently lead more complex projects, demonstrating stronger strategic thinking capabilities. His 47 shared memories with Li Ming bear witness to his growth journey.
|
||||
{% endif %}
|
||||
|
||||
===End of Example===
|
||||
|
||||
@@ -133,20 +210,40 @@ After generating the report, perform the following self-review steps:
|
||||
|
||||
===Output Requirements===
|
||||
|
||||
{% if language == "zh" %}
|
||||
**语言要求:**
|
||||
- 输出语言必须始终为简体中文
|
||||
- 所有章节内容必须使用中文
|
||||
- 章节标题必须使用指定的中文格式:【总体概述】【行为模式】【关键发现】【成长轨迹】
|
||||
|
||||
**格式要求:**
|
||||
- 每个章节必须以标题开头,标题独占一行
|
||||
- 内容紧跟标题之后
|
||||
- 章节之间用空行分隔
|
||||
- 关键发现章节必须使用项目符号(•)
|
||||
- 严格遵守每个章节的字数限制
|
||||
|
||||
**内容要求:**
|
||||
- 仅使用提供的数据点
|
||||
- 不得捏造或推测信息
|
||||
- 如果某个章节数据不足,请简要说明或跳过
|
||||
- 全文保持专业、分析性的语气
|
||||
{% else %}
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS be Chinese (Simplified)
|
||||
- All section content must be in Chinese
|
||||
- Section headers must use the specified Chinese format: 【总体概述】【行为模式】【关键发现】【成长轨迹】
|
||||
- The output language must ALWAYS be English
|
||||
- All section content must be in English
|
||||
- Section headers must use the specified English format: 【Overview】【Behavior Pattern】【Key Findings】【Growth Trajectory】
|
||||
|
||||
**FORMAT REQUIREMENT:**
|
||||
- Each section must start with its header on a new line
|
||||
- Content follows immediately after the header
|
||||
- Sections are separated by blank lines
|
||||
- Key Findings section must use bullet points (•)
|
||||
- Strictly adhere to character limits for each section
|
||||
- Strictly adhere to word limits for each section
|
||||
|
||||
**CONTENT REQUIREMENT:**
|
||||
- Only use provided data points
|
||||
- Do not fabricate or speculate information
|
||||
- If data is insufficient for a section, provide a brief note or skip
|
||||
- Maintain professional, analytical tone throughout
|
||||
{% endif %}
|
||||
|
||||
@@ -5,10 +5,21 @@
|
||||
=== Task ===
|
||||
Summarize the provided conversation chunks into a concise Memory summary.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成摘要内容。**
|
||||
{% else %}
|
||||
**Important: Please generate the summary content in English.**
|
||||
{% endif %}
|
||||
|
||||
=== Requirements ===
|
||||
- Focus on factual statements, user preferences, relationships, and salient temporal context.
|
||||
- Avoid repetition and filler; be specific.
|
||||
- Keep it under {{ max_words or 200 }} words.
|
||||
{% if language == "zh" %}
|
||||
- 摘要内容必须使用中文
|
||||
{% else %}
|
||||
- Summary content must be in English
|
||||
{% endif %}
|
||||
- Output must be valid JSON conforming to the schema below.
|
||||
|
||||
=== Input ===
|
||||
@@ -24,6 +35,11 @@ Summarize the provided conversation chunks into a concise Memory summary.
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. Example of proper escaping: "statement": "张曼婷说:\"我很喜欢这本书。\""
|
||||
|
||||
The output language should always be the same as the input language.
|
||||
{% if language == "zh" %}
|
||||
**语言要求:输出内容必须使用中文。**
|
||||
{% else %}
|
||||
**Language Requirement: The output content must be in English.**
|
||||
{% endif %}
|
||||
|
||||
Return only a list of extracted labelled statements in the JSON ARRAY of objects that match the schema below:
|
||||
{{ json_schema }}
|
||||
@@ -1,2 +1,7 @@
|
||||
{% if language == "zh" %}
|
||||
你是一个从对话消息中提取实体节点的 AI 助手。
|
||||
你的主要任务是提取和分类说话者以及对话中提到的其他重要实体。
|
||||
{% else %}
|
||||
You are an AI assistant that extracts entity nodes from conversational messages.
|
||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
|
||||
Your primary task is to extract and classify the speaker and other significant entities mentioned in the conversation.
|
||||
{% endif %}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
{% if language == "zh" %}
|
||||
给定一个对话上下文和一个当前消息。
|
||||
你的任务是提取在当前消息中**明确或隐含**提到的用户名称和年龄。
|
||||
代词引用(如 he/she/they 或 this/that/those)应消歧为引用实体的名称。
|
||||
|
||||
{{ message }}
|
||||
{% else %}
|
||||
You are given a conversation context and a CURRENT MESSAGE.
|
||||
Your task is to extract user name and age mentioned **explicitly or implicitly** in the CURRENT MESSAGE.
|
||||
Pronoun references such as he/she/they or this/that/those should be disambiguated to the names of the reference entities.
|
||||
|
||||
{{ message }}
|
||||
{{ message }}
|
||||
{% endif %}
|
||||
|
||||
@@ -7,6 +7,11 @@
|
||||
|
||||
Your task is to generate a comprehensive user profile based on the provided entities and statements. The profile should include four distinct sections that capture different aspects of the user's identity and characteristics.
|
||||
|
||||
{% if language == "zh" %}
|
||||
**重要:请使用中文生成用户画像内容。**
|
||||
{% else %}
|
||||
**Important: Please generate the user profile content in English.**
|
||||
{% endif %}
|
||||
|
||||
===Inputs===
|
||||
{% if user_id %}
|
||||
@@ -30,40 +35,73 @@ Your task is to generate a comprehensive user profile based on the provided enti
|
||||
|
||||
**Section-Specific Requirements:**
|
||||
|
||||
1. **Basic Introduction** (4-5 sentences, max 150 Chinese characters)
|
||||
{% if language == "zh" %}
|
||||
1. **基本介绍** (4-5句话,最多150字)
|
||||
- 重点:身份、职业、地点及其他基本人口统计信息
|
||||
- 提供关于用户是谁的事实背景
|
||||
|
||||
2. **性格特点** (2-3句话,最多80字)
|
||||
- 重点:性格特征、行为习惯、沟通风格
|
||||
- 描述用户互动和行为中可观察到的模式
|
||||
|
||||
3. **核心价值观** (1-2句话,最多50字)
|
||||
- 重点:价值观、信念、目标和愿望
|
||||
- 捕捉对用户最重要的内容以及驱动其决策的因素
|
||||
|
||||
4. **一句话总结** (1句话,最多40字)
|
||||
- 提供对用户核心特质的高度浓缩描述
|
||||
- 类似于捕捉其本质的个人标语或座右铭
|
||||
{% else %}
|
||||
1. **Basic Introduction** (4-5 sentences, max 150 words)
|
||||
- Focus on: identity, occupation, location, and other basic demographic information
|
||||
- Provide factual background about who the user is
|
||||
|
||||
2. **Personality Traits** (2-3 sentences, max 80 Chinese characters)
|
||||
2. **Personality Traits** (2-3 sentences, max 80 words)
|
||||
- Focus on: personality characteristics, behavioral habits, communication style
|
||||
- Describe observable patterns in how the user interacts and behaves
|
||||
|
||||
3. **Core Values** (1-2 sentences, max 50 Chinese characters)
|
||||
3. **Core Values** (1-2 sentences, max 50 words)
|
||||
- Focus on: values, beliefs, goals, and aspirations
|
||||
- Capture what matters most to the user and what drives their decisions
|
||||
|
||||
4. **One-Sentence Summary** (1 sentence, max 40 Chinese characters)
|
||||
4. **One-Sentence Summary** (1 sentence, max 40 words)
|
||||
- Provide a highly condensed characterization of the user's core traits
|
||||
- Similar to a personal tagline or motto that captures their essence
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Output Format (MUST STRICTLY FOLLOW)===
|
||||
|
||||
{% if language == "zh" %}
|
||||
【基本介绍】
|
||||
[4-5 sentences describing the user's basic identity, occupation, and location]
|
||||
[4-5句话描述用户的基本身份、职业和地点]
|
||||
|
||||
【性格特点】
|
||||
[2-3 sentences describing the user's personality traits, behavioral habits, and communication style]
|
||||
[2-3句话描述用户的性格特征、行为习惯和沟通风格]
|
||||
|
||||
【核心价值观】
|
||||
[1-2 sentences describing the user's values, beliefs, and goals]
|
||||
[1-2句话描述用户的价值观、信念和目标]
|
||||
|
||||
【一句话总结】
|
||||
[1句话提供对用户核心特质的高度浓缩总结]
|
||||
{% else %}
|
||||
【Basic Introduction】
|
||||
[4-5 sentences describing the user's basic identity, occupation, and location]
|
||||
|
||||
【Personality Traits】
|
||||
[2-3 sentences describing the user's personality traits, behavioral habits, and communication style]
|
||||
|
||||
【Core Values】
|
||||
[1-2 sentences describing the user's values, beliefs, and goals]
|
||||
|
||||
【One-Sentence Summary】
|
||||
[1 sentence providing a highly condensed summary of the user's core characteristics]
|
||||
{% endif %}
|
||||
|
||||
|
||||
===Example===
|
||||
|
||||
{% if language == "zh" %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7)
|
||||
@@ -81,6 +119,25 @@ Example Output:
|
||||
|
||||
【一句话总结】
|
||||
"让每一个产品决策都充满温度。"
|
||||
{% else %}
|
||||
Example Input:
|
||||
- User ID: user_12345
|
||||
- Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7)
|
||||
- Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle
|
||||
|
||||
Example Output:
|
||||
【Basic Introduction】
|
||||
This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities.
|
||||
|
||||
【Personality Traits】
|
||||
Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution.
|
||||
|
||||
【Core Values】
|
||||
User-first, data-driven, continuous learning, team collaboration
|
||||
|
||||
【One-Sentence Summary】
|
||||
"Making every product decision with warmth and purpose."
|
||||
{% endif %}
|
||||
|
||||
===End of Example===
|
||||
|
||||
@@ -91,7 +148,7 @@ Before generating your final output, internally verify:
|
||||
1. All content is grounded in provided data (no fabrication)
|
||||
2. Format follows the specified structure with correct headers
|
||||
3. Tone is objective, third-person, and neutral
|
||||
4. All four sections are complete and within character limits
|
||||
4. All four sections are complete and within character/word limits
|
||||
|
||||
**IMPORTANT: These checks are for your internal use only. DO NOT include them in your output.**
|
||||
|
||||
@@ -101,14 +158,24 @@ Before generating your final output, internally verify:
|
||||
**CRITICAL: Your response must ONLY contain the four sections below. Do not include any reflection, self-review, or meta-commentary.**
|
||||
|
||||
**LANGUAGE REQUIREMENT:**
|
||||
- The output language should ALWAYS be Chinese (Simplified)
|
||||
- All section content must be in Chinese
|
||||
- Section headers must use the specified Chinese format: 【基本介绍】【性格特点】【核心价值观】【一句话总结】
|
||||
{% if language == "zh" %}
|
||||
- 输出语言必须为简体中文
|
||||
- 所有部分内容必须使用中文
|
||||
- 部分标题必须使用指定的中文格式:【基本介绍】【性格特点】【核心价值观】【一句话总结】
|
||||
{% else %}
|
||||
- The output language must be English
|
||||
- All section content must be in English
|
||||
- Section headers must use the specified format: 【Basic Introduction】【Personality Traits】【Core Values】【One-Sentence Summary】
|
||||
{% endif %}
|
||||
|
||||
**FORMAT REQUIREMENT:**
|
||||
- Each section must start with its header on a new line
|
||||
- Content follows immediately after the header
|
||||
- Sections are separated by blank lines
|
||||
- Strictly adhere to character limits for each section
|
||||
- **DO NOT include any text after the 【一句话总结】 section**
|
||||
{% if language == "zh" %}
|
||||
- 严格遵守每个部分的字数限制
|
||||
{% else %}
|
||||
- Strictly adhere to word limits for each section
|
||||
{% endif %}
|
||||
- **DO NOT include any text after the final section**
|
||||
- **DO NOT output reflection steps, self-review, or verification notes**
|
||||
|
||||
10
api/app/core/memory/utils/validation/__init__.py
Normal file
10
api/app/core/memory/utils/validation/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
"""Validation utilities for ontology extraction.
|
||||
|
||||
This module provides validation classes for ontology class names,
|
||||
descriptions, and OWL compliance checking.
|
||||
"""
|
||||
|
||||
from .ontology_validator import OntologyValidator
|
||||
from .owl_validator import OWLValidator
|
||||
|
||||
__all__ = ['OntologyValidator', 'OWLValidator']
|
||||
270
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
270
api/app/core/memory/utils/validation/ontology_validator.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""String validation for ontology class names and descriptions.
|
||||
|
||||
This module provides the OntologyValidator class for validating and sanitizing
|
||||
ontology class names according to OWL standards and naming conventions.
|
||||
|
||||
Classes:
|
||||
OntologyValidator: Validates class names, removes duplicates, and truncates descriptions
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OntologyValidator:
|
||||
"""Validator for ontology class names and descriptions.
|
||||
|
||||
This validator performs string-level validation including:
|
||||
- PascalCase naming convention validation
|
||||
- OWL reserved word checking
|
||||
- Duplicate class name removal
|
||||
- Description length truncation
|
||||
|
||||
Attributes:
|
||||
OWL_RESERVED_WORDS: Set of OWL reserved words that cannot be used as class names
|
||||
"""
|
||||
|
||||
# OWL reserved words that cannot be used as class names
|
||||
OWL_RESERVED_WORDS = {
|
||||
'Thing', 'Nothing', 'Class', 'Property',
|
||||
'ObjectProperty', 'DatatypeProperty', 'FunctionalProperty',
|
||||
'InverseFunctionalProperty', 'TransitiveProperty', 'SymmetricProperty',
|
||||
'AsymmetricProperty', 'ReflexiveProperty', 'IrreflexiveProperty',
|
||||
'Restriction', 'Ontology', 'Individual', 'NamedIndividual',
|
||||
'Annotation', 'AnnotationProperty', 'Axiom',
|
||||
'AllDifferent', 'AllDisjointClasses', 'AllDisjointProperties',
|
||||
'Datatype', 'DataRange', 'Literal',
|
||||
'DeprecatedClass', 'DeprecatedProperty',
|
||||
'Imports', 'IncompatibleWith', 'PriorVersion', 'VersionInfo',
|
||||
'BackwardCompatibleWith', 'OntologyProperty',
|
||||
}
|
||||
|
||||
def validate_class_name(self, name: str) -> Tuple[bool, str]:
|
||||
"""Validate that a class name follows OWL naming conventions.
|
||||
|
||||
Validation rules:
|
||||
1. Must not be empty
|
||||
2. Must start with an uppercase letter (PascalCase)
|
||||
3. Cannot contain spaces
|
||||
4. Can only contain alphanumeric characters and underscores
|
||||
5. Cannot be an OWL reserved word
|
||||
|
||||
Args:
|
||||
name: The class name to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
- is_valid: True if the name is valid, False otherwise
|
||||
- error_message: Empty string if valid, error description if invalid
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.validate_class_name("MedicalProcedure")
|
||||
(True, "")
|
||||
>>> validator.validate_class_name("medical procedure")
|
||||
(False, "Class name 'medical procedure' cannot contain spaces")
|
||||
>>> validator.validate_class_name("Thing")
|
||||
(False, "Class name 'Thing' is an OWL reserved word")
|
||||
"""
|
||||
logger.debug(f"Validating class name: '{name}'")
|
||||
|
||||
# Check if empty
|
||||
if not name or not name.strip():
|
||||
error_msg = "Class name cannot be empty"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
name = name.strip()
|
||||
|
||||
# Check if it's an OWL reserved word
|
||||
if name in self.OWL_RESERVED_WORDS:
|
||||
error_msg = f"Class name '{name}' is an OWL reserved word"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check if starts with uppercase letter (only for ASCII letters)
|
||||
# For Chinese/Unicode characters, skip this check
|
||||
first_char = name[0]
|
||||
if first_char.isascii() and first_char.isalpha() and not first_char.isupper():
|
||||
error_msg = f"Class name '{name}' must start with an uppercase letter (PascalCase)"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for spaces
|
||||
if ' ' in name:
|
||||
error_msg = f"Class name '{name}' cannot contain spaces"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
# Check for invalid characters (allow alphanumeric, underscore, and Unicode characters)
|
||||
if not re.match(r'^[A-Za-z0-9_\u4e00-\u9fff]+$', name):
|
||||
error_msg = f"Class name '{name}' contains invalid characters. Only alphanumeric characters, underscores, and Chinese characters are allowed"
|
||||
logger.warning(f"Validation failed: {error_msg}")
|
||||
return False, error_msg
|
||||
|
||||
logger.debug(f"Class name '{name}' is valid")
|
||||
return True, ""
|
||||
|
||||
def sanitize_class_name(self, name: str) -> str:
|
||||
"""Attempt to sanitize an invalid class name into a valid format.
|
||||
|
||||
Sanitization steps:
|
||||
1. Strip whitespace
|
||||
2. Remove invalid characters
|
||||
3. Replace spaces with empty string (PascalCase)
|
||||
4. Capitalize first letter of each word
|
||||
5. If result is empty or starts with number, prefix with 'Class'
|
||||
|
||||
Args:
|
||||
name: The class name to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized class name that should pass validation
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> validator.sanitize_class_name("medical procedure")
|
||||
'MedicalProcedure'
|
||||
>>> validator.sanitize_class_name("patient-record")
|
||||
'PatientRecord'
|
||||
>>> validator.sanitize_class_name("123invalid")
|
||||
'Class123Invalid'
|
||||
"""
|
||||
logger.debug(f"Sanitizing class name: '{name}'")
|
||||
|
||||
if not name or not name.strip():
|
||||
logger.warning("Empty class name provided for sanitization, returning 'UnnamedClass'")
|
||||
return "UnnamedClass"
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
original_name = name
|
||||
|
||||
# Split on spaces, hyphens, and underscores, then capitalize each word
|
||||
words = re.split(r'[\s\-_]+', name)
|
||||
|
||||
# Capitalize first letter of each word and keep rest as is
|
||||
sanitized_words = []
|
||||
for word in words:
|
||||
if word:
|
||||
# Remove non-alphanumeric characters except underscore
|
||||
clean_word = re.sub(r'[^A-Za-z0-9_]', '', word)
|
||||
if clean_word:
|
||||
# Capitalize first letter
|
||||
sanitized_words.append(clean_word[0].upper() + clean_word[1:])
|
||||
|
||||
# Join words
|
||||
sanitized = ''.join(sanitized_words)
|
||||
|
||||
# If empty or starts with number, prefix with 'Class'
|
||||
if not sanitized or sanitized[0].isdigit():
|
||||
sanitized = 'Class' + sanitized
|
||||
logger.info(f"Prefixed class name with 'Class': '{original_name}' -> '{sanitized}'")
|
||||
|
||||
# If it's a reserved word, append 'Class' suffix
|
||||
if sanitized in self.OWL_RESERVED_WORDS:
|
||||
sanitized = sanitized + 'Class'
|
||||
logger.info(f"Appended 'Class' suffix to reserved word: '{original_name}' -> '{sanitized}'")
|
||||
|
||||
logger.info(f"Sanitized class name: '{original_name}' -> '{sanitized}'")
|
||||
return sanitized
|
||||
|
||||
def remove_duplicates(self, classes: List[OntologyClass]) -> List[OntologyClass]:
|
||||
"""Remove duplicate ontology classes based on case-insensitive name comparison.
|
||||
|
||||
When duplicates are found, keeps the first occurrence and discards subsequent ones.
|
||||
Comparison is case-insensitive to catch variations like 'Patient' and 'patient'.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
List of OntologyClass objects with duplicates removed
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="patient", description="Another patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> unique = validator.remove_duplicates(classes)
|
||||
>>> len(unique)
|
||||
2
|
||||
>>> [c.name for c in unique]
|
||||
['Patient', 'Doctor']
|
||||
"""
|
||||
if not classes:
|
||||
logger.debug("No classes to check for duplicates")
|
||||
return classes
|
||||
|
||||
logger.debug(f"Checking {len(classes)} classes for duplicates")
|
||||
|
||||
seen_names = set()
|
||||
unique_classes = []
|
||||
duplicates_found = []
|
||||
|
||||
for ontology_class in classes:
|
||||
# Use lowercase for comparison
|
||||
name_lower = ontology_class.name.lower()
|
||||
|
||||
if name_lower not in seen_names:
|
||||
seen_names.add(name_lower)
|
||||
unique_classes.append(ontology_class)
|
||||
else:
|
||||
duplicates_found.append(ontology_class.name)
|
||||
logger.debug(f"Duplicate class found and removed: '{ontology_class.name}'")
|
||||
|
||||
if duplicates_found:
|
||||
logger.info(
|
||||
f"Removed {len(duplicates_found)} duplicate classes: {duplicates_found}"
|
||||
)
|
||||
else:
|
||||
logger.debug("No duplicate classes found")
|
||||
|
||||
return unique_classes
|
||||
|
||||
def truncate_description(self, description: str, max_length: int = 500) -> str:
|
||||
"""Truncate a description to a maximum length.
|
||||
|
||||
If the description exceeds max_length, it will be truncated and
|
||||
an ellipsis (...) will be appended to indicate truncation.
|
||||
|
||||
Args:
|
||||
description: The description text to truncate
|
||||
max_length: Maximum allowed length (default: 500)
|
||||
|
||||
Returns:
|
||||
Truncated description string
|
||||
|
||||
Examples:
|
||||
>>> validator = OntologyValidator()
|
||||
>>> long_desc = "A" * 600
|
||||
>>> truncated = validator.truncate_description(long_desc, max_length=500)
|
||||
>>> len(truncated)
|
||||
500
|
||||
>>> truncated.endswith("...")
|
||||
True
|
||||
"""
|
||||
if not description:
|
||||
return ""
|
||||
|
||||
if len(description) <= max_length:
|
||||
return description
|
||||
|
||||
# Truncate and add ellipsis
|
||||
# Reserve 3 characters for "..."
|
||||
truncate_at = max_length - 3
|
||||
truncated = description[:truncate_at] + "..."
|
||||
|
||||
logger.debug(
|
||||
f"Truncated description from {len(description)} to {len(truncated)} characters"
|
||||
)
|
||||
|
||||
return truncated
|
||||
738
api/app/core/memory/utils/validation/owl_validator.py
Normal file
738
api/app/core/memory/utils/validation/owl_validator.py
Normal file
@@ -0,0 +1,738 @@
|
||||
"""OWL semantic validation for ontology classes using Owlready2.
|
||||
|
||||
This module provides the OWLValidator class for validating ontology classes
|
||||
against OWL standards using the Owlready2 library. It performs semantic
|
||||
validation including consistency checking, circular inheritance detection,
|
||||
and OWL file export.
|
||||
|
||||
Classes:
|
||||
OWLValidator: Validates ontology classes using OWL reasoning and exports to OWL formats
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from owlready2 import (
|
||||
World,
|
||||
Thing,
|
||||
get_ontology,
|
||||
sync_reasoner_pellet,
|
||||
OwlReadyInconsistentOntologyError,
|
||||
)
|
||||
|
||||
from app.core.memory.models.ontology_scenario_models import OntologyClass
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OWLValidator:
|
||||
"""Validator for OWL semantic validation of ontology classes.
|
||||
|
||||
This validator performs semantic-level validation using Owlready2 including:
|
||||
- Creating OWL classes from ontology class definitions
|
||||
- Running consistency checking with Pellet reasoner
|
||||
- Detecting circular inheritance
|
||||
- Validating Protégé compatibility
|
||||
- Exporting ontologies to various OWL formats (RDF/XML, Turtle, N-Triples)
|
||||
|
||||
Attributes:
|
||||
base_namespace: Base URI for the ontology namespace
|
||||
"""
|
||||
|
||||
def __init__(self, base_namespace: str = "http://example.org/ontology#"):
|
||||
"""Initialize the OWL validator.
|
||||
|
||||
Args:
|
||||
base_namespace: Base URI for the ontology namespace (default: http://example.org/ontology#)
|
||||
"""
|
||||
self.base_namespace = base_namespace
|
||||
|
||||
def validate_ontology_classes(
|
||||
self,
|
||||
classes: List[OntologyClass],
|
||||
) -> Tuple[bool, List[str], Optional[World]]:
|
||||
"""Validate extracted ontology classes against OWL standards.
|
||||
|
||||
This method creates an OWL ontology from the provided classes using Owlready2,
|
||||
runs consistency checking with the Pellet reasoner, and detects common issues
|
||||
like circular inheritance.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_messages, world):
|
||||
- is_valid: True if ontology is valid and consistent, False otherwise
|
||||
- error_messages: List of error/warning messages
|
||||
- world: Owlready2 World object containing the ontology (None if validation failed)
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [
|
||||
... OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare"),
|
||||
... OntologyClass(name="Doctor", description="A doctor", entity_type="Person", domain="Healthcare"),
|
||||
... ]
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> is_valid
|
||||
True
|
||||
>>> len(errors)
|
||||
0
|
||||
"""
|
||||
if not classes:
|
||||
return False, ["No classes provided for validation"], None
|
||||
|
||||
errors = []
|
||||
|
||||
try:
|
||||
# Create a new world (isolated ontology environment)
|
||||
world = World()
|
||||
|
||||
# Use a proper ontology IRI
|
||||
# Owlready2 expects the IRI to end with .owl or similar
|
||||
onto_iri = self.base_namespace.rstrip('#/')
|
||||
if not onto_iri.endswith('.owl'):
|
||||
onto_iri = onto_iri + '.owl'
|
||||
|
||||
# Create ontology
|
||||
onto = world.get_ontology(onto_iri)
|
||||
|
||||
with onto:
|
||||
# Dictionary to store created OWL classes for parent reference
|
||||
owl_classes = {}
|
||||
|
||||
# First pass: Create all classes without parent relationships
|
||||
for ontology_class in classes:
|
||||
try:
|
||||
# Create OWL class dynamically using type() with Thing as base
|
||||
# The key is to NOT set namespace in the dict, let Owlready2 handle it
|
||||
owl_class = type(
|
||||
ontology_class.name, # Class name
|
||||
(Thing,), # Base classes
|
||||
{} # Class dict (empty, let Owlready2 manage)
|
||||
)
|
||||
|
||||
# Add label (rdfs:label) - include both English and Chinese names
|
||||
labels = [ontology_class.name]
|
||||
if ontology_class.name_chinese:
|
||||
labels.append(ontology_class.name_chinese)
|
||||
owl_class.label = labels
|
||||
|
||||
# Add comment (rdfs:comment) with description
|
||||
if ontology_class.description:
|
||||
owl_class.comment = [ontology_class.description]
|
||||
|
||||
# Store for parent relationship setup
|
||||
owl_classes[ontology_class.name] = owl_class
|
||||
|
||||
logger.debug(
|
||||
f"Created OWL class: {ontology_class.name} "
|
||||
f"(Chinese: {ontology_class.name_chinese}) "
|
||||
f"IRI: {owl_class.iri if hasattr(owl_class, 'iri') else 'N/A'}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to create OWL class '{ontology_class.name}': {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
|
||||
# Second pass: Set up parent relationships
|
||||
for ontology_class in classes:
|
||||
if ontology_class.parent_class and ontology_class.name in owl_classes:
|
||||
parent_name = ontology_class.parent_class
|
||||
|
||||
# Check if parent exists
|
||||
if parent_name in owl_classes:
|
||||
try:
|
||||
child_class = owl_classes[ontology_class.name]
|
||||
parent_class = owl_classes[parent_name]
|
||||
|
||||
# Set parent by modifying is_a
|
||||
child_class.is_a = [parent_class]
|
||||
|
||||
logger.debug(
|
||||
f"Set parent relationship: {ontology_class.name} -> {parent_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Failed to set parent relationship "
|
||||
f"'{ontology_class.name}' -> '{parent_name}': {str(e)}"
|
||||
)
|
||||
errors.append(error_msg)
|
||||
logger.warning(error_msg)
|
||||
else:
|
||||
warning_msg = (
|
||||
f"Parent class '{parent_name}' not found for '{ontology_class.name}'"
|
||||
)
|
||||
errors.append(warning_msg)
|
||||
logger.warning(warning_msg)
|
||||
|
||||
# Check for circular inheritance
|
||||
for class_name, owl_class in owl_classes.items():
|
||||
if self._has_circular_inheritance(owl_class):
|
||||
error_msg = f"Circular inheritance detected for class '{class_name}'"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Run consistency checking with Pellet reasoner
|
||||
try:
|
||||
logger.info("Running Pellet reasoner for consistency checking...")
|
||||
sync_reasoner_pellet(world, infer_property_values=True, infer_data_property_values=True)
|
||||
logger.info("Consistency check passed")
|
||||
|
||||
except OwlReadyInconsistentOntologyError as e:
|
||||
error_msg = f"Ontology is inconsistent: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
return False, errors, world
|
||||
|
||||
except Exception as e:
|
||||
# Reasoner errors are often due to Java not being installed or configured
|
||||
# Log as warning but don't fail validation - ontology structure is still valid
|
||||
warning_msg = f"Reasoner check skipped: {str(e)}"
|
||||
if str(e).strip(): # Only log if there's an actual error message
|
||||
logger.warning(warning_msg)
|
||||
else:
|
||||
logger.warning("Reasoner check skipped: Java may not be installed or configured")
|
||||
# Continue - ontology structure is valid even without reasoner check
|
||||
|
||||
# If we have errors (excluding warnings), validation failed
|
||||
is_valid = len(errors) == 0
|
||||
|
||||
return is_valid, errors, world
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"OWL validation failed: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg, exc_info=True)
|
||||
return False, errors, None
|
||||
|
||||
def _has_circular_inheritance(self, owl_class) -> bool:
|
||||
"""Check if an OWL class has circular inheritance.
|
||||
|
||||
Circular inheritance occurs when a class inherits from itself through
|
||||
a chain of parent relationships (e.g., A -> B -> C -> A).
|
||||
|
||||
Args:
|
||||
owl_class: Owlready2 class object to check
|
||||
|
||||
Returns:
|
||||
True if circular inheritance is detected, False otherwise
|
||||
"""
|
||||
visited = set()
|
||||
current = owl_class
|
||||
|
||||
while current:
|
||||
# Get class IRI or name as identifier
|
||||
class_id = str(current.iri) if hasattr(current, 'iri') else str(current)
|
||||
|
||||
if class_id in visited:
|
||||
# Found a cycle
|
||||
return True
|
||||
|
||||
visited.add(class_id)
|
||||
|
||||
# Get parent classes (is_a relationship)
|
||||
parents = getattr(current, 'is_a', [])
|
||||
|
||||
# Filter out Thing and other base classes
|
||||
parent_classes = [p for p in parents if p != Thing and hasattr(p, 'is_a')]
|
||||
|
||||
if not parent_classes:
|
||||
# No more parents, no cycle
|
||||
break
|
||||
|
||||
# Check first parent (in single inheritance)
|
||||
current = parent_classes[0] if parent_classes else None
|
||||
|
||||
return False
|
||||
|
||||
def export_to_owl(
|
||||
self,
|
||||
world: World,
|
||||
output_path: Optional[str] = None,
|
||||
format: str = "rdfxml",
|
||||
classes: Optional[List] = None
|
||||
) -> str:
|
||||
"""Export ontology to OWL file in specified format.
|
||||
|
||||
Supported formats:
|
||||
- rdfxml: RDF/XML format (default, most compatible)
|
||||
- turtle: Turtle format (more readable)
|
||||
- ntriples: N-Triples format (simplest)
|
||||
- json: JSON format (simplified, human-readable)
|
||||
|
||||
Args:
|
||||
world: Owlready2 World object containing the ontology
|
||||
output_path: Optional file path to save the ontology (if None, returns string)
|
||||
format: Export format - "rdfxml", "turtle", "ntriples", or "json" (default: "rdfxml")
|
||||
classes: Optional list of OntologyClass objects (required for json format)
|
||||
|
||||
Returns:
|
||||
String representation of the exported ontology
|
||||
|
||||
Raises:
|
||||
ValueError: If format is not supported
|
||||
RuntimeError: If export fails
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> is_valid, errors, world = validator.validate_ontology_classes(classes)
|
||||
>>> owl_content = validator.export_to_owl(world, "ontology.owl", format="rdfxml")
|
||||
"""
|
||||
# Validate format
|
||||
valid_formats = ["rdfxml", "turtle", "ntriples", "json"]
|
||||
if format not in valid_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON format doesn't need OWL processing
|
||||
if format == "json":
|
||||
if not classes:
|
||||
raise ValueError("Classes list is required for JSON format export")
|
||||
return self._export_to_json(classes)
|
||||
|
||||
# For OWL formats, world is required
|
||||
if not world:
|
||||
raise ValueError("World object is None. Cannot export ontology.")
|
||||
|
||||
# Note: Owlready2 has issues with turtle format export
|
||||
# We'll handle it specially by converting from rdfxml
|
||||
use_conversion = (format == "turtle")
|
||||
|
||||
try:
|
||||
# Get all ontologies in the world
|
||||
ontologies = list(world.ontologies.values())
|
||||
|
||||
if not ontologies:
|
||||
raise RuntimeError("No ontologies found in world")
|
||||
|
||||
# Find the ontology with classes (skip anonymous/empty ontologies)
|
||||
onto = None
|
||||
for ont in ontologies:
|
||||
classes_count = len(list(ont.classes()))
|
||||
logger.debug(f"Checking ontology {ont.base_iri}: {classes_count} classes")
|
||||
if classes_count > 0:
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If no ontology with classes found, use the last non-anonymous one
|
||||
if onto is None:
|
||||
for ont in reversed(ontologies):
|
||||
if ont.base_iri != "http://anonymous/":
|
||||
onto = ont
|
||||
break
|
||||
|
||||
# If still no ontology, use the first one
|
||||
if onto is None:
|
||||
onto = ontologies[0]
|
||||
|
||||
# Log ontology contents for debugging
|
||||
logger.info(f"Ontology IRI: {onto.base_iri}")
|
||||
logger.info(f"Ontology contains {len(list(onto.classes()))} classes")
|
||||
|
||||
# List all classes in the ontology
|
||||
all_classes = list(onto.classes())
|
||||
for cls in all_classes:
|
||||
logger.info(f"Class in ontology: {cls.name} (IRI: {cls.iri})")
|
||||
if hasattr(cls, 'label'):
|
||||
logger.debug(f" Labels: {cls.label}")
|
||||
if hasattr(cls, 'comment'):
|
||||
logger.debug(f" Comments: {cls.comment}")
|
||||
|
||||
if len(all_classes) == 0:
|
||||
logger.warning("No classes found in ontology! This may indicate a problem with class creation.")
|
||||
|
||||
if output_path:
|
||||
# Save to file
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
logger.info(f"Exporting ontology to {output_path} in {export_format} format")
|
||||
onto.save(file=output_path, format=export_format)
|
||||
|
||||
# Read back the file content to return
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
logger.info(f"Successfully exported ontology to {output_path}")
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
else:
|
||||
# Export to string (save to temporary location and read)
|
||||
import tempfile
|
||||
import os
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.owl', delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
export_format = "rdfxml" if use_conversion else format
|
||||
onto.save(file=tmp_path, format=export_format)
|
||||
|
||||
with open(tmp_path, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
|
||||
# Convert to turtle if needed
|
||||
if use_conversion:
|
||||
content = self._convert_to_turtle(content)
|
||||
|
||||
# Format the content for better readability
|
||||
content = self._format_owl_content(content, format)
|
||||
|
||||
return content
|
||||
|
||||
finally:
|
||||
# Clean up temporary file
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to export ontology: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _export_to_json(self, classes: List) -> str:
|
||||
"""Export ontology classes to simplified JSON format.
|
||||
|
||||
This format is more compact and easier to parse than OWL XML.
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects
|
||||
|
||||
Returns:
|
||||
JSON string representation (compact format)
|
||||
"""
|
||||
import json
|
||||
|
||||
result = {
|
||||
"ontology": {
|
||||
"namespace": self.base_namespace,
|
||||
"classes": []
|
||||
}
|
||||
}
|
||||
|
||||
for cls in classes:
|
||||
class_data = {
|
||||
"name": cls.name,
|
||||
"name_chinese": cls.name_chinese,
|
||||
"description": cls.description,
|
||||
"entity_type": cls.entity_type,
|
||||
"domain": cls.domain,
|
||||
"parent_class": cls.parent_class,
|
||||
"examples": cls.examples if hasattr(cls, 'examples') else []
|
||||
}
|
||||
result["ontology"]["classes"].append(class_data)
|
||||
|
||||
# 使用紧凑格式:无缩进,使用分隔符减少空格
|
||||
return json.dumps(result, ensure_ascii=False, separators=(',', ':'))
|
||||
|
||||
def _convert_to_turtle(self, rdfxml_content: str) -> str:
|
||||
"""Convert RDF/XML content to Turtle format using rdflib.
|
||||
|
||||
Args:
|
||||
rdfxml_content: RDF/XML format content
|
||||
|
||||
Returns:
|
||||
Turtle format content
|
||||
"""
|
||||
try:
|
||||
from rdflib import Graph
|
||||
|
||||
# Parse RDF/XML
|
||||
g = Graph()
|
||||
g.parse(data=rdfxml_content, format="xml")
|
||||
|
||||
# Serialize to Turtle
|
||||
turtle_content = g.serialize(format="turtle")
|
||||
|
||||
# Handle bytes vs string
|
||||
if isinstance(turtle_content, bytes):
|
||||
turtle_content = turtle_content.decode('utf-8')
|
||||
|
||||
return turtle_content
|
||||
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"rdflib is not installed. Cannot convert to Turtle format. "
|
||||
"Install with: pip install rdflib"
|
||||
)
|
||||
return rdfxml_content
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to convert to Turtle format: {e}")
|
||||
return rdfxml_content
|
||||
|
||||
def _format_owl_content(self, content: str, format: str) -> str:
|
||||
"""Format OWL content for better readability.
|
||||
|
||||
Args:
|
||||
content: Raw OWL content string
|
||||
format: Format type (rdfxml, turtle, ntriples)
|
||||
|
||||
Returns:
|
||||
Formatted OWL content string
|
||||
"""
|
||||
if format == "rdfxml":
|
||||
# Format XML with proper indentation
|
||||
try:
|
||||
import xml.dom.minidom as minidom
|
||||
dom = minidom.parseString(content)
|
||||
# Pretty print with 2-space indentation
|
||||
formatted = dom.toprettyxml(indent=" ", encoding="utf-8").decode("utf-8")
|
||||
|
||||
# Remove extra blank lines
|
||||
lines = []
|
||||
prev_blank = False
|
||||
for line in formatted.split('\n'):
|
||||
is_blank = not line.strip()
|
||||
if not (is_blank and prev_blank): # Skip consecutive blank lines
|
||||
lines.append(line)
|
||||
prev_blank = is_blank
|
||||
|
||||
formatted = '\n'.join(lines)
|
||||
|
||||
return formatted
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to format XML content: {e}")
|
||||
return content
|
||||
|
||||
elif format == "turtle":
|
||||
# Turtle format is already relatively readable
|
||||
# Just ensure consistent line endings and not empty
|
||||
if not content or content.strip() == "":
|
||||
logger.warning("Turtle content is empty, this may indicate an export issue")
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
elif format == "ntriples":
|
||||
# N-Triples format is line-based, ensure proper line endings
|
||||
return content.strip() + '\n' if content.strip() else content
|
||||
|
||||
return content
|
||||
|
||||
def validate_with_protege_compatibility(
|
||||
self,
|
||||
classes: List[OntologyClass]
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""Validate that ontology classes are compatible with Protégé editor.
|
||||
|
||||
Protégé compatibility checks:
|
||||
- Class names are valid OWL identifiers
|
||||
- No special characters that Protégé cannot handle
|
||||
- Namespace is properly formatted
|
||||
- Labels and comments are properly encoded
|
||||
|
||||
Args:
|
||||
classes: List of OntologyClass objects to validate
|
||||
|
||||
Returns:
|
||||
Tuple of (is_compatible, warnings):
|
||||
- is_compatible: True if compatible with Protégé, False otherwise
|
||||
- warnings: List of compatibility warning messages
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = [OntologyClass(name="Patient", description="A patient", entity_type="Person", domain="Healthcare")]
|
||||
>>> is_compatible, warnings = validator.validate_with_protege_compatibility(classes)
|
||||
>>> is_compatible
|
||||
True
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# Check namespace format
|
||||
if not self.base_namespace.startswith(('http://', 'https://')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should start with http:// or https:// "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
if not self.base_namespace.endswith(('#', '/')):
|
||||
warnings.append(
|
||||
f"Namespace '{self.base_namespace}' should end with # or / "
|
||||
"for Protégé compatibility"
|
||||
)
|
||||
|
||||
# Check each class
|
||||
for ontology_class in classes:
|
||||
# Check for special characters that might cause issues
|
||||
if any(char in ontology_class.name for char in ['<', '>', '"', '{', '}', '|', '^', '`']):
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains special characters "
|
||||
"that may cause issues in Protégé"
|
||||
)
|
||||
|
||||
# Check description length (Protégé can handle long descriptions but may display poorly)
|
||||
if ontology_class.description and len(ontology_class.description) > 1000:
|
||||
warnings.append(
|
||||
f"Class '{ontology_class.name}' has a very long description ({len(ontology_class.description)} chars) "
|
||||
"which may display poorly in Protégé"
|
||||
)
|
||||
|
||||
# Check for non-ASCII characters (Protégé supports them but encoding issues may occur)
|
||||
if not ontology_class.name.isascii():
|
||||
warnings.append(
|
||||
f"Class name '{ontology_class.name}' contains non-ASCII characters "
|
||||
"which may cause encoding issues in some Protégé versions"
|
||||
)
|
||||
|
||||
# If no warnings, it's compatible
|
||||
is_compatible = len(warnings) == 0
|
||||
|
||||
return is_compatible, warnings
|
||||
|
||||
def parse_owl_content(
|
||||
self,
|
||||
owl_content: str,
|
||||
format: str = "rdfxml"
|
||||
) -> List[dict]:
|
||||
"""从 OWL 内容解析出本体类型
|
||||
|
||||
支持解析 RDF/XML、Turtle 和 JSON 格式的 OWL 文件,
|
||||
提取其中定义的 owl:Class 及其 rdfs:label 和 rdfs:comment。
|
||||
|
||||
Args:
|
||||
owl_content: OWL 文件内容字符串
|
||||
format: 文件格式,支持 "rdfxml"、"turtle"、"json"
|
||||
|
||||
Returns:
|
||||
解析出的类型列表,每个元素包含:
|
||||
- name: 类型名称(英文标识符)
|
||||
- name_chinese: 中文名称(如果有)
|
||||
- description: 类型描述
|
||||
- parent_class: 父类名称
|
||||
|
||||
Raises:
|
||||
ValueError: 如果格式不支持或解析失败
|
||||
|
||||
Examples:
|
||||
>>> validator = OWLValidator()
|
||||
>>> classes = validator.parse_owl_content(owl_xml, format="rdfxml")
|
||||
>>> for cls in classes:
|
||||
... print(cls["name"], cls["description"])
|
||||
"""
|
||||
valid_formats = ["rdfxml", "turtle", "json"]
|
||||
if format not in valid_formats:
|
||||
raise ValueError(
|
||||
f"Unsupported format '{format}'. Must be one of: {', '.join(valid_formats)}"
|
||||
)
|
||||
|
||||
# JSON 格式单独处理
|
||||
if format == "json":
|
||||
return self._parse_json_owl(owl_content)
|
||||
|
||||
# 使用 rdflib 解析 RDF/XML 或 Turtle
|
||||
try:
|
||||
from rdflib import Graph, RDF, RDFS, OWL, Namespace
|
||||
|
||||
g = Graph()
|
||||
rdf_format = "xml" if format == "rdfxml" else "turtle"
|
||||
g.parse(data=owl_content, format=rdf_format)
|
||||
|
||||
classes = []
|
||||
|
||||
# 查找所有 owl:Class
|
||||
for cls_uri in g.subjects(RDF.type, OWL.Class):
|
||||
cls_str = str(cls_uri)
|
||||
|
||||
# 跳过空节点和 OWL 内置类
|
||||
if cls_str.startswith("http://www.w3.org/") or "/.well-known/" in cls_str:
|
||||
continue
|
||||
|
||||
# 提取类名(从 URI 中获取本地名称)
|
||||
if '#' in cls_str:
|
||||
name = cls_str.split('#')[-1]
|
||||
else:
|
||||
name = cls_str.split('/')[-1]
|
||||
|
||||
# 跳过空名称
|
||||
if not name or name == "Thing":
|
||||
continue
|
||||
|
||||
# 获取 rdfs:label(可能有多个,包括中英文)
|
||||
labels = list(g.objects(cls_uri, RDFS.label))
|
||||
name_chinese = None
|
||||
label_str = name # 默认使用 URI 中的名称
|
||||
|
||||
for label in labels:
|
||||
label_text = str(label)
|
||||
# 检查是否包含中文
|
||||
if any('\u4e00' <= char <= '\u9fff' for char in label_text):
|
||||
name_chinese = label_text
|
||||
else:
|
||||
label_str = label_text
|
||||
|
||||
# 获取 rdfs:comment(描述)
|
||||
comments = list(g.objects(cls_uri, RDFS.comment))
|
||||
description = str(comments[0]) if comments else None
|
||||
|
||||
# 获取父类(rdfs:subClassOf)
|
||||
parent_class = None
|
||||
for parent_uri in g.objects(cls_uri, RDFS.subClassOf):
|
||||
parent_str = str(parent_uri)
|
||||
# 跳过 owl:Thing
|
||||
if parent_str == str(OWL.Thing) or parent_str.endswith("#Thing"):
|
||||
continue
|
||||
# 提取父类名称
|
||||
if '#' in parent_str:
|
||||
parent_class = parent_str.split('#')[-1]
|
||||
else:
|
||||
parent_class = parent_str.split('/')[-1]
|
||||
break # 只取第一个非 Thing 的父类
|
||||
|
||||
classes.append({
|
||||
"name": name,
|
||||
"name_chinese": name_chinese,
|
||||
"description": description,
|
||||
"parent_class": parent_class
|
||||
})
|
||||
|
||||
logger.info(f"Parsed {len(classes)} classes from OWL content (format: {format})")
|
||||
return classes
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to parse OWL(文档格式不正确) content: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
def _parse_json_owl(self, json_content: str) -> List[dict]:
|
||||
"""解析 JSON 格式的 OWL 内容
|
||||
|
||||
JSON 格式是简化的本体表示,由 export_to_owl 的 json 格式导出。
|
||||
|
||||
Args:
|
||||
json_content: JSON 格式的 OWL 内容
|
||||
|
||||
Returns:
|
||||
解析出的类型列表
|
||||
"""
|
||||
import json
|
||||
|
||||
try:
|
||||
data = json.loads(json_content)
|
||||
|
||||
# 检查是否是我们导出的 JSON 格式
|
||||
if "ontology" in data and "classes" in data["ontology"]:
|
||||
raw_classes = data["ontology"]["classes"]
|
||||
elif "classes" in data:
|
||||
raw_classes = data["classes"]
|
||||
else:
|
||||
raise ValueError("Invalid JSON format: missing 'classes' field")
|
||||
|
||||
classes = []
|
||||
for cls in raw_classes:
|
||||
classes.append({
|
||||
"name": cls.get("name", ""),
|
||||
"name_chinese": cls.get("name_chinese"),
|
||||
"description": cls.get("description"),
|
||||
"parent_class": cls.get("parent_class")
|
||||
})
|
||||
|
||||
logger.info(f"Parsed {len(classes)} classes from JSON content")
|
||||
return classes
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Invalid JSON content: {str(e)}") from e
|
||||
@@ -81,6 +81,8 @@ class RedBearModelFactory:
|
||||
# api_key 格式: "access_key_id:secret_access_key" 或只是 access_key_id
|
||||
# region 从 base_url 或 extra_params 获取
|
||||
from botocore.config import Config as BotoConfig
|
||||
from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id
|
||||
|
||||
max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50"))
|
||||
max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2"))
|
||||
# Configure with increased connection pool
|
||||
@@ -89,8 +91,11 @@ class RedBearModelFactory:
|
||||
retries={'max_attempts': max_retries, 'mode': 'adaptive'}
|
||||
)
|
||||
|
||||
# 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID)
|
||||
model_id = normalize_bedrock_model_id(config.model_name)
|
||||
|
||||
params = {
|
||||
"model_id": config.model_name,
|
||||
"model_id": model_id,
|
||||
"config": boto_config,
|
||||
**config.extra_params
|
||||
}
|
||||
|
||||
188
api/app/core/models/bedrock_model_mapper.py
Normal file
188
api/app/core/models/bedrock_model_mapper.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
AWS Bedrock 模型名称映射器
|
||||
|
||||
将简化的模型名称自动转换为正确的 Bedrock Model ID
|
||||
"""
|
||||
from typing import Optional
|
||||
from app.core.logging_config import get_business_logger
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
# Bedrock 模型名称映射表
|
||||
BEDROCK_MODEL_MAPPING = {
|
||||
# Claude 3.5 系列
|
||||
"claude-3.5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-3-5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-sonnet-3.5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
"claude-sonnet-3-5": "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
|
||||
# Claude 3 系列
|
||||
"claude-3-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-3-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-3-opus": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
"claude-sonnet": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"claude-haiku": "anthropic.claude-3-haiku-20240307-v1:0",
|
||||
"claude-opus": "anthropic.claude-3-opus-20240229-v1:0",
|
||||
|
||||
# Claude 2 系列
|
||||
"claude-2": "anthropic.claude-v2",
|
||||
"claude-2.1": "anthropic.claude-v2:1",
|
||||
"claude-instant": "anthropic.claude-instant-v1",
|
||||
|
||||
# Amazon Titan 系列
|
||||
"titan-text-express": "amazon.titan-text-express-v1",
|
||||
"titan-text-lite": "amazon.titan-text-lite-v1",
|
||||
"titan-embed-text": "amazon.titan-embed-text-v1",
|
||||
"titan-embed-image": "amazon.titan-embed-image-v1",
|
||||
|
||||
# Meta Llama 系列
|
||||
"llama3-70b": "meta.llama3-70b-instruct-v1:0",
|
||||
"llama3-8b": "meta.llama3-8b-instruct-v1:0",
|
||||
"llama2-70b": "meta.llama2-70b-chat-v1",
|
||||
"llama2-13b": "meta.llama2-13b-chat-v1",
|
||||
|
||||
# Mistral 系列
|
||||
"mistral-7b": "mistral.mistral-7b-instruct-v0:2",
|
||||
"mixtral-8x7b": "mistral.mixtral-8x7b-instruct-v0:1",
|
||||
"mistral-large": "mistral.mistral-large-2402-v1:0",
|
||||
|
||||
# 常见错误格式的映射
|
||||
"claude-sonnet-4-5": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||
"claude-4-5-sonnet": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||
"claude-sonnet-4.5": "anthropic.claude-3-5-sonnet-20240620-v1:0", # 常见错误
|
||||
}
|
||||
|
||||
|
||||
def normalize_bedrock_model_id(model_name: str, region: Optional[str] = None) -> str:
|
||||
"""
|
||||
标准化 Bedrock 模型 ID
|
||||
|
||||
将简化的模型名称转换为正确的 Bedrock Model ID 格式
|
||||
|
||||
Args:
|
||||
model_name: 模型名称(可能是简化格式或完整格式)
|
||||
region: AWS 区域(可选,如 "us", "eu", "apac")
|
||||
|
||||
Returns:
|
||||
str: 标准化的 Bedrock Model ID
|
||||
|
||||
Examples:
|
||||
>>> normalize_bedrock_model_id("claude-sonnet-4-5")
|
||||
'anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||
|
||||
>>> normalize_bedrock_model_id("claude-3.5-sonnet", region="eu")
|
||||
'eu.anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||
|
||||
>>> normalize_bedrock_model_id("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
'anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||
"""
|
||||
# 如果已经是正确的格式(包含 provider),直接返回
|
||||
if "." in model_name and not model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||
# 检查是否是有效的 provider
|
||||
provider = model_name.split(".", 1)[0]
|
||||
valid_providers = ["anthropic", "amazon", "meta", "mistral", "deepseek", "openai", "ai21", "cohere", "stability"]
|
||||
if provider in valid_providers:
|
||||
logger.debug(f"Model ID 已经是正确格式: {model_name}")
|
||||
return model_name
|
||||
|
||||
# 移除区域前缀(如果存在)
|
||||
original_model_name = model_name
|
||||
region_prefix = None
|
||||
if model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||
parts = model_name.split(".", 1)
|
||||
region_prefix = parts[0]
|
||||
model_name = parts[1] if len(parts) > 1 else model_name
|
||||
|
||||
# 转换为小写进行匹配
|
||||
model_name_lower = model_name.lower()
|
||||
|
||||
# 尝试从映射表中查找
|
||||
if model_name_lower in BEDROCK_MODEL_MAPPING:
|
||||
mapped_id = BEDROCK_MODEL_MAPPING[model_name_lower]
|
||||
logger.info(f"映射模型名称: {original_model_name} -> {mapped_id}")
|
||||
|
||||
# 如果指定了区域或原始名称包含区域前缀,添加区域前缀
|
||||
if region:
|
||||
mapped_id = f"{region}.{mapped_id}"
|
||||
elif region_prefix:
|
||||
mapped_id = f"{region_prefix}.{mapped_id}"
|
||||
|
||||
return mapped_id
|
||||
|
||||
# 如果没有找到映射,返回原始名称并记录警告
|
||||
logger.warning(
|
||||
f"未找到模型名称映射: {original_model_name}。"
|
||||
f"请确保使用正确的 Bedrock Model ID 格式,如 'anthropic.claude-3-5-sonnet-20240620-v1:0'"
|
||||
)
|
||||
return original_model_name
|
||||
|
||||
|
||||
def is_bedrock_model_id(model_name: str) -> bool:
|
||||
"""
|
||||
检查是否是 Bedrock Model ID 格式
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
|
||||
Returns:
|
||||
bool: 是否是 Bedrock Model ID 格式
|
||||
"""
|
||||
# 移除区域前缀
|
||||
if model_name.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||
model_name = model_name.split(".", 1)[1]
|
||||
|
||||
# 检查是否包含 provider
|
||||
if "." not in model_name:
|
||||
return False
|
||||
|
||||
provider = model_name.split(".", 1)[0]
|
||||
valid_providers = ["anthropic", "amazon", "meta", "mistral", "deepseek", "openai", "ai21", "cohere", "stability"]
|
||||
return provider in valid_providers
|
||||
|
||||
|
||||
def get_provider_from_model_id(model_id: str) -> str:
|
||||
"""
|
||||
从 Bedrock Model ID 中提取 provider
|
||||
|
||||
Args:
|
||||
model_id: Bedrock Model ID
|
||||
|
||||
Returns:
|
||||
str: Provider 名称
|
||||
|
||||
Examples:
|
||||
>>> get_provider_from_model_id("anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
'anthropic'
|
||||
|
||||
>>> get_provider_from_model_id("eu.anthropic.claude-3-5-sonnet-20240620-v1:0")
|
||||
'anthropic'
|
||||
"""
|
||||
# 移除区域前缀
|
||||
if model_id.startswith(("us.", "eu.", "apac.", "sa.", "amer.", "global.", "us-gov.")):
|
||||
parts = model_id.split(".", 2)
|
||||
return parts[1] if len(parts) > 1 else model_id.split(".", 1)[0]
|
||||
|
||||
return model_id.split(".", 1)[0]
|
||||
|
||||
|
||||
# 添加更多映射的辅助函数
|
||||
def add_model_mapping(short_name: str, full_model_id: str) -> None:
|
||||
"""
|
||||
添加自定义模型名称映射
|
||||
|
||||
Args:
|
||||
short_name: 简化的模型名称
|
||||
full_model_id: 完整的 Bedrock Model ID
|
||||
"""
|
||||
BEDROCK_MODEL_MAPPING[short_name.lower()] = full_model_id
|
||||
logger.info(f"添加模型映射: {short_name} -> {full_model_id}")
|
||||
|
||||
|
||||
def get_all_mappings() -> dict:
|
||||
"""
|
||||
获取所有模型名称映射
|
||||
|
||||
Returns:
|
||||
dict: 模型名称映射字典
|
||||
"""
|
||||
return BEDROCK_MODEL_MAPPING.copy()
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: bedrock
|
||||
enabled: true
|
||||
models:
|
||||
- name: ai21
|
||||
type: llm
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: dashscope
|
||||
enabled: true
|
||||
models:
|
||||
- name: deepseek-r1-distill-qwen-14b
|
||||
type: llm
|
||||
@@ -285,7 +284,7 @@ models:
|
||||
- stream-tool-call
|
||||
logo: dashscope
|
||||
- name: qwen-vl-max
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -298,7 +297,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-0809
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃
|
||||
is_deprecated: true
|
||||
@@ -311,7 +310,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-02
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -324,7 +323,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-2025-01-25
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -337,7 +336,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus-latest
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -350,7 +349,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen-vl-plus
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃
|
||||
is_deprecated: false
|
||||
@@ -616,7 +615,7 @@ models:
|
||||
- audio
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-instruct
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -631,7 +630,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-235b-a22b-thinking
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -646,7 +645,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-instruct
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -661,7 +660,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-30b-a3b-thinking
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -676,7 +675,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-flash
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -691,7 +690,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus-2025-09-23
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
@@ -704,7 +703,7 @@ models:
|
||||
- video
|
||||
logo: dashscope
|
||||
- name: qwen3-vl-plus
|
||||
type: llm
|
||||
type: chat
|
||||
provider: dashscope
|
||||
description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式
|
||||
is_deprecated: false
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""模型配置加载器 - 用于将预定义模型批量导入到数据库"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models_model import ModelBase, ModelProvider
|
||||
|
||||
|
||||
@@ -19,31 +19,9 @@ def _load_yaml_config(provider: ModelProvider) -> list[dict]:
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
# 检查是否需要加载(默认为 true)
|
||||
if not data.get('enabled', True):
|
||||
return []
|
||||
|
||||
return data.get('models', [])
|
||||
|
||||
|
||||
def _disable_yaml_config(provider: ModelProvider) -> None:
|
||||
"""将YAML文件的enabled标志设置为false"""
|
||||
config_dir = Path(__file__).parent
|
||||
config_file = config_dir / f"{provider.value}_models.yaml"
|
||||
|
||||
if not config_file.exists():
|
||||
return
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
data = yaml.safe_load(f)
|
||||
|
||||
data['enabled'] = False
|
||||
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
yaml.dump(data, f, allow_unicode=True, sort_keys=False)
|
||||
|
||||
|
||||
def load_models(db: Session, providers: list[str] = None, silent: bool = False) -> dict:
|
||||
"""
|
||||
加载模型配置到数据库
|
||||
@@ -75,8 +53,7 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
|
||||
if not silent:
|
||||
print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...")
|
||||
|
||||
# provider_success = 0
|
||||
|
||||
for model_data in models:
|
||||
try:
|
||||
# 检查模型是否已存在
|
||||
@@ -93,7 +70,6 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"更新成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
else:
|
||||
# 创建新模型
|
||||
model = ModelBase(**model_data)
|
||||
@@ -102,17 +78,12 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False)
|
||||
if not silent:
|
||||
print(f"添加成功: {model_data['name']}")
|
||||
result["success"] += 1
|
||||
# provider_success += 1
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
if not silent:
|
||||
print(f"添加失败: {model_data['name']} - {str(e)}")
|
||||
result["failed"] += 1
|
||||
|
||||
# 如果该供应商的模型全部加载成功,将enabled设置为false
|
||||
# if provider_success == len(models):
|
||||
_disable_yaml_config(provider)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
provider: openai
|
||||
enabled: true
|
||||
models:
|
||||
- name: chatgpt-4o-latest
|
||||
type: llm
|
||||
|
||||
@@ -670,7 +670,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
||||
with open(filename, "rb") as f:
|
||||
binary = f.read()
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel"):
|
||||
if parser_config.get("html4excel") and parser_config.get("html4excel").lower() == "true":
|
||||
sections = [(_, "") for _ in excel_parser.html(binary, 12) if _]
|
||||
parser_config["chunk_token_num"] = 0
|
||||
else:
|
||||
|
||||
89
api/app/core/rag/crawler/__main__.py
Normal file
89
api/app/core/rag/crawler/__main__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Command-line interface for web crawler."""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import sys
|
||||
from app.core.rag.crawler.web_crawler import WebCrawler
|
||||
|
||||
|
||||
def setup_logging(verbose: bool = False):
|
||||
"""Set up logging configuration."""
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(sys.stdout)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def main(entry_url: str,
|
||||
max_pages: int = 200,
|
||||
delay_seconds: float = 1.0,
|
||||
timeout_seconds: int = 10,
|
||||
user_agent: str = "KnowledgeBaseCrawler/1.0"):
|
||||
"""Main entry point for the crawler."""
|
||||
# Create crawler
|
||||
crawler = WebCrawler(
|
||||
entry_url=entry_url,
|
||||
max_pages=max_pages,
|
||||
delay_seconds=delay_seconds,
|
||||
timeout_seconds=timeout_seconds,
|
||||
user_agent=user_agent
|
||||
)
|
||||
|
||||
# Crawl and collect documents
|
||||
documents = []
|
||||
try:
|
||||
for doc in crawler.crawl():
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"URL: {doc.url}")
|
||||
print(f"Title: {doc.title}")
|
||||
print(f"Content Length: {doc.content_length} characters")
|
||||
print(f"Word Count: {doc.metadata.get('word_count', 0)} words")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
documents.append({
|
||||
'url': doc.url,
|
||||
'title': doc.title,
|
||||
'content': doc.content,
|
||||
'content_length': doc.content_length,
|
||||
'crawl_timestamp': doc.crawl_timestamp.isoformat(),
|
||||
'http_status': doc.http_status,
|
||||
'metadata': doc.metadata
|
||||
})
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nCrawl interrupted by user.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nError during crawl: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
# Get summary
|
||||
summary = crawler.get_summary()
|
||||
print(f"\n{'=' * 80}")
|
||||
print("CRAWL SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"Total Pages Processed: {summary.total_pages_processed}")
|
||||
print(f"Total Errors: {summary.total_errors}")
|
||||
print(f"Total Skipped: {summary.total_skipped}")
|
||||
print(f"Total URLs Discovered: {summary.total_urls_discovered}")
|
||||
print(f"Duration: {summary.duration_seconds:.2f} seconds")
|
||||
print(f"documents: {documents}")
|
||||
|
||||
if summary.error_breakdown:
|
||||
print(f"\nError Breakdown:")
|
||||
for error_type, count in summary.error_breakdown.items():
|
||||
print(f" {error_type}: {count}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
entry_url = "https://www.xxx.com"
|
||||
max_pages = 20
|
||||
delay_seconds = 1.0
|
||||
timeout_seconds = 10
|
||||
user_agent = "KnowledgeBaseCrawler/1.0"
|
||||
|
||||
main(entry_url, max_pages, delay_seconds, timeout_seconds, user_agent)
|
||||
233
api/app/core/rag/crawler/content_extractor.py
Normal file
233
api/app/core/rag/crawler/content_extractor.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Content extractor for web crawler."""
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
import re
|
||||
import logging
|
||||
|
||||
from app.core.rag.crawler.models import ExtractedContent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ContentExtractor:
|
||||
"""Extract clean, readable text from HTML pages."""
|
||||
|
||||
# Tags to remove completely
|
||||
REMOVE_TAGS = ['script', 'style', 'nav', 'header', 'footer', 'aside']
|
||||
|
||||
# Tags that typically contain main content
|
||||
MAIN_CONTENT_TAGS = ['article', 'main']
|
||||
|
||||
# Content extraction tags
|
||||
CONTENT_TAGS = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'td', 'th', 'section']
|
||||
|
||||
def is_static_content(self, html: str) -> bool:
|
||||
"""
|
||||
Determine if the HTML represents static content.
|
||||
|
||||
Detects JavaScript-rendered content by checking for minimal body
|
||||
with heavy script tag presence.
|
||||
|
||||
Args:
|
||||
html: Raw HTML string
|
||||
|
||||
Returns:
|
||||
bool: True if static, False if JavaScript-rendered
|
||||
"""
|
||||
try:
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
# Count script tags
|
||||
script_tags = soup.find_all('script')
|
||||
script_count = len(script_tags)
|
||||
|
||||
# Get body content (excluding scripts and styles)
|
||||
body = soup.find('body')
|
||||
if not body:
|
||||
return False
|
||||
|
||||
# Remove scripts and styles temporarily for text check
|
||||
for tag in body.find_all(['script', 'style']):
|
||||
tag.decompose()
|
||||
|
||||
# Get text content
|
||||
text = body.get_text(strip=True)
|
||||
text_length = len(text)
|
||||
|
||||
# If there's very little text but many scripts, likely JS-rendered
|
||||
if script_count > 5 and text_length < 200:
|
||||
logger.warning("Detected JavaScript-rendered content (many scripts, little text)")
|
||||
return False
|
||||
|
||||
# If there's no meaningful text, likely JS-rendered
|
||||
if text_length < 50:
|
||||
logger.warning("Detected JavaScript-rendered content (minimal text)")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking if content is static: {e}")
|
||||
return True # Assume static on error
|
||||
|
||||
def extract(self, html: str, url: str) -> ExtractedContent:
|
||||
"""
|
||||
Extract clean text content from HTML.
|
||||
|
||||
Args:
|
||||
html: Raw HTML string
|
||||
url: Source URL (for context)
|
||||
|
||||
Returns:
|
||||
ExtractedContent: Contains title, text, metadata
|
||||
"""
|
||||
try:
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
# Check if content is static
|
||||
is_static = self.is_static_content(html)
|
||||
|
||||
# Extract title
|
||||
title = self._extract_title(soup)
|
||||
|
||||
# Remove unwanted tags
|
||||
for tag_name in self.REMOVE_TAGS:
|
||||
for tag in soup.find_all(tag_name):
|
||||
tag.decompose()
|
||||
|
||||
# Extract main content
|
||||
text = self._extract_main_content(soup)
|
||||
|
||||
# Normalize whitespace
|
||||
text = self._normalize_whitespace(text)
|
||||
|
||||
# Count words
|
||||
word_count = len(text.split())
|
||||
|
||||
logger.info(f"Extracted {word_count} words from {url}")
|
||||
|
||||
return ExtractedContent(
|
||||
title=title,
|
||||
text=text,
|
||||
is_static=is_static,
|
||||
word_count=word_count,
|
||||
metadata={'url': url}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting content from {url}: {e}")
|
||||
return ExtractedContent(
|
||||
title=url,
|
||||
text="",
|
||||
is_static=False,
|
||||
word_count=0,
|
||||
metadata={'url': url, 'error': str(e)}
|
||||
)
|
||||
|
||||
def _extract_title(self, soup: BeautifulSoup) -> str:
|
||||
"""
|
||||
Extract title from HTML.
|
||||
|
||||
Tries <title> tag first, then first <h1>.
|
||||
|
||||
Args:
|
||||
soup: BeautifulSoup object
|
||||
|
||||
Returns:
|
||||
str: Page title
|
||||
"""
|
||||
# Try <title> tag
|
||||
title_tag = soup.find('title')
|
||||
if title_tag and title_tag.string:
|
||||
return title_tag.string.strip()
|
||||
|
||||
# Try first <h1>
|
||||
h1_tag = soup.find('h1')
|
||||
if h1_tag:
|
||||
return h1_tag.get_text(strip=True)
|
||||
|
||||
# Default to empty string
|
||||
return ""
|
||||
|
||||
def _extract_main_content(self, soup: BeautifulSoup) -> str:
|
||||
"""
|
||||
Extract main content from HTML.
|
||||
|
||||
Prioritizes semantic HTML5 elements like <article> and <main>.
|
||||
|
||||
Args:
|
||||
soup: BeautifulSoup object
|
||||
|
||||
Returns:
|
||||
str: Extracted text content
|
||||
"""
|
||||
# Try to find main content area
|
||||
main_content = None
|
||||
|
||||
# Priority 1: <article> or <main> tags
|
||||
for tag_name in self.MAIN_CONTENT_TAGS:
|
||||
main_content = soup.find(tag_name)
|
||||
if main_content:
|
||||
logger.debug(f"Found main content in <{tag_name}> tag")
|
||||
break
|
||||
|
||||
# Priority 2: div with role="main"
|
||||
if not main_content:
|
||||
main_content = soup.find('div', role='main')
|
||||
if main_content:
|
||||
logger.debug("Found main content in div[role='main']")
|
||||
|
||||
# Priority 3: Common class/id patterns
|
||||
if not main_content:
|
||||
for pattern in ['content', 'main', 'article', 'post']:
|
||||
main_content = soup.find(['div', 'section'], class_=re.compile(pattern, re.I))
|
||||
if main_content:
|
||||
logger.debug(f"Found main content with class pattern '{pattern}'")
|
||||
break
|
||||
|
||||
main_content = soup.find(['div', 'section'], id=re.compile(pattern, re.I))
|
||||
if main_content:
|
||||
logger.debug(f"Found main content with id pattern '{pattern}'")
|
||||
break
|
||||
|
||||
# Fallback: use body
|
||||
if not main_content:
|
||||
main_content = soup.find('body')
|
||||
logger.debug("Using <body> as main content (no specific content area found)")
|
||||
|
||||
# Extract text from content tags
|
||||
if main_content:
|
||||
text_parts = []
|
||||
for tag in main_content.find_all(self.CONTENT_TAGS):
|
||||
text = tag.get_text(strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
|
||||
return '\n'.join(text_parts)
|
||||
|
||||
return ""
|
||||
|
||||
def _normalize_whitespace(self, text: str) -> str:
|
||||
"""
|
||||
Normalize whitespace in text.
|
||||
|
||||
- Collapse multiple spaces to single space
|
||||
- Reduce excessive newlines to maximum 2
|
||||
- Strip leading/trailing whitespace
|
||||
|
||||
Args:
|
||||
text: Text to normalize
|
||||
|
||||
Returns:
|
||||
str: Normalized text
|
||||
"""
|
||||
# Collapse multiple spaces to single space
|
||||
text = re.sub(r' +', ' ', text)
|
||||
|
||||
# Reduce excessive newlines to maximum 2
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
# Strip leading/trailing whitespace
|
||||
text = text.strip()
|
||||
|
||||
return text
|
||||
302
api/app/core/rag/crawler/http_fetcher.py
Normal file
302
api/app/core/rag/crawler/http_fetcher.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""HTTP fetcher for web crawler."""
|
||||
|
||||
import requests
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
from typing import Optional, Dict
|
||||
|
||||
|
||||
from app.core.rag.crawler.models import FetchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HTTPFetcher:
|
||||
"""Handle HTTP requests with retries, error handling, and response validation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timeout: int = 10,
|
||||
max_retries: int = 3,
|
||||
user_agent: str = "KnowledgeBaseCrawler/1.0"
|
||||
):
|
||||
"""
|
||||
Initialize HTTP fetcher.
|
||||
|
||||
Args:
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retry attempts
|
||||
user_agent: User-Agent header value
|
||||
"""
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
self.user_agent = user_agent
|
||||
|
||||
# Create session for connection pooling
|
||||
self.session = requests.Session()
|
||||
self.session.headers.update({
|
||||
'User-Agent': user_agent
|
||||
})
|
||||
|
||||
def fetch(self, url: str) -> FetchResult:
|
||||
"""
|
||||
Fetch a URL with retry logic and error handling.
|
||||
|
||||
Args:
|
||||
url: URL to fetch
|
||||
|
||||
Returns:
|
||||
FetchResult: Contains status_code, content, headers, error info
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Calculate backoff delay for retries
|
||||
if attempt > 0:
|
||||
backoff_delay = 2 ** (attempt - 1) # 1s, 2s, 4s
|
||||
logger.info(f"Retry attempt {attempt + 1}/{self.max_retries} for {url} after {backoff_delay}s")
|
||||
time.sleep(backoff_delay)
|
||||
|
||||
# Make HTTP request
|
||||
response = self.session.get(
|
||||
url,
|
||||
timeout=self.timeout,
|
||||
allow_redirects=True
|
||||
)
|
||||
|
||||
# Handle different status codes
|
||||
if response.status_code == 429:
|
||||
# Too Many Requests - backoff and retry
|
||||
logger.warning(f"429 Too Many Requests for {url}, backing off")
|
||||
if attempt < self.max_retries - 1:
|
||||
continue
|
||||
|
||||
if response.status_code == 503:
|
||||
# Service Unavailable - pause and retry
|
||||
logger.warning(f"503 Service Unavailable for {url}")
|
||||
if attempt < self.max_retries - 1:
|
||||
time.sleep(5) # Longer pause for 503
|
||||
continue
|
||||
|
||||
# Success or client error (don't retry 4xx except 429)
|
||||
if 200 <= response.status_code < 300:
|
||||
logger.info(f"Successfully fetched {url} (status: {response.status_code})")
|
||||
|
||||
# Get correctly encoded content
|
||||
content = self._get_decoded_content(response)
|
||||
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=response.url,
|
||||
status_code=response.status_code,
|
||||
content=content,
|
||||
headers=dict(response.headers),
|
||||
error=None,
|
||||
success=True
|
||||
)
|
||||
elif response.status_code == 404:
|
||||
logger.info(f"404 Not Found: {url}")
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=response.url,
|
||||
status_code=response.status_code,
|
||||
content=None,
|
||||
headers=dict(response.headers),
|
||||
error="Not Found",
|
||||
success=False
|
||||
)
|
||||
elif 400 <= response.status_code < 500:
|
||||
logger.warning(f"Client error {response.status_code} for {url}")
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=response.url,
|
||||
status_code=response.status_code,
|
||||
content=None,
|
||||
headers=dict(response.headers),
|
||||
error=f"Client error: {response.status_code}",
|
||||
success=False
|
||||
)
|
||||
elif 500 <= response.status_code < 600:
|
||||
logger.error(f"Server error {response.status_code} for {url}")
|
||||
last_error = f"Server error: {response.status_code}"
|
||||
if attempt < self.max_retries - 1:
|
||||
continue
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=url,
|
||||
status_code=response.status_code,
|
||||
content=None,
|
||||
headers={},
|
||||
error=last_error,
|
||||
success=False
|
||||
)
|
||||
|
||||
except requests.exceptions.Timeout:
|
||||
last_error = "Request timeout"
|
||||
logger.warning(f"Timeout fetching {url} (attempt {attempt + 1}/{self.max_retries})")
|
||||
if attempt >= self.max_retries - 1:
|
||||
break
|
||||
continue
|
||||
|
||||
except requests.exceptions.SSLError as e:
|
||||
last_error = f"SSL/TLS error: {str(e)}"
|
||||
logger.error(f"SSL/TLS error for {url}: {e}")
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=url,
|
||||
status_code=0,
|
||||
content=None,
|
||||
headers={},
|
||||
error=last_error,
|
||||
success=False
|
||||
)
|
||||
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
last_error = f"Connection error: {str(e)}"
|
||||
logger.warning(f"Connection error for {url} (attempt {attempt + 1}/{self.max_retries}): {e}")
|
||||
if attempt >= self.max_retries - 1:
|
||||
break
|
||||
continue
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_error = f"Request error: {str(e)}"
|
||||
logger.error(f"Request error for {url}: {e}")
|
||||
if attempt >= self.max_retries - 1:
|
||||
break
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
logger.error(f"Failed to fetch {url} after {self.max_retries} attempts: {last_error}")
|
||||
return FetchResult(
|
||||
url=url,
|
||||
final_url=url,
|
||||
status_code=0,
|
||||
content=None,
|
||||
headers={},
|
||||
error=last_error or "Unknown error",
|
||||
success=False
|
||||
)
|
||||
|
||||
def _get_decoded_content(self, response) -> str:
|
||||
"""
|
||||
Get correctly decoded content from response.
|
||||
|
||||
Handles encoding detection and fallback strategies:
|
||||
1. Try encoding from HTML meta tags
|
||||
2. Try response.encoding (from Content-Type header or detected)
|
||||
3. Try UTF-8
|
||||
4. Try common encodings (GB2312, GBK for Chinese, etc.)
|
||||
5. Fall back to latin-1 with error replacement
|
||||
|
||||
Args:
|
||||
response: requests.Response object
|
||||
|
||||
Returns:
|
||||
str: Decoded content
|
||||
"""
|
||||
# Try to detect encoding from HTML meta tags
|
||||
meta_encoding = self._detect_encoding_from_meta(response.content)
|
||||
if meta_encoding:
|
||||
try:
|
||||
content = response.content.decode(meta_encoding)
|
||||
logger.info(f"Successfully decoded with meta tag encoding: {meta_encoding}")
|
||||
return content
|
||||
except (UnicodeDecodeError, LookupError) as e:
|
||||
logger.warning(f"Failed to decode with meta encoding {meta_encoding}: {e}")
|
||||
|
||||
# Try response.encoding (from Content-Type header or detected by requests)
|
||||
if response.encoding and response.encoding.lower() != 'iso-8859-1':
|
||||
# Note: requests defaults to ISO-8859-1 if no charset in Content-Type,
|
||||
# so we skip it here and try UTF-8 first
|
||||
try:
|
||||
return response.text
|
||||
except (UnicodeDecodeError, LookupError) as e:
|
||||
logger.warning(f"Failed to decode with detected encoding {response.encoding}: {e}")
|
||||
|
||||
# Try UTF-8 first (most common)
|
||||
try:
|
||||
return response.content.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
logger.debug("UTF-8 decoding failed, trying other encodings")
|
||||
|
||||
# Try common encodings for different languages
|
||||
encodings_to_try = [
|
||||
'gbk', # Chinese (Simplified)
|
||||
'gb2312', # Chinese (Simplified, older)
|
||||
'gb18030', # Chinese (Simplified, extended)
|
||||
'big5', # Chinese (Traditional)
|
||||
'shift_jis', # Japanese
|
||||
'euc-jp', # Japanese
|
||||
'euc-kr', # Korean
|
||||
'iso-8859-1', # Western European
|
||||
'windows-1252', # Windows Western European
|
||||
'windows-1251', # Cyrillic
|
||||
]
|
||||
|
||||
for encoding in encodings_to_try:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
logger.info(f"Successfully decoded with {encoding}")
|
||||
return content
|
||||
except (UnicodeDecodeError, LookupError):
|
||||
continue
|
||||
|
||||
# Last resort: use latin-1 with error replacement
|
||||
logger.warning("All encoding attempts failed, using latin-1 with error replacement")
|
||||
return response.content.decode('latin-1', errors='replace')
|
||||
|
||||
def _detect_encoding_from_meta(self, content: bytes) -> Optional[str]:
|
||||
"""
|
||||
Detect encoding from HTML meta tags.
|
||||
|
||||
Looks for:
|
||||
- <meta charset="...">
|
||||
- <meta http-equiv="Content-Type" content="...; charset=...">
|
||||
|
||||
Args:
|
||||
content: Raw response content (bytes)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Detected encoding or None
|
||||
"""
|
||||
try:
|
||||
# Only check first 2KB for performance
|
||||
head = content[:2048]
|
||||
|
||||
# Try to decode as ASCII/Latin-1 to search for meta tags
|
||||
try:
|
||||
head_str = head.decode('ascii', errors='ignore')
|
||||
except:
|
||||
head_str = head.decode('latin-1', errors='ignore')
|
||||
|
||||
# Look for <meta charset="...">
|
||||
charset_match = re.search(
|
||||
r'<meta[^>]+charset=["\']?([a-zA-Z0-9_-]+)',
|
||||
head_str,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if charset_match:
|
||||
encoding = charset_match.group(1).lower()
|
||||
logger.debug(f"Found charset in meta tag: {encoding}")
|
||||
return encoding
|
||||
|
||||
# Look for <meta http-equiv="Content-Type" content="...; charset=...">
|
||||
content_type_match = re.search(
|
||||
r'<meta[^>]+http-equiv=["\']?content-type["\']?[^>]+content=["\']([^"\']+)',
|
||||
head_str,
|
||||
re.IGNORECASE
|
||||
)
|
||||
if content_type_match:
|
||||
content_value = content_type_match.group(1)
|
||||
charset_match = re.search(r'charset=([a-zA-Z0-9_-]+)', content_value, re.IGNORECASE)
|
||||
if charset_match:
|
||||
encoding = charset_match.group(1).lower()
|
||||
logger.debug(f"Found charset in Content-Type meta: {encoding}")
|
||||
return encoding
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error detecting encoding from meta tags: {e}")
|
||||
|
||||
return None
|
||||
52
api/app/core/rag/crawler/models.py
Normal file
52
api/app/core/rag/crawler/models.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Data models for web crawler."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawledDocument:
|
||||
"""Represents a successfully processed web page with extracted content."""
|
||||
url: str
|
||||
title: str
|
||||
content: str
|
||||
content_length: int
|
||||
crawl_timestamp: datetime
|
||||
http_status: int
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FetchResult:
|
||||
"""Represents the result of an HTTP fetch operation."""
|
||||
url: str
|
||||
final_url: str
|
||||
status_code: int
|
||||
content: Optional[str]
|
||||
headers: Dict[str, str]
|
||||
error: Optional[str]
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedContent:
|
||||
"""Represents content extracted from HTML."""
|
||||
title: str
|
||||
text: str
|
||||
is_static: bool
|
||||
word_count: int
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CrawlSummary:
|
||||
"""Represents statistics from a completed crawl."""
|
||||
total_pages_processed: int
|
||||
total_errors: int
|
||||
total_skipped: int
|
||||
total_urls_discovered: int
|
||||
start_time: datetime
|
||||
end_time: datetime
|
||||
duration_seconds: float
|
||||
error_breakdown: Dict[str, int] = field(default_factory=dict)
|
||||
57
api/app/core/rag/crawler/rate_limiter.py
Normal file
57
api/app/core/rag/crawler/rate_limiter.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Rate limiter for web crawler."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Enforce delays between requests to be polite to servers."""
|
||||
|
||||
def __init__(self, delay_seconds: float = 1.0):
|
||||
"""
|
||||
Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
delay_seconds: Minimum delay between requests
|
||||
"""
|
||||
self.delay_seconds = delay_seconds
|
||||
self.last_request_time = 0.0
|
||||
self.max_delay = 60.0 # Cap maximum delay at 60 seconds
|
||||
|
||||
def wait(self):
|
||||
"""
|
||||
Block until enough time has passed since last request.
|
||||
Respects the configured delay.
|
||||
"""
|
||||
current_time = time.time()
|
||||
elapsed = current_time - self.last_request_time
|
||||
|
||||
if elapsed < self.delay_seconds:
|
||||
sleep_time = self.delay_seconds - elapsed
|
||||
logger.debug(f"Rate limiting: sleeping for {sleep_time:.2f} seconds")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
self.last_request_time = time.time()
|
||||
|
||||
def set_delay(self, delay_seconds: float):
|
||||
"""
|
||||
Update the delay (useful for respecting Crawl-delay from robots.txt).
|
||||
|
||||
Args:
|
||||
delay_seconds: New delay in seconds
|
||||
"""
|
||||
self.delay_seconds = min(delay_seconds, self.max_delay)
|
||||
logger.info(f"Rate limiter delay updated to {self.delay_seconds} seconds")
|
||||
|
||||
def backoff(self, multiplier: float = 2.0):
|
||||
"""
|
||||
Increase delay exponentially for backoff scenarios (429, 503 responses).
|
||||
|
||||
Args:
|
||||
multiplier: Factor to multiply current delay by
|
||||
"""
|
||||
old_delay = self.delay_seconds
|
||||
self.delay_seconds = min(self.delay_seconds * multiplier, self.max_delay)
|
||||
logger.warning(f"Rate limiter backing off: {old_delay:.2f}s -> {self.delay_seconds:.2f}s")
|
||||
118
api/app/core/rag/crawler/robots_parser.py
Normal file
118
api/app/core/rag/crawler/robots_parser.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Robots.txt parser for web crawler."""
|
||||
|
||||
from urllib.robotparser import RobotFileParser
|
||||
from urllib.parse import urlparse, urljoin
|
||||
from typing import Optional
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotsParser:
|
||||
"""Parse and check robots.txt compliance for URLs."""
|
||||
|
||||
def __init__(self, user_agent: str, timeout: int = 10):
|
||||
"""
|
||||
Initialize robots.txt parser.
|
||||
|
||||
Args:
|
||||
user_agent: User agent string to check permissions for
|
||||
timeout: Timeout for fetching robots.txt
|
||||
"""
|
||||
self.user_agent = user_agent
|
||||
self.timeout = timeout
|
||||
self._parsers = {} # Cache parsers by domain
|
||||
|
||||
def _get_robots_url(self, url: str) -> str:
|
||||
"""
|
||||
Get the robots.txt URL for a given URL.
|
||||
|
||||
Args:
|
||||
url: URL to get robots.txt for
|
||||
|
||||
Returns:
|
||||
str: robots.txt URL
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
robots_url = f"{parsed.scheme}://{parsed.netloc}/robots.txt"
|
||||
return robots_url
|
||||
|
||||
def _get_parser(self, url: str) -> RobotFileParser:
|
||||
"""
|
||||
Get or create a RobotFileParser for the domain.
|
||||
|
||||
Args:
|
||||
url: URL to get parser for
|
||||
|
||||
Returns:
|
||||
RobotFileParser: Parser for the domain
|
||||
"""
|
||||
robots_url = self._get_robots_url(url)
|
||||
|
||||
# Return cached parser if available
|
||||
if robots_url in self._parsers:
|
||||
return self._parsers[robots_url]
|
||||
|
||||
# Create new parser
|
||||
parser = RobotFileParser()
|
||||
parser.set_url(robots_url)
|
||||
|
||||
try:
|
||||
# Fetch and parse robots.txt
|
||||
parser.read()
|
||||
logger.info(f"Successfully fetched robots.txt from {robots_url}")
|
||||
except Exception as e:
|
||||
# If robots.txt cannot be fetched, assume all URLs are allowed
|
||||
logger.warning(f"Could not fetch robots.txt from {robots_url}: {e}. Assuming all URLs allowed.")
|
||||
# Create a permissive parser
|
||||
parser = RobotFileParser()
|
||||
parser.parse([]) # Empty robots.txt allows everything
|
||||
|
||||
# Cache the parser
|
||||
self._parsers[robots_url] = parser
|
||||
return parser
|
||||
|
||||
def can_fetch(self, url: str) -> bool:
|
||||
"""
|
||||
Check if the given URL can be fetched according to robots.txt.
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
|
||||
Returns:
|
||||
bool: True if allowed, False if disallowed
|
||||
"""
|
||||
try:
|
||||
parser = self._get_parser(url)
|
||||
allowed = parser.can_fetch(self.user_agent, url)
|
||||
|
||||
if not allowed:
|
||||
logger.info(f"URL disallowed by robots.txt: {url}")
|
||||
|
||||
return allowed
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking robots.txt for {url}: {e}")
|
||||
# On error, assume allowed
|
||||
return True
|
||||
|
||||
def get_crawl_delay(self, url: str) -> Optional[float]:
|
||||
"""
|
||||
Get the Crawl-delay directive from robots.txt if present.
|
||||
|
||||
Args:
|
||||
url: URL to get crawl delay for
|
||||
|
||||
Returns:
|
||||
Optional[float]: Delay in seconds, or None if not specified
|
||||
"""
|
||||
try:
|
||||
parser = self._get_parser(url)
|
||||
delay = parser.crawl_delay(self.user_agent)
|
||||
|
||||
if delay is not None:
|
||||
logger.info(f"Crawl-delay from robots.txt: {delay} seconds")
|
||||
|
||||
return delay
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting crawl delay for {url}: {e}")
|
||||
return None
|
||||
171
api/app/core/rag/crawler/url_normalizer.py
Normal file
171
api/app/core/rag/crawler/url_normalizer.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""URL normalization and validation for web crawler."""
|
||||
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse, urlunparse, parse_qs, urlencode, urljoin
|
||||
from bs4 import BeautifulSoup
|
||||
|
||||
|
||||
class URLNormalizer:
|
||||
"""Normalize and validate URLs for deduplication and domain checking."""
|
||||
|
||||
# Common tracking parameters to remove
|
||||
TRACKING_PARAMS = {
|
||||
'utm_source', 'utm_medium', 'utm_campaign', 'utm_term', 'utm_content',
|
||||
'fbclid', 'gclid', 'msclkid', '_ga', 'mc_cid', 'mc_eid'
|
||||
}
|
||||
|
||||
def __init__(self, base_domain: str):
|
||||
"""
|
||||
Initialize URL normalizer with base domain.
|
||||
|
||||
Args:
|
||||
base_domain: The domain to use for same-domain checks
|
||||
"""
|
||||
parsed = urlparse(base_domain)
|
||||
self.base_domain = parsed.netloc.lower() # example.com:8000
|
||||
self.base_scheme = parsed.scheme or 'https' # https
|
||||
|
||||
def normalize(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Normalize a URL for deduplication.
|
||||
|
||||
Normalization rules:
|
||||
1. Convert domain to lowercase
|
||||
2. Remove fragments (#section)
|
||||
3. Remove default ports (80 for http, 443 for https)
|
||||
4. Remove trailing slashes (except for root)
|
||||
5. Sort query parameters alphabetically
|
||||
6. Remove common tracking parameters
|
||||
|
||||
Args:
|
||||
url: URL to normalize
|
||||
|
||||
Returns:
|
||||
Optional[str]: Normalized URL, or None if invalid
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
|
||||
# Validate scheme
|
||||
if parsed.scheme not in ('http', 'https'):
|
||||
return None
|
||||
|
||||
# Normalize domain to lowercase
|
||||
netloc = parsed.netloc.lower()
|
||||
|
||||
# Remove default ports
|
||||
if ':' in netloc:
|
||||
host, port = netloc.rsplit(':', 1)
|
||||
if (parsed.scheme == 'http' and port == '80') or \
|
||||
(parsed.scheme == 'https' and port == '443'):
|
||||
netloc = host
|
||||
|
||||
# Normalize path
|
||||
path = parsed.path
|
||||
# Remove trailing slash except for root
|
||||
if path != '/' and path.endswith('/'):
|
||||
path = path.rstrip('/')
|
||||
# Ensure path starts with /
|
||||
if not path:
|
||||
path = '/'
|
||||
|
||||
# Process query parameters
|
||||
query = ''
|
||||
if parsed.query:
|
||||
# Parse query parameters
|
||||
params = parse_qs(parsed.query, keep_blank_values=True)
|
||||
# Remove tracking parameters
|
||||
filtered_params = {
|
||||
k: v for k, v in params.items()
|
||||
if k not in self.TRACKING_PARAMS
|
||||
}
|
||||
# Sort parameters alphabetically
|
||||
if filtered_params:
|
||||
sorted_params = sorted(filtered_params.items())
|
||||
query = urlencode(sorted_params, doseq=True)
|
||||
|
||||
# Reconstruct URL without fragment
|
||||
normalized = urlunparse((
|
||||
parsed.scheme,
|
||||
netloc,
|
||||
path,
|
||||
parsed.params,
|
||||
query,
|
||||
'' # Remove fragment
|
||||
))
|
||||
|
||||
return normalized
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def is_same_domain(self, url: str) -> bool:
|
||||
"""
|
||||
Check if URL belongs to the same domain as base_domain.
|
||||
|
||||
Args:
|
||||
url: URL to check
|
||||
|
||||
Returns:
|
||||
bool: True if same domain, False otherwise
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
domain = parsed.netloc.lower()
|
||||
|
||||
# Remove port if present
|
||||
if ':' in domain:
|
||||
domain = domain.split(':')[0]
|
||||
|
||||
# Check if domains match
|
||||
return domain == self.base_domain or domain == self.base_domain.split(':')[0]
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def extract_links(self, html: str, base_url: str) -> List[str]:
|
||||
"""
|
||||
Extract and normalize all links from HTML.
|
||||
|
||||
Args:
|
||||
html: HTML content
|
||||
base_url: Base URL for resolving relative links
|
||||
|
||||
Returns:
|
||||
List[str]: List of normalized absolute URLs
|
||||
"""
|
||||
links = []
|
||||
|
||||
try:
|
||||
soup = BeautifulSoup(html, 'lxml')
|
||||
|
||||
# Find all anchor tags
|
||||
for anchor in soup.find_all('a', href=True):
|
||||
href = anchor['href']
|
||||
|
||||
# Skip empty hrefs
|
||||
if not href or href.strip() == '':
|
||||
continue
|
||||
|
||||
# Skip javascript: and mailto: links
|
||||
if href.startswith(('javascript:', 'mailto:', 'tel:')):
|
||||
continue
|
||||
|
||||
normalized_url = None
|
||||
# Check if href starts with http/https (absolute URL)
|
||||
if href.startswith(('http://', 'https://')):
|
||||
if self.is_same_domain(href):
|
||||
normalized_url = self.normalize(href)
|
||||
else:
|
||||
# Convert relative URL to absolute
|
||||
absolute_url = urljoin(base_url, href)
|
||||
# Normalize the URL
|
||||
normalized_url = self.normalize(absolute_url)
|
||||
|
||||
if normalized_url:
|
||||
links.append(normalized_url)
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return links
|
||||
215
api/app/core/rag/crawler/web_crawler.py
Normal file
215
api/app/core/rag/crawler/web_crawler.py
Normal file
@@ -0,0 +1,215 @@
|
||||
"""Main web crawler orchestrator."""
|
||||
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Iterator, Optional, List, Set
|
||||
from urllib.parse import urlparse
|
||||
import logging
|
||||
|
||||
from app.core.rag.crawler.url_normalizer import URLNormalizer
|
||||
from app.core.rag.crawler.robots_parser import RobotsParser
|
||||
from app.core.rag.crawler.rate_limiter import RateLimiter
|
||||
from app.core.rag.crawler.http_fetcher import HTTPFetcher
|
||||
from app.core.rag.crawler.content_extractor import ContentExtractor
|
||||
from app.core.rag.crawler.models import CrawledDocument, CrawlSummary
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WebCrawler:
|
||||
"""Main orchestrator for web crawling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
entry_url: str,
|
||||
max_pages: int = 200,
|
||||
delay_seconds: float = 1.0,
|
||||
timeout_seconds: int = 10,
|
||||
user_agent: str = "KnowledgeBaseCrawler/1.0",
|
||||
include_patterns: Optional[List[str]] = None,
|
||||
exclude_patterns: Optional[List[str]] = None,
|
||||
content_extractor: Optional[ContentExtractor] = None
|
||||
):
|
||||
"""
|
||||
Initialize the web crawler.
|
||||
|
||||
Args:
|
||||
entry_url: Starting URL for the crawl
|
||||
max_pages: Maximum number of pages to crawl (default: 200)
|
||||
delay_seconds: Delay between requests in seconds (default: 1.0)
|
||||
timeout_seconds: HTTP request timeout (default: 10)
|
||||
user_agent: User-Agent header string
|
||||
include_patterns: List of regex patterns for URLs to include
|
||||
exclude_patterns: List of regex patterns for URLs to exclude
|
||||
content_extractor: Custom content extractor (optional)
|
||||
"""
|
||||
# Validate entry URL
|
||||
parsed = urlparse(entry_url)
|
||||
if not parsed.scheme or not parsed.netloc:
|
||||
raise ValueError(f"Invalid entry URL: {entry_url}")
|
||||
|
||||
self.entry_url = entry_url
|
||||
self.max_pages = max_pages
|
||||
self.user_agent = user_agent
|
||||
|
||||
# Extract domain from entry URL
|
||||
self.domain = parsed.netloc
|
||||
|
||||
# Initialize components
|
||||
self.url_normalizer = URLNormalizer(entry_url)
|
||||
self.robots_parser = RobotsParser(user_agent, timeout_seconds)
|
||||
self.rate_limiter = RateLimiter(delay_seconds)
|
||||
self.http_fetcher = HTTPFetcher(timeout_seconds, max_retries=3, user_agent=user_agent)
|
||||
self.content_extractor = content_extractor or ContentExtractor()
|
||||
|
||||
# State management
|
||||
self.url_queue: deque = deque()
|
||||
self.visited_urls: Set[str] = set()
|
||||
self.pages_processed = 0
|
||||
|
||||
# Statistics
|
||||
self.stats = {
|
||||
'success': 0,
|
||||
'errors': 0,
|
||||
'skipped': 0,
|
||||
'urls_discovered': 0,
|
||||
'error_breakdown': {}
|
||||
}
|
||||
self.start_time: Optional[datetime] = None
|
||||
self.end_time: Optional[datetime] = None
|
||||
|
||||
def crawl(self) -> Iterator[CrawledDocument]:
|
||||
"""
|
||||
Execute the crawl and yield documents as they are processed.
|
||||
|
||||
Yields:
|
||||
CrawledDocument: Structured document with extracted content
|
||||
"""
|
||||
logger.info(f"Starting crawl from {self.entry_url} (max_pages: {self.max_pages})")
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# Add entry URL to queue
|
||||
normalized_entry = self.url_normalizer.normalize(self.entry_url)
|
||||
if normalized_entry:
|
||||
self.url_queue.append(normalized_entry)
|
||||
self.stats['urls_discovered'] += 1
|
||||
|
||||
# Check robots.txt and update rate limiter if needed
|
||||
crawl_delay = self.robots_parser.get_crawl_delay(self.entry_url)
|
||||
if crawl_delay:
|
||||
self.rate_limiter.set_delay(crawl_delay)
|
||||
|
||||
# Main crawl loop
|
||||
while self.url_queue and self.pages_processed < self.max_pages:
|
||||
url = self.url_queue.popleft()
|
||||
|
||||
# Skip if already visited
|
||||
if url in self.visited_urls:
|
||||
continue
|
||||
|
||||
# Mark as visited
|
||||
self.visited_urls.add(url)
|
||||
|
||||
# Check robots.txt permission
|
||||
if not self.robots_parser.can_fetch(url):
|
||||
logger.info(f"Skipping {url} (disallowed by robots.txt)")
|
||||
self.stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Apply rate limiting
|
||||
self.rate_limiter.wait()
|
||||
|
||||
# Fetch URL
|
||||
logger.info(f"Fetching {url} ({self.pages_processed + 1}/{self.max_pages})")
|
||||
fetch_result = self.http_fetcher.fetch(url)
|
||||
|
||||
# Handle fetch errors
|
||||
if not fetch_result.success:
|
||||
self._record_error(fetch_result.error or "Unknown error")
|
||||
continue
|
||||
|
||||
# Check Content-Type
|
||||
content_type = fetch_result.headers.get('Content-Type', '').lower()
|
||||
if not any(substring in content_type for substring in ['text/html', 'application/xhtml+xml']):
|
||||
logger.warning(f"Skipping {url} (Content-Type: {content_type})")
|
||||
self.stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Extract content
|
||||
try:
|
||||
extracted = self.content_extractor.extract(fetch_result.content, url)
|
||||
|
||||
# Check if static content
|
||||
if not extracted.is_static:
|
||||
logger.warning(f"Skipping {url} (JavaScript-rendered content)")
|
||||
self.stats['skipped'] += 1
|
||||
continue
|
||||
|
||||
# Create document
|
||||
document = CrawledDocument(
|
||||
url=url,
|
||||
title=extracted.title,
|
||||
content=extracted.text,
|
||||
content_length=len(extracted.text),
|
||||
crawl_timestamp=datetime.now(),
|
||||
http_status=fetch_result.status_code,
|
||||
metadata={
|
||||
'word_count': extracted.word_count,
|
||||
'final_url': fetch_result.final_url
|
||||
}
|
||||
)
|
||||
|
||||
# Update statistics
|
||||
self.pages_processed += 1
|
||||
self.stats['success'] += 1
|
||||
|
||||
# Extract and queue links
|
||||
links = self.url_normalizer.extract_links(fetch_result.content, url)
|
||||
for link in links:
|
||||
if link not in self.visited_urls and self.url_normalizer.is_same_domain(link):
|
||||
if link not in self.url_queue:
|
||||
self.url_queue.append(link)
|
||||
self.stats['urls_discovered'] += 1
|
||||
|
||||
# Yield document
|
||||
yield document
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing {url}: {e}")
|
||||
self._record_error(f"Processing error: {str(e)}")
|
||||
continue
|
||||
|
||||
self.end_time = datetime.now()
|
||||
logger.info(f"Crawl completed. Processed {self.pages_processed} pages.")
|
||||
|
||||
def get_summary(self) -> CrawlSummary:
|
||||
"""
|
||||
Get summary statistics after crawl completion.
|
||||
|
||||
Returns:
|
||||
CrawlSummary: Statistics including success/error/skip counts
|
||||
"""
|
||||
if not self.start_time:
|
||||
self.start_time = datetime.now()
|
||||
if not self.end_time:
|
||||
self.end_time = datetime.now()
|
||||
|
||||
duration = (self.end_time - self.start_time).total_seconds()
|
||||
|
||||
return CrawlSummary(
|
||||
total_pages_processed=self.stats['success'],
|
||||
total_errors=self.stats['errors'],
|
||||
total_skipped=self.stats['skipped'],
|
||||
total_urls_discovered=self.stats['urls_discovered'],
|
||||
start_time=self.start_time,
|
||||
end_time=self.end_time,
|
||||
duration_seconds=duration,
|
||||
error_breakdown=self.stats['error_breakdown']
|
||||
)
|
||||
|
||||
def _record_error(self, error: str):
|
||||
"""Record an error in statistics."""
|
||||
self.stats['errors'] += 1
|
||||
error_type = error.split(':')[0] if ':' in error else error
|
||||
self.stats['error_breakdown'][error_type] = \
|
||||
self.stats['error_breakdown'].get(error_type, 0) + 1
|
||||
1
api/app/core/rag/integrations/__init__.py
Normal file
1
api/app/core/rag/integrations/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Integrations package for external services."""
|
||||
1
api/app/core/rag/integrations/feishu/__init__.py
Normal file
1
api/app/core/rag/integrations/feishu/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Feishu integration module for document synchronization."""
|
||||
84
api/app/core/rag/integrations/feishu/__main__.py
Normal file
84
api/app/core/rag/integrations/feishu/__main__.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Command-line interface for feishu integration."""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from app.core.rag.integrations.feishu.client import FeishuAPIClient
|
||||
from app.core.rag.integrations.feishu.models import FileInfo
|
||||
|
||||
|
||||
def main(feishu_app_id: str, # Feishu application ID
|
||||
feishu_app_secret: str, # Feishu application secret
|
||||
feishu_folder_token: str, # Feishu Folder Token
|
||||
save_dir: str, # save file directory
|
||||
feishu_api_base_url: str = "https://open.feishu.cn/open-apis", # Feishu API base URL
|
||||
timeout: int = 30, # Request timeout in seconds
|
||||
max_retries: int = 3, # Maximum number of retries
|
||||
recursive: bool = True # recursive: Whether to sync subfolders recursively,
|
||||
):
|
||||
"""Main entry point for the feishuAPIClient."""
|
||||
# Create feishuAPIClient
|
||||
api_client = FeishuAPIClient(
|
||||
app_id=feishu_app_id,
|
||||
app_secret=feishu_app_secret,
|
||||
api_base_url=feishu_api_base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries
|
||||
)
|
||||
|
||||
# Get all files from folder
|
||||
async def async_get_files(api_client: FeishuAPIClient, feishu_folder_token: str):
|
||||
async with api_client as client:
|
||||
if recursive:
|
||||
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
|
||||
else:
|
||||
all_files = []
|
||||
page_token = None
|
||||
while True:
|
||||
files_page, page_token = await client.list_folder_files(
|
||||
feishu_folder_token, page_token
|
||||
)
|
||||
all_files.extend(files_page)
|
||||
if not page_token:
|
||||
break
|
||||
files = all_files
|
||||
return files
|
||||
files = asyncio.run(async_get_files(api_client,feishu_folder_token))
|
||||
|
||||
# Filter out folders, only sync documents
|
||||
# documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file", "slides"]]
|
||||
documents = [f for f in files if f.type in ["doc", "docx", "sheet", "bitable", "file"]]
|
||||
|
||||
try:
|
||||
for doc in documents:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"token: {doc.token}")
|
||||
print(f"name: {doc.name}")
|
||||
print(f"type: {doc.type}")
|
||||
print(f"created_time: {doc.created_time}")
|
||||
print(f"modified_time: {doc.modified_time}")
|
||||
print(f"owner_id: {doc.owner_id}")
|
||||
print(f"url: {doc.url}")
|
||||
print(f"{'=' * 80}\n")
|
||||
# download document from Feishu FileInfo
|
||||
async def async_download_document(api_client: FeishuAPIClient, doc: FileInfo, save_dir: str):
|
||||
async with api_client as client:
|
||||
file_path = await client.download_document(document=doc, save_dir=save_dir)
|
||||
return file_path
|
||||
|
||||
file_path = asyncio.run(async_download_document(api_client, doc, save_dir))
|
||||
print(file_path)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nfeishu integration interrupted by user.")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n\nError during feishu integration: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
feishu_app_id = ""
|
||||
feishu_app_secret = ""
|
||||
feishu_folder_token = ""
|
||||
save_dir = "/Volumes/MacintoshBD/Repository/RedBearAI/MemoryBear/api/files/"
|
||||
main(feishu_app_id, feishu_app_secret, feishu_folder_token, save_dir)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user