Compare commits
112 Commits
release/v0
...
release/v0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
988a41f5e4 | ||
|
|
14946d9a1d | ||
|
|
c8591d7bca | ||
|
|
27d1174dbb | ||
|
|
c5e0df12ad | ||
|
|
d56e168df9 | ||
|
|
5dcc815240 | ||
|
|
ac160b6b41 | ||
|
|
acecdcc041 | ||
|
|
5ced11999e | ||
|
|
4923708515 | ||
|
|
2cbbb829f7 | ||
|
|
1eacd3abe6 | ||
|
|
c5c2f84356 | ||
|
|
742e2f037b | ||
|
|
e3110d2f48 | ||
|
|
29718b1c03 | ||
|
|
cd3b4d8dde | ||
|
|
5a3cddab0f | ||
|
|
15221005d1 | ||
|
|
da75abb223 | ||
|
|
8b32f80e27 | ||
|
|
ab9c2d81b0 | ||
|
|
5ff8cdb13a | ||
|
|
44783574c0 | ||
|
|
1e7c53d944 | ||
|
|
655ae796fd | ||
|
|
93686dbc1e | ||
|
|
0356add7e0 | ||
|
|
9bea74fcef | ||
|
|
c08b10c20f | ||
|
|
16c0d9bb6c | ||
|
|
9f0d1616a8 | ||
|
|
fafab973ee | ||
|
|
4648ec04c7 | ||
|
|
64e4411048 | ||
|
|
e901d3c9d6 | ||
|
|
fb25495f1b | ||
|
|
b6e6dbf27f | ||
|
|
bd5b97e69b | ||
|
|
1e5acd85ff | ||
|
|
6e1f6d886d | ||
|
|
940af67a87 | ||
|
|
c24fb73147 | ||
|
|
4e96c12634 | ||
|
|
37ef497f4c | ||
|
|
2e504f9c48 | ||
|
|
3be3604125 | ||
|
|
6920deef63 | ||
|
|
6c30347219 | ||
|
|
d6b08b3c5c | ||
|
|
21ec923f24 | ||
|
|
3a0eab068c | ||
|
|
8aa496f588 | ||
|
|
af7b9ee41c | ||
|
|
9e64cb574a | ||
|
|
783593a79d | ||
|
|
afed5e10fc | ||
|
|
a7c0789e36 | ||
|
|
b5b1a98bc4 | ||
|
|
91d3758691 | ||
|
|
c6030bbec8 | ||
|
|
cb62608dbd | ||
|
|
83fe793e72 | ||
|
|
9d36ec70bc | ||
|
|
6b95cd05c8 | ||
|
|
804d87bca2 | ||
|
|
e518b57dea | ||
|
|
642587fc97 | ||
|
|
cd1a50a1d1 | ||
|
|
8881daf592 | ||
|
|
3ced895c9c | ||
|
|
75c1892611 | ||
|
|
9f0c4410f7 | ||
|
|
4976fccf7d | ||
|
|
ee2d3fd53a | ||
|
|
63baf3bd40 | ||
|
|
b37ad0e145 | ||
|
|
c255be8d09 | ||
|
|
12a27dbcf7 | ||
|
|
547ce858e7 | ||
|
|
995b896b9d | ||
|
|
2d90b0c752 | ||
|
|
9d25b08641 | ||
|
|
004ec0da6d | ||
|
|
3da990ec77 | ||
|
|
ff6bdc1bed | ||
|
|
2891f2c068 | ||
|
|
9353053a23 | ||
|
|
de058e3b1d | ||
|
|
16fb9f59fe | ||
|
|
eb58e0ea63 | ||
|
|
6ba4b9e7bd | ||
|
|
26dd15ef83 | ||
|
|
46752420da | ||
|
|
49f6f27ffc | ||
|
|
3670674e6b | ||
|
|
3606000740 | ||
|
|
622e67e952 | ||
|
|
546d52149d | ||
|
|
825f257cf4 | ||
|
|
0489013ddd | ||
|
|
07760d55b7 | ||
|
|
2aca4ed67e | ||
|
|
c2c2b306a2 | ||
|
|
2b017139ef | ||
|
|
034559aac7 | ||
|
|
a6a18b7304 | ||
|
|
67d0b196b8 | ||
|
|
ba30161559 | ||
|
|
85e3d5a392 | ||
|
|
0b685b136f |
14
README.md
14
README.md
@@ -334,7 +334,13 @@ step6: Log In to the Frontend Interface.
|
||||
## License
|
||||
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
|
||||
|
||||
## Acknowledgements & Community
|
||||
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
|
||||
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
|
||||
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
## Community & Support
|
||||
|
||||
Join our community to ask questions, share your work, and connect with fellow developers.
|
||||
|
||||
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
|
||||
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
|
||||
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
|
||||
- **WeChat**: Scan the QR code below to join our WeChat community group.
|
||||
- 
|
||||
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
|
||||
|
||||
11
api/app/cache/__init__.py
vendored
Normal file
11
api/app/cache/__init__.py
vendored
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
12
api/app/cache/memory/__init__.py
vendored
Normal file
12
api/app/cache/memory/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Memory 缓存模块
|
||||
|
||||
提供记忆系统相关的缓存功能
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
]
|
||||
134
api/app/cache/memory/emotion_memory.py
vendored
Normal file
134
api/app/cache/memory/emotion_memory.py
vendored
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Emotion Suggestions Cache
|
||||
|
||||
情绪个性化建议缓存模块
|
||||
用于缓存用户的情绪个性化建议数据
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EmotionMemoryCache:
|
||||
"""情绪建议缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:emotion_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_emotion_suggestions(
|
||||
cls,
|
||||
user_id: str,
|
||||
suggestions_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
suggestions_data: 建议数据字典,包含:
|
||||
- health_summary: 健康状态摘要
|
||||
- suggestions: 建议列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in suggestions_data:
|
||||
suggestions_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
suggestions_data["cached"] = True
|
||||
|
||||
value = json.dumps(suggestions_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置情绪建议缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_emotion_suggestions(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
建议数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取情绪建议缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"情绪建议缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_emotion_suggestions(cls, user_id: str) -> bool:
|
||||
"""删除用户情绪建议缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除情绪建议缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除情绪建议缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_suggestions_ttl(cls, user_id: str) -> int:
|
||||
"""获取情绪建议缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("suggestions", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"情绪建议缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取情绪建议缓存TTL失败: {e}")
|
||||
return -2
|
||||
136
api/app/cache/memory/implicit_memory.py
vendored
Normal file
136
api/app/cache/memory/implicit_memory.py
vendored
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Implicit Memory Profile Cache
|
||||
|
||||
隐式记忆用户画像缓存模块
|
||||
用于缓存用户的完整画像数据(偏好标签、四维画像、兴趣领域、行为习惯)
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImplicitMemoryCache:
|
||||
"""隐式记忆用户画像缓存类"""
|
||||
|
||||
# Key 前缀
|
||||
PREFIX = "cache:memory:implicit_memory"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, *parts: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
*parts: key 的各个部分
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return ":".join([cls.PREFIX] + list(parts))
|
||||
|
||||
@classmethod
|
||||
async def set_user_profile(
|
||||
cls,
|
||||
user_id: str,
|
||||
profile_data: Dict[str, Any],
|
||||
expire: int = 86400
|
||||
) -> bool:
|
||||
"""设置用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
profile_data: 画像数据字典,包含:
|
||||
- preferences: 偏好标签列表
|
||||
- portrait: 四维画像对象
|
||||
- interest_areas: 兴趣领域分布对象
|
||||
- habits: 行为习惯列表
|
||||
- generated_at: 生成时间(可选)
|
||||
expire: 过期时间(秒),默认24小时(86400秒)
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
|
||||
# 添加生成时间戳
|
||||
if "generated_at" not in profile_data:
|
||||
profile_data["generated_at"] = datetime.now().isoformat()
|
||||
|
||||
# 添加缓存标记
|
||||
profile_data["cached"] = True
|
||||
|
||||
value = json.dumps(profile_data, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置用户画像缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_user_profile(cls, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
画像数据字典,如果不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
value = await aio_redis.get(key)
|
||||
|
||||
if value:
|
||||
data = json.loads(value)
|
||||
logger.info(f"成功获取用户画像缓存: {key}")
|
||||
return data
|
||||
|
||||
logger.info(f"用户画像缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_user_profile(cls, user_id: str) -> bool:
|
||||
"""删除用户完整画像缓存
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除用户画像缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除用户画像缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_profile_ttl(cls, user_id: str) -> int:
|
||||
"""获取用户画像缓存的剩余过期时间
|
||||
|
||||
Args:
|
||||
user_id: 用户ID(end_user_id)
|
||||
|
||||
Returns:
|
||||
剩余秒数,-1表示永不过期,-2表示key不存在
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key("profile", user_id)
|
||||
ttl = await aio_redis.ttl(key)
|
||||
logger.debug(f"用户画像缓存TTL: {key} = {ttl}秒")
|
||||
return ttl
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户画像缓存TTL失败: {e}")
|
||||
return -2
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -14,27 +15,12 @@ celery_app = Celery(
|
||||
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
|
||||
)
|
||||
|
||||
# 配置使用本地队列,避免与远程 worker 冲突
|
||||
celery_app.conf.task_default_queue = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
|
||||
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
|
||||
# Default queue for unrouted tasks
|
||||
celery_app.conf.task_default_queue = 'memory_tasks'
|
||||
|
||||
# macOS 兼容性配置
|
||||
import platform
|
||||
|
||||
if platform.system() == 'Darwin': # macOS
|
||||
# 设置环境变量解决 fork 问题
|
||||
if platform.system() == 'Darwin':
|
||||
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
|
||||
|
||||
# 使用 solo 池避免多进程问题
|
||||
celery_app.conf.worker_pool = 'solo'
|
||||
|
||||
# 设置唯一的节点名称
|
||||
import socket
|
||||
import time
|
||||
hostname = socket.gethostname()
|
||||
timestamp = int(time.time())
|
||||
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
|
||||
|
||||
# Celery 配置
|
||||
celery_app.conf.update(
|
||||
@@ -52,36 +38,47 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=30 * 60, # 30 分钟硬超时
|
||||
task_soft_time_limit=25 * 60, # 25 分钟软超时
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
|
||||
# Worker 设置 - 针对 macOS 优化
|
||||
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
|
||||
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
|
||||
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
|
||||
# 结果过期时间
|
||||
result_expires=3600, # 结果保存 1 小时
|
||||
result_expires=3600, # 结果保存1小时
|
||||
|
||||
# 任务确认设置
|
||||
task_acks_late=True, # 任务完成后才确认,避免任务丢失
|
||||
worker_disable_rate_limits=True, # 禁用速率限制
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
worker_disable_rate_limits=True,
|
||||
|
||||
# 任务路由(可选,用于不同队列)
|
||||
# task_routes={
|
||||
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
|
||||
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
|
||||
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
|
||||
# 'tasks.process_item': {'queue': 'default'},
|
||||
# },
|
||||
# FLower setting
|
||||
worker_send_task_events=True,
|
||||
task_send_sent_event=True,
|
||||
|
||||
# task routing
|
||||
task_routes={
|
||||
# Memory tasks → memory_tasks queue (threads worker)
|
||||
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
|
||||
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
|
||||
|
||||
# Document tasks → document_tasks queue (prefork worker)
|
||||
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
|
||||
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
|
||||
|
||||
# Beat/periodic tasks → document_tasks queue (prefork worker)
|
||||
'app.tasks.workspace_reflection_task': {'queue': 'document_tasks'},
|
||||
'app.tasks.regenerate_memory_cache': {'queue': 'document_tasks'},
|
||||
'app.tasks.run_forgetting_cycle_task': {'queue': 'document_tasks'},
|
||||
'app.controllers.memory_storage_controller.search_all': {'queue': 'document_tasks'},
|
||||
},
|
||||
)
|
||||
|
||||
# 自动发现任务模块
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
|
||||
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
@@ -89,12 +86,6 @@ forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘
|
||||
|
||||
# 构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
|
||||
# "check-read-service": {
|
||||
# "task": "app.core.memory.agent.health.check_read_service",
|
||||
# "schedule": health_schedule,
|
||||
# "args": (),
|
||||
# },
|
||||
"run-workspace-reflection": {
|
||||
"task": "app.tasks.workspace_reflection_task",
|
||||
"schedule": workspace_reflection_schedule,
|
||||
|
||||
@@ -3,6 +3,12 @@ Celery Worker 入口点
|
||||
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
|
||||
"""
|
||||
from app.celery_app import celery_app
|
||||
from app.core.logging_config import LoggingConfig, get_logger
|
||||
|
||||
# Initialize logging system for Celery worker
|
||||
LoggingConfig.setup_logging()
|
||||
logger = get_logger(__name__)
|
||||
logger.info("Celery worker logging initialized")
|
||||
|
||||
# 导入任务模块以注册任务
|
||||
import app.tasks
|
||||
|
||||
@@ -14,6 +14,7 @@ from . import (
|
||||
emotion_config_controller,
|
||||
emotion_controller,
|
||||
file_controller,
|
||||
file_storage_controller,
|
||||
home_page_controller,
|
||||
implicit_memory_controller,
|
||||
knowledge_controller,
|
||||
@@ -88,5 +89,6 @@ manager_router.include_router(home_page_controller.router)
|
||||
manager_router.include_router(implicit_memory_controller.router)
|
||||
manager_router.include_router(memory_perceptual_controller.router)
|
||||
manager_router.include_router(memory_working_controller.router)
|
||||
manager_router.include_router(file_storage_controller.router)
|
||||
|
||||
__all__ = ["manager_router"]
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.response_utils import success, fail
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user, cur_workspace_access_guard
|
||||
from app.models import User
|
||||
@@ -661,6 +661,11 @@ async def draft_run(
|
||||
data=result,
|
||||
msg="工作流任务执行成功"
|
||||
)
|
||||
else:
|
||||
return fail(
|
||||
msg="未知应用类型",
|
||||
code=422
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
|
||||
|
||||
@@ -24,7 +24,7 @@ from app.schemas.emotion_schema import (
|
||||
)
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.emotion_analytics_service import EmotionAnalyticsService
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
# 获取API专用日志器
|
||||
@@ -45,6 +45,7 @@ emotion_service = EmotionAnalyticsService()
|
||||
@router.post("/tags", response_model=ApiResponse)
|
||||
async def get_emotion_tags(
|
||||
request: EmotionTagsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
@@ -59,7 +60,7 @@ async def get_emotion_tags(
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_tags(
|
||||
end_user_id=request.group_id,
|
||||
@@ -68,7 +69,7 @@ async def get_emotion_tags(
|
||||
end_date=request.end_date,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪标签统计获取成功",
|
||||
extra={
|
||||
@@ -77,9 +78,9 @@ async def get_emotion_tags(
|
||||
"tags_count": len(data.get("tags", []))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪标签获取成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪标签统计失败: {str(e)}",
|
||||
@@ -96,6 +97,7 @@ async def get_emotion_tags(
|
||||
@router.post("/wordcloud", response_model=ApiResponse)
|
||||
async def get_emotion_wordcloud(
|
||||
request: EmotionWordcloudRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
@@ -108,14 +110,14 @@ async def get_emotion_wordcloud(
|
||||
"limit": request.limit
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.get_emotion_wordcloud(
|
||||
end_user_id=request.group_id,
|
||||
emotion_type=request.emotion_type,
|
||||
limit=request.limit
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪词云数据获取成功",
|
||||
extra={
|
||||
@@ -123,9 +125,9 @@ async def get_emotion_wordcloud(
|
||||
"total_keywords": data.get("total_keywords", 0)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪词云获取成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取情绪词云数据失败: {str(e)}",
|
||||
@@ -142,6 +144,7 @@ async def get_emotion_wordcloud(
|
||||
@router.post("/health", response_model=ApiResponse)
|
||||
async def get_emotion_health(
|
||||
request: EmotionHealthRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
|
||||
@@ -152,7 +155,7 @@ async def get_emotion_health(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="时间范围参数无效,必须是 7d、30d 或 90d"
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求获取情绪健康指数",
|
||||
extra={
|
||||
@@ -160,13 +163,13 @@ async def get_emotion_health(
|
||||
"time_range": request.time_range
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层
|
||||
data = await emotion_service.calculate_emotion_health_index(
|
||||
end_user_id=request.group_id,
|
||||
time_range=request.time_range
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"情绪健康指数获取成功",
|
||||
extra={
|
||||
@@ -175,9 +178,9 @@ async def get_emotion_health(
|
||||
"level": data.get("level", "未知")
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="情绪健康指数获取成功")
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -196,16 +199,17 @@ async def get_emotion_health(
|
||||
@router.post("/suggestions", response_model=ApiResponse)
|
||||
async def get_emotion_suggestions(
|
||||
request: EmotionSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""获取个性化情绪建议(从缓存读取)
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含 group_id 和可选的 config_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
|
||||
Returns:
|
||||
缓存的个性化情绪建议响应
|
||||
"""
|
||||
@@ -217,13 +221,13 @@ async def get_emotion_suggestions(
|
||||
"config_id": request.config_id
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 从缓存获取建议
|
||||
data = await emotion_service.get_cached_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
if data is None:
|
||||
# 缓存不存在或已过期
|
||||
api_logger.info(
|
||||
@@ -231,11 +235,11 @@ async def get_emotion_suggestions(
|
||||
extra={"group_id": request.group_id}
|
||||
)
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议",
|
||||
None
|
||||
BizCode.NOT_FOUND,
|
||||
"建议缓存不存在或已过期,请右上角刷新生成新建议",
|
||||
""
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议获取成功(缓存)",
|
||||
extra={
|
||||
@@ -243,9 +247,9 @@ async def get_emotion_suggestions(
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="个性化建议获取成功(缓存)")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"获取个性化建议失败: {str(e)}",
|
||||
@@ -261,80 +265,56 @@ async def get_emotion_suggestions(
|
||||
@router.post("/generate_suggestions", response_model=ApiResponse)
|
||||
async def generate_emotion_suggestions(
|
||||
request: EmotionGenerateSuggestionsRequest,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""生成个性化情绪建议(调用LLM并缓存)
|
||||
|
||||
|
||||
Args:
|
||||
request: 包含 group_id、可选的 config_id 和 force_refresh
|
||||
request: 包含 end_user_id
|
||||
db: 数据库会话
|
||||
current_user: 当前用户
|
||||
|
||||
|
||||
Returns:
|
||||
新生成的个性化情绪建议响应
|
||||
"""
|
||||
try:
|
||||
# 验证 config_id(如果提供)
|
||||
# 获取终端用户关联的配置
|
||||
config_id = request.config_id
|
||||
if config_id is None:
|
||||
# 如果没有提供 config_id,尝试获取用户关联的配置
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(request.group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
except ValueError as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
|
||||
else:
|
||||
# 如果提供了 config_id,验证其有效性
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
try:
|
||||
config_service = MemoryConfigService(db)
|
||||
config = config_service.get_config_by_id(config_id)
|
||||
if not config:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
|
||||
except Exception as e:
|
||||
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
|
||||
|
||||
api_logger.info(
|
||||
f"用户 {current_user.username} 请求生成个性化情绪建议",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"config_id": config_id
|
||||
"end_user_id": request.end_user_id
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# 调用服务层生成建议
|
||||
data = await emotion_service.generate_emotion_suggestions(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
# 保存到缓存
|
||||
await emotion_service.save_suggestions_cache(
|
||||
end_user_id=request.group_id,
|
||||
end_user_id=request.end_user_id,
|
||||
suggestions_data=data,
|
||||
db=db,
|
||||
expires_hours=24
|
||||
)
|
||||
|
||||
|
||||
api_logger.info(
|
||||
"个性化建议生成成功",
|
||||
extra={
|
||||
"group_id": request.group_id,
|
||||
"end_user_id": request.end_user_id,
|
||||
"suggestions_count": len(data.get("suggestions", []))
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
return success(data=data, msg="个性化建议生成成功")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(
|
||||
f"生成个性化建议失败: {str(e)}",
|
||||
extra={"group_id": request.group_id},
|
||||
extra={"end_user_id": request.end_user_id},
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
|
||||
499
api/app/controllers/file_storage_controller.py
Normal file
499
api/app/controllers/file_storage_controller.py
Normal file
@@ -0,0 +1,499 @@
|
||||
"""
|
||||
File storage controller module.
|
||||
|
||||
This module provides API endpoints for file storage operations using the
|
||||
configurable storage backend. It is a new controller that does not modify
|
||||
the existing file_controller.py.
|
||||
|
||||
Routes:
|
||||
POST /storage/files - Upload a file
|
||||
GET /storage/files/{file_id} - Download a file
|
||||
DELETE /storage/files/{file_id} - Delete a file
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
|
||||
from fastapi.responses import FileResponse, RedirectResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.core.storage import LocalStorage
|
||||
from app.core.storage.url_signer import generate_signed_url, verify_signed_url
|
||||
from app.core.storage_exceptions import (
|
||||
StorageDeleteError,
|
||||
StorageUploadError,
|
||||
)
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.file_metadata_model import FileMetadata
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services.file_storage_service import (
|
||||
FileStorageService,
|
||||
generate_file_key,
|
||||
get_file_storage_service,
|
||||
)
|
||||
|
||||
api_logger = get_api_logger()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/storage",
|
||||
tags=["storage"]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/files", response_model=ApiResponse)
|
||||
async def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Upload a file to the configured storage backend.
|
||||
"""
|
||||
tenant_id = current_user.tenant_id
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(
|
||||
f"Storage upload request: tenant_id={tenant_id}, workspace_id={workspace_id}, "
|
||||
f"filename={file.filename}, username={current_user.username}"
|
||||
)
|
||||
|
||||
# Read file contents
|
||||
contents = await file.read()
|
||||
file_size = len(contents)
|
||||
|
||||
# Validate file size
|
||||
if file_size == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="The file is empty."
|
||||
)
|
||||
|
||||
if file_size > settings.MAX_FILE_SIZE:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
|
||||
)
|
||||
|
||||
# Extract file extension
|
||||
_, file_extension = os.path.splitext(file.filename)
|
||||
file_ext = file_extension.lower()
|
||||
|
||||
# Generate file_id and file_key
|
||||
file_id = uuid.uuid4()
|
||||
file_key = generate_file_key(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
)
|
||||
|
||||
# Create file metadata record with pending status
|
||||
file_metadata = FileMetadata(
|
||||
id=file_id,
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_key=file_key,
|
||||
file_name=file.filename,
|
||||
file_ext=file_ext,
|
||||
file_size=file_size,
|
||||
content_type=file.content_type,
|
||||
status="pending",
|
||||
)
|
||||
db.add(file_metadata)
|
||||
db.commit()
|
||||
db.refresh(file_metadata)
|
||||
|
||||
# Upload file to storage backend
|
||||
try:
|
||||
await storage_service.upload_file(
|
||||
tenant_id=tenant_id,
|
||||
workspace_id=workspace_id,
|
||||
file_id=file_id,
|
||||
file_ext=file_ext,
|
||||
content=contents,
|
||||
content_type=file.content_type,
|
||||
)
|
||||
# Update status to completed
|
||||
file_metadata.status = "completed"
|
||||
db.commit()
|
||||
api_logger.info(f"File uploaded to storage: file_key={file_key}")
|
||||
except StorageUploadError as e:
|
||||
# Update status to failed
|
||||
file_metadata.status = "failed"
|
||||
db.commit()
|
||||
api_logger.error(f"Storage upload failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"File storage failed: {str(e)}"
|
||||
)
|
||||
|
||||
api_logger.info(f"File upload successful: {file.filename} (file_id: {file_id})")
|
||||
|
||||
return success(
|
||||
data={"file_id": str(file_id), "file_key": file_key},
|
||||
msg="File upload successful"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/files/{file_id}", response_model=Any)
|
||||
async def download_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Download a file from the configured storage backend.
|
||||
"""
|
||||
api_logger.info(f"Storage download request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found (possibly deleted)"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving local file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except FileNotFoundError:
|
||||
api_logger.warning(f"File not found in remote storage: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found in storage"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/files/{file_id}", response_model=ApiResponse)
|
||||
async def delete_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Delete a file from the configured storage backend.
|
||||
"""
|
||||
api_logger.info(
|
||||
f"Storage delete request: file_id={file_id}, username={current_user.username}"
|
||||
)
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
|
||||
# Delete file from storage
|
||||
try:
|
||||
deleted = await storage_service.delete_file(file_key)
|
||||
if deleted:
|
||||
api_logger.info(f"File deleted from storage: file_key={file_key}")
|
||||
else:
|
||||
api_logger.info(f"File did not exist in storage: file_key={file_key}")
|
||||
except StorageDeleteError as e:
|
||||
api_logger.error(f"Storage delete failed: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete file from storage: {str(e)}"
|
||||
)
|
||||
|
||||
# Delete database record
|
||||
try:
|
||||
db.delete(file_metadata)
|
||||
db.commit()
|
||||
api_logger.info(f"File record deleted from database: file_id={file_id}")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Database delete failed: {e}")
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete file record: {str(e)}"
|
||||
)
|
||||
|
||||
return success(msg="File deleted successfully")
|
||||
|
||||
|
||||
@router.get("/files/{file_id}/url", response_model=ApiResponse)
|
||||
async def get_file_url(
|
||||
file_id: uuid.UUID,
|
||||
expires: int = None,
|
||||
permanent: bool = False,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
):
|
||||
"""
|
||||
Get an access URL for a file (no authentication required).
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
expires: URL validity period in seconds (default from FILE_URL_EXPIRES env).
|
||||
permanent: If True, return a permanent URL without expiration.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
ApiResponse with the access URL.
|
||||
"""
|
||||
if expires is None:
|
||||
expires = settings.FILE_URL_EXPIRES
|
||||
|
||||
api_logger.info(f"Get file URL request: file_id={file_id}, expires={expires}, permanent={permanent}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
try:
|
||||
if permanent:
|
||||
# Generate permanent URL (no expiration check)
|
||||
server_url = f"http://{settings.SERVER_IP}:8000/api"
|
||||
url = f"{server_url}/storage/permanent/{file_id}"
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"expires_in": None,
|
||||
"permanent": True,
|
||||
"file_name": file_metadata.file_name,
|
||||
},
|
||||
msg="Permanent file URL generated successfully"
|
||||
)
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
# For local storage, generate signed URL with expiration
|
||||
url = generate_signed_url(str(file_id), expires)
|
||||
else:
|
||||
# For remote storage (OSS/S3), get presigned URL
|
||||
url = await storage_service.get_file_url(file_key, expires=expires)
|
||||
|
||||
api_logger.info(f"Generated file URL: file_id={file_id}")
|
||||
return success(
|
||||
data={
|
||||
"url": url,
|
||||
"expires_in": expires,
|
||||
"permanent": False,
|
||||
"file_name": file_metadata.file_name,
|
||||
},
|
||||
msg="File URL generated successfully"
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to generate file URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to generate file URL: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/public/{file_id}", response_model=Any)
|
||||
async def public_download_file(
|
||||
file_id: uuid.UUID,
|
||||
expires: int = 0,
|
||||
signature: str = "",
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Public file download endpoint with signature verification.
|
||||
|
||||
This endpoint allows downloading files without authentication,
|
||||
but requires a valid signature and non-expired timestamp.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
expires: Expiration timestamp.
|
||||
signature: HMAC signature for verification.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
FileResponse for the requested file.
|
||||
"""
|
||||
api_logger.info(f"Public download request: file_id={file_id}")
|
||||
|
||||
# Verify signature
|
||||
is_valid, error_msg = verify_signed_url(str(file_id), expires, signature)
|
||||
if not is_valid:
|
||||
api_logger.warning(f"Invalid signed URL: file_id={file_id}, error={error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving public file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
# For remote storage, redirect to presigned URL
|
||||
try:
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/permanent/{file_id}", response_model=Any)
|
||||
async def permanent_download_file(
|
||||
file_id: uuid.UUID,
|
||||
db: Session = Depends(get_db),
|
||||
storage_service: FileStorageService = Depends(get_file_storage_service),
|
||||
) -> Any:
|
||||
"""
|
||||
Permanent file download endpoint (no expiration, no signature required).
|
||||
|
||||
This endpoint allows downloading files without authentication or expiration.
|
||||
Use with caution as URLs are permanently accessible.
|
||||
|
||||
Args:
|
||||
file_id: The UUID of the file.
|
||||
db: Database session.
|
||||
storage_service: The file storage service.
|
||||
|
||||
Returns:
|
||||
FileResponse for the requested file.
|
||||
"""
|
||||
api_logger.info(f"Permanent download request: file_id={file_id}")
|
||||
|
||||
# Query file metadata from database
|
||||
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
|
||||
if not file_metadata:
|
||||
api_logger.warning(f"File not found in database: file_id={file_id}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="The file does not exist"
|
||||
)
|
||||
|
||||
if file_metadata.status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"File upload not completed, status: {file_metadata.status}"
|
||||
)
|
||||
|
||||
file_key = file_metadata.file_key
|
||||
storage = storage_service.storage
|
||||
|
||||
if isinstance(storage, LocalStorage):
|
||||
full_path = storage._get_full_path(file_key)
|
||||
|
||||
if not full_path.exists():
|
||||
api_logger.warning(f"File not found on disk: file_key={file_key}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="File not found"
|
||||
)
|
||||
|
||||
api_logger.info(f"Serving permanent file: file_key={file_key}")
|
||||
return FileResponse(
|
||||
path=str(full_path),
|
||||
filename=file_metadata.file_name,
|
||||
media_type=file_metadata.content_type or "application/octet-stream"
|
||||
)
|
||||
else:
|
||||
# For remote storage, redirect to presigned URL with long expiration
|
||||
try:
|
||||
# Use a very long expiration (7 days max for most cloud providers)
|
||||
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
|
||||
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
|
||||
except Exception as e:
|
||||
api_logger.error(f"Failed to get presigned URL: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve file: {str(e)}"
|
||||
)
|
||||
@@ -161,9 +161,9 @@ async def get_preference_tags(
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
|
||||
# Extract preferences from cache
|
||||
@@ -232,9 +232,9 @@ async def get_dimension_portrait(
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
|
||||
# Extract portrait from cache
|
||||
@@ -280,9 +280,9 @@ async def get_interest_area_distribution(
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
|
||||
# Extract interest areas from cache
|
||||
@@ -332,9 +332,9 @@ async def get_behavior_habits(
|
||||
if cached_profile is None:
|
||||
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
|
||||
return fail(
|
||||
BizCode.RESOURCE_NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
|
||||
None
|
||||
BizCode.NOT_FOUND,
|
||||
"画像缓存不存在或已过期,请右上角刷新生成新画像",
|
||||
""
|
||||
)
|
||||
|
||||
# Extract habits from cache
|
||||
|
||||
@@ -9,14 +9,16 @@ from app.db import get_db
|
||||
from app.dependencies import cur_workspace_access_guard, get_current_user
|
||||
from app.models import ModelApiKey
|
||||
from app.models.user_model import User
|
||||
from app.repositories import knowledge_repository
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.repositories import knowledge_repository, WorkspaceRepository
|
||||
from app.schemas.memory_agent_schema import UserInput, Write_UserInput
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.services import task_service, workspace_service
|
||||
from app.services.memory_agent_service import MemoryAgentService
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
|
||||
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
@@ -160,9 +162,12 @@ async def write_server(
|
||||
|
||||
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
result = await memory_agent_service.write_memory(
|
||||
user_input.group_id,
|
||||
user_input.message,
|
||||
messages_list, # 传递结构化消息列表
|
||||
config_id,
|
||||
db,
|
||||
storage_type,
|
||||
@@ -219,9 +224,12 @@ async def write_server_async(
|
||||
if knowledge: user_rag_memory_id = str(knowledge.id)
|
||||
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
task = celery_app.send_task(
|
||||
"app.core.memory.agent.write_message",
|
||||
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
|
||||
args=[user_input.group_id, messages_list, config_id, storage_type, user_rag_memory_id]
|
||||
)
|
||||
api_logger.info(f"Write task queued: {task.id}")
|
||||
|
||||
@@ -285,6 +293,19 @@ async def read_server(
|
||||
storage_type,
|
||||
user_rag_memory_id
|
||||
)
|
||||
if str(user_input.search_switch) == "2":
|
||||
retrieve_info = result['answer']
|
||||
history = await SessionService(store).get_history(user_input.group_id, user_input.group_id, user_input.group_id)
|
||||
query = user_input.message
|
||||
|
||||
# 调用 memory_agent_service 的方法生成最终答案
|
||||
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
|
||||
retrieve_info=retrieve_info,
|
||||
history=history,
|
||||
query=query,
|
||||
config_id=config_id,
|
||||
db=db
|
||||
)
|
||||
return success(data=result, msg="回复对话消息成功")
|
||||
except BaseException as e:
|
||||
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
|
||||
@@ -564,8 +585,23 @@ async def status_type(
|
||||
"""
|
||||
api_logger.info(f"Status type check requested for group {user_input.group_id}")
|
||||
try:
|
||||
# 获取标准化的消息列表
|
||||
messages_list = memory_agent_service.get_messages_list(user_input)
|
||||
|
||||
# 将消息列表转换为字符串用于分类
|
||||
# 只取最后一条用户消息进行分类
|
||||
last_user_message = ""
|
||||
for msg in reversed(messages_list):
|
||||
if msg.get('role') == 'user':
|
||||
last_user_message = msg.get('content', '')
|
||||
break
|
||||
|
||||
if not last_user_message:
|
||||
# 如果没有用户消息,使用所有消息的内容
|
||||
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
|
||||
|
||||
result = await memory_agent_service.classify_message_type(
|
||||
user_input.message,
|
||||
last_user_message,
|
||||
user_input.config_id,
|
||||
db
|
||||
)
|
||||
@@ -616,8 +652,10 @@ async def get_knowledge_type_stats_api(
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
current_user: User = Depends(get_current_user)
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
@@ -628,10 +666,22 @@ async def get_hot_memory_tags_by_user_api(
|
||||
...
|
||||
]
|
||||
"""
|
||||
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
end_user_id=end_user_id,
|
||||
language_type=language_type,
|
||||
model_id=model_id,
|
||||
limit=limit
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
@@ -647,7 +697,7 @@ async def get_user_profile_api(
|
||||
current_user: User = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
获取用户详情,包含:
|
||||
获取工作空间下Popular Memory Tags,包含:
|
||||
- name: 用户名字(直接使用 end_user_id)
|
||||
- tags: 3个用户特征标签(从语句和实体中LLM总结)
|
||||
- hot_tags: 4个热门记忆标签
|
||||
|
||||
@@ -5,7 +5,6 @@ from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.memory_agent_schema import End_User_Information
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
@@ -40,54 +39,7 @@ def get_workspace_total_end_users(
|
||||
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
|
||||
return success(data=total_end_users, msg="用户数量获取成功")
|
||||
|
||||
@router.post("/update/end_users", response_model=ApiResponse)
|
||||
async def update_workspace_end_users(
|
||||
user_input: End_User_Information,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
更新工作空间的宿主信息
|
||||
"""
|
||||
username = user_input.end_user_name # 要更新的用户名
|
||||
end_user_input_id = user_input.id # 宿主ID
|
||||
workspace_id = current_user.current_workspace_id
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
|
||||
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
|
||||
|
||||
try:
|
||||
# 导入更新函数
|
||||
from app.repositories.end_user_repository import update_end_user_other_name
|
||||
import uuid
|
||||
|
||||
# 转换 end_user_id 为 UUID 类型
|
||||
end_user_uuid = uuid.UUID(end_user_input_id)
|
||||
|
||||
# 直接更新数据库中的 other_name 字段
|
||||
updated_count = update_end_user_other_name(
|
||||
db=db,
|
||||
end_user_id=end_user_uuid,
|
||||
other_name=username
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
|
||||
|
||||
return success(
|
||||
data={
|
||||
"updated_count": updated_count,
|
||||
"end_user_id": end_user_input_id,
|
||||
"updated_other_name": username
|
||||
},
|
||||
msg=f"成功更新 {updated_count} 个宿主的信息"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
api_logger.error(f"更新宿主信息失败: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"更新宿主信息失败: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
|
||||
ReflectionConfig,
|
||||
ReflectionEngine,
|
||||
ReflectionEngine, ReflectionRange, ReflectionBaseline,
|
||||
)
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
@@ -19,7 +20,7 @@ from app.services.memory_reflection_service import (
|
||||
)
|
||||
from app.services.model_service import ModelConfigService
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -39,9 +40,6 @@ async def save_reflection_config(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Save reflection configuration to data_comfig table"""
|
||||
|
||||
|
||||
|
||||
try:
|
||||
config_id = request.config_id
|
||||
if not config_id:
|
||||
@@ -52,51 +50,30 @@ async def save_reflection_config(
|
||||
|
||||
api_logger.info(f"用户 {current_user.username} 保存反思配置,config_id: {config_id}")
|
||||
|
||||
update_params = {
|
||||
"enable_self_reflexion": request.reflection_enabled,
|
||||
"iteration_period": request.reflection_period_in_hours,
|
||||
"reflexion_range": request.reflexion_range,
|
||||
"baseline": request.baseline,
|
||||
"reflection_model_id": request.reflection_model_id,
|
||||
"memory_verify": request.memory_verify,
|
||||
"quality_assessment": request.quality_assessment,
|
||||
}
|
||||
data_config = DataConfigRepository.update_reflection_config(
|
||||
db,
|
||||
config_id=config_id,
|
||||
enable_self_reflexion=request.reflection_enabled,
|
||||
iteration_period=request.reflection_period_in_hours,
|
||||
reflexion_range=request.reflexion_range,
|
||||
baseline=request.baseline,
|
||||
reflection_model_id=request.reflection_model_id,
|
||||
memory_verify=request.memory_verify,
|
||||
quality_assessment=request.quality_assessment
|
||||
)
|
||||
|
||||
|
||||
|
||||
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
|
||||
|
||||
result = db.execute(text(query), params)
|
||||
if result.rowcount == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
db.commit()
|
||||
|
||||
# 查询更新后的配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"更新后未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
api_logger.info(f"成功保存反思配置到数据库,config_id: {config_id}")
|
||||
db.refresh(data_config)
|
||||
|
||||
reflection_result={
|
||||
"config_id": result.config_id,
|
||||
"enable_self_reflexion": result.enable_self_reflexion,
|
||||
"iteration_period": result.iteration_period,
|
||||
"reflexion_range": result.reflexion_range,
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id}
|
||||
"config_id": data_config.config_id,
|
||||
"enable_self_reflexion": data_config.enable_self_reflexion,
|
||||
"iteration_period": data_config.iteration_period,
|
||||
"reflexion_range": data_config.reflexion_range,
|
||||
"baseline": data_config.baseline,
|
||||
"reflection_model_id": data_config.reflection_model_id,
|
||||
"memory_verify": data_config.memory_verify,
|
||||
"quality_assessment": data_config.quality_assessment}
|
||||
|
||||
return success(data=reflection_result, msg="反思配置成功")
|
||||
|
||||
@@ -116,9 +93,8 @@ async def save_reflection_config(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/reflection")
|
||||
@router.get("/reflection")
|
||||
async def start_workspace_reflection(
|
||||
config_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -143,11 +119,20 @@ async def start_workspace_reflection(
|
||||
end_users = data['end_users']
|
||||
|
||||
for base, config, user in zip(releases, data_configs, end_users):
|
||||
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
|
||||
# 安全地转换为整数,处理空字符串和None的情况
|
||||
print(base['config'])
|
||||
try:
|
||||
base_config = int(base['config']) if base['config'] else 0
|
||||
config_id = int(config['config_id']) if config['config_id'] else 0
|
||||
except (ValueError, TypeError):
|
||||
api_logger.warning(f"无效的配置ID: base['config']={base.get('config')}, config['config_id']={config.get('config_id')}")
|
||||
continue
|
||||
|
||||
if base_config == config_id and base['app_id'] == user['app_id']:
|
||||
# 调用反思服务
|
||||
api_logger.info(f"为用户 {user['id']} 启动反思,config_id: {config['config_id']}")
|
||||
|
||||
reflection_result = await reflection_service.start_reflection_from_data(
|
||||
reflection_result = await reflection_service.start_text_reflection(
|
||||
config_data=config,
|
||||
end_user_id=user['id']
|
||||
)
|
||||
@@ -178,17 +163,7 @@ async def start_reflection_configs(
|
||||
"""通过config_id查询data_config表中的反思配置信息"""
|
||||
try:
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"未找到config_id为 {config_id} 的配置"
|
||||
)
|
||||
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
# 构建返回数据
|
||||
reflection_config = {
|
||||
"config_id": result.config_id,
|
||||
@@ -198,8 +173,7 @@ async def start_reflection_configs(
|
||||
"baseline": result.baseline,
|
||||
"reflection_model_id": result.reflection_model_id,
|
||||
"memory_verify": result.memory_verify,
|
||||
"quality_assessment": result.quality_assessment,
|
||||
"user_id": result.user_id
|
||||
"quality_assessment": result.quality_assessment
|
||||
}
|
||||
api_logger.info(f"成功查询反思配置,config_id: {config_id}")
|
||||
return success(data=reflection_config, msg="反思配置查询成功")
|
||||
@@ -218,7 +192,7 @@ async def start_reflection_configs(
|
||||
@router.get("/reflection/run")
|
||||
async def reflection_run(
|
||||
config_id: int,
|
||||
language_type: str = "zh",
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -227,9 +201,7 @@ async def reflection_run(
|
||||
api_logger.info(f"用户 {current_user.username} 查询反思配置,config_id: {config_id}")
|
||||
|
||||
# 使用DataConfigRepository查询反思配置
|
||||
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
|
||||
result = db.execute(text(select_query), select_params).fetchone()
|
||||
|
||||
result = DataConfigRepository.query_reflection_config_by_id(db, config_id)
|
||||
if not result:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
@@ -242,7 +214,7 @@ async def reflection_run(
|
||||
model_id = result.reflection_model_id
|
||||
if model_id:
|
||||
try:
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
|
||||
ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id))
|
||||
api_logger.info(f"模型ID验证成功: {model_id}")
|
||||
except Exception as e:
|
||||
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
|
||||
@@ -252,8 +224,8 @@ async def reflection_run(
|
||||
config = ReflectionConfig(
|
||||
enabled=result.enable_self_reflexion,
|
||||
iteration_period=result.iteration_period,
|
||||
reflexion_range=result.reflexion_range,
|
||||
baseline=result.baseline,
|
||||
reflexion_range=ReflectionRange(result.reflexion_range),
|
||||
baseline=ReflectionBaseline(result.baseline),
|
||||
output_example='',
|
||||
memory_verify=result.memory_verify,
|
||||
quality_assessment=result.quality_assessment,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, HTTPException, status,Header
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
@@ -20,6 +20,7 @@ router = APIRouter(
|
||||
@router.get("/short_term")
|
||||
async def short_term_configs(
|
||||
end_user_id: str,
|
||||
language_type:str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import os
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.memory.utils.self_reflexion_utils import self_reflexion
|
||||
from app.core.response_utils import fail, success
|
||||
from app.db import get_db
|
||||
from app.dependencies import get_current_user
|
||||
@@ -30,7 +28,6 @@ from app.services.memory_storage_service import (
|
||||
search_dialogue,
|
||||
search_edges,
|
||||
search_entity,
|
||||
search_entity_graph,
|
||||
search_statement,
|
||||
)
|
||||
from fastapi import APIRouter, Depends
|
||||
@@ -414,21 +411,7 @@ async def search_entity_edges(
|
||||
api_logger.error(f"Search edges failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
|
||||
|
||||
@router.get("/search/entity_graph", response_model=ApiResponse)
|
||||
async def search_for_entity_graph(
|
||||
end_user_id: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""
|
||||
搜索所有实体之间的关系网络
|
||||
"""
|
||||
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
|
||||
try:
|
||||
result = await search_entity_graph(end_user_id)
|
||||
return success(data=result, msg="查询成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Search entity graph failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
|
||||
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
|
||||
@@ -458,18 +441,3 @@ async def get_recent_activity_stats_api(
|
||||
api_logger.error(f"Recent activity stats failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/self_reflexion")
|
||||
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
|
||||
"""
|
||||
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
|
||||
|
||||
Args:
|
||||
None
|
||||
Returns:
|
||||
自我反思结果。
|
||||
"""
|
||||
return await self_reflexion(host_id)
|
||||
|
||||
|
||||
@@ -8,9 +8,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.response_utils import success
|
||||
from app.db import get_db
|
||||
from app.db import get_db, get_db_read
|
||||
from app.dependencies import get_share_user_id, ShareTokenData
|
||||
from app.repositories import knowledge_repository
|
||||
from app.repositories.workflow_repository import WorkflowConfigRepository
|
||||
from app.schemas import release_share_schema, conversation_schema
|
||||
from app.schemas.response_schema import PageData, PageMeta
|
||||
from app.services import workspace_service
|
||||
@@ -19,7 +20,8 @@ from app.services.conversation_service import ConversationService
|
||||
from app.services.release_share_service import ReleaseShareService
|
||||
from app.services.shared_chat_service import SharedChatService
|
||||
from app.services.app_chat_service import AppChatService, get_app_chat_service
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \
|
||||
agent_config_4_app_release, multi_agent_config_4_app_release
|
||||
|
||||
router = APIRouter(prefix="/public/share", tags=["Public Share"])
|
||||
logger = get_business_logger()
|
||||
@@ -65,10 +67,10 @@ def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
|
||||
summary="获取访问 token"
|
||||
)
|
||||
def get_access_token(
|
||||
share_token: str,
|
||||
payload: release_share_schema.TokenRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
share_token: str,
|
||||
payload: release_share_schema.TokenRequest,
|
||||
request: Request,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取访问 token
|
||||
|
||||
@@ -113,9 +115,9 @@ def get_access_token(
|
||||
response_model=None
|
||||
)
|
||||
def get_shared_release(
|
||||
password: str = Query(None, description="访问密码(如果需要)"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
password: str = Query(None, description="访问密码(如果需要)"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取公开分享的发布版本信息
|
||||
|
||||
@@ -137,9 +139,9 @@ def get_shared_release(
|
||||
summary="验证访问密码"
|
||||
)
|
||||
def verify_password(
|
||||
payload: release_share_schema.PasswordVerifyRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
payload: release_share_schema.PasswordVerifyRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""验证分享的访问密码
|
||||
|
||||
@@ -159,11 +161,11 @@ def verify_password(
|
||||
summary="获取嵌入代码"
|
||||
)
|
||||
def get_embed_code(
|
||||
width: str = Query("100%", description="iframe 宽度"),
|
||||
height: str = Query("600px", description="iframe 高度"),
|
||||
request: Request = None,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
width: str = Query("100%", description="iframe 宽度"),
|
||||
height: str = Query("600px", description="iframe 高度"),
|
||||
request: Request = None,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取嵌入代码
|
||||
|
||||
@@ -183,7 +185,6 @@ def get_embed_code(
|
||||
return success(data=embed_code)
|
||||
|
||||
|
||||
|
||||
# ---------- 会话管理接口 ----------
|
||||
|
||||
@router.get(
|
||||
@@ -191,11 +192,11 @@ def get_embed_code(
|
||||
summary="获取会话列表"
|
||||
)
|
||||
def list_conversations(
|
||||
password: str = Query(None, description="访问密码"),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
password: str = Query(None, description="访问密码"),
|
||||
page: int = Query(1, ge=1),
|
||||
pagesize: int = Query(20, ge=1, le=100),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取分享应用的会话列表
|
||||
|
||||
@@ -209,9 +210,9 @@ def list_conversations(
|
||||
from app.repositories.end_user_repository import EndUserRepository
|
||||
end_user_repo = EndUserRepository(db)
|
||||
new_end_user = end_user_repo.get_or_create_end_user(
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
app_id=share.app_id,
|
||||
other_id=other_id
|
||||
)
|
||||
logger.debug(new_end_user.id)
|
||||
service = SharedChatService(db)
|
||||
conversations, total = service.list_conversations(
|
||||
@@ -233,10 +234,10 @@ def list_conversations(
|
||||
summary="获取会话详情(含消息)"
|
||||
)
|
||||
def get_conversation(
|
||||
conversation_id: uuid.UUID,
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
conversation_id: uuid.UUID,
|
||||
password: str = Query(None, description="访问密码"),
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""获取会话详情和消息历史"""
|
||||
chat_service = SharedChatService(db)
|
||||
@@ -266,10 +267,10 @@ def get_conversation(
|
||||
summary="发送消息(支持流式和非流式)"
|
||||
)
|
||||
async def chat(
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
payload: conversation_schema.ChatRequest,
|
||||
share_data: ShareTokenData = Depends(get_share_user_id),
|
||||
db: Session = Depends(get_db),
|
||||
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
|
||||
):
|
||||
"""发送消息并获取回复
|
||||
|
||||
@@ -313,7 +314,7 @@ async def chat(
|
||||
)
|
||||
end_user_id = str(new_end_user.id)
|
||||
|
||||
appid=share.app_id
|
||||
appid = share.app_id
|
||||
"""获取存储类型和工作空间的ID"""
|
||||
|
||||
# 直接通过 SQLAlchemy 查询 app
|
||||
@@ -425,16 +426,16 @@ async def chat(
|
||||
# )
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.agnet_chat_stream(
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id= str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=agent_config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
web_search=payload.web_search,
|
||||
config=agent_config,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
workspace_id=workspace_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -481,15 +482,15 @@ async def chat(
|
||||
async def event_generator():
|
||||
async for event in app_chat_service.multi_agent_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=str(new_end_user.id), # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id
|
||||
):
|
||||
yield event
|
||||
|
||||
@@ -561,24 +562,27 @@ async def chat(
|
||||
|
||||
# return success(data=conversation_schema.ChatResponse(**result))
|
||||
elif app_type == AppType.WORKFLOW:
|
||||
|
||||
config = workflow_config_4_app_release(release)
|
||||
if not config.id:
|
||||
with get_db_read() as db:
|
||||
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
|
||||
config.id = source_config.id
|
||||
config.id = uuid.UUID(config.id)
|
||||
if payload.stream:
|
||||
async def event_generator():
|
||||
|
||||
async for event in app_chat_service.workflow_chat_stream(
|
||||
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
message=payload.message,
|
||||
conversation_id=conversation.id, # 使用已创建的会话 ID
|
||||
user_id=end_user_id, # 转换为字符串
|
||||
variables=payload.variables,
|
||||
config=config,
|
||||
web_search=payload.web_search,
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -610,7 +614,8 @@ async def chat(
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=release.app_id,
|
||||
workspace_id=workspace_id
|
||||
workspace_id=workspace_id,
|
||||
release_id=release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
|
||||
@@ -242,8 +242,9 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
@@ -274,8 +275,9 @@ async def chat(
|
||||
memory=payload.memory,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
app_id=app.app_id,
|
||||
workspace_id=workspace_id
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id
|
||||
)
|
||||
logger.debug(
|
||||
"工作流试运行返回结果",
|
||||
|
||||
@@ -5,13 +5,14 @@
|
||||
from typing import Optional
|
||||
import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends,Header
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_api_logger
|
||||
from app.core.response_utils import success, fail
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.api_key_utils import timestamp_to_datetime
|
||||
from app.services.memory_base_service import Translation_English
|
||||
from app.services.user_memory_service import (
|
||||
UserMemoryService,
|
||||
analytics_memory_types,
|
||||
@@ -20,7 +21,7 @@ from app.services.user_memory_service import (
|
||||
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
from app.schemas.memory_storage_schema import GenerateCacheRequest
|
||||
|
||||
from app.repositories.workspace_repository import WorkspaceRepository
|
||||
from app.schemas.end_user_schema import (
|
||||
EndUserProfileResponse,
|
||||
EndUserProfileUpdate,
|
||||
@@ -44,6 +45,7 @@ router = APIRouter(
|
||||
@router.get("/analytics/memory_insight/report", response_model=ApiResponse)
|
||||
async def get_memory_insight_report_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -53,10 +55,18 @@ async def get_memory_insight_report_api(
|
||||
此接口仅查询数据库中已缓存的记忆洞察数据,不执行生成操作。
|
||||
如需生成新的洞察报告,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"记忆洞察报告查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id)
|
||||
result = await user_memory_service.get_cached_memory_insight(db, end_user_id,model_id,language_type)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的记忆洞察报告: end_user_id={end_user_id}")
|
||||
@@ -72,6 +82,7 @@ async def get_memory_insight_report_api(
|
||||
@router.get("/analytics/user_summary", response_model=ApiResponse)
|
||||
async def get_user_summary_api(
|
||||
end_user_id: str,
|
||||
language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
@@ -81,10 +92,18 @@ async def get_user_summary_api(
|
||||
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
|
||||
如需生成新的用户摘要,请使用专门的生成接口。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
|
||||
try:
|
||||
# 调用服务层获取缓存数据
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
|
||||
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language_type)
|
||||
|
||||
if result["is_cached"]:
|
||||
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
|
||||
@@ -253,7 +272,6 @@ async def get_graph_data_api(
|
||||
depth=depth,
|
||||
center_node_id=center_node_id
|
||||
)
|
||||
|
||||
# 检查是否有错误消息
|
||||
if "message" in result and result["statistics"]["total_nodes"] == 0:
|
||||
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
|
||||
@@ -278,7 +296,13 @@ async def get_end_user_profile(
|
||||
db: Session = Depends(get_db),
|
||||
) -> dict:
|
||||
workspace_id = current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
# 检查用户是否已选择工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
|
||||
@@ -296,7 +320,6 @@ async def get_end_user_profile(
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
@@ -328,12 +351,11 @@ async def update_end_user_profile(
|
||||
|
||||
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
|
||||
所有字段都是可选的,只更新提供的字段。
|
||||
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
end_user_id = profile_update.end_user_id
|
||||
|
||||
# 检查用户是否已选择工作空间
|
||||
# 验证工作空间
|
||||
if workspace_id is None:
|
||||
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
|
||||
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
|
||||
@@ -343,65 +365,41 @@ async def update_end_user_profile(
|
||||
f"workspace={workspace_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 查询终端用户
|
||||
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
|
||||
# 调用 Service 层处理业务逻辑
|
||||
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
|
||||
|
||||
if not end_user:
|
||||
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
|
||||
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
|
||||
|
||||
# 更新字段(只更新提供的字段,排除 end_user_id)
|
||||
# 允许 None 值来重置字段(如 hire_date)
|
||||
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
|
||||
|
||||
# 特殊处理 hire_date:如果提供了时间戳,转换为 DateTime
|
||||
if 'hire_date' in update_data:
|
||||
hire_date_timestamp = update_data['hire_date']
|
||||
if hire_date_timestamp is not None:
|
||||
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
|
||||
# 如果是 None,保持 None(允许清空)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(end_user, field, value)
|
||||
|
||||
# 更新 updated_at 时间戳
|
||||
end_user.updated_at = datetime.datetime.now()
|
||||
|
||||
# 更新 updatetime_profile 为当前时间
|
||||
end_user.updatetime_profile = datetime.datetime.now()
|
||||
|
||||
# 提交更改
|
||||
db.commit()
|
||||
db.refresh(end_user)
|
||||
|
||||
# 构建响应数据
|
||||
profile_data = EndUserProfileResponse(
|
||||
id=end_user.id,
|
||||
other_name=end_user.other_name,
|
||||
position=end_user.position,
|
||||
department=end_user.department,
|
||||
contact=end_user.contact,
|
||||
phone=end_user.phone,
|
||||
hire_date=end_user.hire_date,
|
||||
updatetime_profile=end_user.updatetime_profile
|
||||
)
|
||||
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
|
||||
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
|
||||
if result["success"]:
|
||||
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
|
||||
return success(data=result["data"], msg="更新成功")
|
||||
else:
|
||||
error_msg = result["error"]
|
||||
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
|
||||
|
||||
# 根据错误类型映射到合适的业务错误码
|
||||
if error_msg == "终端用户不存在":
|
||||
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
|
||||
elif error_msg == "无效的用户ID格式":
|
||||
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
|
||||
else:
|
||||
# 只有未预期的错误才使用 INTERNAL_ERROR
|
||||
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
|
||||
|
||||
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,
|
||||
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default="zh", alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
workspace_id=current_user.current_workspace_id
|
||||
workspace_repo = WorkspaceRepository(db)
|
||||
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
|
||||
|
||||
if workspace_models:
|
||||
model_id = workspace_models.get("llm", None)
|
||||
else:
|
||||
model_id = None
|
||||
MemoryEntity = MemoryEntityService(id, label)
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
|
||||
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language_type)
|
||||
|
||||
return success(data=timeline_memories_result, msg="共同记忆时间线")
|
||||
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
|
||||
async def memory_space_relationship_evolution(id: str, label: str,
|
||||
|
||||
@@ -145,44 +145,98 @@ class LangChainAgent:
|
||||
messages.append(HumanMessage(content=user_content))
|
||||
|
||||
return messages
|
||||
async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
'''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
end_user_end=f"Term_{end_user_end}"
|
||||
print(messages)
|
||||
print(aimessages)
|
||||
session_id = store.save_session(
|
||||
userid=end_user_end,
|
||||
messages=messages,
|
||||
apply_id=end_user_end,
|
||||
group_id=end_user_end,
|
||||
aimessages=aimessages
|
||||
)
|
||||
store.delete_duplicate_sessions()
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
return session_id
|
||||
async def term_memory_redis_read(self,end_user_end):
|
||||
end_user_end = f"Term_{end_user_end}"
|
||||
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
messagss_list=[]
|
||||
retrieved_content=[]
|
||||
for messages in history:
|
||||
query = messages.get("Query")
|
||||
aimessages = messages.get("Answer")
|
||||
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
retrieved_content.append({query: aimessages})
|
||||
return messagss_list,retrieved_content
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_save(self,messages,end_user_end,aimessages):
|
||||
# '''短长期存储redis,为不影响正常使用6句一段话,存储用户名加一个前缀,当数据存够6条返回给neo4j'''
|
||||
# end_user_end=f"Term_{end_user_end}"
|
||||
# print(messages)
|
||||
# print(aimessages)
|
||||
# session_id = store.save_session(
|
||||
# userid=end_user_end,
|
||||
# messages=messages,
|
||||
# apply_id=end_user_end,
|
||||
# group_id=end_user_end,
|
||||
# aimessages=aimessages
|
||||
# )
|
||||
# store.delete_duplicate_sessions()
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{session_id}')
|
||||
# return session_id
|
||||
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# async def term_memory_redis_read(self,end_user_end):
|
||||
# end_user_end = f"Term_{end_user_end}"
|
||||
# history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
|
||||
# # logger.info(f'Redis_Agent:{end_user_end};{history}')
|
||||
# messagss_list=[]
|
||||
# retrieved_content=[]
|
||||
# for messages in history:
|
||||
# query = messages.get("Query")
|
||||
# aimessages = messages.get("Answer")
|
||||
# messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
|
||||
# retrieved_content.append({query: aimessages})
|
||||
# return messagss_list,retrieved_content
|
||||
|
||||
|
||||
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
|
||||
async def write(self, storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id, actual_config_id):
|
||||
"""
|
||||
写入记忆(支持结构化消息)
|
||||
|
||||
Args:
|
||||
storage_type: 存储类型 (neo4j/rag)
|
||||
end_user_id: 终端用户ID
|
||||
user_message: 用户消息内容
|
||||
ai_message: AI 回复内容
|
||||
user_rag_memory_id: RAG 记忆ID
|
||||
actual_end_user_id: 实际用户ID
|
||||
actual_config_id: 配置ID
|
||||
|
||||
逻辑说明:
|
||||
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
|
||||
- Neo4j 模式:使用结构化消息列表
|
||||
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
|
||||
2. 如果只有 user_message:创建单条用户消息 [user](用于历史记忆场景)
|
||||
3. 每条消息会被转换为独立的 Chunk,保留 speaker 字段
|
||||
"""
|
||||
if storage_type == "rag":
|
||||
await write_rag(end_user_id, message, user_rag_memory_id)
|
||||
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
|
||||
combined_message = f"user: {user_message}\nassistant: {ai_message}"
|
||||
await write_rag(end_user_id, combined_message, user_rag_memory_id)
|
||||
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
|
||||
else:
|
||||
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
|
||||
user_rag_memory_id)
|
||||
# Neo4j 模式:使用结构化消息列表
|
||||
structured_messages = []
|
||||
|
||||
# 始终添加用户消息(如果不为空)
|
||||
if user_message:
|
||||
structured_messages.append({"role": "user", "content": user_message})
|
||||
|
||||
# 只有当 AI 回复不为空时才添加 assistant 消息
|
||||
if ai_message:
|
||||
structured_messages.append({"role": "assistant", "content": ai_message})
|
||||
|
||||
# 如果没有消息,直接返回
|
||||
if not structured_messages:
|
||||
logger.warning(f"No messages to write for user {actual_end_user_id}")
|
||||
return
|
||||
|
||||
# 调用 Celery 任务,传递结构化消息列表
|
||||
# 数据流:
|
||||
# 1. structured_messages 传递给 write_message_task
|
||||
# 2. write_message_task 调用 memory_agent_service.write_memory
|
||||
# 3. write_memory 调用 write_tools.write,传递 messages 参数
|
||||
# 4. write_tools.write 调用 get_chunked_dialogs,传递 messages 参数
|
||||
# 5. get_chunked_dialogs 为每条消息创建独立的 Chunk,设置 speaker 字段
|
||||
# 6. 每个 Chunk 保存到 Neo4j,包含 speaker 字段
|
||||
logger.info(f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
|
||||
write_id = write_message_task.delay(
|
||||
actual_end_user_id, # group_id: 用户ID
|
||||
structured_messages, # message: 结构化消息列表 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
||||
actual_config_id, # config_id: 配置ID
|
||||
storage_type, # storage_type: "neo4j"
|
||||
user_rag_memory_id # user_rag_memory_id: RAG记忆ID(Neo4j模式下不使用)
|
||||
)
|
||||
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
|
||||
write_status = get_task_memory_write_result(str(write_id))
|
||||
logger.info(f'Agent:{actual_end_user_id};{write_status}')
|
||||
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
|
||||
|
||||
async def chat(
|
||||
self,
|
||||
@@ -227,29 +281,30 @@ class LangChainAgent:
|
||||
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
|
||||
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
|
||||
# # TODO 乐力齐,在长短期记忆存储的时候再使用此代码
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# db_for_memory = next(get_db())
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# print(retrieved_content)
|
||||
# # 为长期记忆操作获取新的数据库连接
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
# raise
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
db_for_memory = next(get_db())
|
||||
if memory_flag:
|
||||
if len(history_term_memory)>=4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
print(retrieved_content)
|
||||
# 为长期记忆操作获取新的数据库连接
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to LongTermMemory: {e}")
|
||||
raise
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
|
||||
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
|
||||
# # 长期记忆写入(
|
||||
# await self.write(storage_type, actual_end_user_id, history_term_memory, "", user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# # 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -277,8 +332,10 @@ class LangChainAgent:
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
if memory_flag:
|
||||
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
|
||||
await self.term_memory_save(message_chat,end_user_id,content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, actual_end_user_id, message_chat, content, user_rag_memory_id, actual_end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, content)
|
||||
response = {
|
||||
"content": content,
|
||||
"model": self.model_name,
|
||||
@@ -346,27 +403,27 @@ class LangChainAgent:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get db session: {e}")
|
||||
# # TODO 乐力齐
|
||||
# history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
# history_term_memory = history_term_memory_result[0]
|
||||
# if memory_flag:
|
||||
# if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
# history_term_memory = ';'.join(history_term_memory)
|
||||
# retrieved_content = history_term_memory_result[1]
|
||||
# db_for_memory = next(get_db())
|
||||
# try:
|
||||
# repo = LongTermMemoryRepository(db_for_memory)
|
||||
# repo.upsert(end_user_id, retrieved_content)
|
||||
# logger.info(
|
||||
# f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
# # 长期记忆写入
|
||||
# await self.write(storage_type, end_user_id, history_term_memory, "", user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# except Exception as e:
|
||||
# logger.error(f"Failed to write to long term memory: {e}")
|
||||
# finally:
|
||||
# db_for_memory.close()
|
||||
|
||||
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
|
||||
history_term_memory = history_term_memory_result[0]
|
||||
if memory_flag:
|
||||
if len(history_term_memory) >= 4 and storage_type != "rag":
|
||||
history_term_memory = ';'.join(history_term_memory)
|
||||
retrieved_content = history_term_memory_result[1]
|
||||
db_for_memory = next(get_db())
|
||||
try:
|
||||
repo = LongTermMemoryRepository(db_for_memory)
|
||||
repo.upsert(end_user_id, retrieved_content)
|
||||
logger.info(
|
||||
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
|
||||
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
|
||||
history_term_memory, actual_config_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write to long term memory: {e}")
|
||||
finally:
|
||||
db_for_memory.close()
|
||||
|
||||
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
|
||||
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
|
||||
try:
|
||||
# 准备消息列表
|
||||
messages = self._prepare_messages(message, history, context)
|
||||
@@ -418,8 +475,10 @@ class LangChainAgent:
|
||||
|
||||
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
|
||||
if memory_flag:
|
||||
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
|
||||
await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
|
||||
await self.write(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, end_user_id, actual_config_id)
|
||||
# TODO 乐力齐 - 累积多组对话批量写入功能已禁用
|
||||
# await self.term_memory_save(message_chat, end_user_id, full_content)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
|
||||
|
||||
@@ -38,6 +38,7 @@ class Settings:
|
||||
REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379"))
|
||||
REDIS_DB: int = int(os.getenv("REDIS_DB", "1"))
|
||||
REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "")
|
||||
|
||||
|
||||
# ElasticSearch configuration
|
||||
ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1")
|
||||
@@ -75,6 +76,22 @@ class Settings:
|
||||
# File Upload
|
||||
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
|
||||
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
|
||||
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
|
||||
|
||||
# Storage Configuration
|
||||
STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "local")
|
||||
|
||||
# Aliyun OSS Configuration
|
||||
OSS_ENDPOINT: str = os.getenv("OSS_ENDPOINT", "")
|
||||
OSS_ACCESS_KEY_ID: str = os.getenv("OSS_ACCESS_KEY_ID", "")
|
||||
OSS_ACCESS_KEY_SECRET: str = os.getenv("OSS_ACCESS_KEY_SECRET", "")
|
||||
OSS_BUCKET_NAME: str = os.getenv("OSS_BUCKET_NAME", "")
|
||||
|
||||
# AWS S3 Configuration
|
||||
S3_REGION: str = os.getenv("S3_REGION", "")
|
||||
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
|
||||
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
|
||||
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
|
||||
|
||||
# VOLC ASR settings
|
||||
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
|
||||
@@ -146,6 +163,7 @@ class Settings:
|
||||
# Celery configuration (internal)
|
||||
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
|
||||
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
"""
|
||||
LangGraph Graph package for memory agent.
|
||||
|
||||
This package provides the LangGraph workflow orchestrator with modular
|
||||
node implementations, routing logic, and state management.
|
||||
|
||||
Package structure:
|
||||
- read_graph: Main graph factory for read operations
|
||||
- write_graph: Main graph factory for write operations
|
||||
- nodes: LangGraph node implementations
|
||||
- routing: State routing logic
|
||||
- state: State management utilities
|
||||
"""
|
||||
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
|
||||
|
||||
__all__ = ['make_read_graph']
|
||||
@@ -4,7 +4,7 @@ LangGraph node implementations.
|
||||
This module contains custom node implementations for the LangGraph workflow.
|
||||
"""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
|
||||
__all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
|
||||
# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
|
||||
#
|
||||
# __all__ = ["ToolExecutionNode", "create_input_message"]
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
|
||||
|
||||
|
||||
def content_input_node(state: ReadState) -> ReadState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
|
||||
def content_input_write(state: WriteState) -> WriteState:
|
||||
"""开始节点 - 提取内容并保持状态信息"""
|
||||
|
||||
content = state['messages'][0].content if state.get('messages') else ''
|
||||
# 返回内容并保持所有状态信息
|
||||
return {"data": content}
|
||||
@@ -1,150 +0,0 @@
|
||||
"""
|
||||
Input node for LangGraph workflow entry point.
|
||||
|
||||
This module provides the create_input_message function which processes initial
|
||||
user input with multimodal support and creates the first tool call message.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def create_input_message(
|
||||
state: Dict[str, Any],
|
||||
tool_name: str,
|
||||
session_id: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
multimodal_processor: MultimodalProcessor,
|
||||
memory_config: MemoryConfig,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create initial tool call message from user input.
|
||||
|
||||
This function:
|
||||
1. Extracts the last message content from state
|
||||
2. Processes multimodal inputs (images/audio) using the multimodal processor
|
||||
3. Generates a unique message ID
|
||||
4. Extracts namespace from session_id
|
||||
5. Handles verified_data extraction for backward compatibility
|
||||
6. Returns AIMessage with complete tool_calls structure
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary containing messages
|
||||
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
|
||||
session_id: Session identifier (format: "call_id_{namespace}")
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
multimodal_processor: Processor for handling image/audio inputs
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
State update with AIMessage containing tool_call
|
||||
|
||||
Examples:
|
||||
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
|
||||
>>> result = await create_input_message(
|
||||
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
|
||||
... )
|
||||
>>> result["messages"][0].tool_calls[0]["name"]
|
||||
'Split_The_Problem'
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Extract last message content
|
||||
if messages:
|
||||
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
|
||||
else:
|
||||
logger.warning("[create_input_message] No messages in state, using empty string")
|
||||
last_message = ""
|
||||
|
||||
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
|
||||
|
||||
# Process multimodal input (images/audio)
|
||||
try:
|
||||
processed_content = await multimodal_processor.process_input(last_message)
|
||||
if processed_content != last_message:
|
||||
logger.info(
|
||||
f"[create_input_message] Multimodal processing converted input "
|
||||
f"from {len(last_message)} to {len(processed_content)} chars"
|
||||
)
|
||||
last_message = processed_content
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[create_input_message] Multimodal processing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with original content
|
||||
|
||||
# Generate unique message ID
|
||||
uuid_str = uuid.uuid4()
|
||||
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
# Extract namespace from session_id
|
||||
# Expected format: "call_id_{namespace}" or similar
|
||||
try:
|
||||
namespace = str(session_id).split('_id_')[1]
|
||||
except (IndexError, AttributeError):
|
||||
logger.warning(
|
||||
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
|
||||
)
|
||||
namespace = "unknown"
|
||||
|
||||
# Handle verified_data extraction (backward compatibility)
|
||||
# This regex-based extraction is kept for compatibility with existing data formats
|
||||
if 'verified_data' in str(last_message):
|
||||
try:
|
||||
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
|
||||
query_match = re.findall(r'"query": "(.*?)",', messages_last)
|
||||
if query_match:
|
||||
last_message = query_match[0]
|
||||
logger.debug(
|
||||
f"[create_input_message] Extracted query from verified_data: {last_message}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[create_input_message] Failed to extract query from verified_data: {e}"
|
||||
)
|
||||
|
||||
# Construct tool call message
|
||||
tool_call_id = f"{session_id}_{uuid_str}"
|
||||
|
||||
logger.info(
|
||||
f"[create_input_message] Creating tool call for '{tool_name}' "
|
||||
f"with ID: {tool_call_id}"
|
||||
)
|
||||
|
||||
# Build tool arguments
|
||||
tool_args = {
|
||||
"sentence": last_message,
|
||||
"sessionid": session_id,
|
||||
"messages_id": str(uuid_str),
|
||||
"search_switch": search_switch,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
|
||||
return {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": tool_name,
|
||||
"args": tool_args,
|
||||
"id": tool_call_id
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
249
api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py
Normal file
249
api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ProblemNodeService(LLMServiceMixin):
|
||||
"""问题处理节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
|
||||
# 创建全局服务实例
|
||||
problem_service = ProblemNodeService()
|
||||
|
||||
|
||||
async def Split_The_Problem(state: ReadState) -> ReadState:
|
||||
"""问题分解节点"""
|
||||
# 从状态中获取数据
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=content,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
structured = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
# 添加更详细的日志记录
|
||||
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
|
||||
|
||||
# 验证结构化响应
|
||||
if not structured or not hasattr(structured, 'root'):
|
||||
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif not structured.root:
|
||||
logger.warning("Split_The_Problem: 结构化响应的root为空")
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
else:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
|
||||
split_result_dict = []
|
||||
for index, item in enumerate(json.loads(split_result)):
|
||||
split_data = {
|
||||
"id": f"Q{index + 1}",
|
||||
"question": item['extended_question'],
|
||||
"type": item['type'],
|
||||
"reason": item['reason']
|
||||
}
|
||||
split_result_dict.append(split_data)
|
||||
|
||||
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
|
||||
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": content,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"title": "问题拆分",
|
||||
"data": split_result_dict,
|
||||
"original_query": content
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"content_length": len(content),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Split_The_Problem error details: {error_details}")
|
||||
|
||||
# 创建默认的空结果
|
||||
result = {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": content,
|
||||
"error": str(e),
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"title": "问题拆分",
|
||||
"data": [],
|
||||
"original_query": content,
|
||||
"error": error_details
|
||||
}
|
||||
}
|
||||
|
||||
# 返回更新后的状态,包含spit_context字段
|
||||
return {"spit_data": result}
|
||||
|
||||
|
||||
async def Problem_Extension(state: ReadState) -> ReadState:
|
||||
"""问题扩展节点"""
|
||||
# 获取原始数据和分解结果
|
||||
start = time.time()
|
||||
content = state.get('data', '')
|
||||
data = state.get('spit_data', '')['context']
|
||||
group_id = state.get('group_id', '')
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
databasets = {}
|
||||
try:
|
||||
data = json.loads(data)
|
||||
for i in data:
|
||||
databasets[i['extended_question']] = i['type']
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(f"Problem_Extension: 数据解析失败: {e}")
|
||||
# 使用空字典作为fallback
|
||||
databasets = {}
|
||||
data = []
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = ProblemExtensionResponse.model_json_schema()
|
||||
|
||||
system_prompt = await problem_service.template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=databasets,
|
||||
json_schema=json_schema
|
||||
)
|
||||
|
||||
try:
|
||||
# 使用优化的LLM服务
|
||||
response_content = await problem_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=ProblemExtensionResponse,
|
||||
fallback_value=[]
|
||||
)
|
||||
|
||||
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
|
||||
|
||||
# 验证结构化响应
|
||||
if not response_content or not hasattr(response_content, 'root'):
|
||||
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
|
||||
aggregated_dict = {}
|
||||
elif not response_content.root:
|
||||
logger.warning("Problem_Extension: 结构化响应的root为空")
|
||||
aggregated_dict = {}
|
||||
else:
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
try:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}")
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
except Exception as item_error:
|
||||
logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}")
|
||||
continue
|
||||
|
||||
logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# 提供更详细的错误信息
|
||||
error_details = {
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"questions_count": len(databasets),
|
||||
"llm_model_id": memory_config.llm_model_id if memory_config else None
|
||||
}
|
||||
|
||||
logger.error(f"Problem_Extension error details: {error_details}")
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
print(time.time() - start)
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"title": "问题扩展",
|
||||
"data": aggregated_dict,
|
||||
"original_query": content,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return {"problem_extension": result}
|
||||
@@ -0,0 +1,417 @@
|
||||
# ===== 标准库 =====
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
|
||||
# ===== 第三方库 =====
|
||||
from langchain.agents import create_agent
|
||||
from langchain_openai import ChatOpenAI
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db, get_db_context
|
||||
|
||||
from app.schemas import model_schema
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
COUNTState,
|
||||
ReadState,
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.tools.tool import (
|
||||
create_hybrid_retrieval_tool_sync,
|
||||
create_time_retrieval_tool,
|
||||
extract_tool_message_content,
|
||||
)
|
||||
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
db = next(get_db())
|
||||
|
||||
|
||||
|
||||
async def rag_config(state):
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
return kb_config
|
||||
async def rag_knowledge(state,question):
|
||||
kb_config = await rag_config(state)
|
||||
group_id = state.get('group_id', '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = question
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except Exception :
|
||||
retrieval_knowledge=[]
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
return retrieval_knowledge,clean_content,cleaned_query,raw_results
|
||||
|
||||
|
||||
async def llm_infomation(state: ReadState) -> ReadState:
|
||||
memory_config = state.get('memory_config', None)
|
||||
model_id = memory_config.llm_model_id
|
||||
tenant_id = memory_config.tenant_id
|
||||
|
||||
# 使用现有的 memory_config 而不是重新查询数据库
|
||||
# 或者使用线程安全的数据库访问
|
||||
with get_db_context() as db:
|
||||
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
|
||||
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
|
||||
return result_pydantic
|
||||
|
||||
|
||||
async def clean_databases(data) -> str:
|
||||
"""
|
||||
简化的数据库搜索结果清理函数
|
||||
|
||||
Args:
|
||||
data: 搜索结果数据
|
||||
|
||||
Returns:
|
||||
清理后的内容字符串
|
||||
"""
|
||||
try:
|
||||
# 解析JSON字符串
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
data = json.loads(data)
|
||||
except json.JSONDecodeError:
|
||||
return data
|
||||
|
||||
if not isinstance(data, dict):
|
||||
return str(data)
|
||||
|
||||
# 获取结果数据
|
||||
# with open("搜索结果.json","w",encoding='utf-8') as f:
|
||||
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
|
||||
results = data.get('results', data)
|
||||
if not isinstance(results, dict):
|
||||
return str(results)
|
||||
|
||||
# 收集所有内容
|
||||
content_list = []
|
||||
|
||||
# 处理重排序结果
|
||||
reranked = results.get('reranked_results', {})
|
||||
if reranked:
|
||||
for category in ['summaries', 'statements', 'chunks', 'entities']:
|
||||
items = reranked.get(category, [])
|
||||
if isinstance(items, list):
|
||||
content_list.extend(items)
|
||||
# 处理时间搜索结果
|
||||
time_search = results.get('time_search', {})
|
||||
if time_search:
|
||||
if isinstance(time_search, dict):
|
||||
statements = time_search.get('statements', time_search.get('time_search', []))
|
||||
if isinstance(statements, list):
|
||||
content_list.extend(statements)
|
||||
elif isinstance(time_search, list):
|
||||
content_list.extend(time_search)
|
||||
|
||||
# 提取文本内容
|
||||
text_parts = []
|
||||
for item in content_list:
|
||||
if isinstance(item, dict):
|
||||
text = item.get('statement') or item.get('content', '')
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif isinstance(item, str):
|
||||
text_parts.append(item)
|
||||
|
||||
|
||||
return '\n'.join(text_parts).strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"clean_databases failed: {e}", exc_info=True)
|
||||
return str(data)
|
||||
|
||||
|
||||
async def retrieve_nodes(state: ReadState) -> ReadState:
|
||||
|
||||
'''
|
||||
|
||||
模型信息
|
||||
'''
|
||||
|
||||
problem_extension=state.get('problem_extension', '')['context']
|
||||
storage_type=state.get('storage_type', '')
|
||||
user_rag_memory_id=state.get('user_rag_memory_id', '')
|
||||
group_id=state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original=state.get('data', '')
|
||||
problem_list=[]
|
||||
for key,values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
# 创建异步任务处理单个问题
|
||||
async def process_question_nodes(idx, question):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
return {
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty result for this question
|
||||
return {
|
||||
"Query_small": question,
|
||||
"Result_small": "",
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": question,
|
||||
"raw_results": [],
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j!=['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve':dup_databases}
|
||||
|
||||
|
||||
|
||||
|
||||
async def retrieve(state: ReadState) -> ReadState:
|
||||
# 从state中获取group_id
|
||||
import time
|
||||
start=time.time()
|
||||
problem_extension = state.get('problem_extension', '')['context']
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
original = state.get('data', '')
|
||||
problem_list = []
|
||||
for key, values in problem_extension.items():
|
||||
for data in values:
|
||||
problem_list.append(data)
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
databases_anser = []
|
||||
|
||||
async def get_llm_info():
|
||||
with get_db_context() as db: # 使用同步数据库上下文管理器
|
||||
config_service = MemoryConfigService(db)
|
||||
return await llm_infomation(state)
|
||||
llm_config = await get_llm_info()
|
||||
api_key_obj = llm_config.api_keys[0]
|
||||
api_key = api_key_obj.api_key
|
||||
api_base = api_key_obj.api_base
|
||||
model_name = api_key_obj.model_name
|
||||
llm = ChatOpenAI(
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
temperature=0.2,
|
||||
)
|
||||
|
||||
time_retrieval_tool = create_time_retrieval_tool(group_id)
|
||||
search_params = { "group_id": group_id, "return_raw_results": True }
|
||||
hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params)
|
||||
agent = create_agent(
|
||||
llm,
|
||||
tools=[time_retrieval_tool,hybrid_retrieval],
|
||||
system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的group_id是: {group_id}"
|
||||
)
|
||||
|
||||
# 创建异步任务处理单个问题
|
||||
import asyncio
|
||||
|
||||
# 在模块级别定义信号量,限制最大并发数
|
||||
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
|
||||
|
||||
async def process_question(idx, question):
|
||||
async with SEMAPHORE: # 限制并发
|
||||
try:
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
|
||||
else:
|
||||
cleaned_query = question
|
||||
# 使用 asyncio 在线程池中运行同步的 agent.invoke
|
||||
import asyncio
|
||||
response = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: agent.invoke({"messages": question})
|
||||
)
|
||||
tool_results = extract_tool_message_content(response)
|
||||
if tool_results == None:
|
||||
raw_results = []
|
||||
clean_content = ''
|
||||
else:
|
||||
raw_results = tool_results['content']
|
||||
clean_content = await clean_databases(raw_results)
|
||||
|
||||
try:
|
||||
raw_results = raw_results['results']
|
||||
except Exception:
|
||||
raw_results = []
|
||||
|
||||
return {
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty result for this question
|
||||
return {
|
||||
"Query_small": question,
|
||||
"Result_small": "",
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": question,
|
||||
"raw_results": [],
|
||||
"index": idx + 1,
|
||||
"total": len(problem_list)
|
||||
}
|
||||
}
|
||||
|
||||
# 并发处理所有问题
|
||||
import asyncio
|
||||
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
|
||||
databases_anser = await asyncio.gather(*tasks)
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
if j != ['']:
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
# with open('retrieve_text.json', 'w') as f:
|
||||
# json.dump(dup_databases, f, indent=4)
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
return {'retrieve': dup_databases}
|
||||
|
||||
|
||||
304
api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py
Normal file
304
api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py
Normal file
@@ -0,0 +1,304 @@
|
||||
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
)
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
from app.core.memory.agent.services.search_service import SearchService
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.db import get_db
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
logger = get_agent_logger(__name__)
|
||||
db_session = next(get_db())
|
||||
|
||||
class SummaryNodeService(LLMServiceMixin):
|
||||
"""总结节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
summary_service = SummaryNodeService()
|
||||
|
||||
async def summary_history(state: ReadState) -> ReadState:
|
||||
group_id = state.get("group_id", '')
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
return history
|
||||
|
||||
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str:
|
||||
"""
|
||||
增强的summary_llm函数,包含更好的错误处理和数据验证
|
||||
"""
|
||||
data = state.get("data", '')
|
||||
|
||||
# 构建系统提示词
|
||||
if str(search_mode) == "0":
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
operation_name=operation_name,
|
||||
data=retrieve_info,
|
||||
query=data
|
||||
)
|
||||
else:
|
||||
system_prompt = await summary_service.template_service.render_template(
|
||||
template_name=template_name,
|
||||
operation_name=operation_name,
|
||||
query=data,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info
|
||||
)
|
||||
try:
|
||||
# 使用优化的LLM服务进行结构化输出
|
||||
structured = await summary_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
fallback_value=None
|
||||
)
|
||||
# 验证结构化响应
|
||||
if structured is None:
|
||||
logger.warning(f"LLM返回None,使用默认回答")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
# 根据操作类型提取答案
|
||||
if operation_name == "summary":
|
||||
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
# 处理RetrieveSummaryResponse
|
||||
if hasattr(structured, 'data') and structured.data:
|
||||
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
|
||||
else:
|
||||
logger.warning(f"结构化响应缺少data字段")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
# 验证答案不为空
|
||||
if not aimessages or aimessages.strip() == "":
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
return aimessages
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}", exc_info=True)
|
||||
|
||||
# 尝试非结构化输出作为fallback
|
||||
try:
|
||||
logger.info("尝试非结构化输出作为fallback")
|
||||
response = await summary_service.call_llm_simple(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
fallback_message="信息不足,无法回答"
|
||||
)
|
||||
|
||||
if response and response.strip():
|
||||
# 简单清理响应
|
||||
cleaned_response = response.strip()
|
||||
# 移除可能的JSON标记
|
||||
if cleaned_response.startswith('```'):
|
||||
lines = cleaned_response.split('\n')
|
||||
cleaned_response = '\n'.join(lines[1:-1])
|
||||
|
||||
return cleaned_response
|
||||
else:
|
||||
return "信息不足,无法回答"
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback也失败: {fallback_error}")
|
||||
return "信息不足,无法回答"
|
||||
|
||||
async def summary_redis_save(state: ReadState,aimessages) -> ReadState:
|
||||
data = state.get("data", '')
|
||||
group_id = state.get("group_id", '')
|
||||
await SessionService(store).save_session(
|
||||
user_id=group_id,
|
||||
query=data,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
await SessionService(store).cleanup_duplicates()
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState:
|
||||
storage_type=state.get("storage_type",'')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
input_summary = {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
retrieve={
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"title":"快速检索",
|
||||
"summary": aimessages,
|
||||
"query": data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return input_summary,retrieve
|
||||
|
||||
async def Input_Summary(state: ReadState) -> ReadState:
|
||||
start=time.time()
|
||||
storage_type=state.get("storage_type",'')
|
||||
memory_config = state.get('memory_config', None)
|
||||
user_rag_memory_id=state.get("user_rag_memory_id",'')
|
||||
data=state.get("data", '')
|
||||
group_id=state.get("group_id", '')
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
history = await summary_history( state)
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": data,
|
||||
"return_raw_results": True,
|
||||
"include": ["summaries"] # Only search summary nodes for faster performance
|
||||
}
|
||||
|
||||
try:
|
||||
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config)
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True )
|
||||
retrieve_info, question, raw_results = "", data, []
|
||||
|
||||
|
||||
try:
|
||||
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
|
||||
# 'input_summary',RetrieveSummaryResponse)
|
||||
# logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
|
||||
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
|
||||
summary = summary_result[0]
|
||||
except Exception as e:
|
||||
logger.error( f"Input_Summary failed: {e}", exc_info=True )
|
||||
summary= {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('检索', duration)
|
||||
return {"summary":summary}
|
||||
|
||||
async def Retrieve_Summary(state: ReadState)-> ReadState:
|
||||
retrieve=state.get("retrieve", '')
|
||||
history = await summary_history( state)
|
||||
import json
|
||||
with open("检索.json","w",encoding='utf-8') as f:
|
||||
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
|
||||
retrieve=retrieve.get("Expansion_issue", [])
|
||||
start=time.time()
|
||||
retrieve_info_str=[]
|
||||
for data in retrieve:
|
||||
if data=='':
|
||||
retrieve_info_str=''
|
||||
else:
|
||||
for key, value in data.items():
|
||||
if key=='Answer_Small':
|
||||
for i in value:
|
||||
retrieve_info_str.append(i)
|
||||
retrieve_info_str=list(set(retrieve_info_str))
|
||||
retrieve_info_str='\n'.join(retrieve_info_str)
|
||||
|
||||
aimessages=await summary_llm(state,history,retrieve_info_str,
|
||||
'Retrieve_Summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1")
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
|
||||
|
||||
async def Summary(state: ReadState)-> ReadState:
|
||||
start=time.time()
|
||||
query = state.get("data", '')
|
||||
verify=state.get("verify", '')
|
||||
verify_expansion_issue=verify.get("verified_data", '')
|
||||
retrieve_info_str=''
|
||||
for data in verify_expansion_issue:
|
||||
for key, value in data.items():
|
||||
if key=='answer_small':
|
||||
for i in value:
|
||||
retrieve_info_str+=i+'\n'
|
||||
history=await summary_history(state)
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": retrieve_info_str
|
||||
}
|
||||
aimessages=await summary_llm(state,history,data,
|
||||
'summary_prompt.jinja2','summary',SummaryResponse,0)
|
||||
|
||||
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
|
||||
await summary_redis_save(state, aimessages)
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
try:
|
||||
duration = time.time() - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# 修复协程调用 - 先await,然后访问返回值
|
||||
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
|
||||
summary = summary_result[1]
|
||||
return {"summary":summary}
|
||||
|
||||
async def Summary_fails(state: ReadState)-> ReadState:
|
||||
storage_type=state.get("storage_type", '')
|
||||
user_rag_memory_id=state.get("user_rag_memory_id", '')
|
||||
result= {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
return {"summary":result}
|
||||
@@ -1,234 +0,0 @@
|
||||
"""
|
||||
Tool execution node for LangGraph workflow.
|
||||
|
||||
This module provides the ToolExecutionNode class which wraps tool execution
|
||||
with parameter transformation logic using the ParameterBuilder service.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_content_payload,
|
||||
extract_tool_call_id,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolExecutionNode:
|
||||
"""
|
||||
Custom LangGraph node that wraps tool execution with parameter transformation.
|
||||
|
||||
This node extracts content from previous tool results, transforms parameters
|
||||
based on tool type using ParameterBuilder, and invokes the tool with the
|
||||
correct argument structure.
|
||||
|
||||
Attributes:
|
||||
tool_node: LangGraph ToolNode wrapping the actual tool
|
||||
id: Node identifier for message IDs
|
||||
tool_name: Name of the tool being executed
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool: Callable,
|
||||
node_id: str,
|
||||
namespace: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
parameter_builder: ParameterBuilder,
|
||||
storage_type: str,
|
||||
user_rag_memory_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
):
|
||||
"""
|
||||
Initialize the tool execution node.
|
||||
|
||||
Args:
|
||||
tool: The tool function to execute
|
||||
node_id: Identifier for this node (used in message IDs)
|
||||
namespace: Namespace for session management
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
parameter_builder: Service for building tool-specific arguments
|
||||
storage_type: Storage type for the workspace
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
"""
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = node_id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.parameter_builder = parameter_builder
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
self.memory_config = memory_config
|
||||
|
||||
logger.info(
|
||||
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
|
||||
)
|
||||
|
||||
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute the tool with transformed parameters.
|
||||
|
||||
This method:
|
||||
1. Extracts the last message from state
|
||||
2. Extracts tool call ID using state extractors
|
||||
3. Extracts content payload using state extractors
|
||||
4. Builds tool arguments using parameter builder
|
||||
5. Constructs AIMessage with tool_calls
|
||||
6. Invokes the tool and returns the result
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Updated state with tool result in messages
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
logger.debug( self.tool_name)
|
||||
|
||||
if not messages:
|
||||
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
|
||||
return {"messages": [AIMessage(content="Error: No messages in state")]}
|
||||
|
||||
last_message = messages[-1]
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Extract tool call ID using state extractors
|
||||
tool_call_id = extract_tool_call_id(last_message)
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
|
||||
|
||||
try:
|
||||
# Extract content payload using state extractors
|
||||
content = extract_content_payload(last_message)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
|
||||
)
|
||||
# Log raw message content for debugging
|
||||
if hasattr(last_message, 'content'):
|
||||
raw = last_message.content
|
||||
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
content = {}
|
||||
|
||||
try:
|
||||
# Build tool arguments using parameter builder
|
||||
tool_args = self.parameter_builder.build_tool_args(
|
||||
tool_name=self.tool_name,
|
||||
content=content,
|
||||
tool_call_id=tool_call_id,
|
||||
search_switch=self.search_switch,
|
||||
apply_id=self.apply_id,
|
||||
group_id=self.group_id,
|
||||
memory_config=self.memory_config,
|
||||
storage_type=self.storage_type,
|
||||
user_rag_memory_id=self.user_rag_memory_id,
|
||||
)
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
|
||||
|
||||
# Construct tool input message
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": f"{self.id}_{tool_call_id}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
|
||||
try:
|
||||
# Invoke the tool
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
|
||||
logger.debug(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution completed"
|
||||
)
|
||||
|
||||
# Check for error in tool response
|
||||
error_entry = None
|
||||
if result and "messages" in result:
|
||||
for msg in result["messages"]:
|
||||
if hasattr(msg, 'content'):
|
||||
try:
|
||||
import json
|
||||
content = msg.content
|
||||
if isinstance(content, str):
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, dict) and "error" in parsed:
|
||||
error_msg = parsed["error"]
|
||||
logger.warning(
|
||||
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
|
||||
)
|
||||
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# Return result with error tracking if error was found
|
||||
if error_entry:
|
||||
result["errors"] = [error_entry]
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Track error in state and return error message
|
||||
from langchain_core.messages import ToolMessage
|
||||
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
|
||||
return {
|
||||
"messages": [
|
||||
ToolMessage(
|
||||
content=f"Error executing tool: {str(e)}",
|
||||
tool_call_id=f"{self.id}_{tool_call_id}"
|
||||
)
|
||||
],
|
||||
"errors": [error_entry]
|
||||
}
|
||||
@@ -0,0 +1,155 @@
|
||||
import os
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.db import get_db
|
||||
|
||||
from app.core.memory.agent.models.verification_models import VerificationResult
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
PROJECT_ROOT_,
|
||||
ReadState,
|
||||
)
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
from app.core.memory.agent.utils.session_tools import SessionService
|
||||
from app.core.memory.agent.utils.template_tools import TemplateService
|
||||
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
|
||||
|
||||
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt')
|
||||
db_session = next(get_db())
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class VerificationNodeService(LLMServiceMixin):
|
||||
"""验证节点服务类"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.template_service = TemplateService(template_root)
|
||||
|
||||
# 创建全局服务实例
|
||||
verification_service = VerificationNodeService()
|
||||
|
||||
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
|
||||
"""处理验证结果并生成输出格式"""
|
||||
storage_type = state.get('storage_type', '')
|
||||
user_rag_memory_id = state.get('user_rag_memory_id', '')
|
||||
data = state.get('data', '')
|
||||
|
||||
# 将 VerificationItem 对象转换为字典列表
|
||||
verified_data = []
|
||||
if messages_deal.expansion_issue:
|
||||
for item in messages_deal.expansion_issue:
|
||||
if hasattr(item, 'model_dump'):
|
||||
verified_data.append(item.model_dump())
|
||||
elif isinstance(item, dict):
|
||||
verified_data.append(item)
|
||||
|
||||
Verify_result = {
|
||||
"status": messages_deal.split_result,
|
||||
"verified_data": verified_data,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.split_result,
|
||||
"reason": messages_deal.reason or "验证完成",
|
||||
"query": messages_deal.query,
|
||||
"verified_count": len(verified_data),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
return Verify_result
|
||||
async def Verify(state: ReadState):
|
||||
logger.info("=== Verify 节点开始执行 ===")
|
||||
try:
|
||||
content = state.get('data', '')
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', None)
|
||||
|
||||
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., group_id={group_id}")
|
||||
|
||||
history = await SessionService(store).get_history(group_id, group_id, group_id)
|
||||
logger.info(f"Verify: 获取历史记录完成,history length={len(history)}")
|
||||
|
||||
retrieve = state.get("retrieve", {})
|
||||
logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
|
||||
|
||||
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
|
||||
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
|
||||
|
||||
messages = {
|
||||
"Query": content,
|
||||
"Expansion_issue": retrieve_expansion
|
||||
}
|
||||
|
||||
logger.info("Verify: 开始渲染模板")
|
||||
|
||||
# 生成 JSON schema 以指导 LLM 输出正确格式
|
||||
json_schema = VerificationResult.model_json_schema()
|
||||
|
||||
system_prompt = await verification_service.template_service.render_template(
|
||||
template_name='split_verify_prompt.jinja2',
|
||||
operation_name='split_verify_prompt',
|
||||
history=history,
|
||||
sentence=messages,
|
||||
json_schema=json_schema
|
||||
)
|
||||
logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}")
|
||||
|
||||
# 使用优化的LLM服务,添加超时保护
|
||||
logger.info("Verify: 开始调用 LLM")
|
||||
try:
|
||||
# 添加 asyncio.wait_for 超时包裹,防止无限等待
|
||||
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
|
||||
import asyncio
|
||||
structured = await asyncio.wait_for(
|
||||
verification_service.call_llm_structured(
|
||||
state=state,
|
||||
db_session=db_session,
|
||||
system_prompt=system_prompt,
|
||||
response_model=VerificationResult,
|
||||
fallback_value={
|
||||
"query": content,
|
||||
"history": history if isinstance(history, list) else [],
|
||||
"expansion_issue": [],
|
||||
"split_result": "failed",
|
||||
"reason": "验证失败或超时"
|
||||
}
|
||||
),
|
||||
timeout=150.0 # 150秒超时
|
||||
)
|
||||
logger.info(f"Verify: LLM 调用完成,result={structured}")
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值")
|
||||
structured = VerificationResult(
|
||||
query=content,
|
||||
history=history if isinstance(history, list) else [],
|
||||
expansion_issue=[],
|
||||
split_result="failed",
|
||||
reason="LLM调用超时"
|
||||
)
|
||||
|
||||
result = await Verify_prompt(state, structured)
|
||||
logger.info("=== Verify 节点执行完成 ===")
|
||||
return {"verify": result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
|
||||
# 返回失败的验证结果
|
||||
return {
|
||||
"verify": {
|
||||
"status": "failed",
|
||||
"verified_data": [],
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": "failed",
|
||||
"reason": f"验证过程出错: {str(e)}",
|
||||
"query": state.get('data', ''),
|
||||
"verified_count": 0,
|
||||
"storage_type": state.get('storage_type', ''),
|
||||
"user_rag_memory_id": state.get('user_rag_memory_id', '')
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,56 @@
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
async def write_node(state: WriteState) -> WriteState:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
state: WriteState containing messages, group_id, and memory_config
|
||||
|
||||
Returns:
|
||||
dict: Contains 'write_result' with status and data fields
|
||||
"""
|
||||
messages = state.get('messages', [])
|
||||
group_id = state.get('group_id', '')
|
||||
memory_config = state.get('memory_config', '')
|
||||
|
||||
# Convert LangChain messages to structured format expected by write()
|
||||
structured_messages = []
|
||||
for msg in messages:
|
||||
if hasattr(msg, 'type') and hasattr(msg, 'content'):
|
||||
# Map LangChain message types to role names
|
||||
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
|
||||
structured_messages.append({
|
||||
"role": role,
|
||||
"content": msg.content # content is now guaranteed to be a string
|
||||
})
|
||||
|
||||
try:
|
||||
result = await write(
|
||||
messages=structured_messages,
|
||||
user_id=group_id,
|
||||
apply_id=group_id,
|
||||
group_id=group_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
write_result = {
|
||||
"status": "success",
|
||||
"data": structured_messages,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
return {"write_result": write_result}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
write_result = {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
return {"write_result": write_result}
|
||||
@@ -1,469 +1,177 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import warnings
|
||||
#!/usr/bin/env python3
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Literal
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.langgraph_graph.nodes import (
|
||||
ToolExecutionNode,
|
||||
create_input_message,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
|
||||
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.constants import END, START
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import START, END
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
load_dotenv()
|
||||
redishost=os.getenv("REDISHOST")
|
||||
redisport=os.getenv('REDISPORT')
|
||||
redisdb=os.getenv('REDISDB')
|
||||
redispassword=os.getenv('REDISPASSWORD')
|
||||
counter = COUNTState(limit=3)
|
||||
|
||||
# Update loop count in workflow
|
||||
async def update_loop_count(state):
|
||||
"""Update loop counter"""
|
||||
current_count = state.get("loop_count", 0)
|
||||
return {"loop_count": current_count + 1}
|
||||
|
||||
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
messages = state["messages"]
|
||||
from app.db import get_db
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
# Add boundary check
|
||||
if not messages:
|
||||
return END
|
||||
counter.add(1) # Increment by 1
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
|
||||
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
|
||||
Split_The_Problem,
|
||||
Problem_Extension,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
|
||||
retrieve,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
|
||||
Input_Summary,
|
||||
Retrieve_Summary,
|
||||
Summary_fails,
|
||||
Summary,
|
||||
)
|
||||
from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Split_continue,
|
||||
Retrieve_continue,
|
||||
Verify_continue,
|
||||
)
|
||||
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[should_continue] Current loop count: {loop_count}")
|
||||
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"Status tools: {status_tools}")
|
||||
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Maximum loop count is 3
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Add default return value to avoid returning None
|
||||
counter.reset()
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Add default return value to avoid returning None
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
|
||||
|
||||
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
|
||||
# Direct dictionary access instead of regex parsing
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
# Handle case where search_switch might be in messages
|
||||
if search_switch is None and "messages" in state:
|
||||
messages = state.get("messages", [])
|
||||
if messages:
|
||||
last_message = messages[-1]
|
||||
# Try to extract from tool_calls args
|
||||
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
|
||||
for tool_call in last_message.tool_calls:
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
break
|
||||
|
||||
# Convert to string for comparison if needed
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # Default case
|
||||
|
||||
|
||||
class ProblemExtensionNode:
|
||||
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
|
||||
self.tool_node = ToolNode([tool])
|
||||
self.id = id
|
||||
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
|
||||
self.namespace = namespace
|
||||
self.search_switch = search_switch
|
||||
self.apply_id = apply_id
|
||||
self.group_id = group_id
|
||||
self.storage_type = storage_type
|
||||
self.user_rag_memory_id = user_rag_memory_id
|
||||
|
||||
async def __call__(self, state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1] if messages else ""
|
||||
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
|
||||
if self.tool_name == 'Input_Summary':
|
||||
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
|
||||
else:
|
||||
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
|
||||
|
||||
# Try to extract actual content payload from previous tool result
|
||||
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
|
||||
extracted_payload = None
|
||||
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
|
||||
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
|
||||
if m:
|
||||
extracted_payload = m.group(1)
|
||||
else:
|
||||
# Fallback: use raw string directly
|
||||
extracted_payload = raw_msg
|
||||
|
||||
# Try to parse content as JSON first
|
||||
try:
|
||||
content = json.loads(extracted_payload)
|
||||
except Exception:
|
||||
# Try to extract JSON fragment from text and parse
|
||||
parsed = None
|
||||
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
|
||||
for cand in candidates:
|
||||
try:
|
||||
parsed = json.loads(cand)
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
# If still fails, use raw string as content
|
||||
content = parsed if parsed is not None else extracted_payload
|
||||
|
||||
# Build correct parameters based on tool name
|
||||
tool_args = {}
|
||||
|
||||
if self.tool_name == "Verify":
|
||||
# Verify tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Retrieve":
|
||||
# Retrieve tool requires context and usermessages parameters
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary":
|
||||
# Summary tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == "Summary_fails":
|
||||
# Summary_fails tool requires string type context parameter
|
||||
if isinstance(content, dict):
|
||||
# Convert dict to JSON string
|
||||
tool_args["context"] = json.dumps(content, ensure_ascii=False)
|
||||
else:
|
||||
tool_args["context"] = str(content)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
elif self.tool_name == 'Input_Summary':
|
||||
tool_args["context"] = str(last_message)
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["search_switch"] = str(self.search_switch)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
tool_args["storage_type"] = getattr(self, 'storage_type', "")
|
||||
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
|
||||
elif self.tool_name == 'Retrieve_Summary':
|
||||
# Retrieve_Summary expects dict directly, not JSON string
|
||||
# content might be a JSON string, try to parse it
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
# Check if it has a "context" key
|
||||
if isinstance(parsed_content, dict) and "context" in parsed_content:
|
||||
tool_args["context"] = parsed_content["context"]
|
||||
else:
|
||||
tool_args["context"] = parsed_content
|
||||
except json.JSONDecodeError:
|
||||
# If parsing fails, wrap the string
|
||||
tool_args["context"] = {"content": content}
|
||||
elif isinstance(content, dict):
|
||||
# Check if content has a "context" key that needs unwrapping
|
||||
if "context" in content:
|
||||
tool_args["context"] = content["context"]
|
||||
else:
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": str(content)}
|
||||
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
else:
|
||||
# Other tools use context parameter
|
||||
if isinstance(content, dict):
|
||||
tool_args["context"] = content
|
||||
else:
|
||||
tool_args["context"] = {"content": content}
|
||||
tool_args["usermessages"] = str(tool_call)
|
||||
tool_args["apply_id"] = str(self.apply_id)
|
||||
tool_args["group_id"] = str(self.group_id)
|
||||
|
||||
|
||||
tool_input = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[{
|
||||
"name": self.tool_name,
|
||||
"args": tool_args,
|
||||
"id": self.id + f"{tool_call}",
|
||||
}]
|
||||
)
|
||||
]
|
||||
}
|
||||
result = await self.tool_node.ainvoke(tool_input)
|
||||
result_text = str(result)
|
||||
|
||||
return {"messages": [AIMessage(content=result_text)]}
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
|
||||
"""
|
||||
Create a read graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
namespace: Namespace identifier
|
||||
tools: MCP tools loaded from session
|
||||
search_switch: Search mode switch ("0", "1", or "2")
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type (optional)
|
||||
user_rag_memory_id: User RAG memory ID (optional)
|
||||
"""
|
||||
memory = InMemorySaver()
|
||||
tool = [i.name for i in tools]
|
||||
logger.info(f"Initializing read graph with tools: {tool}")
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
# Extract tool functions
|
||||
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
|
||||
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
|
||||
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
|
||||
Verify_ = next((t for t in tools if t.name == "Verify"), None)
|
||||
Summary_ = next((t for t in tools if t.name == "Summary"), None)
|
||||
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
|
||||
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
|
||||
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
|
||||
|
||||
# Instantiate services
|
||||
parameter_builder = ParameterBuilder()
|
||||
multimodal_processor = MultimodalProcessor()
|
||||
|
||||
# Create nodes using new modular components
|
||||
Split_The_Problem_node = ToolNode([Split_The_Problem_])
|
||||
|
||||
Problem_Extension_node = ToolExecutionNode(
|
||||
tool=Problem_Extension_,
|
||||
node_id="Problem_Extension_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
async def make_read_graph():
|
||||
"""创建并返回 LangGraph 工作流"""
|
||||
try:
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension)
|
||||
workflow.add_node("Input_Summary", Input_Summary)
|
||||
# workflow.add_node("Retrieve", retrieve_nodes)
|
||||
workflow.add_node("Retrieve", retrieve)
|
||||
workflow.add_node("Verify", Verify)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
|
||||
workflow.add_node("Summary", Summary)
|
||||
workflow.add_node("Summary_fails", Summary_fails)
|
||||
|
||||
# 添加边
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
|
||||
'''-----'''
|
||||
# workflow.add_edge("Retrieve", END)
|
||||
|
||||
# 编译工作流
|
||||
graph = workflow.compile()
|
||||
yield graph
|
||||
|
||||
except Exception as e:
|
||||
print(f"创建工作流失败: {e}")
|
||||
raise
|
||||
finally:
|
||||
print("工作流创建完成")
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "昨天有什么好看的电影"
|
||||
group_id = '88a459f5_text09' # 组ID
|
||||
storage_type = 'neo4j' # 存储类型
|
||||
search_switch = '1' # 搜索开关
|
||||
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
import time
|
||||
start=time.time()
|
||||
try:
|
||||
async with make_read_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"group_id":group_id
|
||||
,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config}
|
||||
# 获取节点更新信息
|
||||
_intermediate_outputs = []
|
||||
summary = ''
|
||||
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
print(f"处理节点: {node_name}")
|
||||
|
||||
# 处理不同Summary节点的返回结构
|
||||
if 'Summary' in node_name:
|
||||
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
|
||||
summary = node_data['InputSummary']['summary_result']
|
||||
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
|
||||
summary = node_data['RetrieveSummary']['summary_result']
|
||||
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
|
||||
summary = node_data['summary']['summary_result']
|
||||
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
|
||||
summary = node_data['SummaryFails']['summary_result']
|
||||
|
||||
Retrieve_node = ToolExecutionNode(
|
||||
tool=Retrieve_,
|
||||
node_id="Retrieve_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
|
||||
if spit_data and spit_data != [] and spit_data != {}:
|
||||
_intermediate_outputs.append(spit_data)
|
||||
|
||||
# Problem_Extension 节点
|
||||
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
|
||||
if problem_extension and problem_extension != [] and problem_extension != {}:
|
||||
_intermediate_outputs.append(problem_extension)
|
||||
|
||||
# Retrieve 节点
|
||||
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
|
||||
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
|
||||
_intermediate_outputs.extend(retrieve_node)
|
||||
|
||||
# Verify 节点
|
||||
verify_n = node_data.get('verify', {}).get('_intermediate', None)
|
||||
if verify_n and verify_n != [] and verify_n != {}:
|
||||
_intermediate_outputs.append(verify_n)
|
||||
|
||||
Verify_node = ToolExecutionNode(
|
||||
tool=Verify_,
|
||||
node_id="Verify_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
Summary_node = ToolExecutionNode(
|
||||
tool=Summary_,
|
||||
node_id="Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
# Summary 节点
|
||||
summary_n = node_data.get('summary', {}).get('_intermediate', None)
|
||||
if summary_n and summary_n != [] and summary_n != {}:
|
||||
_intermediate_outputs.append(summary_n)
|
||||
|
||||
Summary_fails_node = ToolExecutionNode(
|
||||
tool=Summary_fails_,
|
||||
node_id="Summary_fails_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
# # 过滤掉空值
|
||||
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
|
||||
#
|
||||
# # 优化搜索结果
|
||||
# print("=== 开始优化搜索结果 ===")
|
||||
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
|
||||
# result=reorder_output_results(optimized_outputs)
|
||||
# # 保存优化后的结果到文件
|
||||
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
|
||||
# import json
|
||||
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
#
|
||||
print(f"=== 最终摘要 ===")
|
||||
print(summary)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
Retrieve_Summary_node = ToolExecutionNode(
|
||||
tool=Retrieve_Summary_,
|
||||
node_id="Retrieve_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
end=time.time()
|
||||
print(100*'y')
|
||||
print(f"总耗时: {end-start}s")
|
||||
print(100*'y')
|
||||
|
||||
Input_Summary_node = ToolExecutionNode(
|
||||
tool=Input_Summary_,
|
||||
node_id="Input_Summary_id",
|
||||
namespace=namespace,
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
parameter_builder=parameter_builder,
|
||||
storage_type=storage_type,
|
||||
user_rag_memory_id=user_rag_memory_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
async def content_input_node(state):
|
||||
state_search_switch = state.get("search_switch", search_switch)
|
||||
|
||||
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
|
||||
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
|
||||
|
||||
return await create_input_message(
|
||||
state=state,
|
||||
tool_name=tool_name,
|
||||
session_id=f"{session_prefix}_{namespace}",
|
||||
search_switch=search_switch,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
multimodal_processor=multimodal_processor,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
|
||||
|
||||
# Build workflow graph
|
||||
workflow = StateGraph(ReadState)
|
||||
workflow.add_node("content_input", content_input_node)
|
||||
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
|
||||
workflow.add_node("Problem_Extension", Problem_Extension_node)
|
||||
workflow.add_node("Retrieve", Retrieve_node)
|
||||
workflow.add_node("Verify", Verify_node)
|
||||
workflow.add_node("Summary", Summary_node)
|
||||
workflow.add_node("Summary_fails", Summary_fails_node)
|
||||
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
|
||||
workflow.add_node("Input_Summary", Input_Summary_node)
|
||||
|
||||
# Add edges using imported routers
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_conditional_edges("content_input", Split_continue)
|
||||
workflow.add_edge("Input_Summary", END)
|
||||
workflow.add_edge("Split_The_Problem", "Problem_Extension")
|
||||
workflow.add_edge("Problem_Extension", "Retrieve")
|
||||
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
|
||||
workflow.add_edge("Retrieve_Summary", END)
|
||||
workflow.add_conditional_edges("Verify", Verify_continue)
|
||||
workflow.add_edge("Summary_fails", END)
|
||||
workflow.add_edge("Summary", END)
|
||||
|
||||
graph = workflow.compile(checkpointer=memory)
|
||||
yield graph
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""LangGraph routing logic."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.routing.routers import (
|
||||
Verify_continue,
|
||||
Retrieve_continue,
|
||||
Split_continue,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Verify_continue",
|
||||
"Retrieve_continue",
|
||||
"Split_continue",
|
||||
]
|
||||
@@ -1,123 +1,61 @@
|
||||
"""
|
||||
Routing functions for LangGraph conditional edges.
|
||||
|
||||
This module provides routing functions that determine the next node to execute
|
||||
based on state values. All functions return Literal types for type safety.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global counter for Verify routing
|
||||
logger = get_agent_logger(__name__)
|
||||
counter = COUNTState(limit=3)
|
||||
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
logger.debug(f"Split_continue state: {state}")
|
||||
search_switch = state.get('search_switch', '')
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
return 'Split_The_Problem' # 默认情况
|
||||
|
||||
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing based on search_switch value.
|
||||
|
||||
Args:
|
||||
state: State dictionary containing search_switch
|
||||
|
||||
Returns:
|
||||
Next node to execute
|
||||
"""
|
||||
search_switch = state.get('search_switch', '')
|
||||
if search_switch is not None:
|
||||
search_switch = str(search_switch)
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
return 'Retrieve_Summary' # Default based on business logic
|
||||
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
|
||||
"""
|
||||
Determine routing after Verify node based on verification result.
|
||||
|
||||
This function checks the verification result in the last message and routes to:
|
||||
- Summary: if verification succeeded
|
||||
- content_input: if verification failed and retry limit not reached
|
||||
- Summary_fails: if verification failed and retry limit reached
|
||||
|
||||
Args:
|
||||
state: LangGraph state containing messages
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
messages = state.get("messages", [])
|
||||
|
||||
# Boundary check
|
||||
if not messages:
|
||||
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
|
||||
counter.reset()
|
||||
status=state.get('verify', '')['status']
|
||||
# loop_count = counter.get_total()
|
||||
if "success" in status:
|
||||
# counter.reset()
|
||||
return "Summary"
|
||||
|
||||
# Increment counter
|
||||
counter.add(1)
|
||||
loop_count = counter.get_total()
|
||||
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
|
||||
|
||||
# Extract verification result from last message
|
||||
last_message = messages[-1]
|
||||
last_message_str = str(last_message).replace('\\', '')
|
||||
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
|
||||
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
|
||||
|
||||
# Route based on verification result
|
||||
if "success" in status_tools:
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
elif "failed" in status_tools:
|
||||
if loop_count < 2: # Max retry count is 2
|
||||
return "content_input"
|
||||
else:
|
||||
counter.reset()
|
||||
return "Summary_fails"
|
||||
elif "failed" in status:
|
||||
# if loop_count < 2: # Maximum loop count is 3
|
||||
# return "content_input"
|
||||
# else:
|
||||
# counter.reset()
|
||||
return "Summary_fails"
|
||||
else:
|
||||
# Default to Summary if status is unclear
|
||||
counter.reset()
|
||||
return "Summary"
|
||||
|
||||
|
||||
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
|
||||
"""
|
||||
Determine routing after Retrieve node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '0': Route to Verify (verification needed)
|
||||
- search_switch == '1': Route to Retrieve_Summary (direct summary)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '0':
|
||||
return 'Verify'
|
||||
elif search_switch == '1':
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
# Default to Retrieve_Summary
|
||||
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
|
||||
return 'Retrieve_Summary'
|
||||
|
||||
|
||||
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
|
||||
"""
|
||||
Determine routing after content_input node based on search_switch value.
|
||||
|
||||
This function routes based on the search_switch parameter:
|
||||
- search_switch == '2': Route to Input_Summary (direct input summary)
|
||||
- Otherwise: Route to Split_The_Problem (problem decomposition)
|
||||
|
||||
Args:
|
||||
state: LangGraph state dictionary
|
||||
|
||||
Returns:
|
||||
Next node name as Literal type
|
||||
"""
|
||||
logger.debug(f"[Split_continue] state keys: {state.keys()}")
|
||||
|
||||
search_switch = extract_search_switch(state)
|
||||
|
||||
logger.debug(f"[Split_continue] search_switch: {search_switch}")
|
||||
|
||||
if search_switch == '2':
|
||||
return 'Input_Summary'
|
||||
|
||||
# Default to Split_The_Problem
|
||||
return 'Split_The_Problem'
|
||||
# Add default return value to avoid returning None
|
||||
# counter.reset()
|
||||
return "Summary" # Default based on business requirements
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
"""LangGraph state management utilities."""
|
||||
|
||||
from app.core.memory.agent.langgraph_graph.state.extractors import (
|
||||
extract_search_switch,
|
||||
extract_tool_call_id,
|
||||
extract_content_payload,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"extract_search_switch",
|
||||
"extract_tool_call_id",
|
||||
"extract_content_payload",
|
||||
]
|
||||
@@ -1,179 +0,0 @@
|
||||
"""
|
||||
State extraction utilities for type-safe access to LangGraph state values.
|
||||
|
||||
This module provides utility functions for extracting values from LangGraph state
|
||||
dictionaries with proper error handling and sensible defaults.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_search_switch(state: dict) -> Optional[str]:
|
||||
"""
|
||||
Extract search_switch from state or messages.
|
||||
"""
|
||||
|
||||
search_switch = state.get("search_switch")
|
||||
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# Try to extract from messages
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 从最新的消息开始查找
|
||||
for message in reversed(messages):
|
||||
# 尝试从 tool_calls 中提取
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
# 从 tool_call 的 args 中提取
|
||||
if "args" in tool_call and isinstance(tool_call["args"], dict):
|
||||
search_switch = tool_call["args"].get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
# 直接从 tool_call 中提取
|
||||
search_switch = tool_call.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
|
||||
# 尝试从 content 中提取(如果是 JSON 格式)
|
||||
if hasattr(message, "content"):
|
||||
try:
|
||||
import json
|
||||
if isinstance(message.content, str):
|
||||
content_data = json.loads(message.content)
|
||||
if isinstance(content_data, dict):
|
||||
search_switch = content_data.get("search_switch")
|
||||
if search_switch is not None:
|
||||
return str(search_switch)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_tool_call_id(message: Any) -> str:
|
||||
"""
|
||||
Extract tool call ID from message using structured attributes.
|
||||
|
||||
This function extracts the tool call ID from a message object, handling both
|
||||
direct attribute access and tool_calls list structures.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage or AIMessage)
|
||||
|
||||
Returns:
|
||||
Tool call ID as string
|
||||
|
||||
Raises:
|
||||
ValueError: If tool call ID cannot be extracted
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content="...", tool_call_id="call_123")
|
||||
>>> extract_tool_call_id(message)
|
||||
'call_123'
|
||||
"""
|
||||
# Try direct attribute access for ToolMessage
|
||||
if hasattr(message, "tool_call_id"):
|
||||
tool_call_id = message.tool_call_id
|
||||
if tool_call_id:
|
||||
return str(tool_call_id)
|
||||
|
||||
# Try extracting from tool_calls list for AIMessage
|
||||
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "id" in tool_call:
|
||||
return str(tool_call["id"])
|
||||
|
||||
# Try extracting from id attribute
|
||||
if hasattr(message, "id"):
|
||||
message_id = message.id
|
||||
if message_id:
|
||||
return str(message_id)
|
||||
|
||||
# If all else fails, raise an error
|
||||
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
|
||||
|
||||
|
||||
def extract_content_payload(message: Any) -> Any:
|
||||
"""
|
||||
Extract content payload from ToolMessage, parsing JSON if needed.
|
||||
|
||||
This function extracts the content from a message and attempts to parse it as JSON
|
||||
if it appears to be a JSON string. It handles various message formats and provides
|
||||
sensible fallbacks.
|
||||
|
||||
Args:
|
||||
message: Message object (typically ToolMessage)
|
||||
|
||||
Returns:
|
||||
Parsed content (dict, list, or str)
|
||||
|
||||
Examples:
|
||||
>>> message = ToolMessage(content='{"key": "value"}')
|
||||
>>> extract_content_payload(message)
|
||||
{'key': 'value'}
|
||||
|
||||
>>> message = ToolMessage(content='plain text')
|
||||
>>> extract_content_payload(message)
|
||||
'plain text'
|
||||
"""
|
||||
# Extract raw content
|
||||
# For ToolMessages (responses from tools), extract from content
|
||||
if hasattr(message, "content"):
|
||||
raw_content = message.content
|
||||
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
|
||||
|
||||
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
|
||||
if isinstance(raw_content, list):
|
||||
for block in raw_content:
|
||||
if isinstance(block, dict) and block.get('type') == 'text':
|
||||
raw_content = block.get('text', '')
|
||||
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
|
||||
break
|
||||
|
||||
# If content is empty and this is an AIMessage with tool_calls,
|
||||
# extract from args (this handles the initial tool call from content_input)
|
||||
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
|
||||
tool_call = message.tool_calls[0]
|
||||
if isinstance(tool_call, dict) and "args" in tool_call:
|
||||
return tool_call["args"]
|
||||
else:
|
||||
raw_content = str(message)
|
||||
|
||||
# If content is already a dict or list, return it directly
|
||||
if isinstance(raw_content, (dict, list)):
|
||||
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
|
||||
return raw_content
|
||||
|
||||
# Try to parse as JSON
|
||||
if isinstance(raw_content, str):
|
||||
# First, try direct JSON parsing
|
||||
try:
|
||||
parsed = json.loads(raw_content)
|
||||
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
pass
|
||||
|
||||
# If that fails, try to extract JSON from the string
|
||||
# This handles cases where the content is embedded in a larger string
|
||||
import re
|
||||
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
|
||||
for candidate in json_candidates:
|
||||
try:
|
||||
parsed = json.loads(candidate)
|
||||
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
|
||||
return parsed
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
continue
|
||||
|
||||
# If all parsing attempts fail, return the raw content
|
||||
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
|
||||
return raw_content
|
||||
320
api/app/core/memory/agent/langgraph_graph/tools/tool.py
Normal file
320
api/app/core/memory/agent/langgraph_graph/tools/tool.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from app.core.memory.src.search import (
|
||||
search_by_temporal,
|
||||
search_by_keyword_temporal,
|
||||
)
|
||||
|
||||
def extract_tool_message_content(response):
|
||||
"""从agent响应中提取ToolMessage内容和工具名称"""
|
||||
messages = response.get('messages', [])
|
||||
|
||||
for message in messages:
|
||||
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
|
||||
# 这是一个ToolMessage
|
||||
tool_content = message.content
|
||||
tool_name = None
|
||||
|
||||
# 尝试获取工具名称
|
||||
if hasattr(message, 'name'):
|
||||
tool_name = message.name
|
||||
elif hasattr(message, 'tool_name'):
|
||||
tool_name = message.tool_name
|
||||
|
||||
try:
|
||||
# 解析JSON内容
|
||||
parsed_content = json.loads(tool_content)
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': parsed_content
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# 如果不是JSON格式,直接返回内容
|
||||
return {
|
||||
'tool_name': tool_name,
|
||||
'content': tool_content
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class TimeRetrievalInput(BaseModel):
|
||||
"""时间检索工具的输入模式"""
|
||||
context: str = Field(description="用户输入的查询内容")
|
||||
group_id: str = Field(default="88a459f5_text09", description="组ID,用于过滤搜索结果")
|
||||
|
||||
def create_time_retrieval_tool(group_id: str):
|
||||
"""
|
||||
创建一个带有特定group_id的TimeRetrieval工具(同步版本),用于按时间范围搜索语句(Statements)
|
||||
"""
|
||||
|
||||
def clean_temporal_result_fields(data):
|
||||
"""
|
||||
清理时间搜索结果中不需要的字段,并修改结构
|
||||
|
||||
Args:
|
||||
data: 要清理的数据
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
fields_to_remove = {
|
||||
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
|
||||
'valid_at', 'invalid_at', 'statement_ids'
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
|
||||
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
|
||||
cleaned_value = clean_temporal_result_fields(value)
|
||||
# 进一步将内部的 statements 改为 time_search
|
||||
if 'statements' in cleaned_value:
|
||||
cleaned['results'] = {
|
||||
'time_search': cleaned_value['statements']
|
||||
}
|
||||
else:
|
||||
cleaned['results'] = cleaned_value
|
||||
elif key not in fields_to_remove:
|
||||
cleaned[key] = clean_temporal_result_fields(value)
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
return [clean_temporal_result_fields(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
@tool
|
||||
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, group_id_param: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询上下文内容
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- group_id_param: 组ID(可选,用于覆盖默认组ID)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
-end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
# 使用传入的参数或默认值
|
||||
actual_group_id = group_id_param or group_id
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
|
||||
|
||||
# 基本时间搜索
|
||||
results = await search_by_temporal(
|
||||
group_id=actual_group_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=10
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
@tool
|
||||
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
|
||||
"""
|
||||
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
|
||||
显式接收参数:
|
||||
- context: 查询内容
|
||||
- days_back: 向前搜索的天数,默认7天
|
||||
- start_date: 开始时间(可选,格式:YYYY-MM-DD)
|
||||
- end_date: 结束时间(可选,格式:YYYY-MM-DD)
|
||||
- clean_output: 是否清理输出中的元数据字段
|
||||
- end_date 需要根据用户的描述获取结束的时间,输出格式用strftime("%Y-%m-%d")
|
||||
"""
|
||||
async def _async_search():
|
||||
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
|
||||
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
|
||||
|
||||
# 关键词时间搜索
|
||||
results = await search_by_keyword_temporal(
|
||||
query_text=context,
|
||||
group_id=group_id,
|
||||
start_date=actual_start_date,
|
||||
end_date=actual_end_date,
|
||||
limit=15
|
||||
)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_temporal_result_fields(results)
|
||||
else:
|
||||
cleaned_results = results
|
||||
|
||||
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return TimeRetrievalWithGroupId
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
|
||||
"""
|
||||
创建混合检索工具,使用run_hybrid_search进行混合检索,优化输出格式并过滤不需要的字段
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数,包含group_id, limit, include等
|
||||
"""
|
||||
|
||||
def clean_result_fields(data):
|
||||
"""
|
||||
递归清理结果中不需要的字段
|
||||
|
||||
Args:
|
||||
data: 要清理的数据(可能是字典、列表或其他类型)
|
||||
|
||||
Returns:
|
||||
清理后的数据
|
||||
"""
|
||||
# 需要过滤的字段列表
|
||||
fields_to_remove = {
|
||||
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
|
||||
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
|
||||
'user_id', 'statement_ids', 'updated_at',"chunk_ids","fact_summary"
|
||||
}
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 对字典进行清理
|
||||
cleaned = {}
|
||||
for key, value in data.items():
|
||||
if key not in fields_to_remove:
|
||||
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
|
||||
return cleaned
|
||||
elif isinstance(data, list):
|
||||
# 对列表中的每个元素进行清理
|
||||
return [clean_result_fields(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
|
||||
@tool
|
||||
async def HybridSearch(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
use_forgetting_rerank: bool = False,
|
||||
use_llm_rerank: bool = False,
|
||||
clean_output: bool = True # 新增:是否清理输出字段
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
rerank_alpha: 重排序权重参数
|
||||
use_forgetting_rerank: 是否使用遗忘重排序
|
||||
use_llm_rerank: 是否使用LLM重排序
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
try:
|
||||
# 导入run_hybrid_search函数
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
|
||||
# 合并参数,优先使用传入的参数
|
||||
final_params = {
|
||||
"query_text": context,
|
||||
"search_type": search_type,
|
||||
"group_id": group_id or search_params.get("group_id"),
|
||||
"limit": limit or search_params.get("limit", 10),
|
||||
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
|
||||
"output_path": None, # 不保存到文件
|
||||
"memory_config": memory_config,
|
||||
"rerank_alpha": rerank_alpha,
|
||||
"use_forgetting_rerank": use_forgetting_rerank,
|
||||
"use_llm_rerank": use_llm_rerank
|
||||
}
|
||||
|
||||
# 执行混合检索
|
||||
raw_results = await run_hybrid_search(**final_params)
|
||||
|
||||
# 清理结果中不需要的字段
|
||||
if clean_output:
|
||||
cleaned_results = clean_result_fields(raw_results)
|
||||
else:
|
||||
cleaned_results = raw_results
|
||||
|
||||
# 格式化返回结果
|
||||
formatted_results = {
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"results": cleaned_results
|
||||
}
|
||||
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
except Exception as e:
|
||||
error_result = {
|
||||
"error": f"混合检索失败: {str(e)}",
|
||||
"search_query": context,
|
||||
"search_type": search_type,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
return HybridSearch
|
||||
|
||||
|
||||
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
|
||||
"""
|
||||
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
|
||||
|
||||
Args:
|
||||
memory_config: 内存配置对象
|
||||
**search_params: 搜索参数
|
||||
"""
|
||||
@tool
|
||||
def HybridSearchSync(
|
||||
context: str,
|
||||
search_type: str = "hybrid",
|
||||
limit: int = 10,
|
||||
group_id: str = None,
|
||||
clean_output: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
|
||||
|
||||
Args:
|
||||
context: 查询内容
|
||||
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
|
||||
limit: 结果数量限制
|
||||
group_id: 组ID,用于过滤搜索结果
|
||||
clean_output: 是否清理输出中的元数据字段
|
||||
"""
|
||||
async def _async_search():
|
||||
# 创建异步工具并执行
|
||||
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
|
||||
return await async_tool.ainvoke({
|
||||
"context": context,
|
||||
"search_type": search_type,
|
||||
"limit": limit,
|
||||
"group_id": group_id,
|
||||
"clean_output": clean_output
|
||||
})
|
||||
|
||||
return asyncio.run(_async_search())
|
||||
|
||||
return HybridSearchSync
|
||||
@@ -1,80 +1,80 @@
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import warnings
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.constants import END, START
|
||||
from langgraph.graph import StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
|
||||
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.llm_tools import WriteState
|
||||
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
|
||||
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if sys.platform.startswith("win"):
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
|
||||
async def make_write_graph():
|
||||
"""
|
||||
Create a write graph workflow for memory operations.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
tools: MCP tools loaded from session
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
The workflow directly processes messages from the initial state
|
||||
and saves them to Neo4j storage.
|
||||
"""
|
||||
logger.info("Loading MCP tools: %s", [t.name for t in tools])
|
||||
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
|
||||
|
||||
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
|
||||
|
||||
if not data_write_tool:
|
||||
logger.error("Data_write tool not found", exc_info=True)
|
||||
raise ValueError("Data_write tool not found")
|
||||
|
||||
write_node = ToolNode([data_write_tool])
|
||||
|
||||
async def call_model(state):
|
||||
messages = state["messages"]
|
||||
last_message = messages[-1]
|
||||
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
|
||||
|
||||
# Call Data_write directly with memory_config
|
||||
write_params = {
|
||||
"content": content,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"memory_config": memory_config,
|
||||
}
|
||||
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
|
||||
|
||||
write_result = await data_write_tool.ainvoke(write_params)
|
||||
|
||||
if isinstance(write_result, dict):
|
||||
result_content = write_result.get("data", str(write_result))
|
||||
else:
|
||||
result_content = str(write_result)
|
||||
logger.info("Write content: %s", result_content)
|
||||
return {"messages": [AIMessage(content=result_content)]}
|
||||
|
||||
workflow = StateGraph(WriteState)
|
||||
workflow.add_node("content_input", call_model)
|
||||
workflow.add_node("save_neo4j", write_node)
|
||||
workflow.add_edge(START, "content_input")
|
||||
workflow.add_edge("content_input", "save_neo4j")
|
||||
workflow.add_edge(START, "save_neo4j")
|
||||
workflow.add_edge("save_neo4j", END)
|
||||
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
yield graph
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数 - 运行工作流"""
|
||||
message = "今天周一"
|
||||
group_id = 'new_2025test1103' # 组ID
|
||||
|
||||
|
||||
# 获取数据库会话
|
||||
db_session = next(get_db())
|
||||
config_service = MemoryConfigService(db_session)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=17, # 改为整数
|
||||
service_name="MemoryAgentService"
|
||||
)
|
||||
try:
|
||||
async with make_write_graph() as graph:
|
||||
config = {"configurable": {"thread_id": group_id}}
|
||||
# 初始状态 - 包含所有必要字段
|
||||
initial_state = {"messages": [HumanMessage(content=message)], "group_id": group_id, "memory_config": memory_config}
|
||||
|
||||
# 获取节点更新信息
|
||||
async for update_event in graph.astream(
|
||||
initial_state,
|
||||
stream_mode="updates",
|
||||
config=config
|
||||
):
|
||||
for node_name, node_data in update_event.items():
|
||||
if 'save_neo4j'==node_name:
|
||||
massages=node_data
|
||||
massages=massages.get('write_result')['status']
|
||||
print(massages) # | 更新数据: {node_data}
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
@@ -1,28 +0,0 @@
|
||||
"""
|
||||
MCP Server package for memory agent.
|
||||
|
||||
This package provides the FastMCP server implementation with context-based
|
||||
dependency injection for tool functions.
|
||||
|
||||
Package structure:
|
||||
- server: FastMCP server initialization and context setup
|
||||
- tools: MCP tool implementations
|
||||
- models: Pydantic response models
|
||||
- services: Business logic services
|
||||
"""
|
||||
# from app.core.memory.agent.mcp_server.server import (
|
||||
# mcp,
|
||||
# initialize_context,
|
||||
# main,
|
||||
# get_context_resource
|
||||
# )
|
||||
|
||||
# # Import tools to register them (but don't export them)
|
||||
# from app.core.memory.agent.mcp_server import tools
|
||||
|
||||
# __all__ = [
|
||||
# 'mcp',
|
||||
# 'initialize_context',
|
||||
# 'main',
|
||||
# 'get_context_resource',
|
||||
# ]
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
MCP Server Instance
|
||||
|
||||
This module contains the FastMCP server instance that is shared across all modules.
|
||||
It's in a separate file to avoid circular import issues.
|
||||
"""
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server instance
|
||||
# This instance is shared across all tool modules
|
||||
mcp = FastMCP('data_flow')
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Pydantic models for verification operations."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str
|
||||
expansion_issue: List[Dict[str, Any]]
|
||||
split_result: str
|
||||
reason: Optional[str] = None
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
@@ -1,159 +0,0 @@
|
||||
"""
|
||||
MCP Server initialization with FastMCP context setup.
|
||||
|
||||
This module initializes the FastMCP server and registers shared resources
|
||||
in the context for dependency injection into tool functions.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.services.search_service import SearchService
|
||||
from app.core.memory.agent.mcp_server.services.session_service import SessionService
|
||||
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.redis_tool import store
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def get_context_resource(ctx, resource_name: str):
|
||||
"""
|
||||
Helper function to retrieve a resource from the FastMCP context.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP Context object (passed to tool functions)
|
||||
resource_name: Name of the resource to retrieve
|
||||
|
||||
Returns:
|
||||
The requested resource
|
||||
|
||||
Raises:
|
||||
AttributeError: If the resource doesn't exist
|
||||
|
||||
Example:
|
||||
@mcp.tool()
|
||||
async def my_tool(ctx: Context):
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
llm_client = get_context_resource(ctx, 'llm_client')
|
||||
"""
|
||||
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
|
||||
raise RuntimeError("Context does not have fastmcp attribute")
|
||||
|
||||
if not hasattr(ctx.fastmcp, resource_name):
|
||||
raise AttributeError(
|
||||
f"Resource '{resource_name}' not found in context. "
|
||||
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
|
||||
)
|
||||
|
||||
return getattr(ctx.fastmcp, resource_name)
|
||||
|
||||
|
||||
def initialize_context():
|
||||
"""
|
||||
Initialize and register shared resources in FastMCP context.
|
||||
|
||||
This function sets up all shared resources that will be available
|
||||
to tool functions via dependency injection through the context parameter.
|
||||
|
||||
Resources are stored as attributes on the FastMCP instance and can be
|
||||
accessed via ctx.fastmcp in tool functions.
|
||||
|
||||
Resources registered:
|
||||
- session_store: RedisSessionStore for session management
|
||||
- llm_client: LLM client for structured API calls
|
||||
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
|
||||
- template_service: Service for template rendering
|
||||
- search_service: Service for hybrid search
|
||||
- session_service: Service for session operations
|
||||
"""
|
||||
try:
|
||||
# Register Redis session store
|
||||
logger.info("Registering session_store in context")
|
||||
mcp.session_store = store
|
||||
|
||||
# Note: LLM client is NOT loaded at server startup
|
||||
# It should be loaded dynamically when needed, with config_id passed explicitly
|
||||
# to make_write_graph or make_read_graph functions
|
||||
logger.info("LLM client will be loaded dynamically with config_id when needed")
|
||||
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
|
||||
|
||||
# Register application settings (renamed to avoid conflict with FastMCP's settings)
|
||||
logger.info("Registering app_settings in context")
|
||||
mcp.app_settings = settings
|
||||
|
||||
# Register template service
|
||||
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
|
||||
# logger.info(f"Registering template_service in context with root: {template_root}")
|
||||
template_service = TemplateService(template_root)
|
||||
mcp.template_service = template_service
|
||||
|
||||
# Register search service
|
||||
# logger.info("Registering search_service in context")
|
||||
search_service = SearchService()
|
||||
mcp.search_service = search_service
|
||||
|
||||
# Register session service
|
||||
# logger.info("Registering session_service in context")
|
||||
session_service = SessionService(store)
|
||||
mcp.session_service = session_service
|
||||
|
||||
# logger.info("All context resources registered successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize context: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Main entry point for the MCP server.
|
||||
|
||||
Initializes context and starts the server with SSE transport.
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting MCP server initialization")
|
||||
# Initialize context resources
|
||||
initialize_context()
|
||||
|
||||
# Import and register tools (imports trigger tool registration)
|
||||
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
|
||||
data_tools,
|
||||
problem_tools,
|
||||
retrieval_tools,
|
||||
summary_tools,
|
||||
verification_tools,
|
||||
)
|
||||
|
||||
# Tools are registered via imports above
|
||||
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = int(os.getenv("MCP_PORT", "8081"))
|
||||
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||
|
||||
# Configure DNS rebinding protection for Docker container compatibility
|
||||
from mcp.server.fastmcp.server import TransportSecuritySettings
|
||||
|
||||
# Disable DNS rebinding protection to allow Docker container hostnames
|
||||
# This allows containers to connect using service names like 'mcp-server'
|
||||
mcp.settings.transport_security = TransportSecuritySettings(
|
||||
enable_dns_rebinding_protection=False,
|
||||
)
|
||||
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
|
||||
|
||||
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
|
||||
|
||||
# Run the server with SSE transport for HTTP connections
|
||||
import uvicorn
|
||||
app = mcp.sse_app()
|
||||
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,27 +0,0 @@
|
||||
"""
|
||||
MCP Tools module.
|
||||
|
||||
This module contains all MCP tool implementations organized by functionality.
|
||||
|
||||
Tools are organized into the following modules:
|
||||
- problem_tools: Question segmentation and extension
|
||||
- retrieval_tools: Database and context retrieval
|
||||
- verification_tools: Data verification
|
||||
- summary_tools: Summarization and summary retrieval
|
||||
- data_tools: Data type differentiation and writing
|
||||
"""
|
||||
|
||||
# Import all tool modules to register them with the MCP server
|
||||
from . import problem_tools
|
||||
from . import retrieval_tools
|
||||
from . import verification_tools
|
||||
from . import summary_tools
|
||||
from . import data_tools
|
||||
|
||||
__all__ = [
|
||||
'problem_tools',
|
||||
'retrieval_tools',
|
||||
'verification_tools',
|
||||
'summary_tools',
|
||||
'data_tools',
|
||||
]
|
||||
@@ -1,155 +0,0 @@
|
||||
"""
|
||||
Data Tools for data type differentiation and writing.
|
||||
|
||||
This module contains MCP tools for distinguishing data types and writing data.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.retrieval_models import (
|
||||
DistinguishTypeResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.write_tools import write
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_type_differentiation(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Distinguish the type of data (read or write).
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Text to analyze for type differentiation
|
||||
memory_config: MemoryConfig object containing LLM configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with the original text and 'type' field
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, 'template_service')
|
||||
|
||||
# Get LLM client from memory_config using factory pattern
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='distinguish_types_prompt.jinja2',
|
||||
operation_name='status_typle',
|
||||
user_query=context
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=DistinguishTypeResponse
|
||||
)
|
||||
|
||||
result = structured.model_dump()
|
||||
|
||||
# Add context to result
|
||||
result["context"] = context
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Data_type_differentiation: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": f"LLM call failed: {str(e)}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Data_type_differentiation failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": context,
|
||||
"type": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Data_write(
|
||||
ctx: Context,
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Write data to the database/file system.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
content: Data content to write
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status', 'saved_to', and 'data' fields
|
||||
"""
|
||||
try:
|
||||
# Ensure output directory exists
|
||||
os.makedirs("data_output", exist_ok=True)
|
||||
file_path = os.path.join("data_output", "user_data.csv")
|
||||
|
||||
# Write data - clients are constructed inside write() from memory_config
|
||||
await write(
|
||||
content=content,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
memory_config=memory_config,
|
||||
)
|
||||
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"saved_to": file_path,
|
||||
"data": content,
|
||||
"config_id": memory_config.config_id,
|
||||
"config_name": memory_config.config_name,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data_write failed: {e}", exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
@@ -1,304 +0,0 @@
|
||||
"""
|
||||
Problem Tools for question segmentation and extension.
|
||||
|
||||
This module contains MCP tools for breaking down and extending user questions.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.problem_models import (
|
||||
ProblemBreakdownResponse,
|
||||
ProblemExtensionResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Split_The_Problem(
|
||||
ctx: Context,
|
||||
sentence: str,
|
||||
sessionid: str,
|
||||
messages_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
) -> dict:
|
||||
"""
|
||||
Segment the dialogue or sentence into sub-problems.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
sentence: Original sentence to split
|
||||
sessionid: Session identifier
|
||||
messages_id: Message identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (JSON string of split results) and 'original' sentence
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Extract user ID from session
|
||||
user_id = session_service.resolve_user_id(sessionid)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(user_id, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='problem_breakdown_prompt.jinja2',
|
||||
operation_name='split_the_problem',
|
||||
history=history,
|
||||
sentence=sentence
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemBreakdownResponse
|
||||
)
|
||||
|
||||
# Handle RootModel response with .root attribute access
|
||||
if structured is None:
|
||||
# LLM returned None, use empty list as fallback
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
elif hasattr(structured, 'root') and structured.root is not None:
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured.root],
|
||||
ensure_ascii=False
|
||||
)
|
||||
elif isinstance(structured, list):
|
||||
# Fallback: treat structured itself as the list
|
||||
split_result = json.dumps(
|
||||
[item.model_dump() for item in structured],
|
||||
ensure_ascii=False
|
||||
)
|
||||
else:
|
||||
# Last resort: use empty list
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Split_The_Problem: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
split_result = json.dumps([], ensure_ascii=False)
|
||||
|
||||
logger.info("Problem splitting")
|
||||
logger.info(f"Problem split result: {split_result}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": split_result,
|
||||
"original": sentence,
|
||||
"_intermediate": {
|
||||
"type": "problem_split",
|
||||
"data": json.loads(split_result) if split_result else [],
|
||||
"original_query": sentence
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Split_The_Problem failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": json.dumps([], ensure_ascii=False),
|
||||
"original": sentence,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Problem splitting', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Problem_Extension(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Extend the problem with additional sub-questions.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing split problem results
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' (aggregated questions) and 'original' question
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID from usermessages
|
||||
from app.core.memory.agent.utils.messages_tool import Resolve_username
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
history = []
|
||||
|
||||
# Process context to extract questions
|
||||
extent_quest, original = await Problem_Extension_messages_deal(context)
|
||||
|
||||
# Format questions for template rendering
|
||||
questions_formatted = []
|
||||
for msg in extent_quest:
|
||||
if msg.get("role") == "user":
|
||||
questions_formatted.append(msg.get("content", ""))
|
||||
|
||||
# Render template
|
||||
try:
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Problem_Extension_prompt.jinja2',
|
||||
operation_name='problem_extension',
|
||||
history=history,
|
||||
questions=questions_formatted
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": original,
|
||||
"error": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
# Call LLM with structured response
|
||||
try:
|
||||
response_content = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=ProblemExtensionResponse
|
||||
)
|
||||
|
||||
# Aggregate results by original question
|
||||
aggregated_dict = {}
|
||||
for item in response_content.root:
|
||||
key = getattr(item, "original_question", None) or (
|
||||
item.get("original_question") if isinstance(item, dict) else None
|
||||
)
|
||||
value = getattr(item, "extended_question", None) or (
|
||||
item.get("extended_question") if isinstance(item, dict) else None
|
||||
)
|
||||
if not key or not value:
|
||||
continue
|
||||
aggregated_dict.setdefault(key, []).append(value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Problem_Extension: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aggregated_dict = {}
|
||||
|
||||
logger.info("Problem extension")
|
||||
logger.info(f"Problem extension result: {aggregated_dict}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
result = {
|
||||
"context": aggregated_dict,
|
||||
"original": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "problem_extension",
|
||||
"data": aggregated_dict,
|
||||
"original_query": original,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Problem_Extension failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {},
|
||||
"original": context.get("original", ""),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Problem extension', duration)
|
||||
@@ -1,294 +0,0 @@
|
||||
"""
|
||||
Retrieval Tools for database and context retrieval.
|
||||
|
||||
This module contains MCP tools for retrieving data using hybrid search.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import (
|
||||
deduplicate_entries,
|
||||
merge_to_key_value_pairs,
|
||||
)
|
||||
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve(
|
||||
ctx: Context,
|
||||
context,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Retrieve data from the database using hybrid search.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary or string containing query information
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'context' with Query and Expansion_issue results
|
||||
"""
|
||||
kb_config = {
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id": os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
start = time.time()
|
||||
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
search_service = get_context_resource(ctx, 'search_service')
|
||||
|
||||
databases_anser = []
|
||||
|
||||
# Handle both dict and string context
|
||||
if isinstance(context, dict):
|
||||
# Process dict context with extended questions
|
||||
all_items = []
|
||||
logger.info(f"Retrieve: context keys={list(context.keys())}")
|
||||
content, original = await Retriev_messages_deal(context)
|
||||
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
|
||||
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
|
||||
|
||||
if not original:
|
||||
logger.warning(f"Retrieve: original query is empty! context={context}")
|
||||
|
||||
# Extract all query items from content
|
||||
# content is like {original_question: [extended_questions...], ...}
|
||||
for key, values in content.items():
|
||||
if isinstance(values, list):
|
||||
all_items.extend(values)
|
||||
elif isinstance(values, str):
|
||||
all_items.append(values)
|
||||
elif values is not None:
|
||||
# Fallback: convert non-empty non-list values to string
|
||||
all_items.append(str(values))
|
||||
|
||||
# Execute search for each question
|
||||
for idx, question in enumerate(all_items):
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": question,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query=question
|
||||
raw_results=clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results=''
|
||||
cleaned_query = question
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
databases_anser.append({
|
||||
"Query_small": cleaned_query,
|
||||
"Result_small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": idx + 1,
|
||||
"total": len(all_items)
|
||||
}
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for question '{question}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Continue with empty result for this question
|
||||
databases_anser.append({
|
||||
"Query_small": question,
|
||||
"Result_small": ""
|
||||
})
|
||||
|
||||
# Build initial database data structure
|
||||
databases_data = {
|
||||
"Query": original,
|
||||
"Expansion_issue": databases_anser
|
||||
}
|
||||
|
||||
# Collect intermediate outputs before deduplication
|
||||
intermediate_outputs = []
|
||||
for item in databases_anser:
|
||||
if '_intermediate' in item:
|
||||
intermediate_outputs.append(item['_intermediate'])
|
||||
|
||||
# Deduplicate and merge results
|
||||
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
|
||||
deduplicated_data_merged = merge_to_key_value_pairs(
|
||||
deduplicated_data,
|
||||
'Query_small',
|
||||
'Result_small'
|
||||
)
|
||||
|
||||
# Restructure for Verify/Retrieve_Summary compatibility
|
||||
keys, val = [], []
|
||||
for item in deduplicated_data_merged:
|
||||
for items_key, items_value in item.items():
|
||||
keys.append(items_key)
|
||||
val.append(items_value)
|
||||
|
||||
send_verify = []
|
||||
for i, j in zip(keys, val, strict=False):
|
||||
send_verify.append({
|
||||
"Query_small": i,
|
||||
"Answer_Small": j
|
||||
})
|
||||
|
||||
dup_databases = {
|
||||
"Query": original,
|
||||
"Expansion_issue": send_verify,
|
||||
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
|
||||
}
|
||||
|
||||
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
|
||||
|
||||
else:
|
||||
# Handle string context (simple query)
|
||||
query = str(context).strip()
|
||||
|
||||
try:
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
clean_content = '\n\n'.join(retrieval_knowledge)
|
||||
cleaned_query = query
|
||||
raw_results = clean_content
|
||||
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
clean_content = ''
|
||||
raw_results = ''
|
||||
cleaned_query = query
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
# Keep structure for Verify/Retrieve_Summary compatibility
|
||||
dup_databases = {
|
||||
"Query": cleaned_query,
|
||||
"Expansion_issue": [{
|
||||
"Query_small": cleaned_query,
|
||||
"Answer_Small": clean_content,
|
||||
"_intermediate": {
|
||||
"type": "search_result",
|
||||
"query": cleaned_query,
|
||||
"raw_results": raw_results,
|
||||
"index": 1,
|
||||
"total": 1
|
||||
}
|
||||
}]
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve: hybrid_search failed for query '{query}': {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty results on failure
|
||||
dup_databases = {
|
||||
"Query": query,
|
||||
"Expansion_issue": []
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
|
||||
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
|
||||
)
|
||||
|
||||
# Build result with intermediate outputs
|
||||
result = {
|
||||
"context": dup_databases,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
# Add intermediate outputs list if they exist
|
||||
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
|
||||
if intermediate_outputs:
|
||||
result['_intermediates'] = intermediate_outputs
|
||||
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
|
||||
else:
|
||||
logger.warning("No intermediate outputs found in dup_databases")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"context": {
|
||||
"Query": "",
|
||||
"Expansion_issue": []
|
||||
},
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval', duration)
|
||||
@@ -1,640 +0,0 @@
|
||||
"""
|
||||
Summary Tools for data summarization.
|
||||
|
||||
This module contains MCP tools for summarizing retrieved data and generating responses.
|
||||
LLM clients are constructed from MemoryConfig when needed.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.models.summary_models import (
|
||||
RetrieveSummaryResponse,
|
||||
SummaryResponse,
|
||||
)
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Summary_messages_deal,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.rag.nlp.search import knowledge_retrieval
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
load_dotenv()
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize the verified data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: JSON string containing verified data
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Process context to extract answer and query
|
||||
answer_small, query = await Summary_messages_deal(context)
|
||||
|
||||
|
||||
start_time= time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
data = {
|
||||
"query": query,
|
||||
"history": history,
|
||||
"retrieve_info": answer_small
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='summary_prompt.jinja2',
|
||||
operation_name='summary',
|
||||
data=data,
|
||||
query=query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=SummaryResponse
|
||||
)
|
||||
|
||||
aimessages = structured.query_answer or ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"LLM call failed for Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
|
||||
try:
|
||||
# Save session
|
||||
if aimessages != "":
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e)
|
||||
}
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"Summary after verification: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Summary', duration)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Retrieve_Summary(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Summarize data directly from retrieval results.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing Query and Expansion_issue from Retrieve
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'summary_result'
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
template_service = get_context_resource(ctx, "template_service")
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
|
||||
# Get LLM client from memory_config
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client_from_config(memory_config)
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
|
||||
|
||||
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
|
||||
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
if "content" in context:
|
||||
inner = context["content"]
|
||||
# If it's a JSON string, parse it
|
||||
if isinstance(inner, str):
|
||||
try:
|
||||
parsed = json.loads(inner)
|
||||
logger.info("Retrieve_Summary: successfully parsed JSON")
|
||||
except json.JSONDecodeError:
|
||||
# Try unescaping first
|
||||
try:
|
||||
unescaped = inner.encode('utf-8').decode('unicode_escape')
|
||||
parsed = json.loads(unescaped)
|
||||
logger.info("Retrieve_Summary: parsed after unescaping")
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: parsing failed even after unescape: {e}"
|
||||
)
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
parsed = None
|
||||
|
||||
if parsed:
|
||||
# Check if parsed has 'context' wrapper
|
||||
if isinstance(parsed, dict) and "context" in parsed:
|
||||
context_dict = parsed["context"]
|
||||
else:
|
||||
context_dict = parsed
|
||||
elif isinstance(inner, dict):
|
||||
context_dict = inner
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
elif "context" in context:
|
||||
context_dict = context["context"] if isinstance(context["context"], dict) else context
|
||||
else:
|
||||
context_dict = context
|
||||
else:
|
||||
context_dict = {"Query": "", "Expansion_issue": []}
|
||||
|
||||
query = context_dict.get("Query", "")
|
||||
expansion_issue = context_dict.get("Expansion_issue", [])
|
||||
|
||||
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
|
||||
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
|
||||
|
||||
# Extract retrieve_info from expansion_issue
|
||||
retrieve_info = []
|
||||
for item in expansion_issue:
|
||||
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
|
||||
answer = None
|
||||
if isinstance(item, dict):
|
||||
if "Answer_Small" in item:
|
||||
answer = item["Answer_Small"]
|
||||
|
||||
|
||||
if answer is not None:
|
||||
# Handle both string and list formats
|
||||
if isinstance(answer, list):
|
||||
# Join list of characters/strings into a single string
|
||||
retrieve_info.append(''.join(str(x) for x in answer))
|
||||
elif isinstance(answer, str):
|
||||
retrieve_info.append(answer)
|
||||
else:
|
||||
retrieve_info.append(str(answer))
|
||||
|
||||
# Join all retrieve_info into a single string
|
||||
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
|
||||
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
# Override with empty list for now (as in original)
|
||||
end_time=time.time()
|
||||
logger.info(f"Retrieve_Summary-REDIS搜索:{end_time - start_time}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: initialization failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"summary_result": "信息不足,无法回答"
|
||||
}
|
||||
|
||||
try:
|
||||
# Render template
|
||||
system_prompt = await template_service.render_template(
|
||||
template_name='Retrieve_Summary_prompt.jinja2',
|
||||
operation_name='retrieve_summary',
|
||||
query=query,
|
||||
history=history,
|
||||
retrieve_info=retrieve_info_str
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for Retrieve_Summary: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Prompt rendering failed: {str(e)}"
|
||||
}
|
||||
|
||||
try:
|
||||
# Call LLM with structured response
|
||||
structured = await llm_client.response_structured(
|
||||
messages=[{"role": "system", "content": system_prompt}],
|
||||
response_model=RetrieveSummaryResponse
|
||||
)
|
||||
|
||||
# Handle case where structured response might be None or incomplete
|
||||
if structured and hasattr(structured, 'data') and structured.data:
|
||||
aimessages = structured.data.query_answer or ""
|
||||
else:
|
||||
logger.warning("Structured response is None or incomplete, using default message")
|
||||
aimessages = "信息不足,无法回答"
|
||||
|
||||
|
||||
# Check for insufficient information response
|
||||
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
|
||||
# Save session
|
||||
await session_service.save_session(
|
||||
user_id=sessionid,
|
||||
query=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
ai_response=aimessages
|
||||
)
|
||||
logger.info(f"sessionid: {aimessages} 写入成功")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_Summary: LLM call failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
aimessages = ""
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
# Use fallback if empty
|
||||
if aimessages == '':
|
||||
aimessages = '信息不足,无法回答'
|
||||
|
||||
logger.info(f"Summary after retrieval: {aimessages}")
|
||||
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval summary', duration)
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "retrieval_summary",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Input_Summary(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Generate a quick summary for direct input without verification.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: String containing the input sentence
|
||||
usermessages: User messages identifier
|
||||
search_switch: Search switch value for routing ('2' for summaries only)
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
|
||||
user_rag_memory_id: User RAG memory identifier
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with the summary result
|
||||
"""
|
||||
start = time.time()
|
||||
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, "session_service")
|
||||
search_service = get_context_resource(ctx, "search_service")
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages) or ""
|
||||
sessionid = sessionid.replace('call_id_', '')
|
||||
|
||||
start_time=time.time()
|
||||
history = await session_service.get_history(
|
||||
str(sessionid),
|
||||
str(apply_id),
|
||||
str(group_id)
|
||||
)
|
||||
end_time=time.time()
|
||||
logger.info(f"Input_Summary-REDIS搜索:{end_time - start_time}")
|
||||
# Override with empty list for now (as in original)
|
||||
|
||||
# Log the raw context for debugging
|
||||
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
|
||||
|
||||
# Extract sentence from context
|
||||
# Context can be a string or might contain the sentence in various formats
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
|
||||
try:
|
||||
import json
|
||||
context_dict = json.loads(context)
|
||||
if isinstance(context_dict, dict):
|
||||
query = context_dict.get('sentence', context_dict.get('content', context))
|
||||
else:
|
||||
query = context
|
||||
except json.JSONDecodeError:
|
||||
# Not valid JSON, try regex
|
||||
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
|
||||
query = match.group(1) if match else context
|
||||
else:
|
||||
query = context
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract query from context: {e}")
|
||||
query = context
|
||||
|
||||
# Clean query
|
||||
query = str(query).strip().strip("\"'")
|
||||
|
||||
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
|
||||
|
||||
# Execute search based on search_switch and storage_type
|
||||
try:
|
||||
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
|
||||
|
||||
# Prepare search parameters based on storage type
|
||||
search_params = {
|
||||
"group_id": group_id,
|
||||
"question": query,
|
||||
"return_raw_results": True
|
||||
}
|
||||
|
||||
# Add storage-specific parameters
|
||||
|
||||
# Retrieval
|
||||
if search_switch == '2':
|
||||
search_params["include"] = ["summaries"]
|
||||
if storage_type == "rag" and user_rag_memory_id:
|
||||
raw_results = []
|
||||
retrieve_info = ""
|
||||
kb_config={
|
||||
"knowledge_bases": [
|
||||
{
|
||||
"kb_id": user_rag_memory_id,
|
||||
"similarity_threshold": 0.7,
|
||||
"vector_similarity_weight": 0.5,
|
||||
"top_k": 10,
|
||||
"retrieve_type": "participle"
|
||||
}
|
||||
],
|
||||
"merge_strategy": "weight",
|
||||
"reranker_id":os.getenv('reranker_id'),
|
||||
"reranker_top_k": 10
|
||||
}
|
||||
|
||||
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
|
||||
try:
|
||||
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
|
||||
retrieve_info = '\n\n'.join(retrieval_knowledge)
|
||||
raw_results=[retrieve_info]
|
||||
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
|
||||
except:
|
||||
retrieve_info=''
|
||||
raw_results=['']
|
||||
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
logger.info("Input_Summary: Using summary for retrieval")
|
||||
else:
|
||||
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
|
||||
**search_params, memory_config=memory_config
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary: hybrid_search failed, using empty results: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
retrieve_info, question, raw_results = "", query, []
|
||||
|
||||
# Return retrieved information directly without LLM processing
|
||||
# Use the raw retrieved info as the answer
|
||||
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
|
||||
|
||||
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": aimessages,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "input_summary",
|
||||
"title": "快速答案",
|
||||
"summary": aimessages,
|
||||
"query": query,
|
||||
"raw_results": raw_results,
|
||||
"search_mode": "quick_search",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Input_Summary failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "信息不足,无法回答",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Retrieval', duration)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Summary_fails(
|
||||
ctx: Context,
|
||||
context: str,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Handle workflow failure when summary cannot be generated.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Failure context string
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'query_answer' with failure message
|
||||
"""
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Parse session ID from usermessages
|
||||
usermessages_parts = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages_parts[:-1])
|
||||
|
||||
# Cleanup duplicate sessions
|
||||
await session_service.cleanup_duplicates()
|
||||
|
||||
logger.info("没有相关数据")
|
||||
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Summary_fails failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "fail",
|
||||
"summary_result": "没有相关数据",
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
Verification Tools for data verification.
|
||||
|
||||
This module contains MCP tools for verifying retrieved data.
|
||||
"""
|
||||
import time
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_time
|
||||
from app.core.memory.agent.mcp_server.mcp_instance import mcp
|
||||
from app.core.memory.agent.mcp_server.server import get_context_resource
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import (
|
||||
Resolve_username,
|
||||
Retrieve_verify_tool_messages_deal,
|
||||
Verify_messages_deal,
|
||||
)
|
||||
from app.core.memory.agent.utils.verify_tool import VerifyTool
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from jinja2 import Template
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def Verify(
|
||||
ctx: Context,
|
||||
context: dict,
|
||||
usermessages: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: str = "",
|
||||
user_rag_memory_id: str = ""
|
||||
) -> dict:
|
||||
"""
|
||||
Verify the retrieved data.
|
||||
|
||||
Args:
|
||||
ctx: FastMCP context for dependency injection
|
||||
context: Dictionary containing query and expansion issues
|
||||
usermessages: User messages identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory identifier (optional)
|
||||
|
||||
Returns:
|
||||
dict: Contains 'status' and 'verified_data' with verification results
|
||||
"""
|
||||
start = time.time()
|
||||
|
||||
|
||||
try:
|
||||
# Extract services from context
|
||||
session_service = get_context_resource(ctx, 'session_service')
|
||||
|
||||
# Load verification prompt template
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
|
||||
|
||||
# Read template file directly (VerifyTool expects raw template content)
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
system_prompt = await read_template_file(file_path)
|
||||
|
||||
|
||||
|
||||
# Resolve session ID
|
||||
sessionid = Resolve_username(usermessages)
|
||||
|
||||
# Get conversation history
|
||||
history = await session_service.get_history(sessionid, apply_id, group_id)
|
||||
|
||||
template = Template(system_prompt)
|
||||
system_prompt = template.render(history=history, sentence=context)
|
||||
|
||||
# Process context to extract query and results
|
||||
Query_small, Result_small, query = await Verify_messages_deal(context)
|
||||
|
||||
# Build query list for verification
|
||||
query_list = []
|
||||
for query_small, anser in zip(Query_small, Result_small, strict=False):
|
||||
query_list.append({
|
||||
'Query_small': query_small,
|
||||
'Answer_Small': anser
|
||||
})
|
||||
|
||||
messages = {
|
||||
"Query": query,
|
||||
"Expansion_issue": query_list
|
||||
}
|
||||
|
||||
|
||||
|
||||
# Call verification workflow with LLM model ID from memory_config
|
||||
verify_tool = VerifyTool(
|
||||
system_prompt=system_prompt,
|
||||
verify_data=messages,
|
||||
llm_model_id=str(memory_config.llm_model_id)
|
||||
)
|
||||
verify_result = await verify_tool.verify()
|
||||
|
||||
# Parse LLM verification result with error handling
|
||||
try:
|
||||
messages_deal = await Retrieve_verify_tool_messages_deal(
|
||||
verify_result,
|
||||
history,
|
||||
query
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Fallback to avoid 500 errors
|
||||
messages_deal = {
|
||||
"data": {
|
||||
"query": query,
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": history,
|
||||
}
|
||||
|
||||
logger.info(f"Verification result: {messages_deal}")
|
||||
|
||||
# Emit intermediate output for frontend
|
||||
return {
|
||||
"status": "success",
|
||||
"verified_data": messages_deal,
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"_intermediate": {
|
||||
"type": "verification",
|
||||
"title": "Data Verification",
|
||||
"result": messages_deal.get("split_result", "unknown"),
|
||||
"reason": messages_deal.get("reason", ""),
|
||||
"query": query,
|
||||
"verified_count": len(query_list),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Verify failed: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"storage_type": storage_type,
|
||||
"user_rag_memory_id": user_rag_memory_id,
|
||||
"verified_data": {
|
||||
"data": {
|
||||
"query": "",
|
||||
"expansion_issue": []
|
||||
},
|
||||
"split_result": "failed",
|
||||
"reason": str(e),
|
||||
"history": [],
|
||||
}
|
||||
}
|
||||
|
||||
finally:
|
||||
# Log execution time
|
||||
end = time.time()
|
||||
try:
|
||||
duration = end - start
|
||||
except Exception:
|
||||
duration = 0.0
|
||||
log_time('Verification', duration)
|
||||
32
api/app/core/memory/agent/models/verification_models.py
Normal file
32
api/app/core/memory/agent/models/verification_models.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Pydantic models for verification operations."""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VerificationItem(BaseModel):
|
||||
"""Individual verification item for a query-answer pair."""
|
||||
|
||||
query_small: str = Field(..., description="子问题")
|
||||
answer_small: str = Field(..., description="子问题的回答")
|
||||
status: str = Field(..., description="验证状态:True 或 False")
|
||||
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
|
||||
|
||||
|
||||
class VerificationResult(BaseModel):
|
||||
"""Result model for verification operation."""
|
||||
|
||||
query: str = Field(..., description="原始查询问题")
|
||||
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
|
||||
expansion_issue: List[VerificationItem] = Field(
|
||||
default_factory=list,
|
||||
description="验证后的数据列表,包含所有通过验证的问答对"
|
||||
)
|
||||
split_result: str = Field(
|
||||
...,
|
||||
description="验证结果状态:success(expansion_issue 非空)或 failed(expansion_issue 为空)"
|
||||
)
|
||||
reason: Optional[str] = Field(
|
||||
None,
|
||||
description="验证结果的说明和分析"
|
||||
)
|
||||
@@ -1,114 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import requests
|
||||
|
||||
# from qcloud_cos import CosConfig, CosS3Client
|
||||
# from qcloud_cos.cos_exception import CosClientError, CosServiceError
|
||||
|
||||
# from config.paths import BASE_DIR
|
||||
BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
|
||||
|
||||
class OSSUploader:
|
||||
"""对象存储文件上传工具类"""
|
||||
|
||||
def __init__(self, env):
|
||||
api = {
|
||||
"test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon",
|
||||
"prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon"
|
||||
}
|
||||
self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon")
|
||||
self.privacy = "false"
|
||||
self.headers = {
|
||||
"User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
|
||||
'AppleWebKit/537.36 (KHTML, like Gecko)'
|
||||
' Chrome/133.0.6833.84 Safari/537.36'
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _generate_object_key(file_path, prefix='xhs_'):
|
||||
"""
|
||||
生成对象存储的Key
|
||||
|
||||
:param file_path: 本地文件路径
|
||||
:param prefix: 存储前缀,用于分类存储
|
||||
:return: 生成的对象Key
|
||||
"""
|
||||
# 文件md5值.后缀名
|
||||
filename = os.path.basename(file_path)
|
||||
filename = f"{filename}"
|
||||
|
||||
# 组合成完整的对象Key
|
||||
return f"{prefix}{filename}"
|
||||
|
||||
def upload_image(self, file_name, prefix='jd_'):
|
||||
"""
|
||||
上传文件到COS并返回可访问的URL
|
||||
|
||||
:param file_url: 文件路径
|
||||
:param file_name: 文件名称
|
||||
:param media_type: 文件类型
|
||||
:param prefix: 存储前缀,用于分类存储
|
||||
:return: 文件访问URL
|
||||
"""
|
||||
# 检查文件是否存在
|
||||
|
||||
|
||||
|
||||
file_path = os.path.join(BASE_DIR, file_name)
|
||||
|
||||
# response = requests.get(url, headers=self.headers, stream=True)
|
||||
|
||||
# if response.status_code == 200:
|
||||
# with open(file_path, "wb") as f:
|
||||
# for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大
|
||||
# f.write(chunk)
|
||||
# else:
|
||||
# raise Exception(f"文件下载失败,{file_name}")
|
||||
|
||||
# 生成对象Key
|
||||
object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1])
|
||||
|
||||
try:
|
||||
upload_response = requests.post(
|
||||
self.api,
|
||||
data={
|
||||
"privacy": self.privacy,
|
||||
"fileName": object_key,
|
||||
}
|
||||
)
|
||||
|
||||
if upload_response.status_code != 200:
|
||||
raise Exception('上传接口请求失败')
|
||||
resp = upload_response.json()
|
||||
name = resp["data"]["name"]
|
||||
file_url = resp["data"]["path"]
|
||||
policy = resp["data"]["policy"]
|
||||
with open(file_path, 'rb') as f:
|
||||
oss_push_resp = requests.post(
|
||||
policy["host"],
|
||||
files={
|
||||
"key": policy["dir"],
|
||||
"OSSAccessKeyId": policy["accessid"],
|
||||
"name": name,
|
||||
"policy": policy["policy"],
|
||||
"success_action_status": 200,
|
||||
"signature": policy["signature"],
|
||||
"file": f,
|
||||
}
|
||||
)
|
||||
if oss_push_resp.status_code == 200:
|
||||
return file_url
|
||||
raise Exception("OSS上传失败")
|
||||
except Exception:
|
||||
raise Exception(f"上传失败: \n{traceback.format_exc()}")
|
||||
finally:
|
||||
print('success')
|
||||
# os.remove(file_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
cos_uploader = OSSUploader("prod")
|
||||
url =cos_uploader.upload_image('./example01.jpg')
|
||||
print(url)
|
||||
@@ -1,121 +0,0 @@
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
# file_urls = [
|
||||
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav",
|
||||
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav",
|
||||
# ]
|
||||
class Vico_recognition:
|
||||
def __init__(self,file_urls):
|
||||
self.api_key=''
|
||||
self.backend_model_name=''
|
||||
self.api_base=''
|
||||
self.file_urls=file_urls
|
||||
|
||||
# 提交文件转写任务,包含待转写文件url列表
|
||||
async def submit_task(self) -> str:
|
||||
self.api_key, self.backend_model_name, self.api_base =await Voice_recognize()
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
data = {
|
||||
"model": self.backend_model_name,
|
||||
"input": {"file_urls": self.file_urls},
|
||||
"parameters": {
|
||||
"channel_id": [0],
|
||||
"vocabulary_id": "vocab-Xxxx",
|
||||
},
|
||||
}
|
||||
# 录音文件转写服务url
|
||||
service_url = (
|
||||
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
|
||||
)
|
||||
response = requests.post(
|
||||
service_url, headers=headers, data=json.dumps(data)
|
||||
)
|
||||
|
||||
# 打印响应内容
|
||||
if response.status_code == 200:
|
||||
return response.json()["output"]["task_id"]
|
||||
else:
|
||||
print("task failed!")
|
||||
print(response.json())
|
||||
return None
|
||||
|
||||
async def download_transcription_result(self, transcription_url):
|
||||
"""
|
||||
Args:
|
||||
transcription_url (str): 转写结果文件URL
|
||||
Returns:
|
||||
dict: 转写结果内容
|
||||
"""
|
||||
try:
|
||||
response = requests.get(transcription_url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
print(f"下载转写结果失败: {e}")
|
||||
return None
|
||||
|
||||
# 循环查询任务状态直到成功
|
||||
async def wait_for_complete(self,task_id):
|
||||
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"X-DashScope-Async": "enable",
|
||||
}
|
||||
|
||||
pending = True
|
||||
while pending:
|
||||
# 查询任务状态服务url
|
||||
service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
|
||||
response = requests.post(
|
||||
service_url, headers=headers
|
||||
)
|
||||
if response.status_code == 200:
|
||||
status = response.json()['output']['task_status']
|
||||
if status == 'SUCCEEDED':
|
||||
print("task succeeded!")
|
||||
pending = False
|
||||
return response.json()['output']['results']
|
||||
elif status == 'RUNNING' or status == 'PENDING':
|
||||
pass
|
||||
else:
|
||||
print("task failed!")
|
||||
pending = False
|
||||
else:
|
||||
print("query failed!")
|
||||
pending = False
|
||||
time.sleep(0.1)
|
||||
async def run(self):
|
||||
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
|
||||
task_id=await self.submit_task()
|
||||
result=await self.wait_for_complete(task_id)
|
||||
result_context=[]
|
||||
for i in result:
|
||||
transcription_url=i['transcription_url']
|
||||
print(f"转写URL: {transcription_url}")
|
||||
|
||||
# 下载并打印转写内容
|
||||
content = await self.download_transcription_result(transcription_url)
|
||||
if content:
|
||||
content=json.dumps(content, indent=2, ensure_ascii=False)
|
||||
context=re.findall(r'"text": "(.*?)"', content)
|
||||
result_context.append(context[0])
|
||||
result=''.join(result_context)
|
||||
return (result)
|
||||
|
||||
|
||||
|
||||
|
||||
277
api/app/core/memory/agent/services/optimized_llm_service.py
Normal file
277
api/app/core/memory/agent/services/optimized_llm_service.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""
|
||||
优化的LLM服务类,用于压缩和统一LLM调用
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class OptimizedLLMService:
|
||||
"""
|
||||
优化的LLM服务类,提供统一的LLM调用接口
|
||||
|
||||
特性:
|
||||
1. 客户端复用 - 避免重复创建LLM客户端
|
||||
2. 批量处理 - 支持并发处理多个请求
|
||||
3. 错误处理 - 统一的错误处理和降级策略
|
||||
4. 性能优化 - 缓存和连接池优化
|
||||
"""
|
||||
|
||||
def __init__(self, db_session: Session):
|
||||
self.db_session = db_session
|
||||
self.client_factory = MemoryClientFactory(db_session)
|
||||
self._client_cache: Dict[str, OpenAIClient] = {}
|
||||
|
||||
def _get_cached_client(self, llm_model_id: str) -> OpenAIClient:
|
||||
"""获取缓存的LLM客户端,避免重复创建"""
|
||||
if llm_model_id not in self._client_cache:
|
||||
self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id)
|
||||
return self._client_cache[llm_model_id]
|
||||
|
||||
async def structured_response(
|
||||
self,
|
||||
llm_model_id: str,
|
||||
system_prompt: str,
|
||||
response_model: Type[T],
|
||||
user_message: Optional[str] = None,
|
||||
fallback_value: Optional[Any] = None
|
||||
) -> T:
|
||||
"""
|
||||
统一的结构化响应接口
|
||||
|
||||
Args:
|
||||
llm_model_id: LLM模型ID
|
||||
system_prompt: 系统提示词
|
||||
response_model: 响应模型类
|
||||
user_message: 用户消息(可选)
|
||||
fallback_value: 失败时的降级值
|
||||
|
||||
Returns:
|
||||
结构化响应对象
|
||||
"""
|
||||
try:
|
||||
llm_client = self._get_cached_client(llm_model_id)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if user_message:
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}")
|
||||
|
||||
structured = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=response_model
|
||||
)
|
||||
|
||||
if structured is None:
|
||||
logger.warning(f"LLM返回None,使用降级值")
|
||||
return self._create_fallback_response(response_model, fallback_value)
|
||||
|
||||
return structured
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化响应失败: {e}", exc_info=True)
|
||||
return self._create_fallback_response(response_model, fallback_value)
|
||||
|
||||
async def batch_structured_response(
|
||||
self,
|
||||
llm_model_id: str,
|
||||
requests: List[Dict[str, Any]],
|
||||
response_model: Type[T],
|
||||
max_concurrent: int = 5
|
||||
) -> List[T]:
|
||||
"""
|
||||
批量处理结构化响应
|
||||
|
||||
Args:
|
||||
llm_model_id: LLM模型ID
|
||||
requests: 请求列表,每个请求包含system_prompt等参数
|
||||
response_model: 响应模型类
|
||||
max_concurrent: 最大并发数
|
||||
|
||||
Returns:
|
||||
结构化响应列表
|
||||
"""
|
||||
semaphore = asyncio.Semaphore(max_concurrent)
|
||||
|
||||
async def process_single_request(request: Dict[str, Any]) -> T:
|
||||
async with semaphore:
|
||||
return await self.structured_response(
|
||||
llm_model_id=llm_model_id,
|
||||
system_prompt=request.get('system_prompt', ''),
|
||||
response_model=response_model,
|
||||
user_message=request.get('user_message'),
|
||||
fallback_value=request.get('fallback_value')
|
||||
)
|
||||
|
||||
tasks = [process_single_request(req) for req in requests]
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def simple_response(
|
||||
self,
|
||||
llm_model_id: str,
|
||||
system_prompt: str,
|
||||
user_message: Optional[str] = None,
|
||||
fallback_message: str = "信息不足,无法回答"
|
||||
) -> str:
|
||||
"""
|
||||
简单的文本响应接口
|
||||
|
||||
Args:
|
||||
llm_model_id: LLM模型ID
|
||||
system_prompt: 系统提示词
|
||||
user_message: 用户消息(可选)
|
||||
fallback_message: 失败时的降级消息
|
||||
|
||||
Returns:
|
||||
响应文本
|
||||
"""
|
||||
try:
|
||||
llm_client = self._get_cached_client(llm_model_id)
|
||||
|
||||
messages = [{"role": "system", "content": system_prompt}]
|
||||
if user_message:
|
||||
messages.append({"role": "user", "content": user_message})
|
||||
|
||||
response = await llm_client.response(messages=messages)
|
||||
|
||||
if not response or not response.strip():
|
||||
return fallback_message
|
||||
|
||||
return response.strip()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"简单响应失败: {e}", exc_info=True)
|
||||
return fallback_message
|
||||
|
||||
def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T:
|
||||
"""创建降级响应"""
|
||||
try:
|
||||
if fallback_value is not None:
|
||||
if isinstance(fallback_value, response_model):
|
||||
return fallback_value
|
||||
elif isinstance(fallback_value, dict):
|
||||
return response_model(**fallback_value)
|
||||
|
||||
# 尝试创建空的响应模型
|
||||
if hasattr(response_model, 'root'):
|
||||
# RootModel类型
|
||||
return response_model([])
|
||||
else:
|
||||
# 普通BaseModel类型
|
||||
return response_model()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"创建降级响应失败: {e}")
|
||||
# 最后的降级策略
|
||||
if hasattr(response_model, 'root'):
|
||||
return response_model([])
|
||||
else:
|
||||
return response_model()
|
||||
|
||||
def clear_cache(self):
|
||||
"""清理客户端缓存"""
|
||||
self._client_cache.clear()
|
||||
|
||||
|
||||
class LLMServiceMixin:
|
||||
"""
|
||||
LLM服务混入类,为节点提供便捷的LLM调用方法
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._llm_service: Optional[OptimizedLLMService] = None
|
||||
|
||||
def get_llm_service(self, db_session: Session) -> OptimizedLLMService:
|
||||
"""获取LLM服务实例"""
|
||||
if self._llm_service is None:
|
||||
self._llm_service = OptimizedLLMService(db_session)
|
||||
return self._llm_service
|
||||
|
||||
async def call_llm_structured(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
db_session: Session,
|
||||
system_prompt: str,
|
||||
response_model: Type[T],
|
||||
user_message: Optional[str] = None,
|
||||
fallback_value: Optional[Any] = None
|
||||
) -> T:
|
||||
"""
|
||||
便捷的结构化LLM调用方法
|
||||
|
||||
Args:
|
||||
state: 状态字典,包含memory_config
|
||||
db_session: 数据库会话
|
||||
system_prompt: 系统提示词
|
||||
response_model: 响应模型类
|
||||
user_message: 用户消息(可选)
|
||||
fallback_value: 失败时的降级值
|
||||
|
||||
Returns:
|
||||
结构化响应对象
|
||||
"""
|
||||
memory_config = state.get('memory_config')
|
||||
if not memory_config:
|
||||
raise ValueError("State中缺少memory_config")
|
||||
|
||||
llm_model_id = memory_config.llm_model_id
|
||||
if not llm_model_id:
|
||||
raise ValueError("Memory config中缺少llm_model_id")
|
||||
|
||||
llm_service = self.get_llm_service(db_session)
|
||||
return await llm_service.structured_response(
|
||||
llm_model_id=llm_model_id,
|
||||
system_prompt=system_prompt,
|
||||
response_model=response_model,
|
||||
user_message=user_message,
|
||||
fallback_value=fallback_value
|
||||
)
|
||||
|
||||
async def call_llm_simple(
|
||||
self,
|
||||
state: Dict[str, Any],
|
||||
db_session: Session,
|
||||
system_prompt: str,
|
||||
user_message: Optional[str] = None,
|
||||
fallback_message: str = "信息不足,无法回答"
|
||||
) -> str:
|
||||
"""
|
||||
便捷的简单LLM调用方法
|
||||
|
||||
Args:
|
||||
state: 状态字典,包含memory_config
|
||||
db_session: 数据库会话
|
||||
system_prompt: 系统提示词
|
||||
user_message: 用户消息(可选)
|
||||
fallback_message: 失败时的降级消息
|
||||
|
||||
Returns:
|
||||
响应文本
|
||||
"""
|
||||
memory_config = state.get('memory_config')
|
||||
if not memory_config:
|
||||
raise ValueError("State中缺少memory_config")
|
||||
|
||||
llm_model_id = memory_config.llm_model_id
|
||||
if not llm_model_id:
|
||||
raise ValueError("Memory config中缺少llm_model_id")
|
||||
|
||||
llm_service = self.get_llm_service(db_session)
|
||||
return await llm_service.simple_response(
|
||||
llm_model_id=llm_model_id,
|
||||
system_prompt=system_prompt,
|
||||
user_message=user_message,
|
||||
fallback_message=fallback_message
|
||||
)
|
||||
@@ -4,22 +4,19 @@ Parameter Builder for constructing tool call arguments.
|
||||
This service provides tool-specific parameter transformation logic
|
||||
to build correct arguments for each tool type.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class ParameterBuilder:
|
||||
"""Service for building tool call arguments based on tool type."""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the parameter builder."""
|
||||
logger.info("ParameterBuilder initialized")
|
||||
|
||||
|
||||
def build_tool_args(
|
||||
self,
|
||||
tool_name: str,
|
||||
@@ -28,9 +25,8 @@ class ParameterBuilder:
|
||||
search_switch: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
storage_type: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None,
|
||||
user_rag_memory_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build tool arguments based on tool type.
|
||||
@@ -49,7 +45,6 @@ class ParameterBuilder:
|
||||
search_switch: Search routing parameter
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
storage_type: Storage type for the workspace (optional)
|
||||
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
|
||||
|
||||
@@ -60,19 +55,18 @@ class ParameterBuilder:
|
||||
base_args = {
|
||||
"usermessages": tool_call_id,
|
||||
"apply_id": apply_id,
|
||||
"group_id": group_id,
|
||||
"memory_config": memory_config,
|
||||
"group_id": group_id
|
||||
}
|
||||
|
||||
|
||||
# Always add storage_type and user_rag_memory_id (with defaults if None)
|
||||
base_args["storage_type"] = storage_type if storage_type is not None else ""
|
||||
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
|
||||
|
||||
# Tool-specific argument construction
|
||||
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
|
||||
# These tools expect dict context
|
||||
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
|
||||
# Verify expects dict context
|
||||
return {
|
||||
"context": content if isinstance(content, dict) else {"content": content},
|
||||
"context": content if isinstance(content, dict) else {},
|
||||
**base_args
|
||||
}
|
||||
|
||||
@@ -4,31 +4,21 @@ Search Service for executing hybrid search and processing results.
|
||||
This service provides clean search result processing with content extraction
|
||||
and deduplication.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.src.search import run_hybrid_search
|
||||
from app.core.memory.utils.data.text_utils import escape_lucene_query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SearchService:
|
||||
"""Service for executing hybrid search and processing results."""
|
||||
|
||||
def __init__(self, memory_config: "MemoryConfig" = None):
|
||||
"""
|
||||
Initialize the search service.
|
||||
|
||||
Args:
|
||||
memory_config: Optional MemoryConfig for embedding model configuration.
|
||||
If not provided, must be passed to execute_hybrid_search.
|
||||
"""
|
||||
self.memory_config = memory_config
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the search service."""
|
||||
logger.info("SearchService initialized")
|
||||
|
||||
def extract_content_from_result(self, result: dict) -> str:
|
||||
@@ -103,49 +93,40 @@ class SearchService:
|
||||
self,
|
||||
group_id: str,
|
||||
question: str,
|
||||
limit: int = 15,
|
||||
limit: int = 5,
|
||||
search_type: str = "hybrid",
|
||||
include: Optional[List[str]] = None,
|
||||
rerank_alpha: float = 0.6,
|
||||
activation_boost_factor: float = 0.8,
|
||||
rerank_alpha: float = 0.4,
|
||||
output_path: str = "search_results.json",
|
||||
return_raw_results: bool = False,
|
||||
memory_config: "MemoryConfig" = None,
|
||||
memory_config = None
|
||||
) -> Tuple[str, str, Optional[dict]]:
|
||||
"""
|
||||
Execute hybrid search with two-stage ranking.
|
||||
|
||||
Stage 1: Filter by content relevance (BM25 + Embedding)
|
||||
Stage 2: Rerank by activation values (ACTR)
|
||||
Execute hybrid search and return clean content.
|
||||
|
||||
Args:
|
||||
group_id: Group identifier for filtering
|
||||
group_id: Group identifier for filtering results
|
||||
question: Search query text
|
||||
limit: Max results per category (default: 15)
|
||||
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: BM25 weight (default: 0.6)
|
||||
activation_boost_factor: Activation impact on memory strength (default: 0.8)
|
||||
output_path: JSON output path (default: "search_results.json")
|
||||
return_raw_results: Return full metadata (default: False)
|
||||
memory_config: MemoryConfig for embedding model
|
||||
limit: Maximum number of results to return (default: 5)
|
||||
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
|
||||
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
|
||||
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
|
||||
output_path: Path to save search results (default: "search_results.json")
|
||||
return_raw_results: If True, also return the raw search results as third element (default: False)
|
||||
memory_config: Memory configuration object (required)
|
||||
|
||||
Returns:
|
||||
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
|
||||
Tuple of (clean_content, cleaned_query, raw_results)
|
||||
raw_results is None if return_raw_results=False
|
||||
"""
|
||||
if include is None:
|
||||
include = ["statements", "chunks", "entities", "summaries"]
|
||||
|
||||
# Use provided memory_config or fall back to instance config
|
||||
config = memory_config or self.memory_config
|
||||
if not config:
|
||||
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
|
||||
|
||||
|
||||
# Clean query
|
||||
cleaned_query = self.clean_query(question)
|
||||
|
||||
|
||||
try:
|
||||
# Execute search using memory_config
|
||||
# Execute search
|
||||
answer = await run_hybrid_search(
|
||||
query_text=cleaned_query,
|
||||
search_type=search_type,
|
||||
@@ -153,9 +134,8 @@ class SearchService:
|
||||
limit=limit,
|
||||
include=include,
|
||||
output_path=output_path,
|
||||
memory_config=config,
|
||||
rerank_alpha=rerank_alpha,
|
||||
activation_boost_factor=activation_boost_factor,
|
||||
memory_config=memory_config,
|
||||
rerank_alpha=rerank_alpha
|
||||
)
|
||||
|
||||
# Extract results based on search type and include parameter
|
||||
@@ -3,12 +3,22 @@ Template Service for loading and rendering Jinja2 templates.
|
||||
|
||||
This service provides centralized template management with caching and error handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from jinja2 import (
|
||||
Environment,
|
||||
FileSystemLoader,
|
||||
Template,
|
||||
TemplateNotFound,
|
||||
)
|
||||
|
||||
from app.core.logging_config import (
|
||||
get_agent_logger,
|
||||
log_prompt_rendering,
|
||||
)
|
||||
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
@@ -1,7 +0,0 @@
|
||||
"""Agent utilities."""
|
||||
|
||||
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
|
||||
|
||||
__all__ = [
|
||||
"MultimodalProcessor",
|
||||
]
|
||||
@@ -12,32 +12,49 @@ async def get_chunked_dialogs(
|
||||
group_id: str = "group_1",
|
||||
user_id: str = "user1",
|
||||
apply_id: str = "applyid",
|
||||
content: str = "这是用户的输入",
|
||||
messages: list = None,
|
||||
ref_id: str = "wyl_20251027",
|
||||
config_id: str = None
|
||||
) -> List[DialogData]:
|
||||
"""Generate chunks from all test data entries using the specified chunker strategy.
|
||||
"""Generate chunks from structured messages using the specified chunker strategy.
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
group_id: Group identifier
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
content: Dialog content
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference identifier
|
||||
config_id: Configuration ID for processing
|
||||
|
||||
Returns:
|
||||
List of DialogData objects with generated chunks for each test entry
|
||||
List of DialogData objects with generated chunks
|
||||
"""
|
||||
dialog_data_list = []
|
||||
messages = []
|
||||
|
||||
messages.append(ConversationMessage(role="用户", msg=content))
|
||||
|
||||
# Create DialogData
|
||||
conversation_context = ConversationContext(msgs=messages)
|
||||
# Create DialogData with group_id based on the entry's id for uniqueness
|
||||
from app.core.logging_config import get_agent_logger
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
if not messages or not isinstance(messages, list) or len(messages) == 0:
|
||||
raise ValueError("messages parameter must be a non-empty list")
|
||||
|
||||
conversation_messages = []
|
||||
|
||||
for idx, msg in enumerate(messages):
|
||||
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
|
||||
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
|
||||
|
||||
role = msg['role']
|
||||
content = msg['content']
|
||||
|
||||
if role not in ['user', 'assistant']:
|
||||
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
|
||||
|
||||
if content.strip():
|
||||
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
|
||||
|
||||
if not conversation_messages:
|
||||
raise ValueError("Message list cannot be empty after filtering")
|
||||
|
||||
conversation_context = ConversationContext(msgs=conversation_messages)
|
||||
dialog_data = DialogData(
|
||||
context=conversation_context,
|
||||
ref_id=ref_id,
|
||||
@@ -46,25 +63,11 @@ async def get_chunked_dialogs(
|
||||
apply_id=apply_id,
|
||||
config_id=config_id
|
||||
)
|
||||
# Create DialogueChunker and process the dialogue
|
||||
|
||||
chunker = DialogueChunker(chunker_strategy)
|
||||
extracted_chunks = await chunker.process_dialogue(dialog_data)
|
||||
dialog_data.chunks = extracted_chunks
|
||||
|
||||
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
|
||||
|
||||
dialog_data_list.append(dialog_data)
|
||||
|
||||
# Convert to dict with datetime serialized
|
||||
def serialize_datetime(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
|
||||
|
||||
combined_output = [dd.model_dump() for dd in dialog_data_list]
|
||||
|
||||
print(dialog_data_list)
|
||||
|
||||
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
|
||||
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
|
||||
|
||||
|
||||
return dialog_data_list
|
||||
return [dialog_data]
|
||||
|
||||
56
api/app/core/memory/agent/utils/llm_client_pool.py
Normal file
56
api/app/core/memory/agent/utils/llm_client_pool.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Optional
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client_fast
|
||||
from app.db import get_db
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class LLMClientPool:
|
||||
"""LLM客户端连接池"""
|
||||
|
||||
def __init__(self, max_size: int = 5):
|
||||
self.max_size = max_size
|
||||
self.pools: Dict[str, asyncio.Queue] = {}
|
||||
self.active_clients: Dict[str, int] = {}
|
||||
|
||||
async def get_client(self, llm_model_id: str):
|
||||
"""获取LLM客户端"""
|
||||
if llm_model_id not in self.pools:
|
||||
self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size)
|
||||
self.active_clients[llm_model_id] = 0
|
||||
|
||||
pool = self.pools[llm_model_id]
|
||||
|
||||
try:
|
||||
# 尝试从池中获取客户端
|
||||
client = pool.get_nowait()
|
||||
logger.debug(f"从池中获取LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
except asyncio.QueueEmpty:
|
||||
# 池为空,创建新客户端
|
||||
if self.active_clients[llm_model_id] < self.max_size:
|
||||
db_session = next(get_db())
|
||||
client = get_llm_client_fast(llm_model_id, db_session)
|
||||
self.active_clients[llm_model_id] += 1
|
||||
logger.debug(f"创建新LLM客户端: {llm_model_id}")
|
||||
return client
|
||||
else:
|
||||
# 等待可用客户端
|
||||
logger.debug(f"等待LLM客户端可用: {llm_model_id}")
|
||||
return await pool.get()
|
||||
|
||||
async def return_client(self, llm_model_id: str, client):
|
||||
"""归还LLM客户端到池中"""
|
||||
if llm_model_id in self.pools:
|
||||
try:
|
||||
self.pools[llm_model_id].put_nowait(client)
|
||||
logger.debug(f"归还LLM客户端到池: {llm_model_id}")
|
||||
except asyncio.QueueFull:
|
||||
# 池已满,丢弃客户端
|
||||
self.active_clients[llm_model_id] -= 1
|
||||
logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}")
|
||||
|
||||
# 全局客户端池
|
||||
llm_client_pool = LLMClientPool()
|
||||
@@ -1,40 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Annotated, TypedDict
|
||||
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.utils.config.config_utils import (
|
||||
get_picture_config,
|
||||
get_voice_config,
|
||||
)
|
||||
|
||||
# Removed global variable imports - use dependency injection instead
|
||||
from dotenv import load_dotenv
|
||||
from langchain_core.messages import AnyMessage
|
||||
from langgraph.graph import add_messages
|
||||
from openai import OpenAI
|
||||
|
||||
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def picture_model_requests(image_url):
|
||||
'''
|
||||
|
||||
Args:
|
||||
image_url:
|
||||
Returns:
|
||||
|
||||
'''
|
||||
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
|
||||
system_prompt = await read_template_file(file_path)
|
||||
result = await Picture_recognize(image_url,system_prompt)
|
||||
return (result)
|
||||
class WriteState(TypedDict):
|
||||
'''
|
||||
Langgrapg Writing TypedDict
|
||||
@@ -44,39 +16,69 @@ class WriteState(TypedDict):
|
||||
apply_id:str
|
||||
group_id:str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
memory_config: object
|
||||
write_result: dict
|
||||
data:str
|
||||
|
||||
class ReadState(TypedDict):
|
||||
'''
|
||||
Langgrapg READING TypedDict
|
||||
name:
|
||||
id:user id
|
||||
loop_count:Traverse times
|
||||
search_switch:type
|
||||
config_id: configuration id for filtering results
|
||||
errors: list of errors that occurred during workflow execution
|
||||
'''
|
||||
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
|
||||
name: str
|
||||
id: str
|
||||
loop_count:int
|
||||
"""
|
||||
LangGraph 工作流状态定义
|
||||
|
||||
Attributes:
|
||||
messages: 消息列表,支持自动追加
|
||||
loop_count: 遍历次数
|
||||
search_switch: 搜索类型开关
|
||||
group_id: 组标识
|
||||
config_id: 配置ID,用于过滤结果
|
||||
data: 从content_input_node传递的内容数据
|
||||
spit_data: 从Split_The_Problem传递的分解结果
|
||||
tool_calls: 工具调用请求列表
|
||||
tool_results: 工具执行结果列表
|
||||
memory_config: 内存配置对象
|
||||
"""
|
||||
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
|
||||
loop_count: int
|
||||
search_switch: str
|
||||
user_id: str
|
||||
apply_id: str
|
||||
group_id: str
|
||||
config_id: str
|
||||
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
|
||||
|
||||
|
||||
data: str # 新增字段用于传递内容
|
||||
spit_data: dict # 新增字段用于传递问题分解结果
|
||||
problem_extension:dict
|
||||
storage_type: str
|
||||
user_rag_memory_id: str
|
||||
llm_id: str
|
||||
embedding_id: str
|
||||
memory_config: object # 新增字段用于传递内存配置对象
|
||||
retrieve:dict
|
||||
RetrieveSummary: dict
|
||||
InputSummary: dict
|
||||
verify: dict
|
||||
SummaryFails: dict
|
||||
summary: dict
|
||||
class COUNTState:
|
||||
'''
|
||||
The number of times the workflow dialogue retrieval content has no correct message recall traversal
|
||||
'''
|
||||
"""
|
||||
工作流对话检索内容计数器
|
||||
|
||||
用于记录工作流对话检索内容没有正确消息召回遍历的次数。
|
||||
"""
|
||||
|
||||
def __init__(self, limit: int = 5):
|
||||
"""
|
||||
初始化计数器
|
||||
|
||||
Args:
|
||||
limit: 最大计数限制,默认为5
|
||||
"""
|
||||
self.total: int = 0 # 当前累加值
|
||||
self.limit: int = limit # 最大上限
|
||||
|
||||
def add(self, value: int = 1):
|
||||
"""累加数字,如果达到上限就保持最大值"""
|
||||
def add(self, value: int = 1) -> None:
|
||||
"""
|
||||
累加数字,如果达到上限就保持最大值
|
||||
|
||||
Args:
|
||||
value: 要累加的值,默认为1
|
||||
"""
|
||||
self.total += value
|
||||
print(f"[COUNTState] 当前值: {self.total}")
|
||||
if self.total >= self.limit:
|
||||
@@ -84,21 +86,19 @@ class COUNTState:
|
||||
self.total = self.limit # 达到上限不再增加
|
||||
|
||||
def get_total(self) -> int:
|
||||
"""获取当前累加值"""
|
||||
"""
|
||||
获取当前累加值
|
||||
|
||||
Returns:
|
||||
当前累加值
|
||||
"""
|
||||
return self.total
|
||||
|
||||
def reset(self):
|
||||
def reset(self) -> None:
|
||||
"""手动重置累加值"""
|
||||
self.total = 0
|
||||
print("[COUNTState] 已重置为 0")
|
||||
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
grouped[item[query_key]].append(item[result_key])
|
||||
return [{key: values} for key, values in grouped.items()]
|
||||
|
||||
def deduplicate_entries(entries):
|
||||
seen = set()
|
||||
deduped = []
|
||||
@@ -109,70 +109,37 @@ def deduplicate_entries(entries):
|
||||
deduped.append(entry)
|
||||
return deduped
|
||||
|
||||
def merge_to_key_value_pairs(data, query_key, result_key):
|
||||
grouped = defaultdict(list)
|
||||
for item in data:
|
||||
grouped[item[query_key]].append(item[result_key])
|
||||
return [{key: values} for key, values in grouped.items()]
|
||||
|
||||
|
||||
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
|
||||
def convert_extended_question_to_question(data):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
递归地将数据中的 extended_question 字段转换为 question 字段
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
PROMPT_TICKET_EXTRACTION: Extraction prompt
|
||||
picture_model_name: Picture model name (required, no longer from global variables)
|
||||
data: 要转换的数据(可能是字典、列表或其他类型)
|
||||
|
||||
Returns:
|
||||
转换后的数据
|
||||
"""
|
||||
try:
|
||||
model_config = get_picture_config(picture_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base=model_config['api_base']
|
||||
|
||||
logger.info(f"model_name: {backend_model_name}")
|
||||
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
|
||||
logger.info(f"base_url: {model_config['api_base']}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key=api_key, base_url=api_base,
|
||||
)
|
||||
completion = client.chat.completions.create(
|
||||
model=backend_model_name,
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url":image_path,
|
||||
},
|
||||
{"type": "text",
|
||||
"text": PROMPT_TICKET_EXTRACTION}
|
||||
]
|
||||
}
|
||||
])
|
||||
picture_text = completion.choices[0].message.content
|
||||
picture_text = picture_text.replace('```json', '').replace('```', '')
|
||||
picture_text = json.loads(picture_text)
|
||||
return (picture_text['statement'])
|
||||
|
||||
async def Voice_recognize(voice_model_name: str):
|
||||
"""
|
||||
Updated to eliminate global variables in favor of explicit parameters.
|
||||
|
||||
Args:
|
||||
voice_model_name: Voice model name (required, no longer from global variables)
|
||||
"""
|
||||
try:
|
||||
model_config = get_voice_config(voice_model_name)
|
||||
except Exception as e:
|
||||
err = f"LLM配置不可用:{str(e)}。请检查 config.json 和 runtime.json。"
|
||||
logger.error(err)
|
||||
return err
|
||||
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
|
||||
backend_model_name = model_config["llm_name"].split("/")[-1]
|
||||
api_base = model_config['api_base']
|
||||
return api_key,backend_model_name,api_base
|
||||
|
||||
|
||||
if isinstance(data, dict):
|
||||
# 创建新字典来存储转换后的数据
|
||||
converted = {}
|
||||
for key, value in data.items():
|
||||
if key == 'extended_question':
|
||||
# 将 extended_question 转换为 question
|
||||
converted['question'] = convert_extended_question_to_question(value)
|
||||
else:
|
||||
# 递归处理其他字段
|
||||
converted[key] = convert_extended_question_to_question(value)
|
||||
return converted
|
||||
elif isinstance(data, list):
|
||||
# 递归处理列表中的每个元素
|
||||
return [convert_extended_question_to_question(item) for item in data]
|
||||
else:
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
@@ -1,33 +0,0 @@
|
||||
import os
|
||||
from app.core.config import settings
|
||||
|
||||
def get_mcp_server_config():
|
||||
"""
|
||||
Get the MCP server configuration.
|
||||
|
||||
Uses MCP_SERVER_URL environment variable if set (for Docker),
|
||||
otherwise falls back to SERVER_IP and MCP_PORT (for local development).
|
||||
"""
|
||||
# Get MCP port from environment (default: 8081)
|
||||
mcp_port = os.getenv("MCP_PORT", "8081")
|
||||
|
||||
# In Docker: MCP_SERVER_URL=http://mcp-server:8081
|
||||
# In local dev: uses SERVER_IP (127.0.0.1 or localhost)
|
||||
mcp_server_url = os.getenv("MCP_SERVER_URL")
|
||||
|
||||
if mcp_server_url:
|
||||
# Docker environment: use full URL from environment
|
||||
base_url = mcp_server_url
|
||||
else:
|
||||
# Local development: build URL from SERVER_IP and MCP_PORT
|
||||
base_url = f"http://{settings.SERVER_IP}:{mcp_port}"
|
||||
|
||||
mcp_server_config = {
|
||||
"data_flow": {
|
||||
"url": f"{base_url}/sse",
|
||||
"transport": "sse",
|
||||
"timeout": 15000,
|
||||
"sse_read_timeout": 15000,
|
||||
}
|
||||
}
|
||||
return mcp_server_config
|
||||
@@ -1,260 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from langchain_core.messages import AnyMessage
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
def _to_openai_messages(msgs: List[AnyMessage]) -> List[dict]:
|
||||
out = []
|
||||
for m in msgs:
|
||||
if hasattr(m, "content"):
|
||||
out.append({"role": "user", "content": getattr(m, "content", "")})
|
||||
elif isinstance(m, dict) and "role" in m and "content" in m:
|
||||
out.append(m)
|
||||
else:
|
||||
out.append({"role": "user", "content": str(m)})
|
||||
return out
|
||||
|
||||
|
||||
def _extract_content(resp: Any) -> str:
|
||||
"""Extract LLM content and sanitize to raw JSON/text.
|
||||
|
||||
- Supports both object and dict response shapes.
|
||||
- Removes leading role labels (e.g., "Assistant:").
|
||||
- Strips Markdown code fences like ```json ... ```.
|
||||
- Attempts to isolate the first valid JSON array/object block when extra text is present.
|
||||
"""
|
||||
|
||||
def _to_text(r: Any) -> str:
|
||||
try:
|
||||
# 对象形式: resp.choices[0].message.content
|
||||
if hasattr(r, "choices") and getattr(r, "choices", None):
|
||||
msg = r.choices[0].message
|
||||
if hasattr(msg, "content"):
|
||||
return msg.content
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
return msg["content"]
|
||||
# 字典形式: resp["choices"][0]["message"]["content"]
|
||||
if isinstance(r, dict):
|
||||
return r.get("choices", [{}])[0].get("message", {}).get("content", "")
|
||||
except Exception:
|
||||
pass
|
||||
return str(r)
|
||||
|
||||
def _clean_text(text: str) -> str:
|
||||
s = str(text).strip()
|
||||
# 移除可能的角色前缀
|
||||
s = re.sub(r"^\s*(Assistant|assistant)\s*:\s*", "", s)
|
||||
# 提取 ```json ... ``` 代码块
|
||||
m = re.search(r"```json\s*(.*?)\s*```", s, flags=re.S | re.I)
|
||||
if m:
|
||||
s = m.group(1).strip()
|
||||
# 如果仍然包含多余文本,尝试截取第一个 JSON 数组/对象片段
|
||||
if not (s.startswith("{") or s.startswith("[")):
|
||||
left = s.find("[")
|
||||
right = s.rfind("]")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
else:
|
||||
left = s.find("{")
|
||||
right = s.rfind("}")
|
||||
if left != -1 and right != -1 and right > left:
|
||||
s = s[left:right + 1].strip()
|
||||
return s
|
||||
|
||||
raw = _to_text(resp)
|
||||
return _clean_text(raw)
|
||||
|
||||
def Resolve_username(usermessages):
|
||||
'''
|
||||
Extract username
|
||||
Args:
|
||||
usermessages: user name
|
||||
|
||||
Returns:
|
||||
|
||||
'''
|
||||
usermessages = usermessages.split('_')[1:]
|
||||
sessionid = '_'.join(usermessages[:-1])
|
||||
return sessionid
|
||||
|
||||
|
||||
# TODO: USE app.core.memory.src.utils.render_template instead
|
||||
async def read_template_file(template_path: str) -> str:
|
||||
"""
|
||||
读取模板文件
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
|
||||
Returns:
|
||||
模板内容字符串
|
||||
|
||||
Note:
|
||||
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
|
||||
"""
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
raise
|
||||
except IOError as e:
|
||||
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def Problem_Extension_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
extent_quest = []
|
||||
original = context.get('original', '')
|
||||
messages = context.get('context', '')
|
||||
|
||||
# Handle empty or non-string messages
|
||||
if not messages:
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, str):
|
||||
try:
|
||||
messages = json.loads(messages)
|
||||
except json.JSONDecodeError:
|
||||
# If JSON parsing fails, return empty list
|
||||
return extent_quest, original
|
||||
|
||||
if isinstance(messages, list):
|
||||
for message in messages:
|
||||
question = message.get('question', '')
|
||||
type = message.get('type', '')
|
||||
extent_quest.append({"role": "user", "content": f"问题:{question};问题类型:{type}"})
|
||||
|
||||
return extent_quest, original
|
||||
|
||||
|
||||
async def Retriev_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
logger.info(f"Retriev_messages_deal input: type={type(context)}, value={str(context)[:500]}")
|
||||
|
||||
if isinstance(context, dict):
|
||||
logger.info(f"Retriev_messages_deal: context is dict with keys={list(context.keys())}")
|
||||
if 'context' in context or 'original' in context:
|
||||
content = context.get('context', {})
|
||||
original = context.get('original', '')
|
||||
logger.info(f"Retriev_messages_deal output: content_type={type(content)}, content={str(content)[:300]}, original='{original[:50] if original else ''}'")
|
||||
return content, original
|
||||
|
||||
# Return empty defaults if context is not a dict or doesn't have expected keys
|
||||
logger.warning(f"Retriev_messages_deal: context missing expected keys, returning empty defaults")
|
||||
return {}, ''
|
||||
|
||||
async def Verify_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
|
||||
query = context['context']['Query']
|
||||
Query_small_list = context['context']['Expansion_issue']
|
||||
Result_small = []
|
||||
Query_small = []
|
||||
for i in Query_small_list:
|
||||
Result_small.append(i['Answer_Small'][0])
|
||||
Query_small.append(i['Query_small'])
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Summary_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
query = re.findall(r'"query": (.*?),', messages)[0]
|
||||
query = query.replace('[', '').replace(']', '').strip()
|
||||
matches = re.findall(r'"answer_small"\s*:\s*"(\[.*?\])"', messages)
|
||||
answer_small_texts = []
|
||||
for m in matches:
|
||||
try:
|
||||
parsed = json.loads(m)
|
||||
for item in parsed:
|
||||
answer_small_texts.append(item.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
except Exception:
|
||||
answer_small_texts.append(m.strip().replace('\\', '').replace('[', '').replace(']', ''))
|
||||
|
||||
return answer_small_texts, query
|
||||
|
||||
|
||||
async def VerifyTool_messages_deal(context):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
messages = str(context).replace('\\n', '').replace('\n', '').replace('\\', '')
|
||||
content_messages = messages.split('"context":')[1].replace('""', '"')
|
||||
messages = str(content_messages).split("name='Retrieve'")[0]
|
||||
query = re.findall('"Query": "(.*?)"', messages)[0]
|
||||
Query_small = re.findall('"Query_small": "(.*?)"', messages)
|
||||
Result_small = re.findall('"Result_small": "(.*?)"', messages)
|
||||
return Query_small, Result_small, query
|
||||
|
||||
|
||||
async def Retrieve_Summary_messages_deal(context):
|
||||
pass
|
||||
|
||||
|
||||
async def Retrieve_verify_tool_messages_deal(context, history, query):
|
||||
'''
|
||||
Extract data
|
||||
Args:
|
||||
context:
|
||||
Returns:
|
||||
'''
|
||||
results = []
|
||||
# 统一转为字符串,避免 None 或非字符串导致正则报错
|
||||
text = str(context)
|
||||
blocks = re.findall(r'\{(.*?)\}', text, flags=re.S)
|
||||
for block in blocks:
|
||||
query_small = re.search(r'"Query_small"\s*:\s*"([^"]*)"', block)
|
||||
answer_small = re.search(r'"Answer_Small"\s*:\s*(\[[^\]]*\])', block)
|
||||
status = re.search(r'"status"\s*:\s*"([^"]*)"', block)
|
||||
query_answer = re.search(r'"Query_answer"\s*:\s*"([^"]*)"', block)
|
||||
|
||||
results.append({
|
||||
"query_small": query_small.group(1) if query_small else None,
|
||||
"answer_small": answer_small.group(1) if answer_small else None,
|
||||
# 将缺失的 status 统一为空字符串,后续用字符串判定,避免 NoneType 错误
|
||||
"status": status.group(1) if status else "",
|
||||
"query_answer": query_answer.group(1) if query_answer else None
|
||||
})
|
||||
result = []
|
||||
for r in results:
|
||||
# 统一按字符串判定状态,兼容大小写和缺失情况
|
||||
status_str = str(r.get('status', '')).strip().lower()
|
||||
if status_str == 'false':
|
||||
continue
|
||||
else:
|
||||
result.append(r)
|
||||
split_result = 'failed' if not result else 'success'
|
||||
result = {"data": {"query": query, "expansion_issue": result}, "split_result": split_result, "reason": "",
|
||||
"history": history}
|
||||
return result
|
||||
194
api/app/core/memory/agent/utils/messages_tools.py
Normal file
194
api/app/core/memory/agent/utils/messages_tools.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from typing import List, Dict, Any
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
async def read_template_file(template_path: str) -> str:
|
||||
"""
|
||||
读取模板文件
|
||||
|
||||
Args:
|
||||
template_path: 模板文件路径
|
||||
|
||||
Returns:
|
||||
模板内容字符串
|
||||
|
||||
Note:
|
||||
建议使用 app.core.memory.utils.template_render 中的统一模板渲染功能
|
||||
"""
|
||||
try:
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
logger.error(f"模板文件未找到: {template_path}")
|
||||
raise
|
||||
except IOError as e:
|
||||
logger.error(f"读取模板文件失败: {template_path}, 错误: {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
def reorder_output_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
重新排序输出结果,将 retrieval_summary 类型的数据放到最后面
|
||||
|
||||
Args:
|
||||
results: 原始输出结果列表
|
||||
|
||||
Returns:
|
||||
重新排序后的结果列表
|
||||
"""
|
||||
retrieval_summaries = []
|
||||
other_results = []
|
||||
|
||||
# 分离 retrieval_summary 和其他类型的结果
|
||||
for result in results:
|
||||
if 'summary' in result.get('type'):
|
||||
retrieval_summaries.append(result)
|
||||
else:
|
||||
other_results.append(result)
|
||||
|
||||
# 将 retrieval_summary 放到最后
|
||||
return other_results + retrieval_summaries
|
||||
|
||||
def optimize_search_results(intermediate_outputs):
|
||||
"""
|
||||
优化检索结果,合并多个搜索结果,过滤空结果,统一格式
|
||||
|
||||
Args:
|
||||
intermediate_outputs: 原始的中间输出列表
|
||||
|
||||
Returns:
|
||||
优化后的检索结果列表
|
||||
"""
|
||||
optimized_results = []
|
||||
|
||||
for item in intermediate_outputs:
|
||||
if not item or item == [] or item == {}:
|
||||
continue
|
||||
|
||||
# 检查是否是搜索结果类型
|
||||
if isinstance(item, dict) and item.get('type') == 'search_result':
|
||||
raw_results = item.get('raw_results', {})
|
||||
|
||||
# 如果 raw_results 为空,跳过
|
||||
if not raw_results or raw_results == [] or raw_results == {}:
|
||||
continue
|
||||
|
||||
# 创建优化后的结果结构
|
||||
optimized_item = {
|
||||
"type": "search_result",
|
||||
"title": f"检索结果 ({item.get('index', 1)}/{item.get('total', 1)})",
|
||||
"query": item.get('query', ''),
|
||||
"raw_results": {},
|
||||
"index": item.get('index', 1),
|
||||
"total": item.get('total', 1)
|
||||
}
|
||||
|
||||
# 合并所有搜索结果类型到一个 raw_results 中
|
||||
merged_raw_results = {}
|
||||
|
||||
# 处理 time_search
|
||||
if 'time_search' in raw_results and raw_results['time_search']:
|
||||
merged_raw_results['time_search'] = raw_results['time_search']
|
||||
|
||||
# 处理 keyword_search
|
||||
if 'keyword_search' in raw_results and raw_results['keyword_search']:
|
||||
merged_raw_results['keyword_search'] = raw_results['keyword_search']
|
||||
|
||||
# 处理 embedding_search
|
||||
if 'embedding_search' in raw_results and raw_results['embedding_search']:
|
||||
merged_raw_results['embedding_search'] = raw_results['embedding_search']
|
||||
|
||||
# 处理 combined_summary
|
||||
if 'combined_summary' in raw_results and raw_results['combined_summary']:
|
||||
merged_raw_results['combined_summary'] = raw_results['combined_summary']
|
||||
|
||||
# 处理 reranked_results
|
||||
if 'reranked_results' in raw_results and raw_results['reranked_results']:
|
||||
merged_raw_results['reranked_results'] = raw_results['reranked_results']
|
||||
|
||||
# 如果合并后的结果不为空,添加到优化结果中
|
||||
if merged_raw_results:
|
||||
optimized_item['raw_results'] = merged_raw_results
|
||||
optimized_results.append(optimized_item)
|
||||
else:
|
||||
# 非搜索结果类型,直接添加
|
||||
optimized_results.append(item)
|
||||
|
||||
return optimized_results
|
||||
|
||||
|
||||
def merge_multiple_search_results(intermediate_outputs):
|
||||
"""
|
||||
将多个搜索结果合并为一个统一的搜索结果
|
||||
|
||||
Args:
|
||||
intermediate_outputs: 原始的中间输出列表
|
||||
|
||||
Returns:
|
||||
合并后的结果列表
|
||||
"""
|
||||
search_results = []
|
||||
other_results = []
|
||||
|
||||
# 分离搜索结果和其他结果
|
||||
for item in intermediate_outputs:
|
||||
if isinstance(item, dict) and item.get('type') == 'search_result':
|
||||
raw_results = item.get('raw_results', {})
|
||||
# 只保留有内容的搜索结果
|
||||
if raw_results and raw_results != [] and raw_results != {}:
|
||||
search_results.append(item)
|
||||
else:
|
||||
other_results.append(item)
|
||||
|
||||
# 如果没有搜索结果,返回原始结果
|
||||
if not search_results:
|
||||
return intermediate_outputs
|
||||
|
||||
# 如果只有一个搜索结果,优化格式后返回
|
||||
if len(search_results) == 1:
|
||||
optimized = optimize_search_results(search_results)
|
||||
return other_results + optimized
|
||||
|
||||
# 合并多个搜索结果
|
||||
merged_raw_results = {}
|
||||
all_queries = []
|
||||
|
||||
for result in search_results:
|
||||
query = result.get('query', '')
|
||||
if query:
|
||||
all_queries.append(query)
|
||||
|
||||
raw_results = result.get('raw_results', {})
|
||||
|
||||
# 合并各种搜索类型的结果
|
||||
for search_type in ['time_search', 'keyword_search', 'embedding_search', 'combined_summary',
|
||||
'reranked_results']:
|
||||
if search_type in raw_results and raw_results[search_type]:
|
||||
if search_type not in merged_raw_results:
|
||||
merged_raw_results[search_type] = raw_results[search_type]
|
||||
else:
|
||||
# 如果是字典类型,需要合并
|
||||
if isinstance(raw_results[search_type], dict) and isinstance(merged_raw_results[search_type], dict):
|
||||
for key, value in raw_results[search_type].items():
|
||||
if key not in merged_raw_results[search_type]:
|
||||
merged_raw_results[search_type][key] = value
|
||||
elif isinstance(value, list) and isinstance(merged_raw_results[search_type][key], list):
|
||||
merged_raw_results[search_type][key].extend(value)
|
||||
elif isinstance(raw_results[search_type], list):
|
||||
if isinstance(merged_raw_results[search_type], list):
|
||||
merged_raw_results[search_type].extend(raw_results[search_type])
|
||||
else:
|
||||
merged_raw_results[search_type] = raw_results[search_type]
|
||||
|
||||
# 创建合并后的结果
|
||||
if merged_raw_results:
|
||||
merged_result = {
|
||||
"type": "search_result",
|
||||
"title": f"合并检索结果 (共{len(search_results)}个查询)",
|
||||
"query": " | ".join(all_queries),
|
||||
"raw_results": merged_raw_results,
|
||||
"index": 1,
|
||||
"total": 1
|
||||
}
|
||||
return other_results + [merged_result]
|
||||
|
||||
return other_results
|
||||
@@ -1,38 +0,0 @@
|
||||
|
||||
|
||||
# project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
# sys.path.insert(0, project_root)
|
||||
|
||||
# load_dotenv()
|
||||
|
||||
# async def llm_client_chat(messages: List[dict]) -> str:
|
||||
# """使用 OpenAI 兼容接口进行对话,返回内容字符串。"""
|
||||
# try:
|
||||
# cfg = get_model_config(SELECTED_LLM_ID)
|
||||
# rb_config = RedBearModelConfig(
|
||||
# model_name=cfg["model_name"],
|
||||
# provider=cfg["provider"],
|
||||
# api_key=cfg["api_key"],
|
||||
# base_url=cfg["base_url"],
|
||||
# )
|
||||
# client = OpenAIClient(model_config=rb_config, type_="chat")
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"获取模型配置失败:{e}")
|
||||
# err = f"获取模型配置失败:{str(e)}。请检查!!!"
|
||||
# return err
|
||||
# try:
|
||||
# response = await client.chat(messages)
|
||||
# print(f"model_tool's llm_client_chat response ======>:\n {response}")
|
||||
# return _extract_content(response)
|
||||
# # return _extract_content(result)
|
||||
# except Exception as e:
|
||||
# logger.error(f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。")
|
||||
# return f"LLM调用失败:{str(e)}。请检查 model_name、api_key、api_base 是否正确。"
|
||||
|
||||
# async def main(image_url):
|
||||
# await llm_client_chat(image_url)
|
||||
#
|
||||
# # 运行主函数
|
||||
# asyncio.run(main(['https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav']))
|
||||
#
|
||||
@@ -1,131 +0,0 @@
|
||||
"""
|
||||
Multimodal input processor for handling image and audio content.
|
||||
|
||||
This module provides utilities for detecting and processing multimodal inputs
|
||||
(images and audio files) by converting them to text using appropriate models.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from app.core.memory.agent.multimodal.speech_model import Vico_recognition
|
||||
from app.core.memory.agent.utils.llm_tools import picture_model_requests
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MultimodalProcessor:
|
||||
"""
|
||||
Processor for handling multimodal inputs (images and audio).
|
||||
|
||||
This class detects image and audio file paths in input content and converts
|
||||
them to text using appropriate recognition models.
|
||||
"""
|
||||
|
||||
# Supported file extensions
|
||||
IMAGE_EXTENSIONS = ['.jpg', '.png']
|
||||
AUDIO_EXTENSIONS = [
|
||||
'aac', 'amr', 'avi', 'flac', 'flv', 'm4a', 'mkv', 'mov',
|
||||
'mp3', 'mp4', 'mpeg', 'ogg', 'opus', 'wav', 'webm', 'wma', 'wmv'
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the multimodal processor."""
|
||||
pass
|
||||
|
||||
def is_image(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an image file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported image extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_image("photo.jpg")
|
||||
True
|
||||
>>> processor.is_image("document.pdf")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(ext) for ext in self.IMAGE_EXTENSIONS)
|
||||
|
||||
def is_audio(self, content: str) -> bool:
|
||||
"""
|
||||
Check if content is an audio file path.
|
||||
|
||||
Args:
|
||||
content: Input string to check
|
||||
|
||||
Returns:
|
||||
True if content ends with a supported audio extension
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> processor.is_audio("recording.mp3")
|
||||
True
|
||||
>>> processor.is_audio("video.mp4")
|
||||
True
|
||||
>>> processor.is_audio("document.txt")
|
||||
False
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
return False
|
||||
|
||||
content_lower = content.lower()
|
||||
return any(content_lower.endswith(f'.{ext}') for ext in self.AUDIO_EXTENSIONS)
|
||||
|
||||
async def process_input(self, content: str) -> str:
|
||||
"""
|
||||
Process input content, converting images/audio to text if needed.
|
||||
|
||||
This method detects if the input is an image or audio file and converts
|
||||
it to text using the appropriate recognition model. If processing fails
|
||||
or the content is not multimodal, it returns the original content.
|
||||
|
||||
Args:
|
||||
content: Input string (may be file path or regular text)
|
||||
|
||||
Returns:
|
||||
Text content (original or converted from image/audio)
|
||||
|
||||
Examples:
|
||||
>>> processor = MultimodalProcessor()
|
||||
>>> await processor.process_input("photo.jpg")
|
||||
"Recognized text from image..."
|
||||
|
||||
>>> await processor.process_input("Hello world")
|
||||
"Hello world"
|
||||
"""
|
||||
if not isinstance(content, str):
|
||||
logger.warning(f"[MultimodalProcessor] Content is not a string: {type(content)}")
|
||||
return str(content)
|
||||
|
||||
try:
|
||||
# Check for image input
|
||||
if self.is_image(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected image input: {content}")
|
||||
result = await picture_model_requests(content)
|
||||
logger.info(f"[MultimodalProcessor] Image recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
# Check for audio input
|
||||
if self.is_audio(content):
|
||||
logger.info(f"[MultimodalProcessor] Detected audio input: {content}")
|
||||
result = await Vico_recognition([content]).run()
|
||||
logger.info(f"[MultimodalProcessor] Audio recognition result: {result[:100]}...")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MultimodalProcessor] Error processing multimodal input: {e}", exc_info=True)
|
||||
logger.info("[MultimodalProcessor] Falling back to original content")
|
||||
return content
|
||||
|
||||
# Return original content if not multimodal
|
||||
return content
|
||||
56
api/app/core/memory/agent/utils/performance_monitor.py
Normal file
56
api/app/core/memory/agent/utils/performance_monitor.py
Normal file
@@ -0,0 +1,56 @@
|
||||
|
||||
import time
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
from app.core.logging_config import get_agent_logger
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
class ProblemExtensionMonitor:
|
||||
"""Problem_Extension性能监控器"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = defaultdict(list)
|
||||
self.slow_queries = []
|
||||
self.error_count = 0
|
||||
|
||||
def record_execution(self, duration: float, question_count: int, success: bool):
|
||||
"""记录执行指标"""
|
||||
self.metrics['durations'].append(duration)
|
||||
self.metrics['question_counts'].append(question_count)
|
||||
|
||||
if not success:
|
||||
self.error_count += 1
|
||||
|
||||
# 记录慢查询(超过10秒)
|
||||
if duration > 10.0:
|
||||
self.slow_queries.append({
|
||||
'duration': duration,
|
||||
'question_count': question_count,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
def get_stats(self) -> Dict:
|
||||
"""获取统计信息"""
|
||||
durations = self.metrics['durations']
|
||||
if not durations:
|
||||
return {"message": "暂无数据"}
|
||||
|
||||
return {
|
||||
"total_executions": len(durations),
|
||||
"avg_duration": sum(durations) / len(durations),
|
||||
"max_duration": max(durations),
|
||||
"min_duration": min(durations),
|
||||
"slow_queries_count": len(self.slow_queries),
|
||||
"error_rate": self.error_count / len(durations) if durations else 0,
|
||||
"recent_slow_queries": self.slow_queries[-5:] # 最近5个慢查询
|
||||
}
|
||||
|
||||
def log_stats(self):
|
||||
"""记录统计信息到日志"""
|
||||
stats = self.get_stats()
|
||||
logger.info(f"Problem_Extension性能统计: {json.dumps(stats, indent=2)}")
|
||||
|
||||
# 全局监控器实例
|
||||
performance_monitor = ProblemExtensionMonitor()
|
||||
@@ -0,0 +1,81 @@
|
||||
|
||||
你是一个高效的问题拆分助手,任务是根据用户提供的原始问题和问题类型,生成可操作的扩展问题,用于精确回答原问题。请严格遵循以下规则:
|
||||
|
||||
角色:
|
||||
- 你是“问题拆分专家”,专注于逻辑、信息完整性和可操作性。
|
||||
- 你能够结合【历史信息】、【上下文】、【背景知识】进行分析,以保持问题拆分的连贯性和相关性。
|
||||
- 如果历史信息或上下文与当前问题无关,可忽略。
|
||||
|
||||
---
|
||||
|
||||
### 历史信息参考
|
||||
在生成扩展问题时,你可以参考以下历史数据(如果提供):
|
||||
- 历史对话或任务的主题;
|
||||
- 历史中出现的关键实体(时间、人物、地点、研究主题等);
|
||||
- 历史中已解答的问题(避免重复);
|
||||
- 历史推理链(保持逻辑一致性)。
|
||||
|
||||
> 如果没有提供历史信息,则仅根据当前输入问题进行分析。
|
||||
输入历史信息内容:{{history}}
|
||||
|
||||
## User Input
|
||||
{% if questions is string %}
|
||||
{{ questions }}
|
||||
{% else %}
|
||||
{% for question in questions %}
|
||||
- {{ question }}
|
||||
{% endfor %}
|
||||
{% endif %}
|
||||
|
||||
需求:
|
||||
- 如果问题是单跳问题(单步可答),直接保留原问题提取重要提问部分作为拆分/扩展问题。
|
||||
- 如果问题是多跳问题(需多个信息点才能回答),对问题进行扩展拆分。
|
||||
- 扩展问题必须完整覆盖原问题的所有关键要素,包括时间、主体、动作、目标等,不得遗漏。
|
||||
- 扩展问题不得冗余:避免重复询问相同信息或过度拆分同一主题。
|
||||
- 扩展问题必须高度相关:每个子问题直接服务于原问题,不引入未提及的新概念、人物或细节。
|
||||
- 扩展问题必须可操作:每个子问题能在有限资源下独立解答。
|
||||
- 子问题数量不超过4个。
|
||||
- 拆分问题的时候可以考虑输入的历史内容,以保持逻辑连贯。
|
||||
比如:输入历史信息内容:[{'Query': '4月27日,我和你推荐过一本书,书名是什么?', 'ANswer': '张曼玉推荐了《小王子》'}]
|
||||
拆分问题:4月27日,我和你推荐过一本书,书名是什么?,可以拆分为:4月27日,张曼玉推荐过一本书,书名是什么?
|
||||
|
||||
|
||||
|
||||
输出要求:
|
||||
- 仅输出 JSON 数组,不要包含任何解释或代码块。
|
||||
- 每个元素包含:
|
||||
- `original_question`: 原始问题
|
||||
- `extended_question`: 扩展后的问题
|
||||
- `type`: 类型(事实检索/澄清/定义/比较/行动建议)
|
||||
- `reason`: 生成该扩展问题的简短理由
|
||||
- 使用标准 ASCII 双引号,无换行;确保字符串正确关闭并以逗号分隔。
|
||||
|
||||
示例:
|
||||
输入:
|
||||
[
|
||||
"问题:今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?;问题类型:多跳",
|
||||
]
|
||||
|
||||
输出:
|
||||
[
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者有哪些人?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
},
|
||||
{
|
||||
"original_question": "今年诺贝尔物理学奖的获奖者是谁,他们因为什么贡献获奖?",
|
||||
"extended_question": "今年诺贝尔物理学奖的获奖者是因哪些具体贡献获奖的?",
|
||||
"type": "多跳",
|
||||
"reason": "输出原问题的关键要素"
|
||||
}
|
||||
]
|
||||
**Output format**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
|
||||
The output language should always be the same as the input language.{{ json_schema }}
|
||||
@@ -1,13 +1,10 @@
|
||||
# 角色
|
||||
你是一个专业的问答助手,擅长基于检索信息和历史对话回答用户问题。
|
||||
|
||||
# 任务
|
||||
根据提供的上下文信息回答用户的问题。
|
||||
|
||||
# 输入信息
|
||||
- 历史对话:{{history}}
|
||||
- 检索信息:{{retrieve_info}}
|
||||
|
||||
## User Query
|
||||
{{query}}
|
||||
|
||||
|
||||
@@ -9,8 +9,8 @@
|
||||
3. 判断Answer_Small和Query_Small之间分析出来的关系状态
|
||||
4. 如果是True保留,否则不要相对应的问题和回答
|
||||
5. 输出,需要严格按照模版
|
||||
输入:{{history}}
|
||||
历史消息:{"history":{{sentence}}}
|
||||
输入:{{sentence}}
|
||||
历史消息:{"history":{{history}}}
|
||||
### 第一步 获取用户的输入
|
||||
获取用户的输入提取对应的Query_Small和Answer_Small
|
||||
### 第二步 分析验证
|
||||
@@ -42,19 +42,33 @@
|
||||
如果状态是TRUE保留这条数据,否则需不需要这条数据
|
||||
### 第五步 输出格式
|
||||
按照json的形式输出
|
||||
{"data":"Query":原来Query的字段,"history":原来的history字段,
|
||||
"expansion_issue":以为列表的形式存储验证之后的数据比如[
|
||||
{"query_small": query_small,
|
||||
"answer_small": answer_small,,
|
||||
"status": 回答的结果是否符合query_small,填写状态,
|
||||
"query_answer": answer_small},
|
||||
{"query":"原来Query的字段",
|
||||
"history":"原来的history字段",
|
||||
"expansion_issue":以列表的形式存储验证之后的数据比如[
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "True",
|
||||
"query_answer": "张曼 婷喜欢绘画。"
|
||||
},{}......]
|
||||
,
|
||||
"split_result":如果expansion_issue是空的列表返回failed,不是空列表返回success,
|
||||
"reason": 为以上分析完之后的结果给一个说明
|
||||
}
|
||||
"query_small": "子问题",
|
||||
"answer_small": "子问题的回答",
|
||||
"status": "True或False,表示回答是否符合query_small",
|
||||
"query_answer": "问题的答案(与answer_small相同)"
|
||||
},
|
||||
{
|
||||
"query_small": "张曼婷生日是什么时候?",
|
||||
"answer_small": "张曼婷喜欢绘画。",
|
||||
"status": "False",
|
||||
"query_answer": "张曼婷喜欢绘画。"
|
||||
}
|
||||
],
|
||||
"split_result":"如果expansion_issue是空的列表返回failed,不是空列表返回success",
|
||||
"reason": "为以上分析完之后的结果给一个说明"
|
||||
}
|
||||
|
||||
**输出格式要求**
|
||||
**CRITICAL JSON FORMATTING REQUIREMENTS:**
|
||||
1. Use only standard ASCII double quotes (") for JSON structure - never use Chinese quotation marks ("") or other Unicode quotes
|
||||
2. If the extracted statement text contains quotation marks, escape them properly using backslashes (\")
|
||||
3. Ensure all JSON strings are properly closed and comma-separated
|
||||
4. Do not include line breaks within JSON string values
|
||||
5. The output language should always be the same as the input language
|
||||
|
||||
**JSON Schema:**
|
||||
{{ json_schema }}
|
||||
169
api/app/core/memory/agent/utils/session_tools.py
Normal file
169
api/app/core/memory/agent/utils/session_tools.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Session Service for managing user sessions and conversation history.
|
||||
|
||||
This service provides clean Redis interactions with error handling and
|
||||
session management utilities.
|
||||
"""
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.redis_tool import RedisSessionStore
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class SessionService:
|
||||
"""Service for managing user sessions and conversation history."""
|
||||
|
||||
def __init__(self, store: RedisSessionStore):
|
||||
"""
|
||||
Initialize the session service.
|
||||
|
||||
Args:
|
||||
store: Redis session store instance
|
||||
"""
|
||||
self.store = store
|
||||
logger.info("SessionService initialized")
|
||||
|
||||
def resolve_user_id(self, session_string: str) -> str:
|
||||
"""
|
||||
Extract user ID from session string.
|
||||
|
||||
Handles formats like:
|
||||
- 'call_id_user123' -> 'user123'
|
||||
- 'prefix_id_user456_suffix' -> 'user456_suffix'
|
||||
|
||||
Args:
|
||||
session_string: Session identifier string
|
||||
|
||||
Returns:
|
||||
Extracted user ID
|
||||
"""
|
||||
try:
|
||||
# Split by '_id_' and take everything after it
|
||||
parts = session_string.split('_id_')
|
||||
if len(parts) > 1:
|
||||
return parts[1]
|
||||
|
||||
# Fallback: return original string
|
||||
return session_string
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to parse user ID from session string '{session_string}': {e}"
|
||||
)
|
||||
return session_string
|
||||
|
||||
async def get_history(
|
||||
self,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Retrieve conversation history from Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
|
||||
Returns:
|
||||
List of conversation history items with Query and Answer keys
|
||||
Returns empty list if no history found or on error
|
||||
"""
|
||||
try:
|
||||
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
|
||||
|
||||
# Validate history structure
|
||||
if not isinstance(history, list):
|
||||
logger.warning(
|
||||
f"Invalid history format for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
|
||||
)
|
||||
return []
|
||||
|
||||
return history
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve history for user {user_id}, "
|
||||
f"apply {apply_id}, group {group_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
# Return empty list on error to allow execution to continue
|
||||
return []
|
||||
|
||||
async def save_session(
|
||||
self,
|
||||
user_id: str,
|
||||
query: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
ai_response: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Save conversation turn to Redis.
|
||||
|
||||
Args:
|
||||
user_id: User identifier
|
||||
query: User query/message
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
ai_response: AI response/answer
|
||||
|
||||
Returns:
|
||||
Session ID if successful, None on error
|
||||
"""
|
||||
try:
|
||||
# Validate required fields
|
||||
if not user_id:
|
||||
logger.warning("Cannot save session: user_id is empty")
|
||||
return None
|
||||
|
||||
if not query:
|
||||
logger.warning("Cannot save session: query is empty")
|
||||
return None
|
||||
|
||||
# Save session
|
||||
session_id = self.store.save_session(
|
||||
userid=user_id,
|
||||
messages=query,
|
||||
apply_id=apply_id,
|
||||
group_id=group_id,
|
||||
aimessages=ai_response
|
||||
)
|
||||
|
||||
logger.info(f"Session saved successfully: {session_id}")
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to save session for user {user_id}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
|
||||
async def cleanup_duplicates(self) -> int:
|
||||
"""
|
||||
Remove duplicate session entries.
|
||||
|
||||
Duplicates are identified by matching:
|
||||
- sessionid
|
||||
- user_id (id field)
|
||||
- group_id
|
||||
- messages
|
||||
- aimessages
|
||||
|
||||
Returns:
|
||||
Number of duplicate sessions deleted
|
||||
"""
|
||||
try:
|
||||
deleted_count = self.store.delete_duplicate_sessions()
|
||||
logger.info(f"Cleaned up {deleted_count} duplicate sessions")
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup duplicate sessions: {e}", exc_info=True)
|
||||
return 0
|
||||
117
api/app/core/memory/agent/utils/template_tools.py
Normal file
117
api/app/core/memory/agent/utils/template_tools.py
Normal file
@@ -0,0 +1,117 @@
|
||||
"""
|
||||
Template Service for loading and rendering Jinja2 templates.
|
||||
|
||||
This service provides centralized template management with caching and error handling.
|
||||
"""
|
||||
# 标准库
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
|
||||
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
|
||||
|
||||
logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
class TemplateRenderError(Exception):
|
||||
"""Exception raised when template rendering fails."""
|
||||
|
||||
def __init__(self, template_name: str, error: Exception, variables: dict):
|
||||
self.template_name = template_name
|
||||
self.error = error
|
||||
self.variables = variables
|
||||
super().__init__(
|
||||
f"Failed to render template '{template_name}': {str(error)}"
|
||||
)
|
||||
|
||||
|
||||
class TemplateService:
|
||||
"""Service for loading and rendering Jinja2 templates with caching."""
|
||||
|
||||
def __init__(self, template_root: str):
|
||||
"""
|
||||
Initialize the template service.
|
||||
|
||||
Args:
|
||||
template_root: Root directory containing template files
|
||||
"""
|
||||
self.template_root = template_root
|
||||
self.env = Environment(
|
||||
loader=FileSystemLoader(template_root),
|
||||
autoescape=False # Disable autoescape for prompt templates
|
||||
)
|
||||
logger.info(f"TemplateService initialized with root: {template_root}")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _load_template(self, template_name: str) -> Template:
|
||||
"""
|
||||
Load a template from disk with caching.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
|
||||
Returns:
|
||||
Loaded Jinja2 Template object
|
||||
|
||||
Raises:
|
||||
TemplateNotFound: If template file doesn't exist
|
||||
"""
|
||||
try:
|
||||
return self.env.get_template(template_name)
|
||||
except TemplateNotFound as e:
|
||||
expected_path = os.path.join(self.template_root, template_name)
|
||||
logger.error(
|
||||
f"Template not found: {template_name}. "
|
||||
f"Expected path: {expected_path}"
|
||||
)
|
||||
raise
|
||||
|
||||
async def render_template(
|
||||
self,
|
||||
template_name: str,
|
||||
operation_name: str,
|
||||
**variables
|
||||
) -> str:
|
||||
"""
|
||||
Load and render a Jinja2 template.
|
||||
|
||||
Args:
|
||||
template_name: Relative path to template file
|
||||
operation_name: Name for logging (e.g., "split_the_problem")
|
||||
**variables: Template variables to render
|
||||
|
||||
Returns:
|
||||
Rendered template string
|
||||
|
||||
Raises:
|
||||
TemplateRenderError: If template loading or rendering fails
|
||||
"""
|
||||
try:
|
||||
# Load template (cached)
|
||||
template = self._load_template(template_name)
|
||||
|
||||
# Render template
|
||||
rendered = template.render(**variables)
|
||||
|
||||
# Log rendered prompt
|
||||
log_prompt_rendering(operation_name, rendered)
|
||||
|
||||
return rendered
|
||||
|
||||
except TemplateNotFound as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): Template not found",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Template rendering failed for {operation_name} "
|
||||
f"({template_name}): {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise TemplateRenderError(template_name, e, variables)
|
||||
@@ -1,10 +1,9 @@
|
||||
"""
|
||||
Type classification utility for distinguishing read/write operations.
|
||||
"""
|
||||
from app.core.config import settings
|
||||
from app.core.logging_config import get_agent_logger, log_prompt_rendering
|
||||
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
|
||||
from app.core.memory.agent.utils.messages_tool import read_template_file
|
||||
from app.core.memory.agent.utils.messages_tools import read_template_file
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from jinja2 import Template
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from sqlalchemy.orm import Session
|
||||
import logging
|
||||
import json
|
||||
|
||||
from app.db import get_db
|
||||
from app.models.retrieval_info import RetrievalInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
async def write_to_database(host_id: uuid.UUID, data: Any) -> str:
|
||||
"""
|
||||
将数据写入数据库
|
||||
:param host_id: 宿主 ID
|
||||
:param data: 要写入的数据
|
||||
:return: 写入数据库的结果
|
||||
"""
|
||||
# 从数据库会话中获取会话
|
||||
db: Session = next(get_db())
|
||||
try:
|
||||
if isinstance(data, (dict, list)):
|
||||
serialized = json.dumps(data, ensure_ascii=False)
|
||||
elif isinstance(data, str):
|
||||
serialized = data
|
||||
else:
|
||||
serialized = str(data)
|
||||
|
||||
new_retrieval_info = RetrievalInfo(
|
||||
# host_id=host_id,
|
||||
host_id=uuid.UUID("2f6ff1eb-50c7-4765-8e89-e4566be19122"),
|
||||
retrieve_info=serialized,
|
||||
created_at=datetime.now()
|
||||
)
|
||||
db.add(new_retrieval_info)
|
||||
db.commit()
|
||||
logger.info(f"success to write data to database, host_id: {host_id}, retrieve_info: {serialized}")
|
||||
return "success to write data to database"
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
logger.error(f"failed to write data to database, host_id: {host_id}, retrieve_info: {data}, error: {e}")
|
||||
raise e
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
@@ -7,14 +7,12 @@ pipeline. Only MemoryConfig is needed - clients are constructed internally.
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from app.core.logging_config import get_agent_logger
|
||||
from app.core.memory.agent.utils.get_dialogs import get_chunked_dialogs
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import (
|
||||
ExtractionOrchestrator,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import (
|
||||
memory_summary_generation,
|
||||
)
|
||||
from app.core.memory.storage_services.extraction_engine.extraction_orchestrator import ExtractionOrchestrator
|
||||
from app.core.memory.storage_services.extraction_engine.knowledge_extraction.memory_summary import memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.db import get_db_context
|
||||
@@ -23,7 +21,7 @@ from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.schemas.memory_config_schema import MemoryConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -31,25 +29,22 @@ logger = get_agent_logger(__name__)
|
||||
|
||||
|
||||
async def write(
|
||||
content: str,
|
||||
user_id: str,
|
||||
apply_id: str,
|
||||
group_id: str,
|
||||
memory_config: MemoryConfig,
|
||||
messages: list,
|
||||
ref_id: str = "wyl20251027",
|
||||
) -> None:
|
||||
"""
|
||||
Execute the complete knowledge extraction pipeline.
|
||||
|
||||
Only MemoryConfig is needed - LLM and embedding clients are constructed
|
||||
internally from the config.
|
||||
|
||||
Args:
|
||||
content: Dialogue content to process
|
||||
user_id: User identifier
|
||||
apply_id: Application identifier
|
||||
group_id: Group identifier
|
||||
memory_config: MemoryConfig object containing all configuration
|
||||
messages: Structured message list [{"role": "user", "content": "..."}, ...]
|
||||
ref_id: Reference ID, defaults to "wyl20251027"
|
||||
"""
|
||||
# Extract config values
|
||||
@@ -91,7 +86,7 @@ async def write(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
apply_id=apply_id,
|
||||
content=content,
|
||||
messages=messages,
|
||||
ref_id=ref_id,
|
||||
config_id=config_id,
|
||||
)
|
||||
|
||||
@@ -1,48 +1,15 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from typing import List, Tuple
|
||||
|
||||
from neo4j import GraphDatabase
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ------------------- 自包含路径解析 -------------------
|
||||
# 这个代码块确保脚本可以从任何地方运行,并且仍然可以在项目结构中找到它需要的模块。
|
||||
try:
|
||||
# 假设脚本在 /path/to/project/src/analytics/
|
||||
# 上升3个级别以到达项目根目录。
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
|
||||
# 将 'src' 和 'project_root' 都添加到路径中。
|
||||
# 'src' 目录对于像 'from utils.config_utils import ...' 这样的导入是必需的。
|
||||
# 'project_root' 目录对于像 'from variate_config import ...' 这样的导入是必需的。
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
except NameError:
|
||||
# 为 __file__ 未定义的环境(例如某些交互式解释器)提供回退方案
|
||||
project_root = os.path.abspath(os.path.join(os.getcwd()))
|
||||
src_path = os.path.join(project_root, 'src')
|
||||
if src_path not in sys.path:
|
||||
sys.path.insert(0, src_path)
|
||||
if project_root not in sys.path:
|
||||
sys.path.insert(0, project_root)
|
||||
# ---------------------------------------------------------------------
|
||||
|
||||
# 现在路径已经配置好,我们可以使用绝对导入
|
||||
import json
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
#TODO: Fix this
|
||||
# Default values (previously from definitions.py)
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
DEFAULT_GROUP_ID = os.getenv("SELECTED_GROUP_ID", "group_123")
|
||||
|
||||
# 定义用于LLM结构化输出的Pydantic模型
|
||||
class FilteredTags(BaseModel):
|
||||
@@ -52,34 +19,45 @@ class FilteredTags(BaseModel):
|
||||
async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
group_id: 用户组ID,用于获取配置
|
||||
|
||||
Returns:
|
||||
筛选后的标签列表
|
||||
|
||||
Raises:
|
||||
ValueError: 如果无法获取有效的LLM配置
|
||||
"""
|
||||
try:
|
||||
# Get config_id using get_end_user_connected_config
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if not config_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for group_id: {group_id}. "
|
||||
"Please ensure the user has a valid memory configuration."
|
||||
)
|
||||
connected_config = get_end_user_connected_config(group_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
|
||||
if config_id:
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
else:
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM if no config found
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
except Exception as e:
|
||||
print(f"Failed to get user connected config, using default LLM: {e}")
|
||||
# TODO: Remove DEFAULT_LLM_ID fallback once all users have proper config
|
||||
# Fallback to default LLM
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
# Use the config_id to get the proper LLM client
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(config_id)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}. "
|
||||
"Please configure a valid LLM model."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
# 3. 构建Prompt
|
||||
tag_list_str = ", ".join(tags)
|
||||
@@ -107,33 +85,26 @@ async def filter_tags_with_llm(tags: List[str], group_id: str) -> List[str]:
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
def get_db_connection():
|
||||
"""
|
||||
使用项目的标准配置方法建立与Neo4j数据库的连接。
|
||||
"""
|
||||
# 从全局配置获取 Neo4j 连接信息
|
||||
uri = settings.NEO4J_URI
|
||||
user = settings.NEO4J_USERNAME
|
||||
|
||||
# 密码必须为了安全从环境变量加载
|
||||
password = os.getenv("NEO4J_PASSWORD")
|
||||
|
||||
if not uri or not user:
|
||||
raise ValueError("在 config.json 中未找到 Neo4j 的 'uri' 或 'username'。")
|
||||
if not password:
|
||||
raise ValueError("NEO4J_PASSWORD 环境变量未设置。")
|
||||
|
||||
# 为此脚本使用同步驱动
|
||||
return GraphDatabase.driver(uri, auth=(user, password))
|
||||
|
||||
def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
group_id: str,
|
||||
limit: int,
|
||||
by_user: bool = False
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
TODO: not accurate tag extraction
|
||||
从数据库查询原始的、未经过滤的实体标签及其频率。
|
||||
|
||||
使用项目的Neo4jConnector进行查询,遵循仓储模式。
|
||||
|
||||
Args:
|
||||
connector: Neo4j连接器实例
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: 标签名称和频率的元组列表
|
||||
"""
|
||||
names_to_exclude = ['AI', 'Caroline', 'Melanie', 'Jon', 'Gina', '用户', 'AI助手', 'John', 'Maria']
|
||||
|
||||
@@ -154,83 +125,55 @@ def get_raw_tags_from_db(group_id: str, limit: int, by_user: bool = False) -> Li
|
||||
"LIMIT $limit"
|
||||
)
|
||||
|
||||
driver = None
|
||||
try:
|
||||
driver = get_db_connection()
|
||||
with driver.session() as session:
|
||||
result = session.run(query, id=group_id, limit=limit, names_to_exclude=names_to_exclude)
|
||||
return [(record["name"], record["frequency"]) for record in result]
|
||||
finally:
|
||||
if driver:
|
||||
driver.close()
|
||||
# 使用项目的Neo4jConnector执行查询
|
||||
results = await connector.execute_query(
|
||||
query,
|
||||
id=group_id,
|
||||
limit=limit,
|
||||
names_to_exclude=names_to_exclude
|
||||
)
|
||||
|
||||
return [(record["name"], record["frequency"]) for record in results]
|
||||
|
||||
async def get_hot_memory_tags(group_id: str | None = None, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
async def get_hot_memory_tags(group_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。
|
||||
|
||||
Args:
|
||||
group_id: 如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
group_id: 必需参数。如果by_user=False,则为group_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
by_user: 是否按user_id查询(默认False,按group_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果group_id未提供或为空
|
||||
"""
|
||||
# 默认从环境变量读取
|
||||
group_id = group_id or DEFAULT_GROUP_ID
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = get_raw_tags_from_db(group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
for tag, freq in raw_tags_with_freq:
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("开始获取热门记忆标签...")
|
||||
# 验证group_id必须提供且不为空
|
||||
if not group_id or not group_id.strip():
|
||||
raise ValueError(
|
||||
"group_id is required. Please provide a valid group_id or user_id."
|
||||
)
|
||||
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 直接使用环境变量中的 group_id
|
||||
group_id_to_query = DEFAULT_GROUP_ID
|
||||
# 使用 asyncio.run 来执行异步主函数
|
||||
top_tags = asyncio.run(get_hot_memory_tags(group_id=group_id_to_query))
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, group_id, limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
if top_tags:
|
||||
print(f"热门记忆标签 (Group ID: {group_id_to_query}, 经LLM筛选):")
|
||||
for tag, frequency in top_tags:
|
||||
print(f"- {tag} (数量: {frequency})")
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
|
||||
# --- 将结果写入统一的 Signboard.json 到 logs/memory-output ---
|
||||
from app.core.config import settings
|
||||
settings.ensure_memory_output_dir()
|
||||
signboard_path = settings.get_memory_output_path("Signboard.json")
|
||||
payload = {
|
||||
"group_id": group_id_to_query,
|
||||
"hot_tags": [{"name": t, "frequency": f} for t, f in top_tags]
|
||||
}
|
||||
try:
|
||||
existing = {}
|
||||
if os.path.exists(signboard_path):
|
||||
with open(signboard_path, "r", encoding="utf-8") as rf:
|
||||
existing = json.load(rf)
|
||||
existing["hot_memory_tags"] = payload
|
||||
with open(signboard_path, "w", encoding="utf-8") as wf:
|
||||
json.dump(existing, wf, ensure_ascii=False, indent=2)
|
||||
print(f"已写入 {signboard_path} -> hot_memory_tags")
|
||||
except Exception as e:
|
||||
print(f"写入 Signboard.json 失败: {e}")
|
||||
else:
|
||||
print(f"在 Group ID '{group_id_to_query}' 中没有找到符合条件的实体标签。")
|
||||
except Exception as e:
|
||||
print(f"执行过程中发生严重错误: {e}")
|
||||
print("请检查:")
|
||||
print("1. Neo4j数据库服务是否正在运行。")
|
||||
print("2. 'config.json'中的配置是否正确。")
|
||||
print("3. 相关的环境变量 (如 NEO4J_PASSWORD, DASHSCOPE_API_KEY) 是否已正确设置。")
|
||||
# 2. 初始化LLM客户端并使用LLM筛选出有意义的标签
|
||||
meaningful_tag_names = await filter_tags_with_llm(raw_tag_names, group_id)
|
||||
|
||||
# 3. 根据LLM的筛选结果,构建最终的标签列表(保留原始频率和顺序)
|
||||
final_tags = []
|
||||
for tag, freq in raw_tags_with_freq:
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
import asyncio
|
||||
import json
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
# Fix tokenizer parallelism warning
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -23,28 +24,29 @@ from app.core.memory.models.message_models import DialogData, Chunk
|
||||
try:
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
except Exception:
|
||||
# 在测试或无可用依赖(如 langfuse)环境下,允许惰性导入
|
||||
OpenAIClient = Any
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LLMChunker:
|
||||
"""基于LLM的智能分块策略"""
|
||||
"""LLM-based intelligent chunking strategy"""
|
||||
def __init__(self, llm_client: OpenAIClient, chunk_size: int = 1000):
|
||||
self.llm_client = llm_client
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
async def __call__(self, text: str) -> List[Any]:
|
||||
# 使用LLM分析文本结构并进行智能分块
|
||||
prompt = f"""
|
||||
请将以下文本分割成语义连贯的段落。每个段落应该围绕一个主题,长度大约在{self.chunk_size}字符左右。
|
||||
请以JSON格式返回结果,包含chunks数组,每个chunk有text字段。
|
||||
Split the following text into semantically coherent paragraphs. Each paragraph should focus on one topic, approximately {self.chunk_size} characters long.
|
||||
Return results in JSON format with a chunks array, each chunk having a text field.
|
||||
|
||||
文本内容:
|
||||
Text content:
|
||||
{text[:5000]}
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一个专业的文本分析助手,擅长将长文本分割成语义连贯的段落。"},
|
||||
{"role": "system", "content": "You are a professional text analysis assistant, skilled at splitting long texts into semantically coherent paragraphs."},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
@@ -171,8 +173,6 @@ class ChunkerClient:
|
||||
base_chunk_size=self.chunk_size,
|
||||
)
|
||||
elif chunker_config.chunker_strategy == "SentenceChunker":
|
||||
# 某些 chonkie 版本的 SentenceChunker 不支持 tokenizer_or_token_counter 参数
|
||||
# 为了兼容不同版本,这里仅传递广泛支持的参数
|
||||
self.chunker = SentenceChunker(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
@@ -186,100 +186,93 @@ class ChunkerClient:
|
||||
|
||||
async def generate_chunks(self, dialogue: DialogData):
|
||||
"""
|
||||
生成分块,支持异步操作
|
||||
Generate chunks following 1 Message = 1 Chunk strategy.
|
||||
|
||||
Each message creates one chunk, directly inheriting role information.
|
||||
If a message is too long, it will be split into multiple sub-chunks,
|
||||
each maintaining the same speaker.
|
||||
|
||||
Raises:
|
||||
ValueError: If dialogue has no messages or chunking fails
|
||||
"""
|
||||
try:
|
||||
# 预处理文本:确保对话标记格式统一
|
||||
content = dialogue.content
|
||||
content = content.replace('AI:', 'AI:').replace('用户:', '用户:') # 统一冒号
|
||||
content = re.sub(r'(\n\s*)+\n', '\n\n', content) # 合并多个空行
|
||||
|
||||
if hasattr(self.chunker, '__call__') and not asyncio.iscoroutinefunction(self.chunker.__call__):
|
||||
# 同步分块器
|
||||
chunks = self.chunker(content)
|
||||
# Validate dialogue has messages
|
||||
if not dialogue.context or not dialogue.context.msgs:
|
||||
raise ValueError(
|
||||
f"Dialogue {dialogue.ref_id} has no messages. "
|
||||
f"Cannot generate chunks from empty dialogue."
|
||||
)
|
||||
|
||||
dialogue.chunks = []
|
||||
|
||||
# 按消息分块:每个消息创建一个或多个 chunk,直接继承角色
|
||||
for msg_idx, msg in enumerate(dialogue.context.msgs):
|
||||
# Validate message has required attributes
|
||||
if not hasattr(msg, 'role') or not hasattr(msg, 'msg'):
|
||||
raise ValueError(
|
||||
f"Message {msg_idx} in dialogue {dialogue.ref_id} "
|
||||
f"missing 'role' or 'msg' attribute"
|
||||
)
|
||||
|
||||
msg_content = msg.msg.strip()
|
||||
|
||||
# Skip empty messages
|
||||
if not msg_content:
|
||||
continue
|
||||
|
||||
# 如果消息太长,可以进一步分块
|
||||
if len(msg_content) > self.chunk_size:
|
||||
# 对单个消息的内容进行分块
|
||||
try:
|
||||
sub_chunks = self.chunker(msg_content)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to chunk long message {msg_idx} in dialogue {dialogue.ref_id}: {e}"
|
||||
)
|
||||
|
||||
for idx, sub_chunk in enumerate(sub_chunks):
|
||||
sub_chunk_text = sub_chunk.text if hasattr(sub_chunk, 'text') else str(sub_chunk)
|
||||
sub_chunk_text = sub_chunk_text.strip()
|
||||
|
||||
if len(sub_chunk_text) < (self.min_characters_per_chunk or 50):
|
||||
continue
|
||||
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {sub_chunk_text}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
metadata={
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
"sub_chunk_index": idx,
|
||||
"total_sub_chunks": len(sub_chunks),
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
dialogue.chunks.append(chunk)
|
||||
else:
|
||||
# 异步分块器(如LLMChunker)
|
||||
chunks = await self.chunker(content)
|
||||
|
||||
# 过滤空块和过小的块
|
||||
valid_chunks = []
|
||||
for c in chunks:
|
||||
chunk_text = getattr(c, 'text', str(c)) if not isinstance(c, str) else c
|
||||
if isinstance(chunk_text, str) and len(chunk_text.strip()) >= (self.min_characters_per_chunk or 50):
|
||||
valid_chunks.append(c)
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text if hasattr(c, 'text') else str(c),
|
||||
# 消息不长,直接作为一个 chunk
|
||||
chunk = Chunk(
|
||||
content=f"{msg.role}: {msg_content}",
|
||||
speaker=msg.role, # 直接继承角色
|
||||
metadata={
|
||||
"start_index": getattr(c, "start_index", None),
|
||||
"end_index": getattr(c, "end_index", None),
|
||||
"message_index": msg_idx,
|
||||
"message_role": msg.role,
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
},
|
||||
)
|
||||
for c in valid_chunks
|
||||
]
|
||||
return dialogue
|
||||
|
||||
except Exception as e:
|
||||
print(f"分块失败: {e}")
|
||||
|
||||
# 改进的后备方案:尝试按对话回合分割
|
||||
try:
|
||||
# 简单的按对话分割
|
||||
dialogue_pattern = r'(AI:|用户:)(.*?)(?=AI:|用户:|$)'
|
||||
matches = re.findall(dialogue_pattern, dialogue.content, re.DOTALL)
|
||||
|
||||
class SimpleChunk:
|
||||
def __init__(self, text, start_index, end_index):
|
||||
self.text = text
|
||||
self.start_index = start_index
|
||||
self.end_index = end_index
|
||||
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
current_start = 0
|
||||
|
||||
for match in matches:
|
||||
speaker, ct = match[0], match[1].strip()
|
||||
turn_text = f"{speaker} {ct}"
|
||||
|
||||
if len(current_chunk) + len(turn_text) > (self.chunk_size or 500):
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
current_chunk = turn_text
|
||||
current_start = dialogue.content.find(turn_text, current_start)
|
||||
else:
|
||||
current_chunk += ("\n" + turn_text) if current_chunk else turn_text
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(SimpleChunk(current_chunk, current_start, current_start + len(current_chunk)))
|
||||
|
||||
dialogue.chunks = [
|
||||
Chunk(
|
||||
content=c.text,
|
||||
metadata={
|
||||
"start_index": c.start_index,
|
||||
"end_index": c.end_index,
|
||||
"chunker_strategy": "DialogueTurnFallback",
|
||||
},
|
||||
)
|
||||
for c in chunks
|
||||
]
|
||||
|
||||
except Exception:
|
||||
# 最后的手段:单一大块
|
||||
dialogue.chunks = [Chunk(
|
||||
content=dialogue.content,
|
||||
metadata={"chunker_strategy": "SingleChunkFallback"},
|
||||
)]
|
||||
|
||||
return dialogue
|
||||
dialogue.chunks.append(chunk)
|
||||
|
||||
# Validate we generated at least one chunk
|
||||
if not dialogue.chunks:
|
||||
raise ValueError(
|
||||
f"No valid chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"All messages were either empty or too short. "
|
||||
f"Messages count: {len(dialogue.context.msgs)}"
|
||||
)
|
||||
|
||||
return dialogue
|
||||
|
||||
def evaluate_chunking(self, dialogue: DialogData) -> dict:
|
||||
"""
|
||||
评估分块质量
|
||||
"""
|
||||
"""Evaluate chunking quality."""
|
||||
if not getattr(dialogue, 'chunks', None):
|
||||
return {}
|
||||
|
||||
@@ -304,11 +297,8 @@ class ChunkerClient:
|
||||
return metrics
|
||||
|
||||
def save_chunking_results(self, dialogue: DialogData, output_path: str):
|
||||
"""
|
||||
保存分块结果到文件,文件名包含策略名称
|
||||
"""
|
||||
"""Save chunking results to file with strategy name in filename."""
|
||||
strategy_name = self.chunker_config.chunker_strategy
|
||||
# 在文件名中添加策略名称
|
||||
base_name, ext = os.path.splitext(output_path)
|
||||
strategy_output_path = f"{base_name}_{strategy_name}{ext}"
|
||||
|
||||
|
||||
@@ -92,8 +92,6 @@ class OpenAIClient(LLMClient):
|
||||
config["callbacks"] = [self.langfuse_handler]
|
||||
|
||||
response = await chain.ainvoke({"messages": messages}, config=config)
|
||||
|
||||
logger.debug(f"LLM 响应成功: {len(str(response))} 字符")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
@@ -149,13 +147,10 @@ class OpenAIClient(LLMClient):
|
||||
config=config
|
||||
)
|
||||
|
||||
logger.debug(f"使用 PydanticOutputParser 解析成功")
|
||||
return parsed
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"PydanticOutputParser 解析失败,尝试其他方法: {e}"
|
||||
)
|
||||
logger.debug(f"PydanticOutputParser 解析失败,尝试备用方法: {e}")
|
||||
|
||||
# 方法 2: 使用 LangChain 的 with_structured_output
|
||||
template = """{question}"""
|
||||
@@ -173,13 +168,17 @@ class OpenAIClient(LLMClient):
|
||||
|
||||
# 验证并返回结果
|
||||
try:
|
||||
return response_model.model_validate(parsed)
|
||||
result = response_model.model_validate(parsed)
|
||||
return result
|
||||
except Exception:
|
||||
# 如果已经是 Pydantic 实例,直接返回
|
||||
if hasattr(parsed, "model_dump"):
|
||||
return parsed
|
||||
# 尝试从 JSON 解析
|
||||
return response_model.model_validate_json(json.dumps(parsed))
|
||||
result = response_model.model_validate_json(json.dumps(parsed))
|
||||
return result
|
||||
else:
|
||||
logger.warning("with_structured_output 方法不可用")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"结构化输出失败: {e}")
|
||||
|
||||
@@ -224,6 +224,7 @@ class StatementNode(Node):
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user messages, 'AI' for AI responses)
|
||||
emotion_intensity: Optional emotion intensity (0.0-1.0) - displayed on node
|
||||
emotion_target: Optional emotion target (person or object name)
|
||||
emotion_subject: Optional emotion subject (self/other/object)
|
||||
@@ -249,6 +250,12 @@ class StatementNode(Node):
|
||||
stmt_type: str = Field(..., description="Type of the statement")
|
||||
statement: str = Field(..., description="The statement text content")
|
||||
|
||||
# Speaker identification
|
||||
speaker: Optional[str] = Field(
|
||||
None,
|
||||
description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses"
|
||||
)
|
||||
|
||||
# Emotion fields (ordered as requested, emotion_intensity first for display)
|
||||
emotion_intensity: Optional[float] = Field(
|
||||
None,
|
||||
|
||||
@@ -25,10 +25,10 @@ class ConversationMessage(BaseModel):
|
||||
"""Represents a single message in a conversation.
|
||||
|
||||
Attributes:
|
||||
role: Role of the speaker (e.g., '用户' for user, 'AI' for assistant)
|
||||
role: Role of the speaker (e.g., 'user' for user, 'assistant' for AI assistant)
|
||||
msg: Text content of the message
|
||||
"""
|
||||
role: str = Field(..., description="The role of the speaker (e.g., '用户', 'AI').")
|
||||
role: str = Field(..., description="The role of the speaker (e.g., 'user', 'assistant').")
|
||||
msg: str = Field(..., description="The text content of the message.")
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class Statement(BaseModel):
|
||||
chunk_id: ID of the parent chunk this statement belongs to
|
||||
group_id: Optional group ID for multi-tenancy
|
||||
statement: The actual statement text content
|
||||
speaker: Optional speaker identifier ('用户' for user, 'AI' for AI responses)
|
||||
statement_embedding: Optional embedding vector for the statement
|
||||
stmt_type: Type of the statement (from ontology)
|
||||
temporal_info: Temporal information extracted from the statement
|
||||
@@ -74,6 +75,7 @@ class Statement(BaseModel):
|
||||
chunk_id: str = Field(..., description="ID of the parent chunk this statement belongs to.")
|
||||
group_id: Optional[str] = Field(None, description="ID of the group this statement belongs to.")
|
||||
statement: str = Field(..., description="The text content of the statement.")
|
||||
speaker: Optional[str] = Field(None, description="Speaker identifier: 'user' for user messages, 'assistant' for AI responses")
|
||||
statement_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the statement.")
|
||||
stmt_type: StatementType = Field(..., description="The type of the statement.")
|
||||
temporal_info: TemporalInfo = Field(..., description="The temporal information of the statement.")
|
||||
@@ -118,36 +120,36 @@ class Chunk(BaseModel):
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for the chunk
|
||||
text: List of messages in the chunk
|
||||
content: The content of the chunk as a formatted string
|
||||
speaker: The speaker/role for this chunk (user/assistant)
|
||||
statements: List of statements extracted from this chunk
|
||||
chunk_embedding: Optional embedding vector for the chunk
|
||||
metadata: Additional metadata as key-value pairs
|
||||
"""
|
||||
id: str = Field(default_factory=lambda: uuid4().hex, description="A unique identifier for the chunk.")
|
||||
text: List[ConversationMessage] = Field(default_factory=list, description="A list of messages in the chunk.")
|
||||
content: str = Field(..., description="The content of the chunk as a string.")
|
||||
speaker: Optional[str] = Field(None, description="The speaker/role for this chunk (user/assistant).")
|
||||
statements: List[Statement] = Field(default_factory=list, description="A list of statements in the chunk.")
|
||||
chunk_embedding: Optional[List[float]] = Field(None, description="The embedding vector of the chunk.")
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional metadata for the chunk.")
|
||||
|
||||
@classmethod
|
||||
def from_messages(cls, messages: List[ConversationMessage], metadata: Optional[Dict[str, Any]] = None):
|
||||
"""Create a chunk from a list of messages.
|
||||
def from_single_message(cls, message: ConversationMessage, metadata: Optional[Dict[str, Any]] = None):
|
||||
"""Create a chunk from a single message (1 Message = 1 Chunk).
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages
|
||||
message: Single conversation message
|
||||
metadata: Optional metadata dictionary
|
||||
|
||||
Returns:
|
||||
Chunk instance with formatted content
|
||||
Chunk instance with speaker directly from message.role
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
# Generate content from messages
|
||||
content = "\n".join([f"{msg.role}: {msg.msg}" for msg in messages])
|
||||
return cls(text=messages, content=content, metadata=metadata)
|
||||
|
||||
return cls(
|
||||
content=f"{message.role}: {message.msg}",
|
||||
speaker=message.role,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
|
||||
class DialogData(BaseModel):
|
||||
"""Represents the complete data structure for a dialog record.
|
||||
|
||||
@@ -131,179 +131,60 @@ def normalize_scores(results: List[Dict[str, Any]], score_field: str = "score")
|
||||
return results
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# 以下函数已被 rerank_with_activation 替代,暂时保留以供参考
|
||||
# ============================================================================
|
||||
|
||||
# def rerank_hybrid_results(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid search results by combining BM25 and embedding scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined scores
|
||||
# """
|
||||
# reranked = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# # Create a combined pool of unique items
|
||||
# combined_items = {}
|
||||
#
|
||||
# # Add keyword results with BM25 scores
|
||||
# for item in keyword_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
# combined_items[item_id]["embedding_score"] = 0 # Default
|
||||
#
|
||||
# # Add or update with embedding results
|
||||
# for item in embedding_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if item_id:
|
||||
# if item_id in combined_items:
|
||||
# # Update existing item with embedding score
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# # New item from embedding search only
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0 # Default
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate combined scores and rank
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = item.get("bm25_score", 0)
|
||||
# embedding_score = item.get("embedding_score", 0)
|
||||
#
|
||||
# # Combined score: weighted average of normalized scores
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
# item["combined_score"] = combined_score
|
||||
#
|
||||
# # Keep original score for reference
|
||||
# if "score" not in item and bm25_score > 0:
|
||||
# item["score"] = bm25_score
|
||||
# elif "score" not in item and embedding_score > 0:
|
||||
# item["score"] = embedding_score
|
||||
#
|
||||
# # Sort by combined score and limit results
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(),
|
||||
# key=lambda x: x.get("combined_score", 0),
|
||||
# reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
|
||||
# def rerank_with_forgetting_curve(
|
||||
# keyword_results: Dict[str, List[Dict[str, Any]]],
|
||||
# embedding_results: Dict[str, List[Dict[str, Any]]],
|
||||
# alpha: float = 0.6,
|
||||
# limit: int = 10,
|
||||
# forgetting_config: ForgettingEngineConfig | None = None,
|
||||
# now: datetime | None = None,
|
||||
# ) -> Dict[str, List[Dict[str, Any]]]:
|
||||
# """
|
||||
# Rerank hybrid results with a forgetting curve applied to combined scores.
|
||||
#
|
||||
# 已废弃:此函数功能已被 rerank_with_activation 完全替代
|
||||
# rerank_with_activation 提供了更完整的遗忘曲线支持(结合激活度)
|
||||
#
|
||||
# The forgetting curve reduces scores for older memories or weaker connections.
|
||||
#
|
||||
# Args:
|
||||
# keyword_results: Results from keyword/BM25 search
|
||||
# embedding_results: Results from embedding search
|
||||
# alpha: Weight for BM25 scores (1-alpha for embedding scores)
|
||||
# limit: Maximum number of results to return per category
|
||||
# forgetting_config: Configuration for the forgetting engine
|
||||
# now: Optional current time override for testing
|
||||
#
|
||||
# Returns:
|
||||
# Reranked results with combined and final scores (after forgetting)
|
||||
# """
|
||||
# engine = ForgettingEngine(forgetting_config or ForgettingEngineConfig())
|
||||
# now_dt = now or datetime.now()
|
||||
#
|
||||
# reranked: Dict[str, List[Dict[str, Any]]] = {}
|
||||
#
|
||||
# for category in ["statements", "chunks", "entities","summaries"]:
|
||||
# keyword_items = keyword_results.get(category, [])
|
||||
# embedding_items = embedding_results.get(category, [])
|
||||
#
|
||||
# # Normalize scores within each search type
|
||||
# keyword_items = normalize_scores(keyword_items, "score")
|
||||
# embedding_items = normalize_scores(embedding_items, "score")
|
||||
#
|
||||
# combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
#
|
||||
# # Combine two result sets by ID
|
||||
# for src_items, is_embedding in (
|
||||
# (keyword_items, False), (embedding_items, True)
|
||||
# ):
|
||||
# for item in src_items:
|
||||
# item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
# if not item_id:
|
||||
# continue
|
||||
# existing = combined_items.get(item_id)
|
||||
# if not existing:
|
||||
# combined_items[item_id] = item.copy()
|
||||
# combined_items[item_id]["bm25_score"] = 0
|
||||
# combined_items[item_id]["embedding_score"] = 0
|
||||
# # Update normalized score from the right source
|
||||
# if is_embedding:
|
||||
# combined_items[item_id]["embedding_score"] = item.get("normalized_score", 0)
|
||||
# else:
|
||||
# combined_items[item_id]["bm25_score"] = item.get("normalized_score", 0)
|
||||
#
|
||||
# # Calculate scores and apply forgetting weights
|
||||
# for item_id, item in combined_items.items():
|
||||
# bm25_score = float(item.get("bm25_score", 0) or 0)
|
||||
# embedding_score = float(item.get("embedding_score", 0) or 0)
|
||||
# combined_score = alpha * bm25_score + (1 - alpha) * embedding_score
|
||||
#
|
||||
# # Estimate time elapsed in days
|
||||
# dt = _parse_datetime(item.get("created_at"))
|
||||
# if dt is None:
|
||||
# time_elapsed_days = 0.0
|
||||
# else:
|
||||
# time_elapsed_days = max(0.0, (now_dt - dt).total_seconds() / 86400.0)
|
||||
#
|
||||
# # Memory strength (currently set to default value)
|
||||
# memory_strength = 1.0
|
||||
# forgetting_weight = engine.calculate_weight(
|
||||
# time_elapsed=time_elapsed_days, memory_strength=memory_strength
|
||||
# )
|
||||
# final_score = combined_score * forgetting_weight
|
||||
# item["combined_score"] = final_score
|
||||
#
|
||||
# sorted_items = sorted(
|
||||
# combined_items.values(), key=lambda x: x.get("combined_score", 0), reverse=True
|
||||
# )[:limit]
|
||||
#
|
||||
# reranked[category] = sorted_items
|
||||
#
|
||||
# return reranked
|
||||
def _deduplicate_results(items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Remove duplicate items from search results based on content.
|
||||
|
||||
Deduplication strategy:
|
||||
1. First try to deduplicate by ID (id, uuid, or chunk_id)
|
||||
2. Then deduplicate by content hash (text, content, statement, or name fields)
|
||||
|
||||
Args:
|
||||
items: List of search result items
|
||||
|
||||
Returns:
|
||||
Deduplicated list of items, preserving the order of first occurrence
|
||||
"""
|
||||
seen_ids = set()
|
||||
seen_content = set()
|
||||
deduplicated = []
|
||||
|
||||
for item in items:
|
||||
# Try multiple ID fields to identify unique items
|
||||
item_id = item.get("id") or item.get("uuid") or item.get("chunk_id")
|
||||
|
||||
# Extract content from various possible fields
|
||||
content = (
|
||||
item.get("text") or
|
||||
item.get("content") or
|
||||
item.get("statement") or
|
||||
item.get("name") or
|
||||
""
|
||||
)
|
||||
|
||||
# Normalize content for comparison (strip whitespace and lowercase)
|
||||
normalized_content = str(content).strip().lower() if content else ""
|
||||
|
||||
# Check if we've seen this ID or content before
|
||||
is_duplicate = False
|
||||
|
||||
if item_id and item_id in seen_ids:
|
||||
is_duplicate = True
|
||||
elif normalized_content and normalized_content in seen_content:
|
||||
# Only check content duplication if content is not empty
|
||||
is_duplicate = True
|
||||
|
||||
if not is_duplicate:
|
||||
# Mark as seen
|
||||
if item_id:
|
||||
seen_ids.add(item_id)
|
||||
if normalized_content: # Only track non-empty content
|
||||
seen_content.add(normalized_content)
|
||||
|
||||
deduplicated.append(item)
|
||||
|
||||
return deduplicated
|
||||
|
||||
|
||||
def rerank_with_activation(
|
||||
@@ -364,7 +245,7 @@ def rerank_with_activation(
|
||||
keyword_items = normalize_scores(keyword_items, "score")
|
||||
embedding_items = normalize_scores(embedding_items, "score")
|
||||
|
||||
# 步骤 2: 按 ID 合并结果
|
||||
# 步骤 2: 按 ID 合并结果(去重)
|
||||
combined_items: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 添加关键词结果
|
||||
@@ -507,6 +388,9 @@ def rerank_with_activation(
|
||||
# 无激活值:使用内容相关性分数
|
||||
item["final_score"] = item.get("base_score", 0)
|
||||
|
||||
# 最终去重确保没有重复项
|
||||
sorted_items = _deduplicate_results(sorted_items)
|
||||
|
||||
reranked[category] = sorted_items
|
||||
|
||||
return reranked
|
||||
@@ -1144,96 +1028,3 @@ async def search_chunk_by_chunk_id(
|
||||
)
|
||||
return {"chunks": chunks}
|
||||
|
||||
|
||||
# def main():
|
||||
# """Main entry point for the hybrid graph search CLI.
|
||||
|
||||
# Parses command line arguments and executes search with specified parameters.
|
||||
# Supports keyword, embedding, and hybrid search modes.
|
||||
# """
|
||||
# parser = argparse.ArgumentParser(description="Hybrid graph search with keyword and embedding options")
|
||||
# parser.add_argument(
|
||||
# "--query", "-q", required=True, help="Free-text query to search"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--search-type",
|
||||
# "-t",
|
||||
# choices=["keyword", "embedding", "hybrid"],
|
||||
# default="hybrid",
|
||||
# help="Search type: keyword (text matching), embedding (semantic), or hybrid (both) (default: hybrid)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--config-id",
|
||||
# "-c",
|
||||
# type=int,
|
||||
# required=True,
|
||||
# help="Database configuration ID (required)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--group-id",
|
||||
# "-g",
|
||||
# default=None,
|
||||
# help="Optional group_id to filter results (default: None)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--limit",
|
||||
# "-k",
|
||||
# type=int,
|
||||
# default=5,
|
||||
# help="Max number of results per type (default: 5)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--include",
|
||||
# "-i",
|
||||
# nargs="+",
|
||||
# default=["statements", "chunks", "entities", "summaries"],
|
||||
# choices=["statements", "chunks", "entities", "summaries"],
|
||||
# help="Which targets to search for embedding search (default: statements chunks entities summaries)"
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--output",
|
||||
# "-o",
|
||||
# default="search_results.json",
|
||||
# help="Path to save the search results JSON (default: search_results.json)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--rerank-alpha",
|
||||
# "-a",
|
||||
# type=float,
|
||||
# default=0.6,
|
||||
# help="Weight for BM25 scores in reranking (0.0-1.0, higher values favor keyword search) (default: 0.6)",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--forgetting-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply forgetting curve during reranking for hybrid search.",
|
||||
# )
|
||||
# parser.add_argument(
|
||||
# "--llm-rerank",
|
||||
# action="store_true",
|
||||
# help="Apply LLM-based reranking for hybrid search.",
|
||||
# )
|
||||
# args = parser.parse_args()
|
||||
|
||||
# # Load memory config from database
|
||||
# from app.services.memory_config_service import MemoryConfigService
|
||||
# memory_config = MemoryConfigService.load_memory_config(args.config_id)
|
||||
|
||||
# asyncio.run(
|
||||
# run_hybrid_search(
|
||||
# query_text=args.query,
|
||||
# search_type=args.search_type,
|
||||
# group_id=args.group_id,
|
||||
# limit=args.limit,
|
||||
# include=args.include,
|
||||
# output_path=args.output,
|
||||
# memory_config=memory_config,
|
||||
# rerank_alpha=args.rerank_alpha,
|
||||
# use_forgetting_rerank=args.forgetting_rerank,
|
||||
# use_llm_rerank=args.llm_rerank,
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# main()
|
||||
|
||||
@@ -550,7 +550,7 @@ class ExtractionOrchestrator:
|
||||
self, dialog_data_list: List[DialogData]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
从对话中提取情绪信息(优化版:全局陈述句级并行)
|
||||
从对话中提取情绪信息(仅针对用户消息,全局陈述句级并行)
|
||||
|
||||
Args:
|
||||
dialog_data_list: 对话数据列表
|
||||
@@ -558,7 +558,7 @@ class ExtractionOrchestrator:
|
||||
Returns:
|
||||
情绪信息映射列表,每个对话对应一个字典
|
||||
"""
|
||||
logger.info("开始情绪信息提取(全局陈述句级并行)")
|
||||
logger.info("开始情绪信息提取(仅处理用户消息)")
|
||||
|
||||
# 收集所有陈述句及其配置
|
||||
all_statements = []
|
||||
@@ -597,15 +597,22 @@ class ExtractionOrchestrator:
|
||||
if not data_config or not data_config.emotion_enabled:
|
||||
logger.info("情绪提取未启用,跳过")
|
||||
return [{} for _ in dialog_data_list]
|
||||
|
||||
# 收集所有陈述句(只收集 speaker 为 "user" 的)
|
||||
total_statements = 0
|
||||
filtered_statements = 0
|
||||
|
||||
# 收集所有陈述句
|
||||
for d_idx, dialog in enumerate(dialog_data_list):
|
||||
for chunk in dialog.chunks:
|
||||
for statement in chunk.statements:
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
total_statements += 1
|
||||
# 只处理用户的陈述句 (role 为 "user")
|
||||
if hasattr(statement, 'speaker') and statement.speaker == "user":
|
||||
all_statements.append((statement, data_config))
|
||||
statement_metadata.append((d_idx, statement.id))
|
||||
filtered_statements += 1
|
||||
|
||||
logger.info(f"收集到 {len(all_statements)} 个陈述句,开始全局并行提取情绪")
|
||||
logger.info(f"总陈述句: {total_statements}, 用户陈述句: {filtered_statements}, 开始全局并行提取情绪")
|
||||
|
||||
# 初始化情绪提取服务
|
||||
from app.services.emotion_extraction_service import EmotionExtractionService
|
||||
@@ -1033,6 +1040,7 @@ class ExtractionOrchestrator:
|
||||
apply_id=dialog_data.apply_id,
|
||||
run_id=dialog_data.run_id, # 使用 dialog_data 的 run_id
|
||||
statement=statement.statement,
|
||||
speaker=getattr(statement, 'speaker', None), # 添加 speaker 字段
|
||||
statement_embedding=statement.statement_embedding,
|
||||
valid_at=statement.temporal_validity.valid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
invalid_at=statement.temporal_validity.invalid_at if hasattr(statement, 'temporal_validity') and statement.temporal_validity else None,
|
||||
|
||||
@@ -22,12 +22,12 @@ class DialogueChunker:
|
||||
|
||||
Args:
|
||||
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
|
||||
Options include: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker
|
||||
"""
|
||||
self.chunker_strategy = chunker_strategy
|
||||
chunker_config_dict = get_chunker_config(chunker_strategy)
|
||||
self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict)
|
||||
# 对于 LLMChunker,需要传入 llm_client
|
||||
|
||||
if self.chunker_config.chunker_strategy == "LLMChunker":
|
||||
self.chunker_client = ChunkerClient(self.chunker_config, llm_client)
|
||||
else:
|
||||
@@ -41,29 +41,19 @@ class DialogueChunker:
|
||||
|
||||
Returns:
|
||||
A list of Chunk objects
|
||||
|
||||
Raises:
|
||||
ValueError: If chunking fails or returns empty chunks
|
||||
"""
|
||||
result_dialogue = await self.chunker_client.generate_chunks(dialogue)
|
||||
# Defensive fallback: ensure at least one chunk is returned for non-empty content
|
||||
try:
|
||||
chunks = result_dialogue.chunks
|
||||
except Exception:
|
||||
chunks = []
|
||||
chunks = result_dialogue.chunks
|
||||
|
||||
if not chunks or len(chunks) == 0:
|
||||
# If the dialogue has content, return a single fallback chunk built from messages
|
||||
content_str = getattr(result_dialogue, "content", "") or getattr(dialogue, "content", "")
|
||||
if content_str and len(content_str.strip()) > 0:
|
||||
fallback_chunk = Chunk.from_messages(
|
||||
dialogue.context.msgs,
|
||||
metadata={
|
||||
"fallback": "single_chunk",
|
||||
"chunker_strategy": self.chunker_config.chunker_strategy,
|
||||
"source": "DialogueChunkerFallback",
|
||||
},
|
||||
)
|
||||
return [fallback_chunk]
|
||||
# No content: return empty list
|
||||
return []
|
||||
raise ValueError(
|
||||
f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. "
|
||||
f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, "
|
||||
f"Strategy: {self.chunker_config.chunker_strategy}"
|
||||
)
|
||||
|
||||
return chunks
|
||||
|
||||
@@ -72,22 +62,25 @@ class DialogueChunker:
|
||||
|
||||
Args:
|
||||
dialogue: The processed DialogData object with chunks
|
||||
output_path: Optional path to save the output (default: chunker_output_{strategy}.txt)
|
||||
output_path: Optional path to save the output
|
||||
|
||||
Returns:
|
||||
The path where the output was saved
|
||||
"""
|
||||
if not output_path:
|
||||
output_path = os.path.join(os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt")
|
||||
output_path = os.path.join(
|
||||
os.path.dirname(__file__), "..", "..",
|
||||
f"chunker_output_{self.chunker_strategy.lower()}.txt"
|
||||
)
|
||||
|
||||
output_lines = []
|
||||
output_lines.append(f"=== Chunking Results ({self.chunker_strategy}) ===")
|
||||
output_lines.append(f"Dialogue ID: {dialogue.ref_id}")
|
||||
output_lines.append(f"Original conversation has {len(dialogue.context.msgs)} messages")
|
||||
output_lines.append(f"Total characters: {len(dialogue.content)}")
|
||||
|
||||
output_lines.append(f"Generated {len(dialogue.chunks)} chunks:")
|
||||
output_lines = [
|
||||
f"=== Chunking Results ({self.chunker_strategy}) ===",
|
||||
f"Dialogue ID: {dialogue.ref_id}",
|
||||
f"Original conversation has {len(dialogue.context.msgs)} messages",
|
||||
f"Total characters: {len(dialogue.content)}",
|
||||
f"Generated {len(dialogue.chunks)} chunks:"
|
||||
]
|
||||
|
||||
for i, chunk in enumerate(dialogue.chunks):
|
||||
output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters")
|
||||
output_lines.append(f" Content preview: {chunk.content}...")
|
||||
|
||||
@@ -5,8 +5,6 @@ from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.models.message_models import DialogData, Statement
|
||||
|
||||
#避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
from app.core.memory.models.variate_config import StatementExtractionConfig
|
||||
from app.core.memory.utils.data.ontology import (
|
||||
LABEL_DEFINITIONS,
|
||||
@@ -22,11 +20,10 @@ logger = logging.getLogger(__name__)
|
||||
class ExtractedStatement(BaseModel):
|
||||
"""Schema for extracted statement from LLM"""
|
||||
statement: str = Field(..., description="The extracted statement text")
|
||||
statement_type: str = Field(..., description="FACT, OPINION,SUGGESTION or PREDICTION")
|
||||
statement_type: str = Field(..., description="FACT, OPINION, SUGGESTION or PREDICTION")
|
||||
temporal_type: str = Field(..., description="STATIC, DYNAMIC, ATEMPORAL")
|
||||
relevence: str = Field(..., description="RELEVANT or IRRELEVANT")
|
||||
|
||||
# 统一使用 StatementExtractionResponse 作为 LLM 的结构化返回(仅语句)
|
||||
class StatementExtractionResponse(BaseModel):
|
||||
statements: List[ExtractedStatement] = Field(default_factory=list, description="List of extracted statements")
|
||||
|
||||
@@ -58,10 +55,9 @@ class StatementExtractionResponse(BaseModel):
|
||||
return v
|
||||
|
||||
class StatementExtractor:
|
||||
"""Class for extracting statements from dialog chunks using LLM (relations separated)"""
|
||||
"""Class for extracting statements from dialog chunks using LLM"""
|
||||
|
||||
def __init__(self, llm_client: Any, config: StatementExtractionConfig = None):
|
||||
# 避免在测试收集阶段因为 OpenAIClient 间接引入 langfuse 导致 ModuleNotFoundError 。这只是类型注解与导入时机的调整,不改变实现。
|
||||
"""Initialize the StatementExtractor with an LLM client and configuration
|
||||
|
||||
Args:
|
||||
@@ -71,6 +67,21 @@ class StatementExtractor:
|
||||
self.llm_client = llm_client
|
||||
self.config = config or StatementExtractionConfig()
|
||||
|
||||
def _get_speaker_from_chunk(self, chunk) -> Optional[str]:
|
||||
"""Get speaker directly from Chunk
|
||||
|
||||
Args:
|
||||
chunk: Chunk object containing speaker field
|
||||
|
||||
Returns:
|
||||
Speaker role ("user"/"assistant") or None if cannot be determined
|
||||
"""
|
||||
if hasattr(chunk, 'speaker') and chunk.speaker:
|
||||
return chunk.speaker
|
||||
|
||||
logger.warning(f"Chunk {getattr(chunk, 'id', 'unknown')} has no speaker field or is empty")
|
||||
return None
|
||||
|
||||
async def _extract_statements(self, chunk, group_id: Optional[str] = None, dialogue_content: str = None) -> List[Statement]:
|
||||
"""Process a single chunk and return extracted statements
|
||||
|
||||
@@ -82,10 +93,12 @@ class StatementExtractor:
|
||||
Returns:
|
||||
List of ExtractedStatement objects extracted from the chunk
|
||||
"""
|
||||
# Prepare the chunk content for processing
|
||||
chunk_content = chunk.content
|
||||
|
||||
if not chunk_content or len(chunk_content.strip()) < 5:
|
||||
logger.warning(f"Chunk {chunk.id} content too short or empty, skipping")
|
||||
return []
|
||||
|
||||
# Render the prompt using helper function
|
||||
prompt_content = await render_statement_extraction_prompt(
|
||||
chunk_content=chunk_content,
|
||||
definitions=LABEL_DEFINITIONS,
|
||||
@@ -136,7 +149,9 @@ class StatementExtractor:
|
||||
relevence_info = RelevenceInfo[relevence_str] if relevence_str in RelevenceInfo.__members__ else RelevenceInfo.RELEVANT
|
||||
except (KeyError, ValueError):
|
||||
relevence_info = RelevenceInfo.RELEVANT
|
||||
|
||||
|
||||
chunk_speaker = self._get_speaker_from_chunk(chunk)
|
||||
|
||||
chunk_statement = Statement(
|
||||
statement=extracted_stmt.statement,
|
||||
stmt_type=stmt_type,
|
||||
@@ -144,7 +159,9 @@ class StatementExtractor:
|
||||
relevence_info=relevence_info,
|
||||
chunk_id=chunk.id,
|
||||
group_id=group_id,
|
||||
speaker=chunk_speaker,
|
||||
)
|
||||
|
||||
chunk_statements.append(chunk_statement)
|
||||
|
||||
# 分离强弱关系分类:不在句子提取阶段进行,也不写入 chunk.metadata
|
||||
@@ -226,12 +243,7 @@ class StatementExtractor:
|
||||
return output_path
|
||||
|
||||
def save_relations(self, dialogs: List[DialogData], output_path: str = None) -> str:
|
||||
"""按对话分组聚合强/弱关系并写入 TXT 文件。
|
||||
- 每个对话单独成段:输出该对话的 `Dialog ID`、`Group ID`、`Content`
|
||||
- 在该对话段内再分为 Strong Relations / Weak Relations 两部分
|
||||
- Strong: 逐条输出 `Chunk ID` 与 `Triple`
|
||||
- Weak: 逐条输出 `Chunk ID` 与 `Entity`
|
||||
"""
|
||||
"""Group and aggregate strong/weak relations by dialogue and write to TXT file."""
|
||||
print("\n=== Relations Classify ===")
|
||||
|
||||
# 使用全局配置的输出路径
|
||||
|
||||
@@ -18,21 +18,13 @@ from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.memory.llm_tools.openai_client import OpenAIClient
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.config.get_data import (
|
||||
extract_and_process_changes,
|
||||
get_data,
|
||||
get_data_statement,
|
||||
)
|
||||
from app.core.memory.utils.llm.llm_utils import get_llm_client
|
||||
from app.core.memory.utils.prompt.template_render import (
|
||||
render_evaluate_prompt,
|
||||
render_reflexion_prompt,
|
||||
)
|
||||
from app.core.models.base import RedBearModelConfig
|
||||
from app.core.response_utils import success
|
||||
from app.repositories.neo4j.cypher_queries import (
|
||||
UPDATE_STATEMENT_INVALID_AT,
|
||||
neo4j_query_all,
|
||||
neo4j_query_part,
|
||||
neo4j_statement_all,
|
||||
@@ -160,12 +152,11 @@ class ReflectionEngine:
|
||||
self.neo4j_connector = Neo4jConnector()
|
||||
|
||||
if self.llm_client is None:
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
self.llm_client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
self.llm_client = factory.get_llm_client(self.config.model_id)
|
||||
elif isinstance(self.llm_client, str):
|
||||
# 如果 llm_client 是字符串(model_id),则用它初始化客户端
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -263,25 +254,23 @@ class ReflectionEngine:
|
||||
|
||||
# 2. 检测冲突(基于事实的反思)
|
||||
conflict_data = await self._detect_conflicts(reflexion_data, statement_databasets)
|
||||
print(100 * '-')
|
||||
print(conflict_data)
|
||||
print(100 * '-')
|
||||
# # 检查是否真的有冲突
|
||||
conflicts_found=''
|
||||
conflict_list=[]
|
||||
for i in conflict_data:
|
||||
conflict_list.append(i['data'])
|
||||
|
||||
conflicts_found=''
|
||||
|
||||
|
||||
conflicts_found=0
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(conflict_data, statement_databasets)
|
||||
solved_data = await self._resolve_conflicts(conflict_list, statement_databasets)
|
||||
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
message="反思失败,未解决冲突",
|
||||
message=f"没有{self.config.baseline}相关的冲突数据",
|
||||
conflicts_found=conflicts_found,
|
||||
execution_time=asyncio.get_event_loop().time() - start_time
|
||||
)
|
||||
print(100 * '*')
|
||||
print(solved_data)
|
||||
print(100 * '*')
|
||||
|
||||
conflicts_resolved = len(solved_data)
|
||||
logging.info(f"解决了 {conflicts_resolved} 个冲突")
|
||||
@@ -386,7 +375,7 @@ class ReflectionEngine:
|
||||
memory_verifies.append(item['memory_verify'])
|
||||
result_data['memory_verifies'] = memory_verifies
|
||||
result_data['quality_assessments'] = quality_assessments
|
||||
conflicts_found=''
|
||||
conflicts_found = 0 # 初始化为整数0而不是空字符串
|
||||
REMOVE_KEYS = {"created_at", "expired_at","relationship","predicate","statement_id","id","statement_id","relationship_statement_id"}
|
||||
# Clearn conflict_data,And memory_verify和quality_assessment
|
||||
cleaned_conflict_data = []
|
||||
@@ -414,7 +403,7 @@ class ReflectionEngine:
|
||||
cleaned_conflict_data_.append(cleaned_item)
|
||||
print(cleaned_conflict_data_)
|
||||
# 3. 解决冲突
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data, source_data)
|
||||
solved_data = await self._resolve_conflicts(cleaned_conflict_data_, source_data)
|
||||
if not solved_data:
|
||||
return ReflectionResult(
|
||||
success=False,
|
||||
@@ -739,4 +728,3 @@ class ReflectionEngine:
|
||||
|
||||
raise ValueError(f"未知的反思基线: {self.config.baseline}")
|
||||
|
||||
|
||||
|
||||
@@ -14,28 +14,8 @@ from .config_utils import (
|
||||
get_pruning_config,
|
||||
get_voice_config,
|
||||
)
|
||||
|
||||
# DEPRECATED: Global configuration variables removed
|
||||
# Use MemoryConfig objects with dependency injection instead
|
||||
# from .definitions import (
|
||||
# CONFIG, # DEPRECATED - empty dict for backward compatibility
|
||||
# RUNTIME_CONFIG, # DEPRECATED - minimal for backward compatibility
|
||||
# PROJECT_ROOT, # Still needed for file paths
|
||||
# reload_configuration_from_database, # DEPRECATED - returns False
|
||||
# )
|
||||
# DEPRECATED: overrides module removed - use MemoryConfig with dependency injection
|
||||
from .get_data import get_data
|
||||
|
||||
# litellm_config 需要时动态导入,避免循环依赖
|
||||
# from .litellm_config import (
|
||||
# LiteLLMConfig,
|
||||
# setup_litellm_enhanced,
|
||||
# get_usage_summary,
|
||||
# print_usage_summary,
|
||||
# get_instant_qps,
|
||||
# print_instant_qps,
|
||||
# )
|
||||
|
||||
__all__ = [
|
||||
# config_utils
|
||||
"get_model_config",
|
||||
@@ -45,18 +25,5 @@ __all__ = [
|
||||
"get_pruning_config",
|
||||
"get_picture_config",
|
||||
"get_voice_config",
|
||||
# definitions (DEPRECATED - use MemoryConfig objects instead)
|
||||
# "CONFIG", # DEPRECATED
|
||||
# "RUNTIME_CONFIG", # DEPRECATED
|
||||
# "PROJECT_ROOT",
|
||||
# "reload_configuration_from_database", # DEPRECATED
|
||||
# get_data
|
||||
"get_data",
|
||||
# litellm_config - 需要时从 .litellm_config 直接导入
|
||||
# "LiteLLMConfig",
|
||||
# "setup_litellm_enhanced",
|
||||
# "get_usage_summary",
|
||||
# "print_usage_summary",
|
||||
# "get_instant_qps",
|
||||
# "print_instant_qps",
|
||||
]
|
||||
|
||||
@@ -1,398 +0,0 @@
|
||||
"""
|
||||
配置管理优化模块
|
||||
|
||||
提供可选的配置管理优化功能,包括:
|
||||
- LRU 缓存策略
|
||||
- 缓存预热
|
||||
- 缓存监控指标
|
||||
- 动态 TTL 策略
|
||||
- 配置版本控制
|
||||
|
||||
这些优化是可选的,当前的基础实现已经满足大多数需求。
|
||||
"""
|
||||
import logging
|
||||
import statistics
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LRUConfigCache:
|
||||
"""
|
||||
LRU(Least Recently Used)配置缓存
|
||||
|
||||
当缓存达到最大容量时,自动淘汰最少使用的配置
|
||||
"""
|
||||
|
||||
def __init__(self, max_size: int = 100, ttl: timedelta = timedelta(minutes=5)):
|
||||
"""
|
||||
初始化 LRU 缓存
|
||||
|
||||
Args:
|
||||
max_size: 最大缓存容量
|
||||
ttl: 缓存过期时间
|
||||
"""
|
||||
self.max_size = max_size
|
||||
self.ttl = ttl
|
||||
self._cache: OrderedDict[str, Dict[str, Any]] = OrderedDict()
|
||||
self._timestamps: Dict[str, datetime] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
# 统计信息
|
||||
self._stats = {
|
||||
'hits': 0,
|
||||
'misses': 0,
|
||||
'evictions': 0,
|
||||
'load_times': []
|
||||
}
|
||||
|
||||
def get(self, config_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
获取配置(如果存在且未过期)
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
配置字典,如果不存在或已过期则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id not in self._cache:
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 检查是否过期
|
||||
timestamp = self._timestamps.get(config_id)
|
||||
if timestamp and (datetime.now() - timestamp) >= self.ttl:
|
||||
# 过期,移除
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
self._stats['misses'] += 1
|
||||
return None
|
||||
|
||||
# 命中,移动到末尾(标记为最近使用)
|
||||
self._cache.move_to_end(config_id)
|
||||
self._stats['hits'] += 1
|
||||
return self._cache[config_id]
|
||||
|
||||
def put(self, config_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
添加或更新配置
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config: 配置字典
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._cache:
|
||||
# 更新现有配置
|
||||
self._cache.move_to_end(config_id)
|
||||
else:
|
||||
# 添加新配置
|
||||
if len(self._cache) >= self.max_size:
|
||||
# 缓存已满,移除最旧的配置
|
||||
oldest_id, _ = self._cache.popitem(last=False)
|
||||
self._timestamps.pop(oldest_id, None)
|
||||
self._stats['evictions'] += 1
|
||||
logger.debug(f"[LRUCache] 淘汰配置: {oldest_id}")
|
||||
|
||||
self._cache[config_id] = config
|
||||
self._timestamps[config_id] = datetime.now()
|
||||
|
||||
def clear(self, config_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
清除缓存
|
||||
|
||||
Args:
|
||||
config_id: 如果指定,只清除该配置;否则清除所有
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id:
|
||||
self._cache.pop(config_id, None)
|
||||
self._timestamps.pop(config_id, None)
|
||||
else:
|
||||
self._cache.clear()
|
||||
self._timestamps.clear()
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取缓存统计信息
|
||||
|
||||
Returns:
|
||||
统计信息字典
|
||||
"""
|
||||
with self._lock:
|
||||
total = self._stats['hits'] + self._stats['misses']
|
||||
hit_rate = (self._stats['hits'] / total * 100) if total > 0 else 0
|
||||
|
||||
return {
|
||||
'cache_size': len(self._cache),
|
||||
'max_size': self.max_size,
|
||||
'total_requests': total,
|
||||
'cache_hits': self._stats['hits'],
|
||||
'cache_misses': self._stats['misses'],
|
||||
'evictions': self._stats['evictions'],
|
||||
'hit_rate': hit_rate,
|
||||
'avg_load_time': statistics.mean(self._stats['load_times']) if self._stats['load_times'] else 0
|
||||
}
|
||||
|
||||
def record_load_time(self, load_time_ms: float) -> None:
|
||||
"""
|
||||
记录加载时间
|
||||
|
||||
Args:
|
||||
load_time_ms: 加载时间(毫秒)
|
||||
"""
|
||||
with self._lock:
|
||||
self._stats['load_times'].append(load_time_ms)
|
||||
# 只保留最近 1000 次的记录
|
||||
if len(self._stats['load_times']) > 1000:
|
||||
self._stats['load_times'] = self._stats['load_times'][-1000:]
|
||||
|
||||
|
||||
class ConfigCacheWarmer:
|
||||
"""
|
||||
配置缓存预热器
|
||||
|
||||
在系统启动时预加载常用配置,减少首次请求延迟
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def warmup(config_ids: List[str], load_func) -> Dict[str, bool]:
|
||||
"""
|
||||
预热缓存
|
||||
|
||||
Args:
|
||||
config_ids: 要预加载的配置 ID 列表
|
||||
load_func: 配置加载函数
|
||||
|
||||
Returns:
|
||||
每个配置的加载结果
|
||||
"""
|
||||
results = {}
|
||||
|
||||
logger.info(f"[CacheWarmer] 开始预热 {len(config_ids)} 个配置")
|
||||
|
||||
for config_id in config_ids:
|
||||
try:
|
||||
result = load_func(config_id)
|
||||
results[config_id] = result
|
||||
if result:
|
||||
logger.debug(f"[CacheWarmer] 成功预热配置: {config_id}")
|
||||
else:
|
||||
logger.warning(f"[CacheWarmer] 预热配置失败: {config_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CacheWarmer] 预热配置异常: {config_id}, 错误: {e}")
|
||||
results[config_id] = False
|
||||
|
||||
success_count = sum(1 for r in results.values() if r)
|
||||
logger.info(f"[CacheWarmer] 预热完成: {success_count}/{len(config_ids)} 成功")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class DynamicTTLStrategy:
|
||||
"""
|
||||
动态 TTL 策略
|
||||
|
||||
根据配置类型和更新频率动态调整缓存过期时间
|
||||
"""
|
||||
|
||||
# 预定义的 TTL 策略
|
||||
TTL_STRATEGIES = {
|
||||
'production': timedelta(minutes=30), # 生产配置较稳定
|
||||
'staging': timedelta(minutes=15), # 预发布配置中等稳定
|
||||
'development': timedelta(minutes=5), # 开发配置频繁变化
|
||||
'testing': timedelta(minutes=1), # 测试配置快速过期
|
||||
'default': timedelta(minutes=5) # 默认策略
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_ttl(cls, config_id: str, config_type: Optional[str] = None) -> timedelta:
|
||||
"""
|
||||
获取配置的 TTL
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
config_type: 配置类型(production/staging/development/testing)
|
||||
|
||||
Returns:
|
||||
TTL 时间间隔
|
||||
"""
|
||||
if config_type and config_type in cls.TTL_STRATEGIES:
|
||||
return cls.TTL_STRATEGIES[config_type]
|
||||
|
||||
# 根据 config_id 推断类型
|
||||
if 'prod' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['production']
|
||||
elif 'stag' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['staging']
|
||||
elif 'dev' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['development']
|
||||
elif 'test' in config_id.lower():
|
||||
return cls.TTL_STRATEGIES['testing']
|
||||
|
||||
return cls.TTL_STRATEGIES['default']
|
||||
|
||||
|
||||
class ConfigVersionManager:
|
||||
"""
|
||||
配置版本管理器
|
||||
|
||||
跟踪配置版本,当配置更新时自动失效旧版本缓存
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._versions: Dict[str, str] = {}
|
||||
self._lock = threading.RLock()
|
||||
|
||||
def get_version(self, config_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
|
||||
Returns:
|
||||
版本号,如果不存在则返回 None
|
||||
"""
|
||||
with self._lock:
|
||||
return self._versions.get(config_id)
|
||||
|
||||
def set_version(self, config_id: str, version: str) -> None:
|
||||
"""
|
||||
设置配置版本
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
version: 版本号
|
||||
"""
|
||||
with self._lock:
|
||||
old_version = self._versions.get(config_id)
|
||||
self._versions[config_id] = version
|
||||
|
||||
if old_version and old_version != version:
|
||||
logger.info(f"[VersionManager] 配置版本更新: {config_id} {old_version} -> {version}")
|
||||
|
||||
def check_version(self, config_id: str, cached_version: Optional[str]) -> bool:
|
||||
"""
|
||||
检查缓存版本是否有效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
cached_version: 缓存的版本号
|
||||
|
||||
Returns:
|
||||
True 如果版本匹配,False 如果版本不匹配或不存在
|
||||
"""
|
||||
with self._lock:
|
||||
current_version = self._versions.get(config_id)
|
||||
|
||||
if not current_version or not cached_version:
|
||||
return False
|
||||
|
||||
return current_version == cached_version
|
||||
|
||||
def invalidate(self, config_id: str) -> None:
|
||||
"""
|
||||
使配置版本失效
|
||||
|
||||
Args:
|
||||
config_id: 配置 ID
|
||||
"""
|
||||
with self._lock:
|
||||
if config_id in self._versions:
|
||||
# 生成新版本号
|
||||
import uuid
|
||||
new_version = str(uuid.uuid4())
|
||||
self._versions[config_id] = new_version
|
||||
logger.info(f"[VersionManager] 配置版本失效: {config_id} -> {new_version}")
|
||||
|
||||
|
||||
class CacheMonitor:
|
||||
"""
|
||||
缓存监控器
|
||||
|
||||
提供缓存性能监控和报告功能
|
||||
"""
|
||||
|
||||
def __init__(self, cache: LRUConfigCache):
|
||||
self.cache = cache
|
||||
|
||||
def get_report(self) -> str:
|
||||
"""
|
||||
生成缓存性能报告
|
||||
|
||||
Returns:
|
||||
格式化的报告字符串
|
||||
"""
|
||||
stats = self.cache.get_stats()
|
||||
|
||||
report = f"""
|
||||
配置缓存性能报告
|
||||
================
|
||||
缓存容量: {stats['cache_size']}/{stats['max_size']}
|
||||
总请求数: {stats['total_requests']}
|
||||
缓存命中: {stats['cache_hits']}
|
||||
缓存未命中: {stats['cache_misses']}
|
||||
缓存命中率: {stats['hit_rate']:.2f}%
|
||||
淘汰次数: {stats['evictions']}
|
||||
平均加载时间: {stats['avg_load_time']:.2f}ms
|
||||
"""
|
||||
return report
|
||||
|
||||
def log_stats(self) -> None:
|
||||
"""记录统计信息到日志"""
|
||||
stats = self.cache.get_stats()
|
||||
logger.info(
|
||||
f"[CacheMonitor] 缓存统计 - "
|
||||
f"容量: {stats['cache_size']}/{stats['max_size']}, "
|
||||
f"命中率: {stats['hit_rate']:.2f}%, "
|
||||
f"淘汰: {stats['evictions']}"
|
||||
)
|
||||
|
||||
|
||||
# 使用示例
|
||||
def example_usage():
|
||||
"""
|
||||
优化功能使用示例
|
||||
"""
|
||||
# 1. 使用 LRU 缓存
|
||||
lru_cache = LRUConfigCache(max_size=100, ttl=timedelta(minutes=5))
|
||||
|
||||
# 获取配置
|
||||
config = lru_cache.get("config_001")
|
||||
if config is None:
|
||||
# 缓存未命中,从数据库加载
|
||||
config = {"llm_name": "openai/gpt-4"}
|
||||
lru_cache.put("config_001", config)
|
||||
|
||||
# 2. 预热缓存
|
||||
def load_config(config_id):
|
||||
# 实际的配置加载逻辑
|
||||
return True
|
||||
|
||||
warmer = ConfigCacheWarmer()
|
||||
results = warmer.warmup(["config_001", "config_002"], load_config)
|
||||
|
||||
# 3. 动态 TTL
|
||||
ttl = DynamicTTLStrategy.get_ttl("prod_config_001", "production")
|
||||
print(f"TTL: {ttl}")
|
||||
|
||||
# 4. 版本管理
|
||||
version_manager = ConfigVersionManager()
|
||||
version_manager.set_version("config_001", "v1.0.0")
|
||||
|
||||
# 检查版本
|
||||
is_valid = version_manager.check_version("config_001", "v1.0.0")
|
||||
|
||||
# 5. 监控
|
||||
monitor = CacheMonitor(lru_cache)
|
||||
print(monitor.get_report())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
example_usage()
|
||||
@@ -1,268 +0,0 @@
|
||||
# """
|
||||
# 配置加载模块 - DEPRECATED
|
||||
|
||||
# ⚠️ DEPRECATION NOTICE ⚠️
|
||||
# This module is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.schemas.memory_config_schema.MemoryConfig for configuration objects
|
||||
# - config_service = MemoryConfigService(db); config_service.load_memory_config(config_id)
|
||||
|
||||
# 阶段 1: 从 runtime.json 加载配置(路径 A)- DEPRECATED
|
||||
# 阶段 2: 从数据库加载配置(路径 B,基于 dbrun.json 中的 config_id)- DEPRECATED
|
||||
# 阶段 3: 暴露配置常量供项目使用(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# """
|
||||
# import json
|
||||
# import os
|
||||
# import threading
|
||||
# from datetime import datetime, timedelta
|
||||
# from typing import Any, Dict, Optional
|
||||
|
||||
# #TODO: Fix this
|
||||
|
||||
# try:
|
||||
# from dotenv import load_dotenv
|
||||
# load_dotenv()
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# # Import unified configuration system
|
||||
# try:
|
||||
# from app.core.config import settings
|
||||
# USE_UNIFIED_CONFIG = True
|
||||
# except ImportError:
|
||||
# USE_UNIFIED_CONFIG = False
|
||||
# settings = None
|
||||
|
||||
# # PROJECT_ROOT 应该指向 app/core/memory/ 目录
|
||||
# # __file__ = app/core/memory/utils/config/definitions.py
|
||||
# # os.path.dirname(__file__) = app/core/memory/utils/config
|
||||
# # os.path.dirname(...) = app/core/memory/utils
|
||||
# # os.path.dirname(...) = app/core/memory
|
||||
# PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
# # DEPRECATED: Global configuration lock removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
|
||||
# # DEPRECATED: Legacy config.json loading removed
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# CONFIG = {}
|
||||
|
||||
# DEFAULT_VALUES = {
|
||||
# "llm_name": "openai/qwen-plus",
|
||||
# "embedding_name": "openai/nomic-embed-text:v1.5",
|
||||
# "chunker_strategy": "RecursiveChunker",
|
||||
# "group_id": "group_123",
|
||||
# "user_id": "default_user",
|
||||
# "apply_id": "default_apply",
|
||||
# "llm_agent_name": "openai/qwen-plus",
|
||||
# "llm_verify_name": "openai/qwen-plus",
|
||||
# "llm_image_recognition": "openai/qwen-plus",
|
||||
# "llm_voice_recognition": "openai/qwen-plus",
|
||||
# "prompt_level": "DEBUG",
|
||||
# "reflexion_iteration_period": "3",
|
||||
# "reflexion_range": "retrieval",
|
||||
# "reflexion_baseline": "TIME",
|
||||
# }
|
||||
|
||||
# # DEPRECATED: Legacy global variables for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# # Use MemoryConfig objects with dependency injection instead
|
||||
# # LANGFUSE_ENABLED = os.getenv("LANGFUSE_ENABLED", "false").lower() == "true"
|
||||
# # SELECTED_LLM_ID = os.getenv("SELECTED_LLM_ID", DEFAULT_VALUES["llm_name"])
|
||||
|
||||
|
||||
# # 阶段 1: 从 runtime.json 加载配置(路径 A)
|
||||
# def _load_from_runtime_json() -> Dict[str, Any]:
|
||||
# """
|
||||
# DEPRECATED: Legacy runtime.json loading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Dict[str, Any]: Empty configuration (legacy support only)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Runtime JSON loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return {"selections": {}}
|
||||
|
||||
|
||||
# # 阶段 2: 从数据库加载配置(路径 B)- 已整合到统一加载器
|
||||
# # 注意:此函数已被 _load_from_runtime_json 中的统一配置加载器替代
|
||||
# # 保留此函数仅为向后兼容
|
||||
# def _load_from_database() -> Optional[Dict[str, Any]]:
|
||||
# """
|
||||
# DEPRECATED: Legacy database configuration loading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[Dict[str, Any]]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Database configuration loading is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# # 阶段 3: 暴露配置常量(路径 A 和 B 的汇合点)- DEPRECATED
|
||||
# def _expose_runtime_constants(runtime_cfg: Dict[str, Any]) -> None:
|
||||
# """
|
||||
# DEPRECATED: 将运行时配置暴露为全局常量供项目使用
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Global configuration variables have been eliminated in favor of dependency injection.
|
||||
|
||||
# Use the new MemoryConfig system instead:
|
||||
# - app.core.memory_config.config.MemoryConfig for configuration objects
|
||||
# - Pass configuration objects as parameters instead of using global variables
|
||||
|
||||
# Args:
|
||||
# runtime_cfg: 运行时配置字典
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration variables are deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# # Keep minimal global state for backward compatibility only
|
||||
# # These will be removed in a future version
|
||||
# global RUNTIME_CONFIG, SELECTIONS
|
||||
|
||||
# RUNTIME_CONFIG = runtime_cfg
|
||||
# SELECTIONS = RUNTIME_CONFIG.get("selections", {})
|
||||
|
||||
# # All other global variables have been removed
|
||||
# # Use MemoryConfig objects instead
|
||||
|
||||
|
||||
# # 初始化:使用统一配置加载器
|
||||
# def _initialize_configuration() -> None:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration initialization
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "Global configuration initialization is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# # Initialize with empty configuration for backward compatibility
|
||||
# _expose_runtime_constants({"selections": {}})
|
||||
|
||||
|
||||
# # 模块加载时自动初始化配置
|
||||
# _initialize_configuration()
|
||||
|
||||
# # DEPRECATED: Global variables removed
|
||||
# # These variables have been eliminated in favor of dependency injection
|
||||
# # Use MemoryConfig objects instead of accessing global variables
|
||||
|
||||
|
||||
# # 公共 API:动态重新加载配置
|
||||
# def reload_configuration_from_database(config_id, force_reload: bool = False) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration reloading
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
# force_reload: Force reload flag (deprecated)
|
||||
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# warnings.warn(
|
||||
# "reload_configuration_from_database is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# logger.warning(f"Deprecated function reload_configuration_from_database called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def get_current_config_id() -> Optional[str]:
|
||||
# """
|
||||
# DEPRECATED: Legacy config ID retrieval
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# Returns:
|
||||
# Optional[str]: None (deprecated functionality)
|
||||
# """
|
||||
# import warnings
|
||||
# warnings.warn(
|
||||
# "get_current_config_id is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
# return None
|
||||
|
||||
|
||||
# def ensure_fresh_config(config_id = None) -> bool:
|
||||
# """
|
||||
# DEPRECATED: Legacy configuration freshness check
|
||||
|
||||
# ⚠️ This function is deprecated and will be removed in a future version.
|
||||
# Use MemoryConfig objects with dependency injection instead.
|
||||
|
||||
# For new code, use:
|
||||
# - app.services.memory_agent_service.MemoryAgentService.load_memory_config()
|
||||
# - app.services.memory_storage_service.MemoryStorageService.load_memory_config()
|
||||
|
||||
# Args:
|
||||
# config_id: Configuration ID (deprecated)
|
||||
|
||||
# Returns:
|
||||
# bool: Always returns False (deprecated functionality)
|
||||
# """
|
||||
# import logging
|
||||
# import warnings
|
||||
|
||||
# logger = logging.getLogger(__name__)
|
||||
|
||||
# warnings.warn(
|
||||
# "ensure_fresh_config is deprecated. Use MemoryConfig objects with dependency injection instead.",
|
||||
# DeprecationWarning,
|
||||
# stacklevel=2
|
||||
# )
|
||||
|
||||
# logger.warning(f"Deprecated function ensure_fresh_config called with config_id={config_id}. "
|
||||
# "Use MemoryConfig objects with dependency injection instead.")
|
||||
|
||||
# return False
|
||||
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import random
|
||||
import string
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
# 生成包含字母(大小写)和数字的随机字符串
|
||||
def generate_random_string(length=16):
|
||||
characters = string.ascii_letters + string.digits
|
||||
return ''.join(random.choice(characters) for _ in range(length))
|
||||
|
||||
def get_example_data() -> List[Dict[str, Optional[str]]]:
|
||||
"""
|
||||
从句子提取日志中获取数据
|
||||
Content: 在苹果公司中国总部,用户和李华偶遇了从美国来的技术专家约翰·史密斯。
|
||||
Created At: 2025-11-28 19:28:38.256421
|
||||
Expired At: None
|
||||
Valid At: None
|
||||
Invalid At: None
|
||||
将数据构造成如下形式:
|
||||
[
|
||||
{
|
||||
"id":id,
|
||||
"group_id":group_id,
|
||||
"statement": Content,
|
||||
"created_at": Created At,
|
||||
"expired_at": Expired At,
|
||||
"valid_at": Valid At,
|
||||
"invalid_at": Invalid At,
|
||||
"chunk_id": "86da9022710c40eaa5f518a294c398d2",
|
||||
"entity_ids": []
|
||||
},
|
||||
...
|
||||
]
|
||||
"""
|
||||
# 获取日志文件路径
|
||||
log_file_path = os.path.join("logs", "memory-output", "statement_extraction.txt")
|
||||
|
||||
# 检查文件是否存在
|
||||
if not os.path.exists(log_file_path):
|
||||
return []
|
||||
|
||||
# 读取日志文件
|
||||
with open(log_file_path, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 解析数据
|
||||
results = []
|
||||
|
||||
# 使用正则表达式分割每个 Statement
|
||||
statement_blocks = re.split(r"Statement \d+:", content)
|
||||
|
||||
for block in statement_blocks[1:]: # 跳过第一个空块
|
||||
# 提取各个字段
|
||||
id_match = re.search(r"Id:\s*(.+?)(?=\n)", block)
|
||||
group_id_match = re.search(r"Group Id:\s*(.+?)(?=\n)", block)
|
||||
statement_match = re.search(r"Content:\s*(.+?)(?=\n)", block)
|
||||
created_at_match = re.search(r"Created At:\s*(.+?)(?=\n)", block)
|
||||
expired_at_match = re.search(r"Expired At:\s*(.+?)(?=\n)", block)
|
||||
valid_at_match = re.search(r"Valid At:\s*(.+?)(?=\n)", block)
|
||||
invalid_at_match = re.search(r"Invalid At:\s*(.+?)(?=\n)", block)
|
||||
chunk_id_match = re.search(r"Chunk Id:\s*(.+?)(?=\n)", block)
|
||||
|
||||
# 构造字典
|
||||
if statement_match:
|
||||
statement_data = {
|
||||
"id": id_match.group(1).strip() if id_match else generate_random_string(),
|
||||
"group_id": group_id_match.group(1).strip() if group_id_match else "group_example",
|
||||
"statement": statement_match.group(1).strip(),
|
||||
"created_at": created_at_match.group(1).strip() if created_at_match else None,
|
||||
"expired_at": expired_at_match.group(1).strip() if expired_at_match else None,
|
||||
"valid_at": valid_at_match.group(1).strip() if valid_at_match else None,
|
||||
"invalid_at": invalid_at_match.group(1).strip() if invalid_at_match else None,
|
||||
"chunk_id": chunk_id_match.group(1).strip() if chunk_id_match else "chunk_example",
|
||||
"entity_ids": []
|
||||
}
|
||||
|
||||
# 将 "None" 字符串转换为 None
|
||||
for key in ["created_at", "expired_at", "valid_at", "invalid_at"]:
|
||||
if statement_data[key] == "None":
|
||||
statement_data[key] = None
|
||||
|
||||
results.append(statement_data)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"获取数据如下:\n {get_example_data()}")
|
||||
@@ -1,516 +0,0 @@
|
||||
"""
|
||||
LiteLLM Configuration for Enhanced Retry Logic and Usage Tracking with Native QPS Monitoring
|
||||
"""
|
||||
|
||||
import litellm
|
||||
from typing import Dict, Any, List
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
from queue import Queue
|
||||
|
||||
class LiteLLMConfig:
|
||||
"""Configuration class for LiteLLM with enhanced retry and tracking capabilities"""
|
||||
|
||||
def __init__(self):
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [], # Store precise timestamps
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': [] # Store QPS measurements over time
|
||||
})
|
||||
self.start_time = datetime.now()
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
|
||||
# Rate limiting for AWS Bedrock (conservative limits)
|
||||
self.rate_limits = {
|
||||
'bedrock': {
|
||||
'requests_per_minute': 2, # AWS Bedrock default is very low
|
||||
'requests_per_second': 0.033, # 2/60 = 0.033 RPS
|
||||
'last_request_time': 0,
|
||||
'request_queue': Queue(),
|
||||
'lock': threading.Lock()
|
||||
}
|
||||
}
|
||||
self.rate_limiting_enabled = True
|
||||
|
||||
def setup_enhanced_config(self, max_retries: int = 3):
|
||||
"""Configure LiteLLM with retry logic and instant QPS tracking"""
|
||||
|
||||
litellm.num_retries = max_retries
|
||||
litellm.request_timeout = 300
|
||||
|
||||
litellm.retry_policy = {
|
||||
"RateLimitError": {
|
||||
"max_retries": 5,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"APIConnectionError": {
|
||||
"max_retries": 3,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 2,
|
||||
"max_delay": 30,
|
||||
"jitter": True
|
||||
},
|
||||
"InternalServerError": {
|
||||
"max_retries": 2,
|
||||
"exponential_backoff": True,
|
||||
"initial_delay": 5,
|
||||
"max_delay": 60,
|
||||
"jitter": True
|
||||
},
|
||||
"BadRequestError": {
|
||||
"max_retries": 1,
|
||||
"exponential_backoff": False,
|
||||
"initial_delay": 1,
|
||||
"max_delay": 5
|
||||
}
|
||||
}
|
||||
|
||||
litellm.success_callback = [self._success_callback]
|
||||
litellm.failure_callback = [self._failure_callback]
|
||||
litellm.completion_cost_tracking = True
|
||||
litellm.set_verbose = False
|
||||
litellm.modify_params = True
|
||||
|
||||
print("✅ LiteLLM configured with instant QPS tracking and rate limiting")
|
||||
|
||||
def _success_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for successful requests with module-specific QPS tracking"""
|
||||
try:
|
||||
# Extract usage information
|
||||
usage = completion_response.get('usage', {})
|
||||
model = kwargs.get('model', 'unknown')
|
||||
|
||||
# Extract module information from metadata or model name
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
# Calculate cost
|
||||
cost = 0.0
|
||||
try:
|
||||
cost = litellm.completion_cost(completion_response)
|
||||
except:
|
||||
pass
|
||||
|
||||
# Calculate duration
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Record usage data
|
||||
usage_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"input_tokens": usage.get('prompt_tokens', 0),
|
||||
"output_tokens": usage.get('completion_tokens', 0),
|
||||
"total_tokens": usage.get('total_tokens', 0),
|
||||
"cost": cost,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "success"
|
||||
}
|
||||
|
||||
self.usage_data.append(usage_record)
|
||||
|
||||
# Update module-specific stats for QPS tracking
|
||||
self._update_module_stats(module, usage_record, success=True)
|
||||
|
||||
# Print real-time feedback
|
||||
print(f"✓ {model}: {usage_record['input_tokens']}→{usage_record['output_tokens']} tokens, ${cost:.4f}, {usage_record['duration_seconds']:.2f}s")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Success callback failed: {e}")
|
||||
|
||||
def _failure_callback(self, kwargs, completion_response, start_time, end_time):
|
||||
"""Callback for failed requests with module-specific error tracking"""
|
||||
try:
|
||||
model = kwargs.get('model', 'unknown')
|
||||
module = self._extract_module_name(kwargs, model)
|
||||
|
||||
duration_seconds = (end_time - start_time).total_seconds() if hasattr(end_time - start_time, 'total_seconds') else float(end_time - start_time)
|
||||
|
||||
# Handle different error response formats
|
||||
error_message = "Unknown error"
|
||||
error_type = "UnknownError"
|
||||
|
||||
# According to LiteLLM docs, completion_response contains the exception for failures
|
||||
if completion_response is not None:
|
||||
error_message = str(completion_response)
|
||||
error_type = type(completion_response).__name__
|
||||
|
||||
# Also check kwargs for exception (LiteLLM passes exception in kwargs for failure events)
|
||||
elif 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
# Check for other error formats in kwargs
|
||||
elif 'error' in kwargs:
|
||||
error = kwargs['error']
|
||||
error_message = str(error)
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Check log_event_type to confirm this is a failure event
|
||||
log_event_type = kwargs.get('log_event_type', '')
|
||||
if log_event_type == 'failed_api_call' and 'exception' in kwargs:
|
||||
exception = kwargs['exception']
|
||||
error_message = str(exception)
|
||||
error_type = type(exception).__name__
|
||||
|
||||
error_record = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"model": model,
|
||||
"module": module,
|
||||
"error": error_message,
|
||||
"error_type": error_type,
|
||||
"duration_seconds": duration_seconds,
|
||||
"status": "failed"
|
||||
}
|
||||
|
||||
self.error_data.append(error_record)
|
||||
|
||||
# Update module-specific stats for error tracking
|
||||
self._update_module_stats(module, error_record, success=False)
|
||||
|
||||
# Print error feedback
|
||||
print(f"✗ {model}: {error_type} - {error_message[:100]}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Warning: Failure callback failed: {e}")
|
||||
# Debug: print the actual parameters to understand the structure
|
||||
print(f"Debug - kwargs keys: {list(kwargs.keys()) if kwargs else 'None'}")
|
||||
print(f"Debug - completion_response type: {type(completion_response)}")
|
||||
print(f"Debug - completion_response: {completion_response}")
|
||||
|
||||
def _should_rate_limit(self, model: str) -> bool:
|
||||
"""Check if the model should be rate limited"""
|
||||
if not self.rate_limiting_enabled:
|
||||
return False
|
||||
return model.startswith('bedrock/') or 'bedrock' in model.lower()
|
||||
|
||||
def _enforce_rate_limit(self, model: str):
|
||||
"""Enforce rate limiting for AWS Bedrock models"""
|
||||
if not self._should_rate_limit(model):
|
||||
return
|
||||
|
||||
provider = 'bedrock'
|
||||
if provider not in self.rate_limits:
|
||||
return
|
||||
|
||||
rate_config = self.rate_limits[provider]
|
||||
|
||||
with rate_config['lock']:
|
||||
current_time = time.time()
|
||||
time_since_last = current_time - rate_config['last_request_time']
|
||||
min_interval = 1.0 / rate_config['requests_per_second']
|
||||
|
||||
if time_since_last < min_interval:
|
||||
sleep_time = min_interval - time_since_last
|
||||
print(f"⏳ Rate limiting: sleeping {sleep_time:.2f}s for {model}")
|
||||
time.sleep(sleep_time)
|
||||
|
||||
rate_config['last_request_time'] = time.time()
|
||||
|
||||
def _extract_module_name(self, kwargs: Dict[str, Any], model: str) -> str:
|
||||
"""Extract module name from request context"""
|
||||
# Try to get module from metadata
|
||||
metadata = kwargs.get('metadata', {})
|
||||
if 'module' in metadata:
|
||||
return metadata['module']
|
||||
|
||||
# Try to infer from model name or other context
|
||||
if 'claude' in model.lower():
|
||||
return 'bedrock_client'
|
||||
elif 'gpt' in model.lower() or 'openai' in model.lower():
|
||||
return 'openai_client'
|
||||
elif 'embed' in model.lower():
|
||||
return 'embedder'
|
||||
else:
|
||||
return 'unknown'
|
||||
|
||||
def _update_module_stats(self, module: str, record: Dict[str, Any], success: bool):
|
||||
"""Update module-specific statistics with instant QPS tracking"""
|
||||
current_timestamp = time.time()
|
||||
current_time = datetime.now()
|
||||
|
||||
# Initialize module stats if first request
|
||||
if self.module_stats[module]['start_time'] is None:
|
||||
self.module_stats[module]['start_time'] = current_time
|
||||
|
||||
# Update counters
|
||||
self.module_stats[module]['requests'] += 1
|
||||
self.module_stats[module]['last_request_time'] = current_time
|
||||
self.module_stats[module]['request_timestamps'].append(current_timestamp)
|
||||
self.global_request_timestamps.append(current_timestamp)
|
||||
|
||||
# Calculate instant QPS for this module
|
||||
self._calculate_instant_qps(module, current_timestamp)
|
||||
|
||||
# Calculate global instant QPS
|
||||
self._calculate_global_instant_qps(current_timestamp)
|
||||
|
||||
if success:
|
||||
self.module_stats[module]['tokens_in'] += record.get('input_tokens', 0)
|
||||
self.module_stats[module]['tokens_out'] += record.get('output_tokens', 0)
|
||||
self.module_stats[module]['cost'] += record.get('cost', 0.0)
|
||||
else:
|
||||
self.module_stats[module]['errors'] += 1
|
||||
|
||||
def _calculate_instant_qps(self, module: str, current_timestamp: float):
|
||||
"""Calculate instant QPS for a specific module using sliding window"""
|
||||
# Keep only timestamps from last 1 second for instant QPS
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
timestamps = self.module_stats[module]['request_timestamps']
|
||||
|
||||
# Remove old timestamps
|
||||
self.module_stats[module]['request_timestamps'] = [
|
||||
ts for ts in timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current QPS (requests in last second)
|
||||
current_qps = len(self.module_stats[module]['request_timestamps'])
|
||||
self.module_stats[module]['current_qps'] = current_qps
|
||||
|
||||
# Update max QPS if current is higher
|
||||
if current_qps > self.module_stats[module]['max_qps']:
|
||||
self.module_stats[module]['max_qps'] = current_qps
|
||||
|
||||
# Store QPS history (keep last 60 measurements)
|
||||
self.module_stats[module]['qps_history'].append(current_qps)
|
||||
if len(self.module_stats[module]['qps_history']) > 60:
|
||||
self.module_stats[module]['qps_history'].pop(0)
|
||||
|
||||
def _calculate_global_instant_qps(self, current_timestamp: float):
|
||||
"""Calculate global instant QPS across all modules"""
|
||||
# Keep only timestamps from last 1 second
|
||||
cutoff_time = current_timestamp - 1.0
|
||||
self.global_request_timestamps = [
|
||||
ts for ts in self.global_request_timestamps if ts >= cutoff_time
|
||||
]
|
||||
|
||||
# Calculate current global QPS
|
||||
current_global_qps = len(self.global_request_timestamps)
|
||||
|
||||
# Update max global QPS
|
||||
if current_global_qps > self.global_max_qps:
|
||||
self.global_max_qps = current_global_qps
|
||||
|
||||
def get_instant_qps(self, module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
if module:
|
||||
if module in self.module_stats:
|
||||
return {
|
||||
'module': module,
|
||||
'current_qps': self.module_stats[module]['current_qps'],
|
||||
'max_qps': self.module_stats[module]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[module]['qps_history'][-60:]) / min(60, len(self.module_stats[module]['qps_history'])) if self.module_stats[module]['qps_history'] else 0
|
||||
}
|
||||
else:
|
||||
return {'module': module, 'current_qps': 0, 'max_qps': 0, 'avg_qps_last_minute': 0}
|
||||
else:
|
||||
# Return data for all modules plus global
|
||||
result = {
|
||||
'global': {
|
||||
'current_qps': len([ts for ts in self.global_request_timestamps if ts >= time.time() - 1.0]),
|
||||
'max_qps': self.global_max_qps
|
||||
},
|
||||
'modules': {}
|
||||
}
|
||||
|
||||
for mod in self.module_stats:
|
||||
result['modules'][mod] = {
|
||||
'current_qps': self.module_stats[mod]['current_qps'],
|
||||
'max_qps': self.module_stats[mod]['max_qps'],
|
||||
'avg_qps_last_minute': sum(self.module_stats[mod]['qps_history'][-60:]) / min(60, len(self.module_stats[mod]['qps_history'])) if self.module_stats[mod]['qps_history'] else 0
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_usage_summary(self) -> Dict[str, Any]:
|
||||
"""Get essential usage statistics"""
|
||||
if not self.usage_data:
|
||||
return {
|
||||
"total_requests": 0,
|
||||
"total_cost": 0.0,
|
||||
"error_rate": 0.0,
|
||||
"message": "No usage data available"
|
||||
}
|
||||
|
||||
total_requests = len(self.usage_data)
|
||||
total_errors = len(self.error_data)
|
||||
total_cost = sum(record['cost'] for record in self.usage_data)
|
||||
total_input_tokens = sum(record['input_tokens'] for record in self.usage_data)
|
||||
total_output_tokens = sum(record['output_tokens'] for record in self.usage_data)
|
||||
|
||||
# Calculate session duration
|
||||
duration_minutes = (datetime.now() - self.start_time).total_seconds() / 60
|
||||
|
||||
# Build module statistics
|
||||
module_stats = {}
|
||||
for module, stats in self.module_stats.items():
|
||||
if stats['requests'] > 0:
|
||||
module_stats[module] = {
|
||||
"requests": stats['requests'],
|
||||
"errors": stats['errors'],
|
||||
"success_rate": ((stats['requests'] - stats['errors']) / stats['requests'] * 100) if stats['requests'] > 0 else 0,
|
||||
"tokens_in": stats['tokens_in'],
|
||||
"tokens_out": stats['tokens_out'],
|
||||
"cost": stats['cost'],
|
||||
"current_qps": stats['current_qps'],
|
||||
"max_qps": stats['max_qps']
|
||||
}
|
||||
|
||||
return {
|
||||
"session_duration_minutes": duration_minutes,
|
||||
"total_requests": total_requests,
|
||||
"total_errors": total_errors,
|
||||
"error_rate": (total_errors / total_requests * 100) if total_requests > 0 else 0,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_cost": total_cost,
|
||||
"module_stats": module_stats,
|
||||
"global_max_qps": self.global_max_qps
|
||||
}
|
||||
|
||||
def print_usage_summary(self):
|
||||
"""Print essential usage summary"""
|
||||
stats = self.get_usage_summary()
|
||||
|
||||
if stats.get('message'):
|
||||
print(f"📊 {stats['message']}")
|
||||
return
|
||||
|
||||
print("\n📊 USAGE SUMMARY")
|
||||
print(f"{'='*50}")
|
||||
print(f"⏱️ Duration: {stats['session_duration_minutes']:.1f} min")
|
||||
print(f"📈 Requests: {stats['total_requests']}")
|
||||
print(f"❌ Errors: {stats['total_errors']}")
|
||||
print(f"💰 Cost: ${stats['total_cost']:.4f}")
|
||||
print(f"🏆 Global Max QPS: {stats['global_max_qps']}")
|
||||
|
||||
# Module statistics
|
||||
if stats.get('module_stats'):
|
||||
print("\n📦 MODULES:")
|
||||
for module, mod_stats in stats['module_stats'].items():
|
||||
print(f" {module}: {mod_stats['requests']} req, Max QPS: {mod_stats['max_qps']}, Current: {mod_stats['current_qps']}")
|
||||
|
||||
print(f"{'='*50}")
|
||||
|
||||
def save_usage_data(self, filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to JSON file"""
|
||||
data = {
|
||||
"summary": self.get_usage_summary(),
|
||||
"detailed_usage": self.usage_data,
|
||||
"errors": self.error_data,
|
||||
"export_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
with open(filename, 'w') as f:
|
||||
json.dump(data, f, indent=2)
|
||||
|
||||
print(f"📁 Usage data saved to {filename}")
|
||||
|
||||
def reset_tracking(self):
|
||||
"""Reset all tracking data"""
|
||||
self.usage_data = []
|
||||
self.error_data = []
|
||||
self.module_stats = defaultdict(lambda: {
|
||||
'requests': 0,
|
||||
'tokens_in': 0,
|
||||
'tokens_out': 0,
|
||||
'cost': 0.0,
|
||||
'errors': 0,
|
||||
'start_time': None,
|
||||
'last_request_time': None,
|
||||
'request_timestamps': [],
|
||||
'current_qps': 0.0,
|
||||
'max_qps': 0.0,
|
||||
'qps_history': []
|
||||
})
|
||||
self.global_request_timestamps = []
|
||||
self.global_max_qps = 0.0
|
||||
self.start_time = datetime.now()
|
||||
print("🔄 All tracking data reset")
|
||||
|
||||
# Global instance for easy access
|
||||
litellm_config = LiteLLMConfig()
|
||||
|
||||
def setup_litellm_enhanced(max_retries: int = 3):
|
||||
"""
|
||||
Quick setup function for LiteLLM enhanced configuration
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retries for failed requests
|
||||
"""
|
||||
litellm_config.setup_enhanced_config(max_retries)
|
||||
return litellm_config
|
||||
|
||||
def get_usage_summary():
|
||||
"""Get current usage summary"""
|
||||
return litellm_config.get_usage_summary()
|
||||
|
||||
def print_usage_summary():
|
||||
"""Print current usage summary"""
|
||||
litellm_config.print_usage_summary()
|
||||
|
||||
def save_usage_data(filename: str = "litellm_usage.json"):
|
||||
"""Save usage data to file"""
|
||||
litellm_config.save_usage_data(filename)
|
||||
|
||||
def get_instant_qps(module: str = None) -> Dict[str, Any]:
|
||||
"""Get instant QPS data for modules"""
|
||||
return litellm_config.get_instant_qps(module)
|
||||
|
||||
def print_instant_qps(module: str = None):
|
||||
"""Print instant QPS information"""
|
||||
qps_data = get_instant_qps(module)
|
||||
|
||||
print("\n⚡ INSTANT QPS MONITOR")
|
||||
print(f"{'='*60}")
|
||||
|
||||
if module:
|
||||
print(f"Module: {qps_data['module']}")
|
||||
print(f" Current QPS: {qps_data['current_qps']}")
|
||||
print(f" Max QPS: {qps_data['max_qps']}")
|
||||
print(f" Avg (1min): {qps_data['avg_qps_last_minute']:.2f}")
|
||||
else:
|
||||
# Global stats
|
||||
global_data = qps_data.get('global', {})
|
||||
print("🌍 GLOBAL:")
|
||||
print(f" Current QPS: {global_data.get('current_qps', 0)}")
|
||||
print(f" Max QPS: {global_data.get('max_qps', 0)}")
|
||||
|
||||
# Module stats
|
||||
modules = qps_data.get('modules', {})
|
||||
if modules:
|
||||
print("\n📦 MODULES:")
|
||||
for mod, data in modules.items():
|
||||
print(f" {mod}:")
|
||||
print(f" Current: {data['current_qps']} QPS")
|
||||
print(f" Max: {data['max_qps']} QPS")
|
||||
print(f" Avg: {data['avg_qps_last_minute']:.2f} QPS")
|
||||
|
||||
print(f"{'='*60}")
|
||||
|
||||
def reset_tracking():
|
||||
"""Reset all tracking data"""
|
||||
litellm_config.reset_tracking()
|
||||
|
||||
def get_module_stats() -> Dict[str, Dict[str, Any]]:
|
||||
"""Get detailed module statistics"""
|
||||
summary = get_usage_summary()
|
||||
return summary.get('module_stats', {})
|
||||
@@ -24,7 +24,8 @@
|
||||
- **身份冲突**: 同一实体被赋予不同类型或角色
|
||||
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||
### 混合冲突
|
||||
检测所有逻辑不一致或相互矛盾的记录。
|
||||
- 检测所有逻辑不一致或相互矛盾的记录。
|
||||
- **隐私审核**: 存在隐私信息也作为冲突输出当{{ memory_verify }}是true的时候
|
||||
**检测原则**:
|
||||
- 重点检查相同实体的记录
|
||||
- 分析description字段语义冲突
|
||||
|
||||
@@ -63,7 +63,7 @@
|
||||
**脱敏字段**: name、entity1_name、entity2_name、description、relationship
|
||||
|
||||
## 4. 处理流程
|
||||
|
||||
###如果存在冲突数据执行以下步骤,不存在返回【】在data中
|
||||
### 步骤1: 类型匹配验证
|
||||
**匹配规则**:
|
||||
- baseline="TIME": 只处理时间相关冲突(涉及时间表达式、日期、时间点)
|
||||
@@ -78,7 +78,7 @@
|
||||
|
||||
### 步骤2: 冲突数据分组
|
||||
**分组策略**:
|
||||
- 时间冲突组: 涉及用户时间的记录
|
||||
- 时间冲突组: 涉及用户时间的记录比如(生日在2月17...)
|
||||
- 活动时间冲突组: 同一活动不同时间的记录
|
||||
- 事实冲突组: 同一实体不同属性的记录
|
||||
- 其他冲突组: 其他类型冲突记录
|
||||
@@ -97,11 +97,12 @@
|
||||
### 处理规则
|
||||
|
||||
** baseline是TIME
|
||||
-保留正确记录不变修改错误记录的expired_at为当前时间(2025-12-16T12:00:00),以及name需要修改成正确的
|
||||
** baseline不是TIME
|
||||
- 只处理时间相关的内容,比如时间表达式、日期、时间点
|
||||
-保留正确记录不变修改错误记录的expired_at为当前时间,比如(2025-12-16T12:00:00)
|
||||
** baseline是FACT或者HYBRID
|
||||
- 处理不是时间相关的内容
|
||||
- 修改字段内容( name、entity1_name、entity2_name、description、relationship)字段内容是否正确,如果不正确,需要对这些字段的内容重新生成,则不需要修改expired_at字段,
|
||||
如果涉及到修改entity1_name/entity2_name字段的时候,同时也需要修改description字段,输出修改前和修改后的放入change里面的field
|
||||
|
||||
**核心原则**:
|
||||
- 只输出需要修改的记录
|
||||
- 优先保留策略: 时间冲突保留最可信created_at时间,事实冲突选择最新且可信度最高记录
|
||||
@@ -110,22 +111,26 @@
|
||||
- 脱敏变更记录: 隐私脱敏变更也必须在change字段中记录{% endif %}
|
||||
- 不可修改数据: 数据被判定为正确时不可修改,无数据可输出时为空
|
||||
- 输出的结果reflexion字段中的reason字段和solution不允许含有(expired_at设为2024-01-01T00:00:00Z、memory_verify=true、memory_verify=false)等原数据字段以及涉及需要修改的字段以及内容,
|
||||
,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据
|
||||
,如果是FACT,只记录事实冲突相关的数据;如果是TIME,只记录时间冲突相关的数据;如果是HYBRID,则记录所有冲突相关的数据,如果存在隐私审核,隐私审核是true,也需要放到reflexion的reason字段和solution
|
||||
|
||||
**变更记录格式**:
|
||||
```json
|
||||
"change": [
|
||||
{
|
||||
"field": [
|
||||
{"id":修改字段对应的ID}
|
||||
{"statement_id":需要修改的对象对应的statement_id}
|
||||
{"字段名1": ["修改前的值1","修改后的值1"]},
|
||||
{"字段名2": ["修改前的值2","修改后的值2"]}
|
||||
{"id": "修改字段对应的ID"},
|
||||
{"字段名1": ["修改前的值1", "修改后的值1"]},
|
||||
{"字段名2": ["修改前的值2", "修改后的值2"]}
|
||||
]
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
**resolved_memory格式说明**:
|
||||
- 对于TIME类型冲突: 只需expired_at字段即可
|
||||
- 对于FACT/HYBRID类型冲突: 需要包含完整的记录对象(包括name、entity1_name、entity2_name、description、relationship等所有相关字段)
|
||||
- resolved_memory中只包含需要修改的记录,不需要修改的记录不要包含在内
|
||||
|
||||
**类型不匹配处理**:
|
||||
- 冲突类型与baseline不匹配时,resolved设为null
|
||||
- reflexion.reason说明类型不匹配原因
|
||||
@@ -157,7 +162,8 @@
|
||||
"conflict": true
|
||||
},
|
||||
"reflexion": {
|
||||
"reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,
|
||||
"reason": "该冲突类型的原因分析,如果是FACT就是存在事实冲突,分析该冲突原因,如果是TIME就是存在时间冲突,分析该冲突原因,如果是HYBRID,可以输出存在时间与事实的混合冲突再添加上原因分析,如果
|
||||
隐私审核打开的时候如果存在冲突,分析该冲突的原因
|
||||
不可以随意分配冲突类型以及原因,不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种",
|
||||
"solution": "该冲突类型的解决方案(不允许输出字段比如(statement、description、entity1_name、entity2_name、name、memory_verify、expired_at、conflict)等类似这种)"
|
||||
},
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""自我反思工具模块
|
||||
|
||||
本模块提供自我反思引擎的核心功能,包括:
|
||||
- 记忆冲突判定
|
||||
- 反思执行
|
||||
- 记忆更新
|
||||
|
||||
从 app.core.memory.src.data_config_api 迁移而来。
|
||||
"""
|
||||
|
||||
from app.core.memory.utils.self_reflexion_utils.evaluate import conflict
|
||||
from app.core.memory.utils.self_reflexion_utils.reflexion import reflexion
|
||||
from app.core.memory.utils.self_reflexion_utils.self_reflexion import self_reflexion
|
||||
|
||||
__all__ = ["conflict", "reflexion", "self_reflexion"]
|
||||
@@ -1,52 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""记忆冲突判定模块
|
||||
|
||||
本模块提供记忆冲突判定功能,使用LLM判断记忆数据中是否存在冲突。
|
||||
从 app.core.memory.src.data_config_api.evaluate 迁移而来。
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, List
|
||||
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.prompt.template_render import render_evaluate_prompt
|
||||
from app.db import get_db_context
|
||||
from app.schemas.memory_storage_schema import ConflictResultSchema
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
async def conflict(evaluate_data: List[Any]) -> List[Any]:
|
||||
"""
|
||||
Evaluates memory conflict using the evaluate.jinja2 template.
|
||||
|
||||
Args:
|
||||
evaluate_data: 反思数据列表。
|
||||
Returns:
|
||||
冲突记忆列表(JSON 数组)。
|
||||
"""
|
||||
from app.core.memory.utils.config import definitions as config_defs
|
||||
with get_db_context() as db:
|
||||
factory = MemoryClientFactory(db)
|
||||
client = factory.get_llm_client(config_defs.SELECTED_LLM_ID)
|
||||
rendered_prompt = await render_evaluate_prompt(evaluate_data, ConflictResultSchema)
|
||||
messages = [{"role": "user", "content": rendered_prompt}]
|
||||
print(f"提示词长度: {len(rendered_prompt)}")
|
||||
print(f"====== 冲突判定开始 ======\n")
|
||||
start_time = time.time()
|
||||
response = await client.response_structured(messages, ConflictResultSchema)
|
||||
end_time = time.time()
|
||||
print(f"冲突判定耗时: {end_time - start_time} 秒")
|
||||
print(f"冲突判定原始输出:(type={type(response)})\n{response}")
|
||||
|
||||
if not response:
|
||||
logging.error("LLM 冲突判定输出解析失败,返回空列表以继续流程。")
|
||||
return []
|
||||
try:
|
||||
return [response.model_dump()] if isinstance(response, BaseModel) else [response]
|
||||
except Exception:
|
||||
try:
|
||||
return [response.dict()]
|
||||
except Exception:
|
||||
logging.warning("无法标准化冲突判定返回类型,尝试直接封装为列表。")
|
||||
return [response]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user