Compare commits

..

1 Commits

Author SHA1 Message Date
yingzhao
543be4d610 Merge pull request #141 from SuanmoSuanyangTechnology/fix/web_zy
Fix/web zy
2026-01-17 11:43:58 +08:00
974 changed files with 33486 additions and 108532 deletions

4
.gitignore vendored
View File

@@ -21,7 +21,6 @@ examples/
# Temporary outputs
.DS_Store
.hypothesis/
time.log
celerybeat-schedule.db
search_results.json
@@ -29,7 +28,6 @@ search_results.json
api/migrations/versions
tmp
files
powers/
# Exclude dep files
huggingface.co/
@@ -37,5 +35,3 @@ nltk_data/
tika-server*.jar*
cl100k_base.tiktoken
libssl*.deb
sandbox/lib/seccomp_redbear/target

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (Using Redis as broker)
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
# JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here
@@ -334,13 +334,7 @@ step6: Log In to the Frontend Interface.
## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Community & Support
Join our community to ask questions, share your work, and connect with fellow developers.
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6)
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
## Acknowledgements & Community
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (使用Redis作为broker)
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
# JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -10,6 +10,7 @@ from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
@@ -20,7 +21,6 @@ pool = ConnectionPool.from_url(
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
@@ -29,8 +29,7 @@ async def get_redis_connection():
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
"""设置Redis键值
Args:
@@ -41,7 +40,7 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
@@ -51,7 +50,6 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
@@ -60,7 +58,6 @@ async def aio_redis_get(key: str):
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
@@ -69,7 +66,6 @@ async def aio_redis_delete(key: str):
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
@@ -82,10 +78,9 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
@@ -93,25 +88,25 @@ class RedisSubscriber:
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
@@ -132,7 +127,7 @@ class RedisSubscriber:
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
@@ -146,7 +141,7 @@ class RedisSubscriber:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
@@ -158,7 +153,7 @@ class RedisSubscriber:
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
@@ -168,33 +163,32 @@ class RedisSubscriber:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
@@ -202,6 +196,6 @@ class RedisPubSubManager:
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()

View File

@@ -1,10 +0,0 @@
"""
Cache 缓存模块
提供各种缓存功能的统一入口
"""
from .memory import InterestMemoryCache
__all__ = [
"InterestMemoryCache",
]

View File

@@ -1,12 +0,0 @@
"""
Memory 缓存模块
提供记忆系统相关的缓存功能
"""
from .interest_memory import InterestMemoryCache
from .activity_stats_cache import ActivityStatsCache
__all__ = [
"InterestMemoryCache",
"ActivityStatsCache",
]

View File

@@ -1,124 +0,0 @@
"""
Recent Activity Stats Cache
记忆提取活动统计缓存模块
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放
查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
ACTIVITY_STATS_CACHE_EXPIRE = 86400
class ActivityStatsCache:
"""记忆提取活动统计缓存类"""
PREFIX = "cache:memory:activity_stats"
@classmethod
def _get_key(cls, workspace_id: str) -> str:
"""生成 Redis key
Args:
workspace_id: 工作空间ID
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
@classmethod
async def set_activity_stats(
cls,
workspace_id: str,
stats: Dict[str, Any],
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
) -> bool:
"""设置记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
stats: 统计数据,格式:
{
"chunk_count": int,
"statements_count": int,
"triplet_entities_count": int,
"triplet_relations_count": int,
"temporal_count": int,
}
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(workspace_id)
payload = {
"stats": stats,
"generated_at": datetime.now().isoformat(),
"workspace_id": workspace_id,
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_activity_stats(
cls,
workspace_id: str,
) -> Optional[Dict[str, Any]]:
"""获取记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
统计数据字典,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(workspace_id)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}")
return payload
logger.info(f"活动统计缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_activity_stats(
cls,
workspace_id: str,
) -> bool:
"""删除记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
是否删除成功
"""
try:
key = cls._get_key(workspace_id)
result = await aio_redis.delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,122 +0,0 @@
"""
Interest Distribution Cache
兴趣分布缓存模块
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
"""
import json
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
INTEREST_CACHE_EXPIRE = 86400
class InterestMemoryCache:
"""兴趣分布缓存类"""
PREFIX = "cache:memory:interest_distribution"
@classmethod
def _get_key(cls, end_user_id: str, language: str) -> str:
"""生成 Redis key
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
@classmethod
async def set_interest_distribution(
cls,
end_user_id: str,
language: str,
data: List[Dict[str, Any]],
expire: int = INTEREST_CACHE_EXPIRE,
) -> bool:
"""设置用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(end_user_id, language)
payload = {
"data": data,
"generated_at": datetime.now().isoformat(),
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> Optional[List[Dict[str, Any]]]:
"""获取用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
兴趣分布列表,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
return payload.get("data")
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> bool:
"""删除用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
是否删除成功
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,60 +1,40 @@
import os
import platform
from datetime import timedelta
from urllib.parse import quote
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
from celery import Celery
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
celery_app = Celery(
"redbear_tasks",
broker=_broker_url,
backend=_backend_url,
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置
if platform.system() == 'Darwin':
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
@@ -72,65 +52,49 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# 结果过期时间
result_expires=3600, # 结果保存1小时
result_expires=3600, # 结果保存 1 小时
# 任务确认设置
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
},
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
implicit_emotions_update_schedule = crontab(
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
#构建定时任务配置
# 构建定时任务配置
beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,
@@ -148,16 +112,16 @@ beat_schedule_config = {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
"write-all-workspaces-memory": {
"task": "app.tasks.write_all_workspaces_memory_task",
"schedule": memory_increment_schedule,
"args": (),
},
"update-implicit-emotions-storage": {
"task": "app.tasks.update_implicit_emotions_storage",
"schedule": implicit_emotions_update_schedule,
"args": (),
},
}
# 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -3,12 +3,6 @@ Celery Worker 入口点
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
"""
from app.celery_app import celery_app
from app.core.logging_config import LoggingConfig, get_logger
# Initialize logging system for Celery worker
LoggingConfig.setup_logging()
logger = get_logger(__name__)
logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks

View File

@@ -1 +0,0 @@
"""Configuration module for application settings."""

View File

@@ -1,239 +0,0 @@
"""默认本体场景配置
本模块定义系统预设的本体场景和实体类型配置。
这些配置用于在工作空间创建时自动初始化默认场景。
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
"""
# 在线教育场景配置
ONLINE_EDUCATION_SCENE = {
"name_chinese": "在线教育",
"name_english": "Online Education",
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
"types": [
{
"name_chinese": "学生",
"name_english": "Student",
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
},
{
"name_chinese": "教师",
"name_english": "Teacher",
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
},
{
"name_chinese": "课程",
"name_english": "Course",
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
},
{
"name_chinese": "作业",
"name_english": "Assignment",
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
},
{
"name_chinese": "成绩",
"name_english": "Grade",
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
},
{
"name_chinese": "考试",
"name_english": "Exam",
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
},
{
"name_chinese": "教室",
"name_english": "Classroom",
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
},
{
"name_chinese": "学科",
"name_english": "Subject",
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
},
{
"name_chinese": "教材",
"name_english": "Textbook",
"description_chinese": "教学使用的书籍或资料包含书名、作者、出版社、ISBN等属性",
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
},
{
"name_chinese": "班级",
"name_english": "Class",
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
},
{
"name_chinese": "学期",
"name_english": "Semester",
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
},
{
"name_chinese": "课时",
"name_english": "Class Hour",
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
},
{
"name_chinese": "教学计划",
"name_english": "Teaching Plan",
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
}
]
}
# 情感陪伴场景配置
EMOTIONAL_COMPANION_SCENE = {
"name_chinese": "情感陪伴",
"name_english": "Emotional Companion",
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
"types": [
{
"name_chinese": "用户",
"name_english": "User",
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
},
{
"name_chinese": "情绪",
"name_english": "Emotion",
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
},
{
"name_chinese": "活动",
"name_english": "Activity",
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
},
{
"name_chinese": "对话",
"name_english": "Conversation",
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
},
{
"name_chinese": "兴趣爱好",
"name_english": "Hobby",
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
},
{
"name_chinese": "日常事件",
"name_english": "Daily Event",
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
},
{
"name_chinese": "关系",
"name_english": "Relationship",
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
},
{
"name_chinese": "回忆",
"name_english": "Memory",
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
},
{
"name_chinese": "地点",
"name_english": "Location",
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
},
{
"name_chinese": "时间节点",
"name_english": "Time Point",
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
"description_english": "Important time markers, including attributes such as date, event, and significance"
},
{
"name_chinese": "目标",
"name_english": "Goal",
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
},
{
"name_chinese": "成就",
"name_english": "Achievement",
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
}
]
}
# 导出默认场景列表
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
"""获取场景名称(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景名称
"""
if language == "en":
return scene_config.get("name_english", scene_config.get("name_chinese"))
return scene_config.get("name_chinese")
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
"""获取场景描述(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景描述
"""
if language == "en":
return scene_config.get("description_english", scene_config.get("description_chinese"))
return scene_config.get("description_chinese")
def get_type_name(type_config: dict, language: str = "zh") -> str:
"""获取类型名称(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型名称
"""
if language == "en":
return type_config.get("name_english", type_config.get("name_chinese"))
return type_config.get("name_chinese")
def get_type_description(type_config: dict, language: str = "zh") -> str:
"""获取类型描述(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型描述
"""
if language == "en":
return type_config.get("description_english", type_config.get("description_chinese"))
return type_config.get("description_chinese")

View File

@@ -1,249 +0,0 @@
# -*- coding: utf-8 -*-
"""默认本体场景初始化器
本模块提供默认本体场景和类型的自动初始化功能。
在工作空间创建时,自动添加预设的本体场景和实体类型。
Classes:
DefaultOntologyInitializer: 默认本体场景初始化器
"""
import logging
from typing import List, Optional, Tuple
from uuid import UUID
from sqlalchemy.orm import Session
from app.config.default_ontology_config import (
DEFAULT_SCENES,
get_scene_name,
get_scene_description,
get_type_name,
get_type_description,
)
from app.core.logging_config import get_business_logger
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
class DefaultOntologyInitializer:
"""默认本体场景初始化器
负责在工作空间创建时自动初始化默认的本体场景和类型。
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
Attributes:
db: 数据库会话
scene_repo: 场景Repository
class_repo: 类型Repository
logger: 业务日志记录器
"""
def __init__(self, db: Session):
"""初始化
Args:
db: 数据库会话
"""
self.db = db
self.scene_repo = OntologySceneRepository(db)
self.class_repo = OntologyClassRepository(db)
self.logger = get_business_logger()
def initialize_default_scenes(
self,
workspace_id: UUID,
language: str = "zh"
) -> Tuple[bool, str]:
"""为工作空间初始化默认场景
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
如果创建失败,记录错误日志但不抛出异常。
Args:
workspace_id: 工作空间ID
language: 语言类型 ("zh""en"),默认为 "zh"
Returns:
Tuple[bool, str]: (是否成功, 错误信息)
"""
try:
self.logger.info(
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
)
scenes_created = 0
total_types_created = 0
# 遍历默认场景配置
for scene_config in DEFAULT_SCENES:
scene_name = get_scene_name(scene_config, language)
# 创建场景及其类型
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
if scene_id:
scenes_created += 1
# 统计类型数量
types_count = len(scene_config.get("types", []))
total_types_created += types_count
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene_id}, types_count={types_count}, language={language}"
)
else:
self.logger.warning(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}"
)
# 记录总体结果
self.logger.info(
f"默认场景初始化完成 - workspace_id={workspace_id}, "
f"language={language}, scenes_created={scenes_created}, "
f"total_types_created={total_types_created}"
)
# 如果至少创建了一个场景,视为成功
if scenes_created > 0:
return True, ""
else:
error_msg = "所有默认场景创建失败"
self.logger.error(
f"默认场景初始化失败 - workspace_id={workspace_id}, "
f"language={language}, error={error_msg}"
)
return False, error_msg
except Exception as e:
error_msg = f"默认场景初始化异常: {str(e)}"
self.logger.error(
f"默认场景初始化异常 - workspace_id={workspace_id}, "
f"language={language}, error={str(e)}",
exc_info=True
)
return False, error_msg
def _create_scene_with_types(
self,
workspace_id: UUID,
scene_config: dict,
language: str = "zh"
) -> Optional[UUID]:
"""创建场景及其类型
Args:
workspace_id: 工作空间ID
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
Optional[UUID]: 创建的场景ID失败返回None
"""
try:
scene_name = get_scene_name(scene_config, language)
scene_description = get_scene_description(scene_config, language)
# 检查是否已存在同名场景(支持向后兼容)
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
if existing_scene:
self.logger.info(
f"场景已存在,跳过创建 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
f"language={language}"
)
return None
# 创建场景记录,设置 is_system_default=true
scene_data = {
"scene_name": scene_name,
"scene_description": scene_description
}
scene = self.scene_repo.create(scene_data, workspace_id)
# 设置系统默认标识
scene.is_system_default = True
self.db.flush()
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
)
# 批量创建类型
types_config = scene_config.get("types", [])
types_created = self._batch_create_types(scene.scene_id, types_config, language)
self.logger.info(
f"场景类型创建完成 - scene_id={scene.scene_id}, "
f"types_created={types_created}/{len(types_config)}, language={language}"
)
return scene.scene_id
except Exception as e:
scene_name = get_scene_name(scene_config, language)
self.logger.error(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
exc_info=True
)
return None
def _batch_create_types(
self,
scene_id: UUID,
types_config: List[dict],
language: str = "zh"
) -> int:
"""批量创建实体类型
Args:
scene_id: 场景ID
types_config: 类型配置列表
language: 语言类型 ("zh""en")
Returns:
int: 成功创建的类型数量
"""
created_count = 0
for type_config in types_config:
try:
type_name = get_type_name(type_config, language)
type_description = get_type_description(type_config, language)
# 创建类型数据
class_data = {
"class_name": type_name,
"class_description": type_description
}
# 创建类型
ontology_class = self.class_repo.create(class_data, scene_id)
# 设置系统默认标识
ontology_class.is_system_default = True
self.db.flush()
created_count += 1
self.logger.debug(
f"类型创建成功 - class_name={type_name}, "
f"class_id={ontology_class.class_id}, "
f"scene_id={scene_id}, is_system_default=True, language={language}"
)
except Exception as e:
type_name = get_type_name(type_config, language)
self.logger.warning(
f"单个类型创建失败,继续创建其他类型 - "
f"class_name={type_name}, scene_id={scene_id}, "
f"language={language}, error={str(e)}"
)
# 继续创建其他类型
continue
return created_count

View File

@@ -14,23 +14,18 @@ from . import (
emotion_config_controller,
emotion_controller,
file_controller,
file_storage_controller,
home_page_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
mcp_market_controller,
mcp_market_config_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_episodic_controller,
memory_explicit_controller,
memory_forget_controller,
memory_perceptual_controller,
memory_reflection_controller,
memory_short_term_controller,
memory_storage_controller,
memory_working_controller,
model_controller,
multi_agent_controller,
prompt_optimizer_controller,
@@ -43,9 +38,12 @@ from . import (
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
workspace_controller,
ontology_controller,
skill_controller
memory_forget_controller,
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
)
# 创建管理端 API 路由器
@@ -62,8 +60,6 @@ manager_router.include_router(model_controller.router)
manager_router.include_router(file_controller.router)
manager_router.include_router(document_controller.router)
manager_router.include_router(knowledge_controller.router)
manager_router.include_router(mcp_market_controller.router)
manager_router.include_router(mcp_market_config_controller.router)
manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
@@ -80,6 +76,7 @@ manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
@@ -91,8 +88,5 @@ manager_router.include_router(home_page_controller.router)
manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,14 +1,13 @@
import uuid
from typing import Optional, Annotated
import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi import APIRouter, Depends, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success, fail
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
@@ -18,12 +17,10 @@ from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service
router = APIRouter(prefix="/apps", tags=["Apps"])
@@ -67,7 +64,7 @@ def list_apps(
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
@@ -396,10 +393,10 @@ async def draft_run(
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import AgentRunService
from app.services.draft_run_service import DraftRunService
service = AppService(db)
draft_service = AgentRunService(db)
draft_service = DraftRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
@@ -457,8 +454,7 @@ async def draft_run(
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
files=payload.files # 传递多模态文件
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -479,13 +475,12 @@ async def draft_run(
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id),
"has_variables": bool(payload.variables),
"has_files": bool(payload.files)
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
@@ -495,8 +490,7 @@ async def draft_run(
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
files=payload.files # 传递多模态文件
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
@@ -667,11 +661,6 @@ async def draft_run(
data=result,
msg="工作流任务执行成功"
)
else:
return fail(
msg="未知应用类型",
code=422
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@@ -789,8 +778,8 @@ async def draft_run_compare(
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
@@ -804,8 +793,7 @@ async def draft_run_compare(
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60,
files=payload.files
timeout=payload.timeout or 60
):
yield event
@@ -820,8 +808,8 @@ async def draft_run_compare(
)
# 非流式返回
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
@@ -835,8 +823,7 @@ async def draft_run_compare(
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60,
files=payload.files
timeout=payload.timeout or 60
)
logger.info(
@@ -880,133 +867,3 @@ async def update_workflow_config(
workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.get("/{app_id}/workflow/export")
@cur_workspace_access_guard()
async def export_workflow_config(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""导出工作流配置为YAML文件"""
workflow_service = WorkflowService(db)
return success(data={
"content": workflow_service.export_workflow_dsl(app_id=app_id),
})
@router.post("/workflow/import")
@cur_workspace_access_guard()
async def import_workflow_config(
file: UploadFile = File(...),
platform: str = Form(...),
app_id: str = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从YAML内容导入工作流配置"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
raw_text = (await file.read()).decode("utf-8")
import_service = WorkflowImportService(db)
config = yaml.safe_load(raw_text)
result = await import_service.upload_config(platform, config)
return success(data=result)
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
import_service = WorkflowImportService(db)
app = await import_service.save_workflow(
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
temp_id=data.temp_id,
name=data.name,
description=data.description,
)
return success(data=app_schema.App.model_validate(app))
@router.get("/{app_id}/statistics", summary="应用统计数据")
@cur_workspace_access_guard()
def get_app_statistics(
app_id: uuid.UUID,
start_date: int,
end_date: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取应用统计数据
Args:
app_id: 应用ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
- daily_conversations: 每日会话数统计
- total_conversations: 总会话数
- daily_new_users: 每日新增用户数
- total_new_users: 总新增用户数
- daily_api_calls: 每日API调用次数
- total_api_calls: 总API调用次数
- daily_tokens: 每日token消耗
- total_tokens: 总token消耗
"""
workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db)
result = stats_service.get_app_statistics(
app_id=app_id,
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
return success(data=result)
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
@cur_workspace_access_guard()
def get_workspace_api_statistics(
start_date: int,
end_date: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取工作空间API调用统计
Args:
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
每日统计数据列表,每项包含:
- date: 日期
- total_calls: 当日总调用次数
- app_calls: 当日应用调用次数
- service_calls: 当日服务调用次数
"""
workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db)
result = stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
return success(data=result)

View File

@@ -61,7 +61,6 @@ async def login_for_access_token(
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -7,13 +7,11 @@ Routes:
GET /memory/config/emotion - 获取情绪引擎配置
POST /memory/config/emotion - 更新情绪引擎配置
"""
import uuid
from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional, Union
from typing import Optional
from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success
from app.dependencies import get_current_user
@@ -22,7 +20,6 @@ from app.schemas.response_schema import ApiResponse
from app.services.emotion_config_service import EmotionConfigService
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -35,11 +32,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型"""
config_id: UUID = Field(..., description="配置ID")
config_id: int = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
config_id: int = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -48,7 +45,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: UUID|int = Query(..., description="配置ID"),
config_id: int = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -81,7 +78,7 @@ def get_emotion_config(
f"用户 {current_user.username} 请求获取情绪配置",
extra={"config_id": config_id}
)
config_id=resolve_config_id(config_id, db)
# 初始化服务
config_service = EmotionConfigService(db)
@@ -160,7 +157,6 @@ def update_emotion_config(
}
}
"""
config.config_id=resolve_config_id(config.config_id, db)
try:
api_logger.info(
f"用户 {current_user.username} 请求更新情绪配置",

View File

@@ -11,7 +11,6 @@ Routes:
"""
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user, get_db
@@ -25,7 +24,7 @@ from app.schemas.emotion_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.emotion_analytics_service import EmotionAnalyticsService
from fastapi import APIRouter, Depends, HTTPException, status,Header
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
# 获取API专用日志器
@@ -46,51 +45,45 @@ emotion_service = EmotionAnalyticsService()
@router.post("/tags", response_model=ApiResponse)
async def get_emotion_tags(
request: EmotionTagsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"start_date": request.start_date,
"end_date": request.end_date,
"limit": request.limit,
"language_type": language
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_tags(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
emotion_type=request.emotion_type,
start_date=request.start_date,
end_date=request.end_date,
limit=request.limit,
language=language
limit=request.limit
)
api_logger.info(
"情绪标签统计获取成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", []))
}
)
return success(data=data, msg="情绪标签获取成功")
except Exception as e:
api_logger.error(
f"获取情绪标签统计失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -103,44 +96,40 @@ async def get_emotion_tags(
@router.post("/wordcloud", response_model=ApiResponse)
async def get_emotion_wordcloud(
request: EmotionWordcloudRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
emotion_type=request.emotion_type,
limit=request.limit
)
api_logger.info(
"情绪词云数据获取成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"total_keywords": data.get("total_keywords", 0)
}
)
return success(data=data, msg="情绪词云获取成功")
except Exception as e:
api_logger.error(
f"获取情绪词云数据失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -153,52 +142,48 @@ async def get_emotion_wordcloud(
@router.post("/health", response_model=ApiResponse)
async def get_emotion_health(
request: EmotionHealthRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
# 验证时间范围参数
if request.time_range not in ["7d", "30d", "90d"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="时间范围参数无效,必须是 7d、30d 或 90d"
)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"time_range": request.time_range
}
)
# 调用服务层
data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
time_range=request.time_range
)
api_logger.info(
"情绪健康指数获取成功",
extra={
"end_user_id": request.end_user_id,
"health_score": data.get("health_score") or 0,
"group_id": request.group_id,
"health_score": data.get("health_score", 0),
"level": data.get("level", "未知")
}
)
return success(data=data, msg="情绪健康指数获取成功")
except HTTPException:
raise
except Exception as e:
api_logger.error(
f"获取情绪健康指数失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -208,112 +193,63 @@ async def get_emotion_health(
# @router.post("/check-data", response_model=ApiResponse)
# async def check_emotion_data_exists(
# request: EmotionSuggestionsRequest,
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user),
# ):
# """检查用户情绪建议数据是否存在
# Args:
# request: 包含 end_user_id
# db: 数据库会话
# current_user: 当前用户
# Returns:
# 数据存在状态
# """
# try:
# api_logger.info(
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
# extra={"end_user_id": request.end_user_id}
# )
# # 从数据库获取建议
# data = await emotion_service.get_cached_suggestions(
# end_user_id=request.end_user_id,
# db=db
# )
# if data is None:
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
# return fail(
# BizCode.NOT_FOUND,
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
# {"exists": False}
# )
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
# return success(data={"exists": True}, msg="情绪建议数据已存在")
# except Exception as e:
# api_logger.error(
# f"检查情绪建议数据失败: {str(e)}",
# extra={"end_user_id": request.end_user_id},
# exc_info=True
# )
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"检查情绪建议数据失败: {str(e)}"
# )
@router.post("/suggestions", response_model=ApiResponse)
async def get_emotion_suggestions(
request: EmotionSuggestionsRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议(从数据库读取)
"""获取个性化情绪建议(从缓存读取)
Args:
request: 包含 end_user_id 和可选的 config_id
request: 包含 group_id 和可选的 config_id
db: 数据库会话
current_user: 当前用户
Returns:
的个性化情绪建议响应
存的个性化情绪建议响应
"""
try:
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议",
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"config_id": request.config_id
}
)
# 从数据库获取建议
# 从缓存获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
db=db
)
if data is None:
# 缓存不存在或已过期
api_logger.info(
f"用户 {request.end_user_id} 的建议数据不存在",
extra={"end_user_id": request.end_user_id}
f"用户 {request.group_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id}
)
return success(
data={"exists": False},
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
return fail(
BizCode.RESOURCE_NOT_FOUND,
"建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议",
None
)
api_logger.info(
"个性化建议获取成功",
"个性化建议获取成功(缓存)",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功")
return success(data=data, msg="个性化建议获取成功(缓存)")
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -325,62 +261,83 @@ async def get_emotion_suggestions(
@router.post("/generate_suggestions", response_model=ApiResponse)
async def generate_emotion_suggestions(
request: EmotionGenerateSuggestionsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""生成个性化情绪建议调用LLM并保存到数据库
"""生成个性化情绪建议调用LLM并缓存
Args:
request: 包含 end_user_id
request: 包含 group_id、可选的 config_id 和 force_refresh
db: 数据库会话
current_user: 当前用户
Returns:
新生成的个性化情绪建议响应
"""
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
# 验证 config_id如果提供
# 获取终端用户关联的配置
config_id = request.config_id
if config_id is None:
# 如果没有提供 config_id尝试获取用户关联的配置
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(request.group_id, db)
config_id = connected_config.get("memory_config_id")
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
else:
# 如果提供了 config_id验证其有效性
from app.services.memory_config_service import MemoryConfigService
try:
config_service = MemoryConfigService(db)
config = config_service.get_config_by_id(config_id)
if not config:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
except Exception as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
api_logger.info(
f"用户 {current_user.username} 请求生成个性化情绪建议",
extra={
"end_user_id": request.end_user_id
"group_id": request.group_id,
"config_id": config_id
}
)
# 调用服务层生成建议
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.end_user_id,
db=db,
language=language
)
# 保存到数据库
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
end_user_id=request.group_id,
db=db
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.group_id,
suggestions_data=data,
db=db,
expires_hours=24
)
api_logger.info(
"个性化建议生成成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议生成成功")
except Exception as e:
api_logger.error(
f"生成个性化建议失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成个性化建议失败: {str(e)}"
)
)

View File

@@ -1,634 +0,0 @@
"""
File storage controller module.
This module provides API endpoints for file storage operations using the
configurable storage backend. It is a new controller that does not modify
the existing file_controller.py.
Routes:
POST /storage/files - Upload a file
GET /storage/files/{file_id} - Download a file
DELETE /storage/files/{file_id} - Delete a file
"""
import os
import uuid
from typing import Any
from fastapi import APIRouter, Depends, File, HTTPException, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.storage import LocalStorage
from app.core.storage.url_signer import generate_signed_url, verify_signed_url
from app.core.storage_exceptions import (
StorageDeleteError,
StorageUploadError,
)
from app.db import get_db
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
from app.models.file_metadata_model import FileMetadata
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.file_storage_service import (
FileStorageService,
generate_file_key,
get_file_storage_service,
)
api_logger = get_api_logger()
router = APIRouter(
prefix="/storage",
tags=["storage"]
)
@router.post("/files", response_model=ApiResponse)
async def upload_file(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Upload a file to the configured storage backend.
"""
tenant_id = current_user.tenant_id
workspace_id = current_user.current_workspace_id
api_logger.info(
f"Storage upload request: tenant_id={tenant_id}, workspace_id={workspace_id}, "
f"filename={file.filename}, username={current_user.username}"
)
# Read file contents
contents = await file.read()
file_size = len(contents)
# Validate file size
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
)
# Extract file extension
_, file_extension = os.path.splitext(file.filename)
file_ext = file_extension.lower()
# Generate file_id and file_key
file_id = uuid.uuid4()
file_key = generate_file_key(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
)
# Create file metadata record with pending status
file_metadata = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=file.filename,
file_ext=file_ext,
file_size=file_size,
content_type=file.content_type,
status="pending",
)
db.add(file_metadata)
db.commit()
db.refresh(file_metadata)
# Upload file to storage backend
try:
await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=contents,
content_type=file.content_type,
)
# Update status to completed
file_metadata.status = "completed"
db.commit()
api_logger.info(f"File uploaded to storage: file_key={file_key}")
except StorageUploadError as e:
# Update status to failed
file_metadata.status = "failed"
db.commit()
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File storage failed: {str(e)}"
)
api_logger.info(f"File upload successful: {file.filename} (file_id: {file_id})")
return success(
data={"file_id": str(file_id), "file_key": file_key},
msg="File upload successful"
)
@router.post("/share/files", response_model=ApiResponse)
async def upload_file_with_share_token(
file: UploadFile = File(...),
db: Session = Depends(get_db),
share_data: ShareTokenData = Depends(get_share_user_id),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Upload a file to the configured storage backend using share_token authentication.
"""
from app.services.release_share_service import ReleaseShareService
from app.models.app_model import App
from app.models.workspace_model import Workspace
# Get share and release info from share_token
service = ReleaseShareService(db)
share_info = service.get_shared_release_info(share_token=share_data.share_token)
# Get share object to access app_id
share = service.repo.get_by_share_token(share_data.share_token)
if not share:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Shared app not found"
)
# Get app to access workspace_id
app = db.query(App).filter(
App.id == share.app_id,
App.is_active.is_(True)
).first()
if not app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="App not found"
)
# Get workspace to access tenant_id
workspace = db.query(Workspace).filter(
Workspace.id == app.workspace_id
).first()
if not workspace:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workspace not found"
)
tenant_id = workspace.tenant_id
workspace_id = app.workspace_id
api_logger.info(
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
f"filename={file.filename}, share_token={share_data.share_token}"
)
# Read file contents
contents = await file.read()
file_size = len(contents)
# Validate file size
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
)
# Extract file extension
_, file_extension = os.path.splitext(file.filename)
file_ext = file_extension.lower()
# Generate file_id and file_key
file_id = uuid.uuid4()
file_key = generate_file_key(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
)
# Create file metadata record with pending status
file_metadata = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=file.filename,
file_ext=file_ext,
file_size=file_size,
content_type=file.content_type,
status="pending",
)
db.add(file_metadata)
db.commit()
db.refresh(file_metadata)
# Upload file to storage backend
try:
await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=contents,
content_type=file.content_type,
)
# Update status to completed
file_metadata.status = "completed"
db.commit()
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
except StorageUploadError as e:
# Update status to failed
file_metadata.status = "failed"
db.commit()
api_logger.error(f"Storage upload failed (share): {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File storage failed: {str(e)}"
)
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
return success(
data={"file_id": str(file_id), "file_key": file_key},
msg="File upload successful"
)
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Download a file from the configured storage backend.
"""
api_logger.info(f"Storage download request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
api_logger.info(f"Serving local file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except FileNotFoundError:
api_logger.warning(f"File not found in remote storage: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found in storage"
)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
@router.delete("/files/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Delete a file from the configured storage backend.
"""
api_logger.info(
f"Storage delete request: file_id={file_id}, username={current_user.username}"
)
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
file_key = file_metadata.file_key
# Delete file from storage
try:
deleted = await storage_service.delete_file(file_key)
if deleted:
api_logger.info(f"File deleted from storage: file_key={file_key}")
else:
api_logger.info(f"File did not exist in storage: file_key={file_key}")
except StorageDeleteError as e:
api_logger.error(f"Storage delete failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete file from storage: {str(e)}"
)
# Delete database record
try:
db.delete(file_metadata)
db.commit()
api_logger.info(f"File record deleted from database: file_id={file_id}")
except Exception as e:
api_logger.error(f"Database delete failed: {e}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete file record: {str(e)}"
)
return success(msg="File deleted successfully")
@router.get("/files/{file_id}/url", response_model=ApiResponse)
async def get_file_url(
file_id: uuid.UUID,
expires: int = None,
permanent: bool = False,
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Get an access URL for a file (no authentication required).
Args:
file_id: The UUID of the file.
expires: URL validity period in seconds (default from FILE_URL_EXPIRES env).
permanent: If True, return a permanent URL without expiration.
db: Database session.
storage_service: The file storage service.
Returns:
ApiResponse with the access URL.
"""
if expires is None:
expires = settings.FILE_URL_EXPIRES
api_logger.info(f"Get file URL request: file_id={file_id}, expires={expires}, permanent={permanent}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
try:
if permanent:
# Generate permanent URL (no expiration check)
server_url = settings.FILE_LOCAL_SERVER_URL
url = f"{server_url}/storage/permanent/{file_id}"
return success(
data={
"url": url,
"expires_in": None,
"permanent": True,
"file_name": file_metadata.file_name,
},
msg="Permanent file URL generated successfully"
)
if isinstance(storage, LocalStorage):
# For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires)
else:
# For remote storage (OSS/S3), get presigned URL
url = await storage_service.get_file_url(file_key, expires=expires)
api_logger.info(f"Generated file URL: file_id={file_id}")
return success(
data={
"url": url,
"expires_in": expires,
"permanent": False,
"file_name": file_metadata.file_name,
},
msg="File URL generated successfully"
)
except Exception as e:
api_logger.error(f"Failed to generate file URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate file URL: {str(e)}"
)
@router.get("/public/{file_id}", response_model=Any)
async def public_download_file(
file_id: uuid.UUID,
expires: int = 0,
signature: str = "",
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Public file download endpoint with signature verification.
This endpoint allows downloading files without authentication,
but requires a valid signature and non-expired timestamp.
Args:
file_id: The UUID of the file.
expires: Expiration timestamp.
signature: HMAC signature for verification.
db: Database session.
storage_service: The file storage service.
Returns:
FileResponse for the requested file.
"""
api_logger.info(f"Public download request: file_id={file_id}")
# Verify signature
is_valid, error_msg = verify_signed_url(str(file_id), expires, signature)
if not is_valid:
api_logger.warning(f"Invalid signed URL: file_id={file_id}, error={error_msg}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=error_msg
)
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
api_logger.info(f"Serving public file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
# For remote storage, redirect to presigned URL
try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
@router.get("/permanent/{file_id}", response_model=Any)
async def permanent_download_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Permanent file download endpoint (no expiration, no signature required).
This endpoint allows downloading files without authentication or expiration.
Use with caution as URLs are permanently accessible.
Args:
file_id: The UUID of the file.
db: Database session.
storage_service: The file storage service.
Returns:
FileResponse for the requested file.
"""
api_logger.info(f"Permanent download request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
api_logger.info(f"Serving permanent file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
# For remote storage, redirect to presigned URL with long expiration
try:
# Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)

View File

@@ -122,52 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def check_user_data_exists(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
检查用户画像数据是否存在
Args:
end_user_id: 目标用户ID
Returns:
数据存在状态
"""
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
try:
# Validate inputs
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return success(
data={"exists": False},
msg="画像数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
return success(data={"exists": True}, msg="画像数据已存在")
except Exception as e:
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
end_user_id: str,
user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -179,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
@@ -188,21 +146,25 @@ async def get_preference_tags(
Returns:
List of preference tags from cache
"""
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract preferences from cache
preferences = cached_profile.get("preferences", [])
@@ -230,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
end_user_id: str,
user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -249,42 +211,46 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
include_history: Whether to include historical trend data (ignored for cached data)
Returns:
Four-dimension personality portrait from cache
"""
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
return handle_implicit_memory_error(e, "四维画像获取", user_id)
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
end_user_id: str,
user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -293,42 +259,46 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
include_trends: Whether to include trend analysis data (ignored for cached data)
Returns:
Interest area distribution from cache
"""
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
end_user_id: str,
user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -339,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
@@ -347,21 +317,25 @@ async def get_behavior_habits(
Returns:
List of behavioral habits from cache
"""
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract habits from cache
habits = cached_profile.get("habits", [])
@@ -394,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
return handle_implicit_memory_error(e, "行为习惯获取", user_id)

View File

@@ -9,16 +9,13 @@ from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.rag.common import settings
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.llm.chat_model import Base
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success, fail
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import knowledge_model
@@ -487,99 +484,3 @@ async def rebuild_knowledge_graph(
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/check/yuque/auth", response_model=ApiResponse)
async def check_yuque_auth(
yuque_user_id: str,
yuque_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check yuque auth info
"""
api_logger.info(f"check yuque auth info, username: {current_user.username}")
try:
api_client = YuqueAPIClient(
user_id=yuque_user_id,
token=yuque_token
)
async with api_client as client:
repos = await client.get_user_repos()
if repos:
return success(msg="Successfully auth yuque info")
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth yuque info failed: {str(e)}")
raise
@router.get("/check/feishu/auth", response_model=ApiResponse)
async def check_feishu_auth(
feishu_app_id: str,
feishu_app_secret: str,
feishu_folder_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check feishu auth info
"""
api_logger.info(f"check feishu auth info, username: {current_user.username}")
try:
api_client = FeishuAPIClient(
app_id=feishu_app_id,
app_secret=feishu_app_secret
)
async with api_client as client:
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
if files:
return success(msg="Successfully auth feishu info")
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth feishu info failed: {str(e)}")
raise
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
async def sync_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
sync knowledge base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 2. sync knowledge
# from app.tasks import sync_knowledge_for_kb
# sync_knowledge_for_kb(kb_id)
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -1,395 +0,0 @@
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
import requests
from sqlalchemy import or_
from sqlalchemy.orm import Session
from modelscope.hub.errors import raise_for_http_status
from modelscope.hub.mcp_api import MCPApi
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import mcp_market_config_model
from app.models.user_model import User
from app.schemas import mcp_market_config_schema
from app.schemas.response_schema import ApiResponse
from app.services import mcp_market_config_service
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/mcp_market_configs",
tags=["mcp_market_configs"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/mcp_servers", response_model=ApiResponse)
async def get_mcp_servers(
mcp_market_config_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + mcp server list
"""
api_logger.info(
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 3. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
api.login(token)
body = {
'filter': {},
'page_number': page,
'page_size': pagesize,
'search': keywords
}
try:
cookies = api.get_cookies(token)
r = api.session.put(
url=api.mcp_base_url,
headers=api.builder_headers(api.headers),
json=body,
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 4. Return structured response
result = {
"items": mcp_server_list,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=result, msg="Query of mcp servers list successful")
@router.get("/operational_mcp_servers", response_model=ApiResponse)
async def get_operational_mcp_servers(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the operational mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + operational mcp server list
"""
api_logger.info(
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Execute paged query
api = MCPApi()
token = db_mcp_market_config.token
api.login(token)
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
r = api.session.get(url, headers=headers, cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get operational MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 3. Return structured response
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
@router.get("/mcp_server", response_model=ApiResponse)
async def get_mcp_server(
mcp_market_config_id: uuid.UUID,
server_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Get detailed information for a specific MCP Server
"""
api_logger.info(
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
# 2. Get detailed information for a specific MCP Server
api = MCPApi()
token = db_mcp_market_config.token
api.login(token)
result = api.get_mcp_server(server_id=server_id)
return success(data=result, msg="Query of mcp servers list successful")
@router.post("/mcp_market_config", response_model=ApiResponse)
async def create_mcp_market_config(
create_data: mcp_market_config_schema.McpMarketConfigCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create mcp market config
"""
api_logger.info(
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
# 1. Check if the mcp market name already exists
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
if db_mcp_market_config_exist:
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
)
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
api_logger.info(
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="The mcp market config has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
raise
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
async def get_mcp_market_config(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market config information based on mcp_market_config_id
"""
api_logger.info(
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
try:
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="Successfully obtained mcp market config information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
async def get_mcp_market_config_by_mcp_market_id(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market config information based on mcp_market_id
"""
api_logger.info(
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or access is denied"
)
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="Successfully obtained mcp market config information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
async def update_mcp_market_config(
mcp_market_config_id: uuid.UUID,
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# 1. Check if the mcp market config exists
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
# 2. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
update_dict = update_data.dict(exclude_unset=True)
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_mcp_market_config, field):
old_value = getattr(db_mcp_market_config, field)
if old_value != value:
# update value
setattr(db_mcp_market_config, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 3. Save to database
try:
db.commit()
db.refresh(db_mcp_market_config)
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
except Exception as e:
db.rollback()
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The mcp market config update failed: {str(e)}"
)
# 4. Return the updated mcp market config
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="The mcp market config information updated successfully")
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
async def delete_mcp_market_config(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete mcp market config
"""
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
try:
# 1. Check whether the mcp market config exists
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market config does not exist or you do not have permission to access it"
)
# 2. Deleting mcp market config
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
return success(msg="The mcp market config has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise

View File

@@ -1,262 +0,0 @@
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import mcp_market_model
from app.models.user_model import User
from app.schemas import mcp_market_schema
from app.schemas.response_schema import ApiResponse
from app.services import mcp_market_service
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/mcp_markets",
tags=["mcp_markets"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/mcp_markets", response_model=ApiResponse)
async def get_mcp_markets(
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the mcp markets list in pages
- Support keyword search for name,description
- Support dynamic sorting
- Return paging metadata + mcp_market list
"""
api_logger.info(
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = []
# Keyword search (fuzzy matching of mcp market name,description)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(
or_(
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
)
)
# 3. Execute paged query
try:
api_logger.debug("Start executing mcp market paging query")
total, items = mcp_market_service.get_mcp_markets_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"mcp market query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
@router.post("/mcp_market", response_model=ApiResponse)
async def create_mcp_market(
create_data: mcp_market_schema.McpMarketCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create mcp market
"""
api_logger.info(
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
# 1. Check if the mcp market name already exists
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
if db_mcp_market_exist:
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market name already exists: {create_data.name}"
)
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
api_logger.info(
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="The mcp market has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
raise
@router.get("/{mcp_market_id}", response_model=ApiResponse)
async def get_mcp_market(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market information based on mcp_market_id
"""
api_logger.info(
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
try:
# 1. Query mcp market information from the database
api_logger.debug(f"Query mcp market: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or access is denied"
)
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="Successfully obtained mcp market information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise
@router.put("/{mcp_market_id}", response_model=ApiResponse)
async def update_mcp_market(
mcp_market_id: uuid.UUID,
update_data: mcp_market_schema.McpMarketUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# 1. Check if the mcp market exists
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or you do not have permission to access it"
)
# 2. not updating the name (name already exists)
update_dict = update_data.dict(exclude_unset=True)
if "name" in update_dict:
name = update_dict["name"]
if name != db_mcp_market.name:
# Check if the mcp market name already exists
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
if db_mcp_market_exist:
api_logger.warning(f"The mcp market name already exists: {name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market name already exists: {name}"
)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_mcp_market, field):
old_value = getattr(db_mcp_market, field)
if old_value != value:
# update value
setattr(db_mcp_market, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 4. Save to database
try:
db.commit()
db.refresh(db_mcp_market)
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
except Exception as e:
db.rollback()
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The mcp market update failed: {str(e)}"
)
# 5. Return the updated mcp market
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="The mcp market information updated successfully")
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
async def delete_mcp_market(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete mcp market
"""
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
try:
# 1. Check whether the mcp market exists
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or you do not have permission to access it"
)
# 2. Deleting mcp market
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
return success(msg="The mcp market has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
raise

View File

@@ -1,17 +1,8 @@
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
@@ -24,6 +15,10 @@ from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv()
api_logger = get_api_logger()
@@ -38,7 +33,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
@@ -56,9 +51,8 @@ async def get_health_status(
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
@@ -77,16 +71,16 @@ async def download_log(
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
@@ -121,28 +115,23 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
user_input: Write request containing message and group_id
Returns:
Response with write operation status
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -151,7 +140,7 @@ async def write_server(
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
@@ -163,27 +152,22 @@ async def write_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.end_user_id,
messages_list,
user_input.group_id,
user_input.message,
config_id,
db,
storage_type,
user_rag_memory_id,
language
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -199,29 +183,23 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
user_input: Write request containing message and group_id
Returns:
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
@@ -241,15 +219,12 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
@@ -259,9 +234,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
@@ -272,14 +247,16 @@ async def read_server(
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and end_user_id
user_input: Read request with message, history, search_switch, and group_id
Returns:
Response with query answer
"""
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
@@ -294,14 +271,12 @@ async def read_server(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.end_user_id,
user_input.group_id,
user_input.message,
user_input.history,
user_input.search_switch,
@@ -310,23 +285,6 @@ async def read_server(
storage_type,
user_rag_memory_id
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
if "信息不足,无法回答" in result['answer']:
result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -342,10 +300,9 @@ async def read_server(
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id: str = Form(..., description="模型ID"),
model_id:str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
文件上传接口 - 支持图片识别
@@ -358,6 +315,9 @@ async def file_update(
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
@@ -366,7 +326,7 @@ async def file_update(
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
@@ -380,12 +340,12 @@ async def file_update(
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@@ -422,7 +382,7 @@ async def read_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
@@ -435,8 +395,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
@@ -457,7 +417,7 @@ async def get_read_task_result(
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -466,7 +426,7 @@ async def get_read_task_result(
return success(
data={
"result": task_result.get("result"),
"end_user_id": task_result.get("end_user_id"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -475,7 +435,7 @@ async def get_read_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -484,7 +444,7 @@ async def get_read_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -504,7 +464,7 @@ async def get_read_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -512,8 +472,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
@@ -534,7 +494,7 @@ async def get_write_task_result(
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -543,7 +503,7 @@ async def get_write_task_result(
return success(
data={
"result": task_result.get("result"),
"end_user_id": task_result.get("end_user_id"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -552,7 +512,7 @@ async def get_write_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -561,7 +521,7 @@ async def get_write_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -581,7 +541,7 @@ async def get_write_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -589,38 +549,23 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and end_user_id
user_input: Request containing user message and group_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
last_user_message,
user_input.message,
user_input.config_id,
db
)
@@ -634,21 +579,26 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
@@ -656,70 +606,45 @@ async def get_knowledge_type_stats_api(
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
):
"""
获取指定用户的兴趣分布标签
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
获取指定用户的热门记忆标签
返回格式:
[
{"name": "兴趣活动", "frequency": 频次},
{"name": "标签", "frequency": 频次},
...
]
"""
language = get_language_from_header(language_type)
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
# 优先读取缓存
cached = await InterestMemoryCache.get_interest_distribution(
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
language=language,
limit=limit
)
if cached is not None:
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
return success(data=cached, msg="获取兴趣分布标签成功")
# 缓存未命中,调用模型生成
result = await memory_agent_service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=limit,
language=language
)
# 写入缓存24小时过期
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
)
return success(data=result, msg="获取兴趣分布标签成功")
return success(data=result, msg="获取热门记忆标签成功")
except Exception as e:
api_logger.error(f"Interest distribution by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
@@ -757,17 +682,17 @@ async def get_user_profile_api(
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
@@ -783,9 +708,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
@@ -804,9 +729,9 @@ async def get_end_user_connected_config(
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
@@ -815,4 +740,4 @@ async def get_end_user_connected_config(
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -5,11 +5,11 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -40,7 +40,54 @@ def get_workspace_total_end_users(
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.post("/update/end_users", response_model=ApiResponse)
async def update_workspace_end_users(
user_input: End_User_Information,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新工作空间的宿主信息
"""
username = user_input.end_user_name # 要更新的用户名
end_user_input_id = user_input.id # 宿主ID
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
try:
# 导入更新函数
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
# 转换 end_user_id 为 UUID 类型
end_user_uuid = uuid.UUID(end_user_input_id)
# 直接更新数据库中的 other_name 字段
updated_count = update_end_user_other_name(
db=db,
end_user_id=end_user_uuid,
other_name=username
)
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
return success(
data={
"updated_count": updated_count,
"end_user_id": end_user_input_id,
"updated_other_name": username
},
msg=f"成功更新 {updated_count} 个宿主的信息"
)
except Exception as e:
api_logger.error(f"更新宿主信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新宿主信息失败: {str(e)}"
)
@@ -50,134 +97,63 @@ async def get_workspace_end_users(
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表(高性能优化版本 v2
获取工作空间的宿主列表
优化策略:
1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
返回格式与原 memory_list 接口中的 end_users 字段相同,
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
"""
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users已优化为批量查询
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
if not end_users:
api_logger.info("工作空间下没有宿主")
# 缓存空结果,避免重复查询
try:
await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
memory_configs_map = {}
if end_user_ids:
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {}
# 失败时使用空字典,不影响其他数据返回
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = []
for end_user in end_users:
user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {})
result.append({
'end_user': {
'id': user_id,
'other_name': end_user.other_name
},
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_config': {
"memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name")
memory_num = {}
if current_workspace_type == "neo4j":
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
elif current_workspace_type == "rag":
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 只保留需要的字段,移除 error 字段(如果有
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
}
)
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")
@@ -470,8 +446,6 @@ async def get_chunk_insight(
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -506,15 +480,6 @@ async def dashboard_data(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 如果没有提供时间范围默认使用最近30天
if start_date is None or end_date is None:
from datetime import datetime, timedelta
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=30)
end_date = int(end_dt.timestamp() * 1000)
start_date = int(start_dt.timestamp() * 1000)
api_logger.info(f"使用默认时间范围: {start_dt}{end_dt}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -575,22 +540,17 @@ async def dashboard_data(
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用统计total_api_call
# 3. 获取API调用增量total_api_call,转换为整数
try:
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
current_user=current_user
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
api_logger.warning(f"获取API调用增量失败: {str(e)}")
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -606,8 +566,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
@@ -619,23 +579,10 @@ async def dashboard_data(
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -3,10 +3,9 @@
包含情景记忆总览和详情查询接口
"""
from fastapi import APIRouter, Depends, Header
from fastapi import APIRouter, Depends
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user
@@ -15,7 +14,6 @@ from app.schemas.response_schema import ApiResponse
from app.schemas.memory_episodic_schema import (
EpisodicMemoryOverviewRequest,
EpisodicMemoryDetailsRequest,
translate_episodic_type,
)
from app.services.memory_episodic_service import memory_episodic_service
@@ -86,7 +84,6 @@ async def get_episodic_memory_overview_api(
@router.post("/details", response_model=ApiResponse)
async def get_episodic_memory_details_api(
request: EpisodicMemoryDetailsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
@@ -114,11 +111,6 @@ async def get_episodic_memory_details_api(
summary_id=request.summary_id
)
# 根据语言参数翻译 episodic_type
language = get_language_from_header(language_type)
if "episodic_type" in result:
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
api_logger.info(
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
)

View File

@@ -11,7 +11,6 @@
"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -34,7 +33,7 @@ from app.schemas.memory_storage_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -84,8 +83,7 @@ async def trigger_forgetting_cycle(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id((config_id), db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
@@ -108,7 +106,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
group_id=end_user_id, # 服务层方法的参数名是 group_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id
@@ -130,7 +128,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: UUID|int,
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -159,7 +157,6 @@ async def read_forgetting_config(
)
try:
config_id=resolve_config_id(config_id, db)
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
@@ -197,8 +194,6 @@ async def update_forgetting_config(
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
payload.config_id=resolve_config_id((payload.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
@@ -241,7 +236,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
end_user_id: Optional[str] = None,
group_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -251,7 +246,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。
Args:
end_user_id: 组ID即 end_user_id可选
group_id: 组ID即 end_user_id可选
current_user: 当前用户
db: 数据库会话
@@ -259,25 +254,26 @@ async def get_forgetting_stats(
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 end_user_id通过它获取 config_id
# 如果提供了 group_id通过它获取 config_id
config_id = None
if end_user_id:
if group_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -287,14 +283,14 @@ async def get_forgetting_stats(
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"end_user_id={end_user_id}, config_id={config_id}"
f"group_id={group_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
end_user_id=end_user_id,
group_id=group_id,
config_id=config_id
)
@@ -328,7 +324,7 @@ async def get_forgetting_curve(
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
request.config_id = resolve_config_id((request.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")

View File

@@ -27,27 +27,27 @@ router = APIRouter(
)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
end_user_id: ID of the user group (usually end_user_id in this context)
group_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(end_user_id)
count_stats = service.get_memory_count(group_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(end_user_id)
visual_memory = service.get_latest_visual_memory(group_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
api_logger.info(f"No visual memory found: group_id={group_id}")
return success(
data=None,
msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(end_user_id)
audio_memory = service.get_latest_audio_memory(group_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
api_logger.info(f"No audio memory found: group_id={group_id}")
return success(
data=None,
msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(end_user_id)
text_memory = service.get_latest_text_memory(group_id)
if text_memory is None:
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
api_logger.info(f"No text memory found: group_id={group_id}")
return success(
data=None,
msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
f"group_id={group_id}, type={perceptual_type}, page={page}"
)
try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(end_user_id, query)
timeline_data = service.get_time_line(group_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"error={str(e)}"
)
return fail(

View File

@@ -1,19 +1,16 @@
import asyncio
import time
import uuid
from uuid import UUID
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
ReflectionConfig,
ReflectionEngine, ReflectionRange, ReflectionBaseline,
ReflectionEngine,
)
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
@@ -22,12 +19,10 @@ from app.services.memory_reflection_service import (
)
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status,Header
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
load_dotenv()
api_logger = get_api_logger()
@@ -44,40 +39,64 @@ async def save_reflection_config(
db: Session = Depends(get_db),
) -> dict:
"""Save reflection configuration to data_comfig table"""
try:
config_id = request.config_id
config_id = resolve_config_id(config_id, db)
if not config_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="缺少必需参数: config_id"
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
memory_config = MemoryConfigRepository.update_reflection_config(
db,
config_id=config_id,
enable_self_reflexion=request.reflection_enabled,
iteration_period=request.reflection_period_in_hours,
reflexion_range=request.reflexion_range,
baseline=request.baseline,
reflection_model_id=request.reflection_model_id,
memory_verify=request.memory_verify,
quality_assessment=request.quality_assessment
)
update_params = {
"enable_self_reflexion": request.reflection_enabled,
"iteration_period": request.reflection_period_in_hours,
"reflexion_range": request.reflexion_range,
"baseline": request.baseline,
"reflection_model_id": request.reflection_model_id,
"memory_verify": request.memory_verify,
"quality_assessment": request.quality_assessment,
}
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
result = db.execute(text(query), params)
if result.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
db.commit()
db.refresh(memory_config)
# 查询更新后的配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"更新后未找到config_id为 {config_id} 的配置"
)
api_logger.info(f"成功保存反思配置到数据库config_id: {config_id}")
reflection_result={
"config_id": memory_config.config_id,
"enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": memory_config.iteration_period,
"reflexion_range": memory_config.reflexion_range,
"baseline": memory_config.baseline,
"reflection_model_id": memory_config.reflection_model_id,
"memory_verify": memory_config.memory_verify,
"quality_assessment": memory_config.quality_assessment}
"config_id": result.config_id,
"enable_self_reflexion": result.enable_self_reflexion,
"iteration_period": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment,
"user_id": result.user_id}
return success(data=reflection_result, msg="反思配置成功")
@@ -97,76 +116,48 @@ async def save_reflection_config(
)
@router.get("/reflection")
@router.post("/reflection")
async def start_workspace_reflection(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""启动工作空间中所有匹配应用的反思功能"""
"""Activate the reflection function for all matching applications in the workspace"""
workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
# 使用独立的数据库会话来获取工作空间应用详情,避免事务失败
from app.db import get_db_context
with get_db_context() as query_db:
service = WorkspaceAppService(query_db)
result = service.get_workspace_apps_detailed(workspace_id)
service = WorkspaceAppService(db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = []
for data in result['apps_detailed_info']:
# 跳过没有配置的应用
if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
if data['data_configs'] == []:
continue
releases = data['releases']
memory_configs = data['memory_configs']
data_configs = data['data_configs']
end_users = data['end_users']
# 为每个配置和用户组合执行反思
for config in memory_configs:
config_id_str = str(config['config_id'])
# 找到匹配此配置的所有release
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue
# 为每个用户执行反思 - 使用独立的数据库会话
for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
# 为每个用户创建独立的数据库会话,避免事务失败影响其他用户
with get_db_context() as user_db:
try:
reflection_service = MemoryReflectionService(user_db)
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": reflection_result
})
except Exception as e:
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": {
"status": "错误",
"message": f"反思失败: {str(e)}"
}
})
for base, config, user in zip(releases, data_configs, end_users):
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": base['app_id'],
"config_id": config['config_id'],
"end_user_id": user['id'],
"reflection_result": reflection_result
})
return success(data=reflection_results, msg="反思配置成功")
@@ -180,27 +171,35 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: uuid.UUID|int,
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""通过config_id查询memory_config表中的反思配置信息"""
config_id = resolve_config_id(config_id, db)
"""通过config_id查询data_config表中的反思配置信息"""
try:
config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db)
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
# 构建返回数据
reflection_config = {
"config_id": memory_config_id,
"config_id": result.config_id,
"reflection_enabled": result.enable_self_reflexion,
"reflection_period_in_hours": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment
"quality_assessment": result.quality_assessment,
"user_id": result.user_id
}
api_logger.info(f"成功查询反思配置config_id: {config_id}")
return success(data=reflection_config, msg="反思配置查询成功")
@@ -218,19 +217,19 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: UUID|int,
language_type: str = Header(default=None, alias="X-Language-Type"),
config_id: int,
language_type: str = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""Activate the reflection function for all matching applications in the workspace"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db)
# 使用MemoryConfigRepository查询反思配置
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -243,7 +242,7 @@ async def reflection_run(
model_id = result.reflection_model_id
if model_id:
try:
ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id))
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
api_logger.info(f"模型ID验证成功: {model_id}")
except Exception as e:
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
@@ -253,8 +252,8 @@ async def reflection_run(
config = ReflectionConfig(
enabled=result.enable_self_reflexion,
iteration_period=result.iteration_period,
reflexion_range=ReflectionRange(result.reflexion_range),
baseline=ReflectionBaseline(result.baseline),
reflexion_range=result.reflexion_range,
baseline=result.baseline,
output_example='',
memory_verify=result.memory_verify,
quality_assessment=result.quality_assessment,

View File

@@ -1,18 +1,15 @@
from typing import Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy.orm import Session
from app.core.language_utils import get_language_from_header
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity
from app.services.memory_storage_service import search_entity
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
api_logger = get_api_logger()
@@ -23,19 +20,15 @@ router = APIRouter(
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
language_type:str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验
language = get_language_from_header(language_type)
# 获取短期记忆数据
short_term=ShortService(end_user_id, db)
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
long_term=LongService(end_user_id, db)
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
entity_result = await search_entity(end_user_id)

View File

@@ -1,13 +1,10 @@
import os
import uuid
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -15,6 +12,7 @@ from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
@@ -32,14 +30,13 @@ from app.services.memory_storage_service import (
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from fastapi import APIRouter, Depends, Header
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# Get API logger
api_logger = get_api_logger()
@@ -75,9 +72,68 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
# --- DB connection dependency ---
_CONN: Optional[object] = None
"""PostgreSQL 连接生成与管理(使用 psycopg2"""
# 这个可以转移,可能是已经有的
# PostgreSQL 数据库连接
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
host = os.getenv("DB_HOST")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host or "localhost",
port=port,
user=user,
password=password,
dbname=database,
)
# 设置自动提交,避免显式事务管理
conn.autocommit = True
# 设置会话时区为中国标准时间Asia/Shanghai便于直接以本地时区展示
try:
cur = conn.cursor()
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
cur.close()
except Exception:
# 时区设置失败不影响连接,仅记录但不抛出
pass
return conn
except Exception as e:
try:
print(f"[PostgreSQL] 连接失败: {e}")
except Exception:
pass
return None
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
global _CONN
if _CONN is None:
_CONN = _make_pgsql_conn()
return _CONN
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
"""Close and recreate the global DB connection."""
global _CONN
try:
if _CONN:
try:
_CONN.close()
except Exception:
pass
_CONN = _make_pgsql_conn()
return _CONN is not None
except Exception:
_CONN = None
return False
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@@ -85,9 +141,9 @@ def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
@@ -100,125 +156,46 @@ def create_config(
svc = DataConfigService(db)
result = svc.create(payload)
return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: UUID|int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""删除记忆配置(带终端用户保护)
- 检查是否为默认配置,默认配置不允许删除
- 检查是否有终端用户连接到该配置
- 如果有连接且 force=False返回警告
- 如果 force=True清除终端用户引用后删除配置
Query Parameters:
force: 设置为 true 可强制删除(即使有终端用户正在使用)
"""
) -> dict:
workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}"
)
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
try:
# 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService
config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force)
if result["status"] == "error":
api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
)
return fail(
code=BizCode.FORBIDDEN,
msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
)
if result["status"] == "warning":
api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, "
f"connected_count={result['connected_count']}"
)
return fail(
code=BizCode.RESOURCE_IN_USE,
msg=result["message"],
data={
"connected_count": result["connected_count"],
"force_required": result["force_required"]
}
)
api_logger.info(
f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}"
)
return success(
msg=result["message"],
data={"affected_users": result["affected_users"]}
)
svc = DataConfigService(db)
result = svc.delete(ConfigParamsDelete(config_id=config_id))
return success(data=result, msg="删除成功")
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
api_logger.error(f"Delete config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
payload: ConfigUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)
@@ -234,9 +211,9 @@ def update_config_extracted(
payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
@@ -258,12 +235,12 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: UUID | int,
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
@@ -282,7 +259,7 @@ def read_config_extracted(
def read_all_config(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -304,22 +281,16 @@ def read_all_config(
@router.post("/pilot_run", response_model=None)
async def pilot_run(
payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, "
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
f"dialogue_text_length={len(payload.dialogue_text)}"
)
payload.config_id = resolve_config_id(payload.config_id, db)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload, language=language),
svc.pilot_run_stream(payload),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -328,8 +299,9 @@ async def pilot_run(
},
)
# ==================== Search & Analytics ====================
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。
"""
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
@@ -442,7 +414,21 @@ async def search_entity_edges(
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
@@ -451,106 +437,39 @@ async def get_hot_memory_tags_api(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取热门记忆标签带Redis缓存
缓存策略:
- 缓存键workspace_id + limit
- 过期时间5分钟300秒
- 缓存命中:~50ms
- 缓存未命中:~600-800ms取决于LLM速度
"""
workspace_id = current_user.current_workspace_id
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
try:
# 尝试从Redis缓存获取
import json
from app.aioRedis import aio_redis_get, aio_redis_set
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
try:
data = json.loads(cached_result)
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串
try:
cache_data = json.dumps(result, ensure_ascii=False)
await aio_redis_set(cache_key, cache_data, expire=300)
api_logger.info(f"Cached result for key: {cache_key}")
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user),
) -> dict:
"""
清除热门标签缓存
用于:
- 手动刷新数据
- 调试和测试
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try:
from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
result = await aio_redis_delete(cache_key)
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
return success(
data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
) -> dict:
api_logger.info("Recent activity stats requested")
try:
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
result = await analytics_recent_activity_stats()
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
Args:
None
Returns:
自我反思结果。
"""
return await self_reflexion(host_id)

View File

@@ -20,18 +20,18 @@ router = APIRouter(
)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
@router.get("/{group_id}/conversations", response_model=ApiResponse)
def get_conversations(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -39,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group.
Args:
end_user_id (UUID): The group identifier.
group_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
@@ -53,7 +53,7 @@ def get_conversations(
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
end_user_id
group_id
)
return success(data=[
{
@@ -63,7 +63,7 @@ def get_conversations(
], msg="get conversations success")
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
@router.get("/{group_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
@@ -100,7 +100,7 @@ def get_messages(
return success(data=messages, msg="get conversation history success")
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
@router.get("/{group_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),

View File

@@ -3,17 +3,15 @@ from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
from app.models.models_model import ModelProvider, ModelType
from app.models.user_model import User
from app.repositories.model_repository import ModelConfigRepository
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -26,54 +24,44 @@ router = APIRouter(
@router.get("/type", response_model=ApiResponse)
def get_model_types():
return success(msg="获取模型类型成功", data=list(ModelType))
@router.get("/provider", response_model=ApiResponse)
def get_model_providers():
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
return success(msg="获取模型提供商成功", data=providers)
@router.get("/strategy", response_model=ApiResponse)
def get_model_strategies():
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
return success(msg="获取模型提供商成功", data=list(ModelProvider))
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个(逗号分隔):?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = []
if type is not None:
flat_type = []
for item in type:
split_items = [t.strip() for t in item.split(',') if t.strip()]
flat_type.extend(split_items)
unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type]
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type_list,
@@ -84,7 +72,7 @@ def get_model_list(
page=page,
pagesize=pagesize
)
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
result = PageData.model_validate(result_orm)
@@ -95,146 +83,6 @@ def get_model_list(
raise
@router.get("/new", response_model=ApiResponse)
def get_model_list_new(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个(逗号分隔):?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = []
if type is not None:
flat_type = []
for item in type:
split_items = [t.strip() for t in item.split(',') if t.strip()]
flat_type.extend(split_items)
unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type]
api_logger.info(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQueryNew(
type=type_list,
provider=provider,
is_active=is_active,
is_public=is_public,
is_composite=is_composite,
search=search
)
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
return success(data=result, msg="模型配置列表获取成功")
except Exception as e:
api_logger.error(f"获取模型配置列表失败: {str(e)}")
raise
@router.get("/model_plaza", response_model=ApiResponse)
def get_model_plaza_list(
type: Optional[ModelType] = Query(None, description="模型类型"),
provider: Optional[ModelProvider] = Query(None, description="供应商"),
is_official: Optional[bool] = Query(None, description="是否官方模型"),
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""模型广场查询接口(按供应商分组)"""
query = model_schema.ModelBaseQuery(
type=type,
provider=provider,
is_official=is_official,
is_deprecated=is_deprecated,
search=search
)
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
return success(data=result, msg="模型广场列表获取成功")
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
def get_model_base_by_id(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取基础模型详情"""
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
@router.post("/model_plaza", response_model=ApiResponse)
def create_model_base(
data: model_schema.ModelBaseCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建基础模型"""
result = ModelBaseService.create_model_base(db=db, data=data)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
def update_model_base(
model_base_id: uuid.UUID,
data: model_schema.ModelBaseUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新基础模型"""
# 不允许更改type类型
if data.type is not None or data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
def delete_model_base(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除基础模型"""
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
return success(msg="基础模型删除成功")
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
def add_model_from_plaza(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从模型广场添加模型到模型列表"""
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
@router.get("/{model_id}", response_model=ApiResponse)
def get_model_by_id(
model_id: uuid.UUID,
@@ -290,73 +138,6 @@ async def create_model(
raise
@router.post("/composite", response_model=ApiResponse)
async def create_composite_model(
model_data: model_schema.CompositeModelCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
创建组合模型
- 绑定一个或多个现有的 API Key
- 所有 API Key 必须来自非组合模型
- 所有 API Key 关联的模型类型必须与组合模型类型一致
"""
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
try:
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="组合模型创建成功")
except Exception as e:
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
raise
@router.put("/composite/{model_id}", response_model=ApiResponse)
async def update_composite_model(
model_id: uuid.UUID,
model_data: model_schema.CompositeModelCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新组合模型"""
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
try:
if model_data.type is not None:
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="组合模型更新成功")
except Exception as e:
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
raise
@router.delete("/composite/{model_id}", response_model=ApiResponse)
def delete_composite_model(
model_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除组合模型"""
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
try:
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型删除成功: model_id={model_id}")
return success(msg="组合模型删除成功")
except Exception as e:
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
raise
@router.put("/{model_id}", response_model=ApiResponse)
def update_model(
model_id: uuid.UUID,
@@ -368,14 +149,6 @@ def update_model(
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
if model_data.type is not None or model_data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
if model_data.is_active:
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
if not active_keys:
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
@@ -441,55 +214,6 @@ def get_model_api_keys(
raise
@router.post("/provider/apikeys", response_model=ApiResponse)
async def create_model_api_key_by_provider(
api_key_data: model_schema.ModelApiKeyCreateByProvider,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
根据供应商为所有匹配的模型创建API Key
"""
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
try:
# 根据tenant_id和provider筛选model_config_id列表
model_config_ids = api_key_data.model_config_ids
if not model_config_ids:
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
db=db,
tenant_id=current_user.tenant_id,
provider=api_key_data.provider
)
if not model_config_ids:
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
# 构造schema并调用service
create_data = model_schema.ModelApiKeyCreateByProvider(
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
description=api_key_data.description,
config=api_key_data.config,
is_active=api_key_data.is_active,
priority=api_key_data.priority,
model_config_ids=model_config_ids,
capability=api_key_data.capability,
is_omni=api_key_data.is_omni
)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
except Exception as e:
api_logger.error(f"创建API Key失败: {str(e)}")
raise
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
async def create_model_api_key(
model_id: uuid.UUID,
@@ -504,12 +228,11 @@ async def create_model_api_key(
try:
# 设置模型配置ID
api_key_data.model_config_ids = [model_id]
api_key_data.model_config_id = model_id
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
result = model_schema.ModelApiKey.model_validate(result_orm)
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
return success(data=result, msg="模型API Key创建成功")
except Exception as e:
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
@@ -611,3 +334,5 @@ async def validate_model_config(
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")

File diff suppressed because it is too large Load Diff

View File

@@ -1,663 +0,0 @@
# -*- coding: utf-8 -*-
"""本体场景和类型路由(续)
由于主Controller文件较大将剩余路由放在此文件中。
"""
from uuid import UUID
from typing import Optional
from fastapi import Depends, Header
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.ontology_schemas import (
SceneResponse,
SceneListResponse,
PaginationInfo,
ClassCreateRequest,
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
ClassBatchCreateResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.repositories.ontology_class_repository import OntologyClassRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
"""获取OntologyService实例不需要LLM
场景和类型管理不需要LLM创建一个dummy配置。
"""
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
return OntologyService(llm_client=llm_client, db=db)
# 这些函数将被导入到主Controller中
async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
pagesize: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
当提供 scene_name 参数时,进行模糊搜索(不分页);
当不提供 scene_name 参数时,返回所有场景(支持分页)。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
)
try:
# 确定工作空间ID
if workspace_id:
try:
ws_uuid = UUID(workspace_id)
except ValueError:
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
else:
ws_uuid = current_user.current_workspace_id
if not ws_uuid:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 根据是否提供 scene_name 决定查询方式
if scene_name and scene_name.strip():
# 验证分页参数(模糊搜索也支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and pagesize is not None:
start_idx = (page - 1) * pagesize
end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
f"in workspace {ws_uuid}, total={total}"
)
else:
# 获取所有场景(支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应
items = []
for scene in scenes:
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
# ==================== 本体类型管理接口 ====================
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
# 根据列表长度判断是单个还是批量
count = len(request.classes)
mode = "single" if count == 1 else "batch"
api_logger.info(
f"Class creation ({mode}) requested by user {current_user.id}, "
f"scene_id={request.scene_id}, count={count}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 准备类型数据
classes_data = [
{
"class_name": item.class_name,
"class_description": item.class_description
}
for item in request.classes
]
if count == 1:
# 单个创建 - 先检查重名
class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
class_description=class_data["class_description"],
workspace_id=workspace_id
)
# 构建单个响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
else:
# 批量创建
created_classes, errors = service.create_classes_batch(
scene_id=request.scene_id,
classes=classes_data,
workspace_id=workspace_id
)
# 构建批量响应
items = []
for ontology_class in created_classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassBatchCreateResponse(
total=len(classes_data),
success_count=len(created_classes),
failed_count=len(errors),
items=items,
errors=errors if errors else None
)
api_logger.info(
f"Batch class creation completed: "
f"success={len(created_classes)}, failed={len(errors)}"
)
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e:
err_str = str(e)
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
async def update_class_handler(
class_id: str,
request: ClassUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体类型"""
api_logger.info(
f"Class update requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试修改系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可修改",
"该类型为系统预设类型,不允许修改"
)
# 创建Service
service = _get_dummy_ontology_service(db)
# 更新类型
ontology_class = service.update_class(
class_id=class_uuid,
class_name=request.class_name,
class_description=request.class_description,
workspace_id=workspace_id
)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class updated successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
except ValueError as e:
api_logger.warning(f"Validation error in class update: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
async def delete_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体类型"""
api_logger.info(
f"Class deletion requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试删除系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可删除",
"该类型为系统预设类型,不允许删除"
)
# 创建Service
service = _get_dummy_ontology_service(db)
# 删除类型
success_flag = service.delete_class(
class_id=class_uuid,
workspace_id=workspace_id
)
api_logger.info(f"Class deleted successfully: {class_id}")
return success(data={"deleted": success_flag}, msg="类型删除成功")
except ValueError as e:
api_logger.warning(f"Validation error in class deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
async def get_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取单个本体类型"""
api_logger.info(
f"Get class requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取类型会抛出ValueError如果不存在
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class retrieved successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
# 类型不存在或无权限访问
api_logger.warning(f"Validation error in get class: {str(e)}")
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
async def classes_handler(
scene_id: str,
class_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取类型列表(支持模糊搜索和全量查询)
当提供 class_name 参数时,进行模糊搜索;
当不提供 class_name 参数时,返回场景下的所有类型。
Args:
scene_id: 场景ID必填
class_name: 类型名称关键词(可选,支持模糊匹配)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if class_name else "list"
api_logger.info(
f"Class {operation} requested by user {current_user.id}, "
f"keyword={class_name}, scene_id={scene_id}"
)
try:
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取场景信息
scene = service.get_scene_by_id(scene_uuid, workspace_id)
if not scene:
api_logger.warning(f"Scene not found: {scene_id}")
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
# 根据是否提供 class_name 决定查询方式
if class_name and class_name.strip():
# 模糊搜索类型
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
else:
# 获取所有类型
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
# 构建响应
items = []
for ontology_class in classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassListResponse(
total=len(items),
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items
)
if class_name:
api_logger.info(
f"Class search completed: found {len(items)} classes matching '{class_name}' "
f"in scene {scene_id}"
)
else:
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))

View File

@@ -1,5 +1,5 @@
import json
import uuid
import json
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
@@ -8,13 +8,9 @@ from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.schemas.prompt_optimizer_schema import (
PromptOptMessage,
CreateSessionResponse,
SessionHistoryResponse,
SessionMessage,
PromptSaveRequest
)
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
@@ -120,8 +116,7 @@ async def get_prompt_opt(
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message,
skill=data.skill
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
@@ -140,109 +135,3 @@ async def get_prompt_opt(
"X-Accel-Buffering": "no"
}
)
@router.post(
"/releases",
summary="Get prompt optimization",
response_model=ApiResponse
)
def save_prompt(
data: PromptSaveRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Save a prompt release for the current tenant.
Args:
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
db (Session): SQLAlchemy database session, injected via dependency.
current_user: Currently authenticated user object, injected via dependency.
Returns:
ApiResponse: Standard API response containing the saved prompt release info:
- id: UUID of the prompt release
- session_id: associated session
- title: prompt title
- prompt: prompt content
- created_at: timestamp of creation
Raises:
Any database or service exceptions are propagated to the global exception handler.
"""
service = PromptOptimizerService(db)
prompt_info = service.save_prompt(
tenant_id=current_user.tenant_id,
session_id=data.session_id,
title=data.title,
prompt=data.prompt
)
return success(data=prompt_info)
@router.delete(
"/releases/{prompt_id}",
summary="Delete prompt (soft delete)",
response_model=ApiResponse
)
def delete_prompt(
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Soft delete a prompt release.
Args:
prompt_id
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Success message confirming deletion
"""
service = PromptOptimizerService(db)
service.delete_prompt(
tenant_id=current_user.tenant_id,
prompt_id=prompt_id
)
return success(msg="Prompt deleted successfully")
@router.get(
"/releases/list",
summary="Get paginated list of released prompts with optional filter",
response_model=ApiResponse
)
def get_release_list(
page: int = 1,
page_size: int = 20,
keyword: str | None = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve paginated list of released prompts for the current tenant.
Optionally filter by keyword in title.
Args:
page (int): Page number (starting from 1)
page_size (int): Number of items per page (max 100)
keyword (str | None): Optional keyword to filter prompt titles
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains paginated list of prompt releases with metadata
"""
service = PromptOptimizerService(db)
result = service.get_release_list(
tenant_id=current_user.tenant_id,
page=max(1, page),
page_size=min(max(1, page_size), 100),
filter_keyword=keyword
)
return success(data=result)

View File

@@ -2,33 +2,24 @@ import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success, fail
from app.db import get_db, get_db_read
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -74,10 +65,10 @@ def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
summary="获取访问 token"
)
def get_access_token(
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
):
"""获取访问 token
@@ -122,9 +113,9 @@ def get_access_token(
response_model=None
)
def get_shared_release(
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取公开分享的发布版本信息
@@ -146,9 +137,9 @@ def get_shared_release(
summary="验证访问密码"
)
def verify_password(
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""验证分享的访问密码
@@ -168,11 +159,11 @@ def verify_password(
summary="获取嵌入代码"
)
def get_embed_code(
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取嵌入代码
@@ -192,6 +183,7 @@ def get_embed_code(
return success(data=embed_code)
# ---------- 会话管理接口 ----------
@router.get(
@@ -199,11 +191,11 @@ def get_embed_code(
summary="获取会话列表"
)
def list_conversations(
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取分享应用的会话列表
@@ -213,13 +205,15 @@ def list_conversations(
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service.get_release_by_share_token(share_data.share_token, password)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
other_id=other_id
)
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
@@ -239,10 +233,10 @@ def list_conversations(
summary="获取会话详情(含消息)"
)
def get_conversation(
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取会话详情和消息历史"""
chat_service = SharedChatService(db)
@@ -272,10 +266,10 @@ def get_conversation(
summary="发送消息(支持流式和非流式)"
)
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
):
"""发送消息并获取回复
@@ -298,15 +292,19 @@ async def chat(
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service.get_release_by_share_token(share_token, password)
share, release = service._get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
@@ -315,14 +313,12 @@ async def chat(
)
end_user_id = str(new_end_user.id)
appid = share.app_id
appid=share.app_id
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
).first()
# 直接通过 SQLAlchemy 查询 app
from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first()
if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
@@ -359,12 +355,12 @@ async def chat(
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == AppType.AGENT:
if app_type == "agent":
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == AppType.MULTI_AGENT:
elif app_type == "multi_agent":
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
@@ -429,17 +425,16 @@ async def chat(
# )
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -476,8 +471,7 @@ async def chat(
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -487,15 +481,15 @@ async def chat(
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -567,29 +561,24 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if not config.id:
with get_db_read() as db:
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
config.id = source_config.id
config.id = uuid.UUID(config.id)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
files=payload.files,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id,
public=True
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
@@ -621,8 +610,7 @@ async def chat(
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
@@ -638,34 +626,6 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@router.get("/config", summary="获取应用启动配置")
async def config_query(
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
share_service = SharedChatService(db)
share_token = share_data.share_token
share, release = share_service.get_release_by_share_token(share_token, password)
if release.app.type == AppType.WORKFLOW:
workflow_service = WorkflowService(db)
content = {
"app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config)
}
elif release.app.type == AppType.AGENT:
content = {
"app_type": release.app.type,
"variables": release.config.get("variables")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {
"app_type": release.app.type,
"variables": []
}
else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
return success(data=content)

View File

@@ -12,6 +12,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_app_or_workspace
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
@@ -20,10 +21,9 @@ from app.schemas import AppChatRequest, conversation_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import get_app_service, AppService
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
logger = get_business_logger()
@@ -34,7 +34,6 @@ async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")
# /v1/app/chat
# @router.post("/chat")
@@ -74,21 +73,21 @@ def _checkAppConfig(app: App):
else:
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
@router.post("/chat")
@require_api_key(scopes=["app"])
async def chat(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
app_service: Annotated[AppService, Depends(get_app_service)] = None,
message: str = Body(..., description="聊天消息内容"),
request:Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
app_service: Annotated[AppService, Depends(get_app_service)] = None,
message: str = Body(..., description="聊天消息内容"),
):
body = await request.json()
payload = AppChatRequest(**body)
other_id = payload.user_id
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id
workspace_id = app.workspace_id
@@ -99,8 +98,8 @@ async def chat(
original_user_id=other_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
web_search = True
memory = True
web_search=True
memory=True
# 提前验证和准备(在流式响应开始前完成)
storage_type = workspace_service.get_workspace_storage_type_without_auth(
db=db,
@@ -134,8 +133,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
user_id=end_user_id,
is_draft=False,
conversation_id=payload.conversation_id
is_draft=False
)
if app_type == AppType.AGENT:
@@ -148,17 +146,16 @@ async def chat(
if payload.stream:
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
web_search=web_search,
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= end_user_id, # 转换为字符串
variables=payload.variables,
web_search=web_search,
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -178,13 +175,12 @@ async def chat(
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=agent_config,
config= agent_config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -194,15 +190,15 @@ async def chat(
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -236,20 +232,18 @@ async def chat(
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
files=payload.files,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
@@ -273,16 +267,15 @@ async def chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id
app_id=app.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
@@ -299,3 +292,4 @@ async def chat(
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -246,73 +246,3 @@ async def rebuild_knowledge_graph(
db=db,
current_user=current_user)
@router.get("/check/yuque/auth", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def check_yuque_auth(
yuque_user_id: str,
yuque_token: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
check yuque auth info
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
api_logger.info(f"check yuque auth info, username: {current_user.username}")
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
yuque_token=yuque_token,
db=db,
current_user=current_user)
@router.get("/check/feishu/auth", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def check_feishu_auth(
feishu_app_id: str,
feishu_app_secret: str,
feishu_folder_token: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
check feishu auth info
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
api_logger.info(f"check feishu auth info, username: {current_user.username}")
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
feishu_app_secret=feishu_app_secret,
feishu_folder_token=feishu_folder_token,
db=db,
current_user=current_user)
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def sync_knowledge(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
sync knowledge base information based on knowledge_id
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
db=db,
current_user=current_user)

View File

@@ -1,85 +0,0 @@
"""Skill Controller - 技能市场管理"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.skill_service import SkillService
from app.core.response_utils import success
router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建技能 - 可以关联现有工具内置、MCP、自定义"""
tenant_id = current_user.tenant_id
skill = SkillService.create_skill(db, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
@router.get("", summary="技能列表")
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
is_public: Optional[bool] = Query(None, description="是否公开"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""技能市场列表 - 包含本工作空间和公开的技能"""
tenant_id = current_user.tenant_id
skills, total = SkillService.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
items = [skill_schema.Skill.model_validate(s) for s in skills]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
@router.get("/{skill_id}", summary="获取技能详情")
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取技能详情"""
tenant_id = current_user.tenant_id
skill = SkillService.get_skill(db, skill_id, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
@router.put("/{skill_id}", summary="更新技能")
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新技能"""
tenant_id = current_user.tenant_id
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
@router.delete("/{skill_id}", summary="删除技能")
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除技能"""
tenant_id = current_user.tenant_id
SkillService.delete_skill(db, skill_id, tenant_id)
return success(msg="技能删除成功")

View File

@@ -2,23 +2,15 @@ from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import uuid
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.models.user_model import User
from app.schemas import user_schema
from app.schemas.user_schema import (
ChangePasswordRequest,
AdminChangePasswordRequest,
SendEmailCodeRequest,
VerifyEmailCodeRequest,
VerifyPasswordRequest)
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
from app.schemas.response_schema import ApiResponse
from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.security import verify_password
# 获取API专用日志器
api_logger = get_api_logger()
@@ -100,7 +92,7 @@ def get_current_user_info(
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
if ws.workspace_id == current_user.current_workspace_id:
result_schema.role = ws.role
break
@@ -128,7 +120,6 @@ def get_tenant_superusers(
return success(data=superusers_schema, msg="租户超管列表获取成功")
@router.get("/{user_id}", response_model=ApiResponse)
def get_user_info_by_id(
user_id: uuid.UUID,
@@ -189,54 +180,4 @@ async def admin_change_password(
return success(msg="密码修改成功")
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg="密码重置成功")
@router.post("/verify_pwd", response_model=ApiResponse)
def verify_pwd(
request: VerifyPasswordRequest,
current_user: User = Depends(get_current_user),
):
"""验证当前用户密码"""
api_logger.info(f"用户验证密码请求: {current_user.username}")
is_valid = verify_password(request.password, current_user.hashed_password)
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
if not is_valid:
raise BusinessException("密码验证失败", code=BizCode.VALIDATION_FAILED)
return success(data={"valid": is_valid}, msg="验证完成")
@router.post("/send-email-code", response_model=ApiResponse)
async def send_email_code(
request: SendEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""发送邮箱验证码"""
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
return success(msg="验证码已发送到您的邮箱,请查收")
@router.put("/change-email", response_model=ApiResponse)
async def change_email(
request: VerifyEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""验证验证码并修改邮箱"""
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
await user_service.verify_and_change_email(
db=db,
user_id=current_user.id,
new_email=request.new_email,
code=request.code
)
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
return success(msg="邮箱修改成功")
return success(data=generated_password, msg="密码重置成功")

View File

@@ -5,10 +5,9 @@
from typing import Optional
import datetime
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends,Header
from fastapi import APIRouter, Depends
from app.db import get_db
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
@@ -21,7 +20,7 @@ from app.services.user_memory_service import (
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -73,7 +72,6 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -82,26 +80,11 @@ async def get_user_summary_api(
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
语言控制:
- 使用 X-Language-Type Header 指定语言
- 如果未传 Header默认使用中文 (zh)
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -117,7 +100,6 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api(
request: GenerateCacheRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -126,14 +108,7 @@ async def generate_cache_api(
- 如果提供 end_user_id只为该用户生成
- 如果不提供,为当前工作空间的所有用户生成
语言控制:
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
- 如果未传 Header默认使用中文 (zh)
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -141,27 +116,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
end_user_id = request.end_user_id
group_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
f"end_user_id={group_id if group_id else '全部用户'}"
)
try:
if end_user_id:
if group_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
# 构建响应
result = {
"end_user_id": end_user_id,
"end_user_id": group_id,
"insight_success": insight_result["success"],
"summary_success": summary_result["success"],
"errors": []
@@ -181,9 +156,9 @@ async def generate_cache_api(
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
api_logger.info(f"成功为用户 {group_id} 生成缓存")
else:
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")
@@ -191,7 +166,7 @@ async def generate_cache_api(
# 为整个工作空间生成
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
# 记录统计信息
api_logger.info(
@@ -278,6 +253,7 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
@@ -302,13 +278,7 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
@@ -326,6 +296,7 @@ async def get_end_user_profile(
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -357,11 +328,12 @@ async def update_end_user_profile(
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 验证工作空间
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
@@ -371,44 +343,65 @@ async def update_end_user_profile(
f"workspace={workspace_id}"
)
# 调用 Service 层处理业务逻辑
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if result["success"]:
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return success(data=result["data"], msg="更新成功")
else:
error_msg = result["error"]
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
# 根据错误类型映射到合适的业务错误码
if error_msg == "终端用户不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
elif error_msg == "无效的用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
else:
# 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
async def memory_space_timeline_of_shared_memories(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,

View File

@@ -0,0 +1,610 @@
"""
工作流 API 控制器
"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Path, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.models.app_model import App
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.schemas.workflow_schema import (
WorkflowConfigCreate,
WorkflowConfigUpdate,
WorkflowConfig,
WorkflowValidationResponse,
WorkflowExecution,
WorkflowNodeExecution,
WorkflowExecutionRequest,
WorkflowExecutionResponse
)
from app.core.response_utils import success, fail
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/apps", tags=["workflow"])
# ==================== 工作流配置管理 ====================
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 创建工作流配置
workflow_config = service.create_workflow_config(
app_id=app_id,
nodes=[node.model_dump() for node in config.nodes],
edges=[edge.model_dump() for edge in config.edges],
variables=[var.model_dump() for var in config.variables],
execution_config=config.execution_config.model_dump(),
triggers=[trigger.model_dump() for trigger in config.triggers],
validate=True # 进行基础验证
)
return success(
data=WorkflowConfig.model_validate(workflow_config),
msg="工作流配置创建成功"
)
except BusinessException as e:
logger.warning(f"创建工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)]
#
# ):
# """获取工作流配置
#
# 获取应用的工作流配置详情。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
#
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
#
# # 获取工作流配置
# service = WorkflowService(db)
# workflow_config = service.get_workflow_config(app_id)
#
# if not workflow_config:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="工作流配置不存在"
# )
#
# return success(
# data=WorkflowConfig.model_validate(workflow_config)
# )
#
# except Exception as e:
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"获取工作流配置失败: {str(e)}"
# )
# @router.put("/{app_id}/workflow")
# async def update_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# config: WorkflowConfigUpdate,
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)],
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
# ):
# """更新工作流配置
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
# # 更新工作流配置
# workflow_config = service.update_workflow_config(
# app_id=app_id,
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
# validate=True
# )
# return success(
# data=WorkflowConfig.model_validate(workflow_config),
# msg="工作流配置更新成功"
# )
# except BusinessException as e:
# logger.warning(f"更新工作流配置失败: {e.message}")
# return fail(code=e.error_code, msg=e.message)
# except Exception as e:
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"更新工作流配置失败: {str(e)}"
# )
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
删除应用的工作流配置。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 删除工作流配置
deleted = service.delete_workflow_config(app_id)
if not deleted:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
return success(msg="工作流配置删除成功")
except Exception as e:
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"删除工作流配置失败: {str(e)}"
)
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证工作流配置
if for_publish:
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
else:
workflow_config = service.get_workflow_config(app_id)
if not workflow_config:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
from app.core.workflow.validator import validate_workflow_config as validate_config
config_dict = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
"execution_config": workflow_config.execution_config,
"triggers": workflow_config.triggers
}
is_valid, errors = validate_config(config_dict, for_publish=False)
return success(
data=WorkflowValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=[]
)
)
except BusinessException as e:
logger.warning(f"验证工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"验证工作流配置失败: {str(e)}"
)
# ==================== 工作流执行管理 ====================
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
获取应用的工作流执行历史记录。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 获取执行记录
executions = service.get_executions_by_app(app_id, limit, offset)
# 获取统计信息
statistics = service.get_execution_statistics(app_id)
return success(
data={
"executions": [WorkflowExecution.model_validate(e) for e in executions],
"statistics": statistics,
"pagination": {
"limit": limit,
"offset": offset,
"total": statistics["total"]
}
}
)
except Exception as e:
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行记录失败: {str(e)}"
)
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
获取单个工作流执行的详细信息,包括所有节点的执行记录。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 获取节点执行记录
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
return success(
data={
"execution": WorkflowExecution.model_validate(execution),
"node_executions": [
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
]
}
)
except Exception as e:
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行详情失败: {str(e)}"
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
执行工作流并返回结果。支持流式和非流式两种模式。
**非流式模式**:等待工作流执行完成后返回完整结果。
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 准备输入数据
input_data = {
"message": request.message or "",
"variables": request.variables
}
# 执行工作流
if request.stream:
# 流式执行
from fastapi.responses import StreamingResponse
import json
async def event_generator():
"""生成 SSE 事件
SSE 格式:
event: <event_type>
data: <json_data>
支持的事件类型:
- workflow_start: 工作流开始
- workflow_end: 工作流结束
- node_start: 节点开始执行
- node_end: 节点执行完成
- node_chunk: 中间节点的流式输出
- message: 最终消息的流式输出End 节点及其相邻节点)
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield sse_error
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
}
)
else:
# 非流式执行
result = await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=False
)
return success(
data=WorkflowExecutionResponse(
execution_id=result["execution_id"],
status=result["status"],
output=result.get("output"),
output_data=result.get("output_data"),
error_message=result.get("error_message"),
elapsed_time=result.get("elapsed_time"),
token_usage=result.get("token_usage")
),
msg="工作流执行完成"
)
except BusinessException as e:
logger.warning(f"执行工作流失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"执行工作流异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"执行工作流失败: {str(e)}"
)
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
取消正在运行的工作流执行。
**注意**:当前版本仅更新状态为 cancelled实际的执行取消功能待实现。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 检查执行状态
if execution.status not in ["pending", "running"]:
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"无法取消状态为 {execution.status} 的执行"
)
# 更新状态为 cancelled
service.update_execution_status(execution_id, "cancelled")
return success(msg="工作流执行已取消")
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"取消工作流执行失败: {str(e)}"
)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
@@ -95,29 +95,16 @@ def get_workspaces(
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
):
"""创建新的工作空间"""
from app.core.language_utils import get_language_from_header
# 验证并获取语言参数
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
f"language={language}"
)
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user, language=language
)
db=db, workspace=workspace, user=current_user)
api_logger.info(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}"
)
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")

View File

@@ -1,4 +0,0 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 16:24

View File

@@ -1,162 +0,0 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skills: Optional[dict] = None):
"""
初始化中间件
Args:
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
"""
self.skills = skills or {}
self.enabled = self.skills.get('enabled', False)
self.all_skills = self.skills.get('all_skills', False)
self.skill_ids = self.skills.get('skill_ids', [])
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
skill_ids_to_load = []
# 如果启用技能且 all_skills 为 True加载租户下所有激活的技能
if self.enabled and self.all_skills:
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
skill_ids_to_load = [str(skill.id) for skill in skills]
elif self.enabled and self.skill_ids:
skill_ids_to_load = self.skill_ids
if skill_ids_to_load:
for skill_id in skill_ids_to_load:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -7,18 +7,23 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
@@ -29,19 +34,16 @@ logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
is_omni: bool = False,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False
):
"""初始化 LangChain Agent
@@ -54,37 +56,13 @@ class LangChainAgent:
max_tokens: 最大 token 数
system_prompt: 系统提示词
tools: 工具列表(可选,框架自动走 ReAct 循环)
streaming: 是否启用流式输出
max_iterations: 最大迭代次数None 表示自动计算:基础 5 次 + 每个工具 2 次)
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
streaming: 是否启用流式输出(默认 True
"""
self.model_name = model_name
self.provider = provider
self.system_prompt = system_prompt or "你是一个专业的AI助手"
self.tools = tools or []
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
self.last_tool_called: Optional[str] = None
# 根据工具数量动态调整最大迭代次数
# 基础值 + 每个工具额外的调用机会
if max_iterations is None:
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
self.max_iterations = 5 + len(self.tools) * 2
else:
self.max_iterations = max_iterations
self.system_prompt = system_prompt or "你是一个专业的AI助手"
logger.debug(
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
f"tool_count={len(self.tools)}, "
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
f"auto_calculated={max_iterations is None}"
)
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
@@ -92,7 +70,6 @@ class LangChainAgent:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
@@ -109,14 +86,11 @@ class LangChainAgent:
if streaming and hasattr(self._underlying_llm, 'streaming'):
self._underlying_llm.streaming = True
# 包装工具以跟踪连续调用次数
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
# 使用 create_agent 创建 agent graphLangChain 1.x 标准方式)
# 无论是否有工具,都使用 agent 统一处理
self.agent = create_agent(
model=self.llm,
tools=wrapped_tools,
tools=self.tools if self.tools else None,
system_prompt=self.system_prompt
)
@@ -128,92 +102,17 @@ class LangChainAgent:
"has_api_base": bool(api_base),
"temperature": temperature,
"streaming": streaming,
"max_iterations": self.max_iterations,
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
# "tool_count": len(self.tools)
"tool_count": len(self.tools)
}
)
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
"""包装工具以跟踪连续调用次数
Args:
tools: 原始工具列表
Returns:
List[BaseTool]: 包装后的工具列表
"""
from langchain_core.tools import StructuredTool
from functools import wraps
wrapped_tools = []
for original_tool in tools:
tool_name = original_tool.name
original_func = original_tool.func if hasattr(original_tool, 'func') else None
if not original_func:
# 如果无法获取原始函数,直接使用原工具
wrapped_tools.append(original_tool)
continue
# 创建包装函数
def make_wrapped_func(tool_name, original_func):
"""创建包装函数的工厂函数,避免闭包问题"""
@wraps(original_func)
def wrapped_func(*args, **kwargs):
"""包装后的工具函数,跟踪连续调用次数"""
# 检查是否是连续调用同一个工具
if self.last_tool_called == tool_name:
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
else:
# 切换到新工具,重置计数器
self.tool_call_counter[tool_name] = 1
self.last_tool_called = tool_name
current_count = self.tool_call_counter[tool_name]
logger.debug(
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
)
# 检查是否超过最大连续调用次数
if current_count > self.max_tool_consecutive_calls:
logger.warning(
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls})"
f"返回提示信息"
)
return (
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
)
# 调用原始工具函数
return original_func(*args, **kwargs)
return wrapped_func
# 使用 StructuredTool 创建新工具
wrapped_tool = StructuredTool(
name=original_tool.name,
description=original_tool.description,
func=make_wrapped_func(tool_name, original_func),
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None
) -> List[BaseMessage]:
"""准备消息列表
@@ -221,7 +120,6 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表
context: 上下文信息
files: 多模态文件内容列表(已处理)
Returns:
List[BaseMessage]: 消息列表
@@ -244,49 +142,47 @@ class LangChainAgent:
if context:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
# 构建用户消息(支持多模态)
if files and len(files) > 0:
content_parts = self._build_multimodal_content(user_content, files)
messages.append(HumanMessage(content=content_parts))
else:
# 纯文本消息
messages.append(HumanMessage(content=user_content))
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
session_id = store.save_session(
userid=end_user_end,
messages=messages,
apply_id=end_user_end,
group_id=end_user_end,
aimessages=aimessages
)
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
Args:
text: 文本内容
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
Returns:
List[Dict]: 消息内容列表
"""
# 根据 provider 使用不同的文本格式
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
# ModelProvider.GPUSTACK] or (
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
# content_parts = [{"type": "text", "text": text}]
# else:
# # 通义千问等: {"text": "..."}
# content_parts = [{"type": "text", "text": text}]
content_parts = [{"type": "text", "text": text}]
# 添加文件内容
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
content_parts.extend(files)
logger.debug(
f"构建多模态消息: provider={self.provider}, "
f"parts={len(content_parts)}, "
f"files={len(files)}"
)
return content_parts
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
async def chat(
self,
@@ -297,8 +193,7 @@ class LangChainAgent:
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
memory_flag: Optional[bool] = True
) -> Dict[str, Any]:
"""执行对话
@@ -310,7 +205,7 @@ class LangChainAgent:
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat = message
message_chat= message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
@@ -330,11 +225,34 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
"准备调用 LangChain Agent",
@@ -342,86 +260,25 @@ class LangChainAgent:
"has_context": bool(context),
"has_history": bool(history),
"has_tools": bool(self.tools),
"has_files": bool(files),
"message_count": len(messages),
"max_iterations": self.max_iterations
"message_count": len(messages)
}
)
# 统一使用 agent.invoke 调用
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
try:
result = await self.agent.ainvoke(
{"messages": messages},
config={"recursion_limit": self.max_iterations}
)
except RecursionError as e:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
extra={"error": str(e)}
)
# 返回一个友好的错误提示
return {
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
"model": self.model_name,
"elapsed_time": time.time() - start_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
result = await self.agent.ainvoke({"messages": messages})
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
logger.debug(f"AI 消息内容: {msg.content}")
# 处理多模态响应content 可能是字符串或列表
if isinstance(msg.content, str):
content = msg.content
logger.debug(f"提取字符串内容,长度: {len(content)}")
elif isinstance(msg.content, list):
# 多模态响应:提取文本部分
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
text_parts = []
for item in msg.content:
logger.debug(f"处理项: {item}")
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
text_parts.append(text)
logger.debug(f"提取文本: {text[:100]}...")
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
text_parts.append(text)
logger.debug(f"提取文本: {text[:100]}...")
elif isinstance(item, str):
text_parts.append(item)
logger.debug(f"提取字符串: {item[:100]}...")
content = "".join(text_parts)
logger.debug(f"合并后内容长度: {len(content)}")
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
content = msg.content
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
await self.term_memory_save(message_chat,end_user_id,content)
response = {
"content": content,
"model": self.model_name,
@@ -429,7 +286,7 @@ class LangChainAgent:
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
"total_tokens": 0
}
}
@@ -448,16 +305,15 @@ class LangChainAgent:
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
memory_flag: Optional[bool] = True
) -> AsyncGenerator[str, None]:
"""执行流式对话
@@ -491,13 +347,32 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备流式调用has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
f"准备流式调用has_tools={bool(self.tools)}, message_count={len(messages)}"
)
chunk_count = 0
@@ -505,106 +380,47 @@ class LangChainAgent:
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content = ''
full_content=''
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
{"messages": messages},
version="v2"
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
# 处理多模态响应content 可能是字符串或列表
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
full_content+=chunk.content
if chunk and hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content"):
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
if hasattr(chunk, "content") and chunk.content:
full_content+=chunk.content
yield chunk.content
yielded_content = True
elif isinstance(chunk, str):
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
@@ -618,3 +434,5 @@ class LangChainAgent:
logger.info("=" * 80)
logger.info("chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -1,34 +1,14 @@
import json
import os
from pathlib import Path
from typing import Annotated, Any, Dict, Optional
from typing import Any, Dict, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
load_dotenv()
class Settings:
# ========================================================================
# Deployment Mode Configuration
# ========================================================================
# community: 社区版(开源,功能受限)
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -92,29 +72,9 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
# Storage Configuration
STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "local")
# Aliyun OSS Configuration
OSS_ENDPOINT: str = os.getenv("OSS_ENDPOINT", "")
OSS_ACCESS_KEY_ID: str = os.getenv("OSS_ACCESS_KEY_ID", "")
OSS_ACCESS_KEY_SECRET: str = os.getenv("OSS_ACCESS_KEY_SECRET", "")
OSS_BUCKET_NAME: str = os.getenv("OSS_BUCKET_NAME", "")
# AWS S3 Configuration
S3_REGION: str = os.getenv("S3_REGION", "")
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
@@ -130,7 +90,6 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
@@ -157,11 +116,6 @@ class Settings:
if origin.strip()
]
# Language Configuration
# Supported values: "zh" (Chinese), "en" (English)
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@@ -190,45 +144,18 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
# SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Celery Beat Schedule Configuration (定时任务执行频率)
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
# implicit_emotions_update: 每天几分执行分钟0-59
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
@@ -239,37 +166,11 @@ class Settings:
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# model square loading
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
# ========================================================================
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
# Prompt 中最大类型数量
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
# 核心通用类型列表(逗号分隔)
CORE_GENERAL_TYPES: str = os.getenv(
"CORE_GENERAL_TYPES",
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
)
# 实验模式开关(允许通过 API 动态切换本体配置)
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.

View File

@@ -46,7 +46,6 @@ class BizCode(IntEnum):
RESOURCE_ALREADY_EXISTS = 5002
VERSION_ALREADY_EXISTS = 5003
STATE_CONFLICT = 5004
RESOURCE_IN_USE = 5005
# 应用发布6xxx
PUBLISH_FAILED = 6001
@@ -126,7 +125,6 @@ HTTP_MAPPING = {
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.RESOURCE_IN_USE: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,

View File

@@ -1,82 +0,0 @@
# -*- coding: utf-8 -*-
"""语言处理工具模块
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
Functions:
validate_language: 校验语言参数,确保其为有效值
get_language_from_header: 从请求头获取并校验语言参数
"""
from typing import Optional
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# 支持的语言列表
SUPPORTED_LANGUAGES = {"zh", "en"}
# 默认回退语言
DEFAULT_LANGUAGE = "zh"
def validate_language(language: Optional[str]) -> str:
"""
校验语言参数,确保其为有效值。
Args:
language: 待校验的语言代码,可以是 None、"zh""en" 或其他值
Returns:
有效的语言代码("zh""en"
Examples:
>>> validate_language("zh")
'zh'
>>> validate_language("en")
'en'
>>> validate_language("EN") # 大小写不敏感
'en'
>>> validate_language(None) # None 回退到默认值
'zh'
>>> validate_language("fr") # 不支持的语言回退到默认值
'zh'
"""
if language is None:
return DEFAULT_LANGUAGE
# 标准化:转小写并去除空白
lang = str(language).lower().strip()
if lang in SUPPORTED_LANGUAGES:
return lang
logger.warning(
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'"
f"支持的语言: {SUPPORTED_LANGUAGES}"
)
return DEFAULT_LANGUAGE
def get_language_from_header(language_type: Optional[str]) -> str:
"""
从请求头获取并校验语言参数。
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
Args:
language_type: 从 X-Language-Type Header 获取的语言值
Returns:
有效的语言代码("zh""en"
Examples:
>>> get_language_from_header(None) # Header 未传递
'zh'
>>> get_language_from_header("en")
'en'
>>> get_language_from_header("invalid") # 无效值回退
'zh'
"""
return validate_language(language_type)

View File

@@ -38,56 +38,6 @@ class SensitiveDataLoggingFilter(logging.Filter):
return True
class Neo4jSuccessNotificationFilter(logging.Filter):
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
- 00000: 成功完成 (successful completion)
- 00N00: 无数据 (no data)
- 00NA0: 无数据,信息性通知 (no data, informational notification)
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
"""
import re
# 编译正则表达式以提高性能
# 匹配所有"成功/信息性"的 GQL 状态码:
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
# 匹配 status_description 中的成功完成或信息性通知消息
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
def filter(self, record: logging.LogRecord) -> bool:
"""
过滤 Neo4j 成功通知
Args:
record: 日志记录
Returns:
True表示允许记录False表示拒绝过滤掉
"""
# 只处理 INFO 和 WARNING 级别的日志
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
# 对 severity='WARNING' 的通知使用 WARNING 级别
if record.levelno not in (logging.INFO, logging.WARNING):
return True
# 检查是否是 Neo4j 的成功通知
message = str(record.msg)
# 使用正则表达式进行更严格的匹配
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
return False # 过滤掉这条日志
# 保留其他所有日志(包括真正的警告和错误)
return True
class LoggingConfig:
"""全局日志配置类"""
@@ -115,22 +65,6 @@ class LoggingConfig:
# 清除现有处理器
root_logger.handlers.clear()
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
neo4j_filter = Neo4jSuccessNotificationFilter()
# 抑制 Neo4j 通知日志
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler
# 导致日志绕过根 logger 的 filter 直接输出。
# 多管齐下确保过滤生效:
# 1. 设置 neo4j.notifications 级别为 WARNING过滤 INFO 级别的 00NA0 通知)
# 2. 在所有 neo4j logger 上添加 filter过滤 WARNING 级别的成功通知)
# 3. 在根 handler 上也添加 filter兜底
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
neo4j_notifications_logger.setLevel(logging.WARNING)
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
neo4j_logger = logging.getLogger(neo4j_logger_name)
neo4j_logger.addFilter(neo4j_filter)
# 创建格式化器
formatter = logging.Formatter(
fmt=settings.LOG_FORMAT,
@@ -146,7 +80,6 @@ class LoggingConfig:
console_handler.setFormatter(formatter)
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
console_handler.addFilter(sensitive_filter)
console_handler.addFilter(neo4j_filter)
root_logger.addHandler(console_handler)
# 文件处理器(带轮转)
@@ -160,7 +93,6 @@ class LoggingConfig:
file_handler.setFormatter(formatter)
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
file_handler.addFilter(sensitive_filter)
file_handler.addFilter(neo4j_filter)
root_logger.addHandler(file_handler)
cls._initialized = True

View File

@@ -0,0 +1,16 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -4,7 +4,7 @@ LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
#
# __all__ = ["ToolExecutionNode", "create_input_message"]
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -1,16 +0,0 @@
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
def content_input_node(state: ReadState) -> ReadState:
"""开始节点 - 提取内容并保持状态信息"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}
def content_input_write(state: WriteState) -> WriteState:
"""开始节点 - 提取内容并保持状态信息"""
content = state['messages'][0].content if state.get('messages') else ''
# 返回内容并保持所有状态信息
return {"data": content}

View File

@@ -0,0 +1,150 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Dict
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor,
memory_config: MemoryConfig,
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
memory_config: MemoryConfig object containing all configuration
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
# Build tool arguments
tool_args = {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
}
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": tool_args,
"id": tool_call_id
}]
)
]
}

View File

@@ -1,250 +0,0 @@
import json
import os
import time
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""问题处理节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""问题分解节点"""
# 从状态中获取数据
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=content,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# 验证结构化响应
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False)
elif not structured.root:
logger.warning("Split_The_Problem: 结构化响应的root为空")
split_result = json.dumps([], ensure_ascii=False)
else:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
split_result_dict = []
for index, item in enumerate(json.loads(split_result)):
split_data = {
"id": f"Q{index + 1}",
"question": item['extended_question'],
"type": item['type'],
"reason": item['reason']
}
split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = {
"context": split_result,
"original": content,
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": split_result_dict,
"original_query": content
}
}
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"content_length": len(content),
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
# 创建默认的空结果
result = {
"context": json.dumps([], ensure_ascii=False),
"original": content,
"error": str(e),
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": [],
"original_query": content,
"error": error_details
}
}
# 返回更新后的状态包含spit_context字段
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""问题扩展节点"""
# 获取原始数据和分解结果
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
databasets = {}
try:
data = json.loads(data)
for i in data:
databasets[i['extended_question']] = i['type']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Problem_Extension: 数据解析失败: {e}")
# 使用空字典作为fallback
databasets = {}
data = []
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=databasets,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# 验证结构化响应
if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {}
elif not response_content.root:
logger.warning("Problem_Extension: 结构化响应的root为空")
aggregated_dict = {}
else:
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
try:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}")
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as item_error:
logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}")
continue
logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组")
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
# 提供更详细的错误信息
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"questions_count": len(databasets),
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Problem_Extension error details: {error_details}")
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"title": "问题扩展",
"data": aggregated_dict,
"original_query": content,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return {"problem_extension": result}

View File

@@ -1,411 +0,0 @@
# ===== 标准库 =====
import asyncio
import json
import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__)
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState:
memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id
# 使用现有的 memory_config 而不是重新查询数据库
# 或者使用线程安全的数据库访问
with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return result_pydantic
async def clean_databases(data) -> str:
"""
简化的数据库搜索结果清理函数
Args:
data: 搜索结果数据
Returns:
清理后的内容字符串
"""
try:
# 解析JSON字符串
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
return data
if not isinstance(data, dict):
return str(data)
# 获取结果数据
# with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data)
if not isinstance(results, dict):
return str(results)
# 收集所有内容
content_list = []
# 处理重排序结果
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
# 处理时间搜索结果
time_search = results.get('time_search', {})
if time_search:
if isinstance(time_search, dict):
statements = time_search.get('statements', time_search.get('time_search', []))
if isinstance(statements, list):
content_list.extend(statements)
elif isinstance(time_search, list):
content_list.extend(time_search)
# 提取文本内容
text_parts = []
for item in content_list:
if isinstance(item, dict):
text = item.get('statement') or item.get('content', '')
if text:
text_parts.append(text)
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
logger.error(f"clean_databases failed: {e}", exc_info=True)
return str(data)
async def retrieve_nodes(state: ReadState) -> ReadState:
'''
模型信息
'''
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# 创建异步任务处理单个问题
async def process_question_nodes(idx, question):
try:
# Prepare search parameters based on storage type
search_params = {
"end_user_id": end_user_id,
"question": question,
"return_raw_results": True
}
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
else:
clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search(
**search_params, memory_config=memory_config
)
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# 并发处理所有问题
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState:
# 从state中获取end_user_id
import time
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
databases_anser = []
async def get_llm_info():
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
api_base = api_key_obj.api_base
model_name = api_key_obj.model_name
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
base_url=api_base,
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = {"end_user_id": end_user_id, "return_raw_results": True}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# 创建异步任务处理单个问题
import asyncio
# 在模块级别定义信号量,限制最大并发数
SEMAPHORE = asyncio.Semaphore(5) # 限制最多5个并发数据库操作
async def process_question(idx, question):
async with SEMAPHORE: # 限制并发
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# 使用 asyncio 在线程池中运行同步的 agent.invoke
import asyncio
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: agent.invoke({"messages": question})
)
tool_results = extract_tool_message_content(response)
if tool_results == None:
raw_results = []
clean_content = ''
else:
raw_results = tool_results['content']
clean_content = await clean_databases(raw_results)
try:
raw_results = raw_results['results']
except Exception:
raw_results = []
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# 并发处理所有问题
import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
# with open('retrieve_text.json', 'w') as f:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -1,371 +0,0 @@
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class SummaryNodeService(LLMServiceMixin):
"""总结节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
summary_service = SummaryNodeService()
async def rag_config(state):
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState:
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
增强的summary_llm函数包含更好的错误处理和数据验证
"""
data = state.get("data", '')
# 构建系统提示词
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
data=retrieve_info,
query=data
)
else:
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
query=data,
history=history,
retrieve_info=retrieve_info
)
try:
# 使用优化的LLM服务进行结构化输出
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
# 验证结构化响应
if structured is None:
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# 根据操作类型提取答案
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else:
# 处理RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# 验证答案不为空
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
return aimessages
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# 尝试非结构化输出作为fallback
try:
logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple(
state=state,
db_session=db_session,
system_prompt=system_prompt,
fallback_message="信息不足,无法回答"
)
if response and response.strip():
# 简单清理响应
cleaned_response = response.strip()
# 移除可能的JSON标记
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response
else:
return "信息不足,无法回答"
except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
user_id=end_user_id,
query=data,
apply_id=end_user_id,
end_user_id=end_user_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": data,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
retrieve = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"title": "快速检索",
"summary": aimessages,
"query": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState:
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history(state)
search_params = {
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True,
"include": ["summaries"] # Only search summary nodes for faster performance
}
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params,
memory_config=memory_config)
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
# 'input_summary',RetrieveSummaryResponse)
# logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0]
except Exception as e:
logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary = {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary": summary}
async def Retrieve_Summary(state: ReadState) -> ReadState:
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve = retrieve.get("Expansion_issue", [])
start = time.time()
retrieve_info_str = []
for data in retrieve:
if data == '':
retrieve_info_str = ''
else:
for key, value in data.items():
if key == 'Answer_Small':
for i in value:
retrieve_info_str.append(i)
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary(state: ReadState) -> ReadState:
start = time.time()
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
history = await summary_history(state)
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
try:
duration = time.time() - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# 修复协程调用 - 先await然后访问返回值
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
return {"summary": result}

View File

@@ -0,0 +1,234 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_content_payload,
extract_tool_call_id,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
memory_config: MemoryConfig object containing all configuration
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type: str,
user_rag_memory_id: str,
memory_config: MemoryConfig,
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
storage_type: Storage type for the workspace
user_rag_memory_id: User RAG memory identifier
memory_config: MemoryConfig object containing all configuration
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
self.memory_config = memory_config
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
)
# Log raw message content for debugging
if hasattr(last_message, 'content'):
raw = last_message.content
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
memory_config=self.memory_config,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id,
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Check for error in tool response
error_entry = None
if result and "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'content'):
try:
import json
content = msg.content
if isinstance(content, str):
parsed = json.loads(content)
if isinstance(parsed, dict) and "error" in parsed:
error_msg = parsed["error"]
logger.warning(
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
)
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
except (json.JSONDecodeError, TypeError):
pass
# Return result with error tracking if error was found
if error_entry:
result["errors"] = [error_entry]
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Track error in state and return error message
from langchain_core.messages import ToolMessage
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
],
"errors": [error_entry]
}

View File

@@ -1,162 +0,0 @@
import asyncio
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""验证节点服务类"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# 创建全局服务实例
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""处理验证结果并生成输出格式"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# 将 VerificationItem 对象转换为字典列表
verified_data = []
if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue:
if hasattr(item, 'model_dump'):
verified_data.append(item.model_dump())
elif isinstance(item, dict):
verified_data.append(item)
Verify_result = {
"status": messages_deal.split_result,
"verified_data": verified_data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.split_result,
"reason": messages_deal.reason or "验证完成",
"query": messages_deal.query,
"verified_count": len(verified_data),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return Verify_result
async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})
logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
logger.info("Verify: 开始渲染模板")
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages,
json_schema=json_schema
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# 添加 asyncio.wait_for 超时包裹,防止无限等待
# 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150秒超时
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
structured = VerificationResult(
query=content,
history=history if isinstance(history, list) else [],
expansion_issue=[],
split_result="failed",
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
return {
"verify": {
"status": "failed",
"verified_data": [],
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": "failed",
"reason": f"验证过程出错: {str(e)}",
"query": state.get('data', ''),
"verified_count": 0,
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}

View File

@@ -1,67 +0,0 @@
from app.cache.memory.interest_memory import InterestMemoryCache
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState:
"""
Write data to the database/file system.
Args:
state: WriteState containing messages, end_user_id, memory_config, and language
Returns:
dict: Contains 'write_result' with status and data fields
"""
messages = state.get('messages', [])
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '')
language = state.get('language', 'zh') # 默认中文
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result = await write(
messages=structured_messages,
end_user_id=end_user_id,
memory_config=memory_config,
language=language,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id=end_user_id,
language=lang,
)
if deleted:
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
write_result = {
"status": "success",
"data": structured_messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
return {"write_result": write_result}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
write_result = {
"status": "error",
"message": str(e),
}
return {"write_result": write_result}

View File

@@ -1,179 +1,469 @@
#!/usr/bin/env python3
import json
import os
import re
import time
import warnings
from contextlib import asynccontextmanager
from typing import Literal
from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.nodes import (
ToolExecutionNode,
create_input_message,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
logger = get_agent_logger(__name__)
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Split_The_Problem,
Problem_Extension,
)
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
Retrieve_Summary,
Summary_fails,
Summary,
)
from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify
from app.core.memory.agent.langgraph_graph.routing.routers import (
Split_continue,
Retrieve_continue,
Verify_continue,
)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# Update loop count in workflow
async def update_loop_count(state):
"""Update loop counter"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# Add boundary check
if not messages:
return END
counter.add(1) # Increment by 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] Current loop count: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
counter.reset()
return "Summary" # Default based on business requirements
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Add default return value to avoid returning None
return 'Retrieve_Summary' # Default based on business logic
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # Default case
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
if self.tool_name == 'Input_Summary':
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
else:
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# Try to extract actual content payload from previous tool result
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# Fallback: use raw string directly
extracted_payload = raw_msg
# Try to parse content as JSON first
try:
content = json.loads(extracted_payload)
except Exception:
# Try to extract JSON fragment from text and parse
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# If still fails, use raw string as content
content = parsed if parsed is not None else extracted_payload
# Build correct parameters based on tool name
tool_args = {}
if self.tool_name == "Verify":
# Verify tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary_fails tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == 'Input_Summary':
tool_args["context"] = str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name == 'Retrieve_Summary':
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# Other tools use context parameter
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph():
"""创建并返回 LangGraph 工作流"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
# workflow.add_node("Retrieve", retrieve_nodes)
workflow.add_node("Retrieve", retrieve)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
# 添加边
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
'''-----'''
# workflow.add_edge("Retrieve", END)
# 编译工作流
graph = workflow.compile()
yield graph
except Exception as e:
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")
async def main():
"""主函数 - 运行工作流"""
message = "昨天有什么好看的电影"
end_user_id = '88a459f5_text09' # 组ID
storage_type = 'neo4j' # 存储类型
search_switch = '1' # 搜索开关
user_rag_memory_id = 'wwwwwwww' # 用户RAG记忆ID
# 获取数据库会话
db_session = next(get_db())
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=17, # 改为整数
service_name="MemoryAgentService"
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
"""
Create a read graph workflow for memory operations.
Args:
namespace: Namespace identifier
tools: MCP tools loaded from session
search_switch: Search mode switch ("0", "1", or "2")
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type (optional)
user_rag_memory_id: User RAG memory ID (optional)
"""
memory = InMemorySaver()
tool = [i.name for i in tools]
logger.info(f"Initializing read graph with tools: {tool}")
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
import time
start = time.time()
try:
async with make_read_graph() as graph:
config = {"configurable": {"thread_id": end_user_id}}
# 初始状态 - 包含所有必要字段
initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch,
"end_user_id": end_user_id
, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id,
"memory_config": memory_config}
# 获取节点更新信息
_intermediate_outputs = []
summary = ''
async for update_event in graph.astream(
initial_state,
stream_mode="updates",
config=config
):
for node_name, node_data in update_event.items():
print(f"处理节点: {node_name}")
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# 处理不同Summary节点的返回结构
if 'Summary' in node_name:
if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']:
summary = node_data['InputSummary']['summary_result']
elif 'RetrieveSummary' in node_data and 'summary_result' in node_data['RetrieveSummary']:
summary = node_data['RetrieveSummary']['summary_result']
elif 'summary' in node_data and 'summary_result' in node_data['summary']:
summary = node_data['summary']['summary_result']
elif 'SummaryFails' in node_data and 'summary_result' in node_data['SummaryFails']:
summary = node_data['SummaryFails']['summary_result']
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
spit_data = node_data.get('spit_data', {}).get('_intermediate', None)
if spit_data and spit_data != [] and spit_data != {}:
_intermediate_outputs.append(spit_data)
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Problem_Extension 节点
problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None)
if problem_extension and problem_extension != [] and problem_extension != {}:
_intermediate_outputs.append(problem_extension)
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Retrieve 节点
retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None)
if retrieve_node and retrieve_node != [] and retrieve_node != {}:
_intermediate_outputs.extend(retrieve_node)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Verify 节点
verify_n = node_data.get('verify', {}).get('_intermediate', None)
if verify_n and verify_n != [] and verify_n != {}:
_intermediate_outputs.append(verify_n)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
# Summary 节点
summary_n = node_data.get('summary', {}).get('_intermediate', None)
if summary_n and summary_n != [] and summary_n != {}:
_intermediate_outputs.append(summary_n)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
# # 过滤掉空值
# _intermediate_outputs = [item for item in _intermediate_outputs if item and item != [] and item != {}]
#
# # 优化搜索结果
# print("=== 开始优化搜索结果 ===")
# optimized_outputs = merge_multiple_search_results(_intermediate_outputs)
# result=reorder_output_results(optimized_outputs)
# # 保存优化后的结果到文件
# with open('_intermediate_outputs_optimized.json', 'w', encoding='utf-8') as f:
# import json
# f.write(json.dumps(result, indent=4, ensure_ascii=False))
#
print(f"=== 最终摘要 ===")
print(summary)
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor,
memory_config=memory_config,
)
except Exception as e:
import traceback
traceback.print_exc()
finally:
db_session.close()
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
end = time.time()
print(100 * 'y')
print(f"总耗时: {end - start}s")
print(100 * 'y')
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
if __name__ == "__main__":
import asyncio
asyncio.run(main())
graph = workflow.compile(checkpointer=memory)
yield graph

View File

@@ -0,0 +1,13 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -1,61 +1,123 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = logging.getLogger(__name__)
logger = get_agent_logger(__name__)
# Global counter for Verify routing
counter = COUNTState(limit=3)
def Split_continue(state:ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status=state.get('verify', '')['status']
# loop_count = counter.get_total()
if "success" in status:
# counter.reset()
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Args:
state: LangGraph state containing messages
Returns:
Next node name as Literal type
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
elif "failed" in status:
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
# counter.reset()
return "Summary_fails"
# Increment counter
counter.add(1)
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
# counter.reset()
return "Summary" # Default based on business requirements
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'

View File

@@ -1,238 +0,0 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.langgraph_graph.write_graph import make_write_graph, long_term_storage
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context, get_db
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
# RAG 模式:组合消息为字符串格式(保持原有逻辑)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(storage_type, end_user_id, user_message, ai_message, user_rag_memory_id, actual_end_user_id,
actual_config_id, long_term_messages=[]):
"""
写入记忆(支持结构化消息)
Args:
storage_type: 存储类型 (neo4j/rag)
end_user_id: 终端用户ID
user_message: 用户消息内容
ai_message: AI 回复内容
user_rag_memory_id: RAG 记忆ID
actual_end_user_id: 实际用户ID
actual_config_id: 配置ID
逻辑说明:
- RAG 模式:组合 user_message 和 ai_message 为字符串格式,保持原有逻辑不变
- Neo4j 模式:使用结构化消息列表
1. 如果 user_message 和 ai_message 都不为空:创建配对消息 [user, assistant]
2. 如果只有 user_message创建单条用户消息 [user](用于历史记忆场景)
3. 每条消息会被转换为独立的 Chunk保留 speaker 字段
"""
db = next(get_db())
try:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j 模式:使用结构化消息列表
structured_messages = []
# 始终添加用户消息(如果不为空)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# 只有当 AI 回复不为空时才添加 assistant 消息
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# 如果提供了 long_term_messages使用它替代 structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# 如果是 JSON 字符串,先解析
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# 如果没有消息,直接返回
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: 用户ID
structured_messages, # message: JSON 字符串格式的消息列表
str(actual_config_id), # config_id: 配置ID字符串
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG记忆IDNeo4j模式下不使用
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
finally:
db.close()
async def term_memory_save(long_term_messages,actual_config_id,end_user_id,type,scope):
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type==AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data)==scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
'''根据窗口'''
async def window_dialogue(end_user_id,langchain_messages,memory_config,scope):
'''
根据窗口获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
langchain_messages原始数据LIST
scope窗口大小
'''
scope=scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# 获取 config_id如果 memory_config 是对象,提取 config_id否则直接使用
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
config_id, formatted_messages)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""根据时间"""
async def memory_long_term_storage(end_user_id,memory_config,time):
'''
根据时间获取redis数据,写入neo4j
Args:
end_user_id: 终端用户ID
memory_config: 内存配置对象
'''
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = (long_time_data)
messages=[]
memory_config=memory_config.config_id
for i in format_messages:
message=json.loads(i['Query'])
messages+= message
if format_messages!=[]:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
'''聚合判断'''
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
聚合判断函数:判断输入句子和历史消息是否描述同一事件
Args:
end_user_id: 终端用户ID
ori_messages: 原始消息列表,格式如 [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: 内存配置对象
"""
try:
# 1. 获取历史会话数据(使用新方法)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -0,0 +1,13 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -0,0 +1,179 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(raw_content, list):
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
raw_content = block.get('text', '')
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
break
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
parsed = json.loads(raw_content)
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
parsed = json.loads(candidate)
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
return raw_content

View File

@@ -1,321 +0,0 @@
import asyncio
import json
from datetime import datetime, timedelta
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
def extract_tool_message_content(response):
"""从agent响应中提取ToolMessage内容和工具名称"""
messages = response.get('messages', [])
for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# 这是一个ToolMessage
tool_content = message.content
tool_name = None
# 尝试获取工具名称
if hasattr(message, 'name'):
tool_name = message.name
elif hasattr(message, 'tool_name'):
tool_name = message.tool_name
try:
# 解析JSON内容
parsed_content = json.loads(tool_content)
return {
'tool_name': tool_name,
'content': parsed_content
}
except json.JSONDecodeError:
# 如果不是JSON格式直接返回内容
return {
'tool_name': tool_name,
'content': tool_content
}
return None
class TimeRetrievalInput(BaseModel):
"""时间检索工具的输入模式"""
context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str):
"""
创建一个带有特定end_user_id的TimeRetrieval工具同步版本用于按时间范围搜索语句(Statements)
"""
def clean_temporal_result_fields(data):
"""
清理时间搜索结果中不需要的字段,并修改结构
Args:
data: 要清理的数据
Returns:
清理后的数据
"""
# 需要过滤的字段列表
fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# 将 statements: {"statements": [...]} 改为 time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value)
# 进一步将内部的 statements 改为 time_search
if 'statements' in cleaned_value:
cleaned['results'] = {
'time_search': cleaned_value['statements']
}
else:
cleaned['results'] = cleaned_value
elif key not in fields_to_remove:
cleaned[key] = clean_temporal_result_fields(value)
return cleaned
elif isinstance(data, list):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None, end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
优化的时间检索工具,只结合时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询上下文内容
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- end_user_id_param: 组ID可选用于覆盖默认组ID
- clean_output: 是否清理输出中的元数据字段
-end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
# 使用传入的参数或默认值
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# 基本时间搜索
results = await search_by_temporal(
end_user_id=actual_end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
@tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None, clean_output: bool = True) -> str:
"""
优化的关键词时间检索工具,结合关键词和时间范围搜索(同步版本),自动过滤不需要的元数据字段
显式接收参数:
- context: 查询内容
- days_back: 向前搜索的天数默认7天
- start_date: 开始时间可选格式YYYY-MM-DD
- end_date: 结束时间可选格式YYYY-MM-DD
- clean_output: 是否清理输出中的元数据字段
- end_date 需要根据用户的描述获取结束的时间输出格式用strftime("%Y-%m-%d")
"""
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# 关键词时间搜索
results = await search_by_keyword_temporal(
query_text=context,
end_user_id=end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
创建混合检索工具使用run_hybrid_search进行混合检索优化输出格式并过滤不需要的字段
Args:
memory_config: 内存配置对象
**search_params: 搜索参数包含end_user_id, limit, include等
"""
def clean_result_fields(data):
"""
递归清理结果中不需要的字段
Args:
data: 要清理的数据(可能是字典、列表或其他类型)
Returns:
清理后的数据
"""
# 需要过滤的字段列表
# TODO: fact_summary 功能暂时禁用,待后续开发完善后启用
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'id', 'apply_id',
'user_id', 'statement_ids', 'updated_at',"chunk_ids" ,"fact_summary"
}
if isinstance(data, dict):
# 对字典进行清理
cleaned = {}
for key, value in data.items():
if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # 递归清理嵌套数据
return cleaned
elif isinstance(data, list):
# 对列表中的每个元素进行清理
return [clean_result_fields(item) for item in data]
else:
# 其他类型直接返回
return data
@tool
async def HybridSearch(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
clean_output: bool = True # 新增:是否清理输出字段
) -> str:
"""
优化的混合检索工具,支持关键词、向量和混合搜索,自动过滤不需要的元数据字段
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
rerank_alpha: 重排序权重参数
use_forgetting_rerank: 是否使用遗忘重排序
use_llm_rerank: 是否使用LLM重排序
clean_output: 是否清理输出中的元数据字段
"""
try:
# 导入run_hybrid_search函数
from app.core.memory.src.search import run_hybrid_search
# 合并参数,优先使用传入的参数
final_params = {
"query_text": context,
"search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities"]),
"output_path": None, # 不保存到文件
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank
}
# 执行混合检索
raw_results = await run_hybrid_search(**final_params)
# 清理结果中不需要的字段
if clean_output:
cleaned_results = clean_result_fields(raw_results)
else:
cleaned_results = raw_results
# 格式化返回结果
formatted_results = {
"search_query": context,
"search_type": search_type,
"results": cleaned_results
}
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
"search_query": context,
"search_type": search_type,
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
创建同步版本的混合检索工具,优化输出格式并过滤不需要的字段
Args:
memory_config: 内存配置对象
**search_params: 搜索参数
"""
@tool
def HybridSearchSync(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
优化的混合检索工具(同步版本),自动过滤不需要的元数据字段
Args:
context: 查询内容
search_type: 搜索类型 ('keyword', 'embedding', 'hybrid')
limit: 结果数量限制
end_user_id: 组ID用于过滤搜索结果
clean_output: 是否清理输出中的元数据字段
"""
async def _async_search():
# 创建异步工具并执行
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({
"context": context,
"search_type": search_type,
"limit": limit,
"end_user_id": end_user_id,
"clean_output": clean_output
})
return asyncio.run(_async_search())
return HybridSearchSync

View File

@@ -1,72 +0,0 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list,type:str='string'):
"""
格式化解析消息列表
Args:
messages: 消息列表
type: 返回类型 ('string''dict')
Returns:
格式化后的消息列表
"""
result = []
user=[]
ai=[]
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role=="user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict" :
if role == 'human' or role=="user":
user.append( content)
else:
ai.append(content)
if type == "dict":
for key,values in zip(user,ai):
result.append({key:values})
return result
async def messages_parse(messages: list | dict):
user=[]
ai=[]
database=[]
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def agent_chat_messages(user_content,ai_content):
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -3,18 +3,17 @@ import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
@@ -22,83 +21,60 @@ if sys.platform.startswith("win"):
@asynccontextmanager
async def make_write_graph():
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
logger.info("Loading MCP tools: %s", [t.name for t in tools])
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_write_tool:
logger.error("Data_write tool not found", exc_info=True)
raise ValueError("Data_write tool not found")
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
# Call Data_write directly with memory_config
write_params = {
"content": content,
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id,
"memory_config": memory_config,
}
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
result_content = write_result.get("data", str(write_result))
else:
result_content = str(write_result)
logger.info("Write content: %s", result_content)
return {"messages": [AIMessage(content=result_content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type:str="chunk",langchain_messages:list=[],memory_config:str='',end_user_id:str='',scope:int=6):
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue,aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, (langchain_messages))
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type=='chunk':
'''方案一:对话窗口6轮对话'''
await window_dialogue(end_user_id,langchain_messages,memory_config,scope)
if long_term_type=='time':
"""时间"""
await memory_long_term_storage(end_user_id, memory_config,5)
if long_term_type=='aggregate':
"""方案三:聚合判断"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type,end_user_id,message_chat,aimessages,user_rag_memory_id,actual_config_id):
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI 回复写入(用户消息和 AI 回复配对,一次性写入完整对话)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
# from app.core.memory.agent.mcp_server.server import (
# mcp,
# initialize_context,
# main,
# get_context_resource
# )
# # Import tools to register them (but don't export them)
# from app.core.memory.agent.mcp_server import tools
# __all__ = [
# 'mcp',
# 'initialize_context',
# 'main',
# 'get_context_resource',
# ]

View File

@@ -0,0 +1,11 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -0,0 +1,14 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -0,0 +1,159 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import store
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Note: LLM client is NOT loaded at server startup
# It should be loaded dynamically when needed, with config_id passed explicitly
# to make_write_graph or make_read_graph functions
logger.info("LLM client will be loaded dynamically with config_id when needed")
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
logger.info("Starting MCP server initialization")
# Initialize context resources
initialize_context()
# Import and register tools (imports trigger tool registration)
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
data_tools,
problem_tools,
retrieval_tools,
summary_tools,
verification_tools,
)
# Tools are registered via imports above
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Configure DNS rebinding protection for Docker container compatibility
from mcp.server.fastmcp.server import TransportSecuritySettings
# Disable DNS rebinding protection to allow Docker container hostnames
# This allows containers to connect using service names like 'mcp-server'
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Run the server with SSE transport for HTTP connections
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -4,19 +4,22 @@ Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
@@ -24,9 +27,10 @@ class ParameterBuilder:
tool_call_id: str,
search_switch: str,
apply_id: str,
end_user_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
@@ -44,7 +48,8 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,18 +60,19 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"end_user_id": end_user_id
"group_id": group_id,
"memory_config": memory_config,
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
# Verify expects dict context
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
# These tools expect dict context
return {
"context": content if isinstance(content, dict) else {},
"context": content if isinstance(content, dict) else {"content": content},
**base_args
}

View File

@@ -4,21 +4,31 @@ Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import List, Tuple, Optional
from typing import TYPE_CHECKING, List, Optional, Tuple
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self):
"""Initialize the search service."""
def __init__(self, memory_config: "MemoryConfig" = None):
"""
Initialize the search service.
Args:
memory_config: Optional MemoryConfig for embedding model configuration.
If not provided, must be passed to execute_hybrid_search.
"""
self.memory_config = memory_config
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
@@ -91,51 +101,61 @@ class SearchService:
async def execute_hybrid_search(
self,
end_user_id: str,
group_id: str,
question: str,
limit: int = 5,
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.4,
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config = None
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search and return clean content.
Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Args:
end_user_id: Group identifier for filtering results
group_id: Group identifier for filtering
question: Search query text
limit: Maximum number of results to return (default: 5)
search_type: Type of search - "hybrid", "keyword", or "embedding" (default: "hybrid")
include: List of result types to include (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: Weight for BM25 scores in reranking (default: 0.4)
output_path: Path to save search results (default: "search_results.json")
return_raw_results: If True, also return the raw search results as third element (default: False)
memory_config: Memory configuration object (required)
limit: Max results per category (default: 15)
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: BM25 weight (default: 0.6)
activation_boost_factor: Activation impact on memory strength (default: 0.8)
output_path: JSON output path (default: "search_results.json")
return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
Returns:
Tuple of (clean_content, cleaned_query, raw_results)
raw_results is None if return_raw_results=False
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Use provided memory_config or fall back to instance config
config = memory_config or self.memory_config
if not config:
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search
# Execute search using memory_config
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
end_user_id=end_user_id,
group_id=group_id,
limit=limit,
include=include,
output_path=output_path,
memory_config=memory_config,
rerank_alpha=rerank_alpha
memory_config=config,
rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
)
# Extract results based on search type and include parameter
@@ -186,7 +206,7 @@ class SearchService:
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{end_user_id}': {e}",
f"Search failed for query '{question}' in group '{group_id}': {e}",
exc_info=True
)
# Return empty results on failure

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
end_user_id: str
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {end_user_id}: {e}",
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
end_user_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
end_user_id=end_user_id,
group_id=group_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- end_user_id
- group_id
- messages
- aimessages

View File

@@ -3,22 +3,12 @@ Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from jinja2 import (
Environment,
FileSystemLoader,
Template,
TemplateNotFound,
)
from app.core.logging_config import (
get_agent_logger,
log_prompt_rendering,
)
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)

View File

@@ -0,0 +1,27 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -0,0 +1,155 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.retrieval_models import (
DistinguishTypeResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str,
memory_config: MemoryConfig,
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
memory_config: MemoryConfig object containing LLM configuration
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
# Get LLM client from memory_config using factory pattern
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data - clients are constructed inside write() from memory_config
await write(
content=content,
user_id=user_id,
apply_id=apply_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
return {
"status": "error",
"message": str(e),
}

View File

@@ -0,0 +1,304 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownResponse,
ProblemExtensionResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info("Problem splitting")
logger.info(f"Problem split result: {split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem splitting', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem extension', duration)

View File

@@ -0,0 +1,294 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import (
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
logger.info(f"Retrieve: context keys={list(context.keys())}")
content, original = await Retriev_messages_deal(context)
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
if not original:
logger.warning(f"Retrieve: original query is empty! context={context}")
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)

View File

@@ -0,0 +1,640 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import os
import re
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
start_time= time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after verification: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Summary', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info("Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info("Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
start_time=time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
start_time=time.time()
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
end_time=time.time()
logger.info(f"Input_Summary-REDIS搜索{end_time - start_time}")
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
# Retrieval
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
logger.info("Input_Summary: Using summary for retrieval")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info("没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -0,0 +1,174 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Retrieve_verify_tool_messages_deal,
Verify_messages_deal,
)
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.schemas.memory_config_schema import MemoryConfig
from jinja2 import Template
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small, strict=False):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow with LLM model ID from memory_config
verify_tool = VerifyTool(
system_prompt=system_prompt,
verify_data=messages,
llm_model_id=str(memory_config.llm_model_id)
)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"Verification result: {messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Verification', duration)

View File

@@ -1,32 +0,0 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationItem(BaseModel):
"""Individual verification item for a query-answer pair."""
query_small: str = Field(..., description="子问题")
answer_small: str = Field(..., description="子问题的回答")
status: str = Field(..., description="验证状态True 或 False")
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str = Field(..., description="原始查询问题")
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
expansion_issue: List[VerificationItem] = Field(
default_factory=list,
description="验证后的数据列表,包含所有通过验证的问答对"
)
split_result: str = Field(
...,
description="验证结果状态successexpansion_issue 非空)或 failedexpansion_issue 为空)"
)
reason: Optional[str] = Field(
None,
description="验证结果的说明和分析"
)

View File

@@ -1,28 +0,0 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,114 @@
import os
import sys
import traceback
import requests
# from qcloud_cos import CosConfig, CosS3Client
# from qcloud_cos.cos_exception import CosClientError, CosServiceError
# from config.paths import BASE_DIR
BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
class OSSUploader:
"""对象存储文件上传工具类"""
def __init__(self, env):
api = {
"test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon",
"prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon"
}
self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon")
self.privacy = "false"
self.headers = {
"User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
'AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/133.0.6833.84 Safari/537.36'
}
@staticmethod
def _generate_object_key(file_path, prefix='xhs_'):
"""
生成对象存储的Key
:param file_path: 本地文件路径
:param prefix: 存储前缀,用于分类存储
:return: 生成的对象Key
"""
# 文件md5值.后缀名
filename = os.path.basename(file_path)
filename = f"{filename}"
# 组合成完整的对象Key
return f"{prefix}{filename}"
def upload_image(self, file_name, prefix='jd_'):
"""
上传文件到COS并返回可访问的URL
:param file_url: 文件路径
:param file_name: 文件名称
:param media_type: 文件类型
:param prefix: 存储前缀,用于分类存储
:return: 文件访问URL
"""
# 检查文件是否存在
file_path = os.path.join(BASE_DIR, file_name)
# response = requests.get(url, headers=self.headers, stream=True)
# if response.status_code == 200:
# with open(file_path, "wb") as f:
# for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大
# f.write(chunk)
# else:
# raise Exception(f"文件下载失败,{file_name}")
# 生成对象Key
object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1])
try:
upload_response = requests.post(
self.api,
data={
"privacy": self.privacy,
"fileName": object_key,
}
)
if upload_response.status_code != 200:
raise Exception('上传接口请求失败')
resp = upload_response.json()
name = resp["data"]["name"]
file_url = resp["data"]["path"]
policy = resp["data"]["policy"]
with open(file_path, 'rb') as f:
oss_push_resp = requests.post(
policy["host"],
files={
"key": policy["dir"],
"OSSAccessKeyId": policy["accessid"],
"name": name,
"policy": policy["policy"],
"success_action_status": 200,
"signature": policy["signature"],
"file": f,
}
)
if oss_push_resp.status_code == 200:
return file_url
raise Exception("OSS上传失败")
except Exception:
raise Exception(f"上传失败: \n{traceback.format_exc()}")
finally:
print('success')
# os.remove(file_path)
if __name__ == '__main__':
cos_uploader = OSSUploader("prod")
url =cos_uploader.upload_image('./example01.jpg')
print(url)

View File

@@ -0,0 +1,121 @@
import asyncio
import re
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_, picture_model_requests,Picture_recognize, Voice_recognize
from app.core.memory.agent.utils.messages_tool import read_template_file
import requests
import json
import os
import time
# file_urls = [
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_female2.wav",
# "https://dashscope.oss-cn-beijing.aliyuncs.com/samples/audio/paraformer/hello_world_male2.wav",
# ]
class Vico_recognition:
def __init__(self,file_urls):
self.api_key=''
self.backend_model_name=''
self.api_base=''
self.file_urls=file_urls
# 提交文件转写任务包含待转写文件url列表
async def submit_task(self) -> str:
self.api_key, self.backend_model_name, self.api_base =await Voice_recognize()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
data = {
"model": self.backend_model_name,
"input": {"file_urls": self.file_urls},
"parameters": {
"channel_id": [0],
"vocabulary_id": "vocab-Xxxx",
},
}
# 录音文件转写服务url
service_url = (
"https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription"
)
response = requests.post(
service_url, headers=headers, data=json.dumps(data)
)
# 打印响应内容
if response.status_code == 200:
return response.json()["output"]["task_id"]
else:
print("task failed!")
print(response.json())
return None
async def download_transcription_result(self, transcription_url):
"""
Args:
transcription_url (str): 转写结果文件URL
Returns:
dict: 转写结果内容
"""
try:
response = requests.get(transcription_url)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"下载转写结果失败: {e}")
return None
# 循环查询任务状态直到成功
async def wait_for_complete(self,task_id):
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"X-DashScope-Async": "enable",
}
pending = True
while pending:
# 查询任务状态服务url
service_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}"
response = requests.post(
service_url, headers=headers
)
if response.status_code == 200:
status = response.json()['output']['task_status']
if status == 'SUCCEEDED':
print("task succeeded!")
pending = False
return response.json()['output']['results']
elif status == 'RUNNING' or status == 'PENDING':
pass
else:
print("task failed!")
pending = False
else:
print("query failed!")
pending = False
time.sleep(0.1)
async def run(self):
self.api_key, self.backend_model_name, self.api_base = await Voice_recognize()
task_id=await self.submit_task()
result=await self.wait_for_complete(task_id)
result_context=[]
for i in result:
transcription_url=i['transcription_url']
print(f"转写URL: {transcription_url}")
# 下载并打印转写内容
content = await self.download_transcription_result(transcription_url)
if content:
content=json.dumps(content, indent=2, ensure_ascii=False)
context=re.findall(r'"text": "(.*?)"', content)
result_context.append(context[0])
result=''.join(result_context)
return (result)

View File

@@ -1,277 +0,0 @@
"""
优化的LLM服务类用于压缩和统一LLM调用
"""
import asyncio
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel
from sqlalchemy.orm import Session
from app.core.logging_config import get_agent_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.memory.llm_tools.openai_client import OpenAIClient
T = TypeVar('T', bound=BaseModel)
logger = get_agent_logger(__name__)
class OptimizedLLMService:
"""
优化的LLM服务类提供统一的LLM调用接口
特性:
1. 客户端复用 - 避免重复创建LLM客户端
2. 批量处理 - 支持并发处理多个请求
3. 错误处理 - 统一的错误处理和降级策略
4. 性能优化 - 缓存和连接池优化
"""
def __init__(self, db_session: Session):
self.db_session = db_session
self.client_factory = MemoryClientFactory(db_session)
self._client_cache: Dict[str, OpenAIClient] = {}
def _get_cached_client(self, llm_model_id: str) -> OpenAIClient:
"""获取缓存的LLM客户端避免重复创建"""
if llm_model_id not in self._client_cache:
self._client_cache[llm_model_id] = self.client_factory.get_llm_client(llm_model_id)
return self._client_cache[llm_model_id]
async def structured_response(
self,
llm_model_id: str,
system_prompt: str,
response_model: Type[T],
user_message: Optional[str] = None,
fallback_value: Optional[Any] = None
) -> T:
"""
统一的结构化响应接口
Args:
llm_model_id: LLM模型ID
system_prompt: 系统提示词
response_model: 响应模型类
user_message: 用户消息(可选)
fallback_value: 失败时的降级值
Returns:
结构化响应对象
"""
try:
llm_client = self._get_cached_client(llm_model_id)
messages = [{"role": "system", "content": system_prompt}]
if user_message:
messages.append({"role": "user", "content": user_message})
logger.debug(f"LLM调用: model={llm_model_id}, prompt_length={len(system_prompt)}")
structured = await llm_client.response_structured(
messages=messages,
response_model=response_model
)
if structured is None:
logger.warning(f"LLM返回None使用降级值")
return self._create_fallback_response(response_model, fallback_value)
return structured
except Exception as e:
logger.error(f"结构化响应失败: {e}", exc_info=True)
return self._create_fallback_response(response_model, fallback_value)
async def batch_structured_response(
self,
llm_model_id: str,
requests: List[Dict[str, Any]],
response_model: Type[T],
max_concurrent: int = 5
) -> List[T]:
"""
批量处理结构化响应
Args:
llm_model_id: LLM模型ID
requests: 请求列表每个请求包含system_prompt等参数
response_model: 响应模型类
max_concurrent: 最大并发数
Returns:
结构化响应列表
"""
semaphore = asyncio.Semaphore(max_concurrent)
async def process_single_request(request: Dict[str, Any]) -> T:
async with semaphore:
return await self.structured_response(
llm_model_id=llm_model_id,
system_prompt=request.get('system_prompt', ''),
response_model=response_model,
user_message=request.get('user_message'),
fallback_value=request.get('fallback_value')
)
tasks = [process_single_request(req) for req in requests]
return await asyncio.gather(*tasks)
async def simple_response(
self,
llm_model_id: str,
system_prompt: str,
user_message: Optional[str] = None,
fallback_message: str = "信息不足,无法回答"
) -> str:
"""
简单的文本响应接口
Args:
llm_model_id: LLM模型ID
system_prompt: 系统提示词
user_message: 用户消息(可选)
fallback_message: 失败时的降级消息
Returns:
响应文本
"""
try:
llm_client = self._get_cached_client(llm_model_id)
messages = [{"role": "system", "content": system_prompt}]
if user_message:
messages.append({"role": "user", "content": user_message})
response = await llm_client.response(messages=messages)
if not response or not response.strip():
return fallback_message
return response.strip()
except Exception as e:
logger.error(f"简单响应失败: {e}", exc_info=True)
return fallback_message
def _create_fallback_response(self, response_model: Type[T], fallback_value: Optional[Any]) -> T:
"""创建降级响应"""
try:
if fallback_value is not None:
if isinstance(fallback_value, response_model):
return fallback_value
elif isinstance(fallback_value, dict):
return response_model(**fallback_value)
# 尝试创建空的响应模型
if hasattr(response_model, 'root'):
# RootModel类型
return response_model([])
else:
# 普通BaseModel类型
return response_model()
except Exception as e:
logger.error(f"创建降级响应失败: {e}")
# 最后的降级策略
if hasattr(response_model, 'root'):
return response_model([])
else:
return response_model()
def clear_cache(self):
"""清理客户端缓存"""
self._client_cache.clear()
class LLMServiceMixin:
"""
LLM服务混入类为节点提供便捷的LLM调用方法
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._llm_service: Optional[OptimizedLLMService] = None
def get_llm_service(self, db_session: Session) -> OptimizedLLMService:
"""获取LLM服务实例"""
if self._llm_service is None:
self._llm_service = OptimizedLLMService(db_session)
return self._llm_service
async def call_llm_structured(
self,
state: Dict[str, Any],
db_session: Session,
system_prompt: str,
response_model: Type[T],
user_message: Optional[str] = None,
fallback_value: Optional[Any] = None
) -> T:
"""
便捷的结构化LLM调用方法
Args:
state: 状态字典包含memory_config
db_session: 数据库会话
system_prompt: 系统提示词
response_model: 响应模型类
user_message: 用户消息(可选)
fallback_value: 失败时的降级值
Returns:
结构化响应对象
"""
memory_config = state.get('memory_config')
if not memory_config:
raise ValueError("State中缺少memory_config")
llm_model_id = memory_config.llm_model_id
if not llm_model_id:
raise ValueError("Memory config中缺少llm_model_id")
llm_service = self.get_llm_service(db_session)
return await llm_service.structured_response(
llm_model_id=llm_model_id,
system_prompt=system_prompt,
response_model=response_model,
user_message=user_message,
fallback_value=fallback_value
)
async def call_llm_simple(
self,
state: Dict[str, Any],
db_session: Session,
system_prompt: str,
user_message: Optional[str] = None,
fallback_message: str = "信息不足,无法回答"
) -> str:
"""
便捷的简单LLM调用方法
Args:
state: 状态字典包含memory_config
db_session: 数据库会话
system_prompt: 系统提示词
user_message: 用户消息(可选)
fallback_message: 失败时的降级消息
Returns:
响应文本
"""
memory_config = state.get('memory_config')
if not memory_config:
raise ValueError("State中缺少memory_config")
llm_model_id = memory_config.llm_model_id
if not llm_model_id:
raise ValueError("Memory config中缺少llm_model_id")
llm_service = self.get_llm_service(db_session)
return await llm_service.simple_response(
llm_model_id=llm_model_id,
system_prompt=system_prompt,
user_message=user_message,
fallback_message=fallback_message
)

View File

@@ -0,0 +1,7 @@
"""Agent utilities."""
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
__all__ = [
"MultimodalProcessor",
]

View File

@@ -9,116 +9,62 @@ from app.core.memory.models.message_models import DialogData, ConversationContex
async def get_chunked_dialogs(
chunker_strategy: str = "RecursiveChunker",
end_user_id: str = "group_1",
messages: list = None,
group_id: str = "group_1",
user_id: str = "user1",
apply_id: str = "applyid",
content: str = "这是用户的输入",
ref_id: str = "wyl_20251027",
config_id: str = None
) -> List[DialogData]:
"""Generate chunks from structured messages using the specified chunker strategy.
"""Generate chunks from all test data entries using the specified chunker strategy.
Args:
chunker_strategy: The chunking strategy to use (default: RecursiveChunker)
end_user_id: Group identifier
messages: Structured message list [{"role": "user", "content": "..."}, ...]
group_id: Group identifier
user_id: User identifier
apply_id: Application identifier
content: Dialog content
ref_id: Reference identifier
config_id: Configuration ID for processing (used to load pruning config)
config_id: Configuration ID for processing
Returns:
List of DialogData objects with generated chunks
List of DialogData objects with generated chunks for each test entry
"""
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
dialog_data_list = []
messages = []
if not messages or not isinstance(messages, list) or len(messages) == 0:
raise ValueError("messages parameter must be a non-empty list")
messages.append(ConversationMessage(role="用户", msg=content))
conversation_messages = []
for idx, msg in enumerate(messages):
if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
raise ValueError(f"Message {idx} format error: must contain 'role' and 'content' fields")
role = msg['role']
content = msg['content']
if role not in ['user', 'assistant']:
raise ValueError(f"Message {idx} role must be 'user' or 'assistant', got: {role}")
if content.strip():
conversation_messages.append(ConversationMessage(role=role, msg=content.strip()))
if not conversation_messages:
raise ValueError("Message list cannot be empty after filtering")
conversation_context = ConversationContext(msgs=conversation_messages)
# Create DialogData
conversation_context = ConversationContext(msgs=messages)
# Create DialogData with group_id based on the entry's id for uniqueness
dialog_data = DialogData(
context=conversation_context,
ref_id=ref_id,
end_user_id=end_user_id,
group_id=group_id,
user_id=user_id,
apply_id=apply_id,
config_id=config_id
)
# 语义剪枝步骤(在分块之前)
try:
from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner
from app.core.memory.models.config_models import PruningConfig
from app.db import get_db_context
from app.services.memory_config_service import MemoryConfigService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# 加载剪枝配置
pruning_config = None
if config_id:
try:
with get_db_context() as db:
# 使用 MemoryConfigService 加载完整的 MemoryConfig 对象
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
service_name="semantic_pruning"
)
if memory_config:
pruning_config = PruningConfig(
pruning_switch=memory_config.pruning_enabled,
pruning_scene=memory_config.pruning_scene or "education",
pruning_threshold=memory_config.pruning_threshold,
scene_id=str(memory_config.scene_id) if memory_config.scene_id else None,
ontology_classes=memory_config.ontology_classes,
)
logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}")
# 获取LLM客户端用于剪枝
if pruning_config.pruning_switch:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# 执行剪枝 - 使用 prune_dataset 支持消息级剪枝
pruner = SemanticPruner(config=pruning_config, llm_client=llm_client)
original_msg_count = len(dialog_data.context.msgs)
# 使用 prune_dataset 而不是 prune_dialog
# prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息
pruned_dialogs = await pruner.prune_dataset([dialog_data])
if pruned_dialogs:
dialog_data = pruned_dialogs[0]
remaining_msg_count = len(dialog_data.context.msgs)
deleted_count = original_msg_count - remaining_msg_count
logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)")
else:
logger.warning("[剪枝] prune_dataset 返回空列表")
else:
logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝")
except Exception as e:
logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True)
except Exception as e:
logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True)
# Create DialogueChunker and process the dialogue
chunker = DialogueChunker(chunker_strategy)
extracted_chunks = await chunker.process_dialogue(dialog_data)
dialog_data.chunks = extracted_chunks
logger.info(f"DialogData created with {len(extracted_chunks)} chunks")
dialog_data_list.append(dialog_data)
return [dialog_data]
# Convert to dict with datetime serialized
def serialize_datetime(obj):
if isinstance(obj, datetime):
return obj.isoformat()
raise TypeError(f"Object of type {obj.__class__.__name__} is not JSON serializable")
combined_output = [dd.model_dump() for dd in dialog_data_list]
print(dialog_data_list)
# with open(os.path.join(os.path.dirname(__file__), "chunker_test_output.txt"), "w", encoding="utf-8") as f:
# json.dump(combined_output, f, ensure_ascii=False, indent=4, default=serialize_datetime)
return dialog_data_list

View File

@@ -1,84 +1,82 @@
import asyncio
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict
from app.core.memory.agent.utils.messages_tool import read_template_file
from app.core.memory.utils.config.config_utils import (
get_picture_config,
get_voice_config,
)
# Removed global variable imports - use dependency injection instead
from dotenv import load_dotenv
from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages
from openai import OpenAI
PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
logger = logging.getLogger(__name__)
load_dotenv()
async def picture_model_requests(image_url):
'''
Args:
image_url:
Returns:
'''
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/Template_for_image_recognition_prompt.jinja2 '
system_prompt = await read_template_file(file_path)
result = await Picture_recognize(image_url,system_prompt)
return (result)
class WriteState(TypedDict):
'''
Langgrapg Writing TypedDict
'''
messages: Annotated[list[AnyMessage], add_messages]
end_user_id: str
user_id:str
apply_id:str
group_id:str
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
memory_config: object
write_result: dict
data: str
language: str # 语言类型 ("zh" 中文, "en" 英文)
class ReadState(TypedDict):
"""
LangGraph 工作流状态定义
Attributes:
messages: 消息列表,支持自动追加
loop_count: 遍历次数
search_switch: 搜索类型开关
end_user_id: 组标识
config_id: 配置ID用于过滤结果
data: 从content_input_node传递的内容数据
spit_data: 从Split_The_Problem传递的分解结果
tool_calls: 工具调用请求列表
tool_results: 工具执行结果列表
memory_config: 内存配置对象
"""
messages: Annotated[list[AnyMessage], add_messages] # 消息追加模式
loop_count: int
'''
Langgrapg READING TypedDict
name:
id:user id
loop_count:Traverse times
search_switchtype
config_id: configuration id for filtering results
errors: list of errors that occurred during workflow execution
'''
messages: Annotated[list[AnyMessage], add_messages] #消息追加的模式增加消息
name: str
id: str
loop_count:int
search_switch: str
end_user_id: str
user_id: str
apply_id: str
group_id: str
config_id: str
data: str # 新增字段用于传递内容
spit_data: dict # 新增字段用于传递问题分解结果
problem_extension:dict
storage_type: str
user_rag_memory_id: str
llm_id: str
embedding_id: str
memory_config: object # 新增字段用于传递内存配置对象
retrieve:dict
RetrieveSummary: dict
InputSummary: dict
verify: dict
SummaryFails: dict
summary: dict
errors: list[dict] # Track errors: [{"tool": "tool_name", "error": "message"}]
class COUNTState:
"""
工作流对话检索内容计数器
用于记录工作流对话检索内容没有正确消息召回遍历的次数。
"""
'''
The number of times the workflow dialogue retrieval content has no correct message recall traversal
'''
def __init__(self, limit: int = 5):
"""
初始化计数器
Args:
limit: 最大计数限制默认为5
"""
self.total: int = 0 # 当前累加值
self.limit: int = limit # 最大上限
def add(self, value: int = 1) -> None:
"""
累加数字,如果达到上限就保持最大值
Args:
value: 要累加的值默认为1
"""
def add(self, value: int = 1):
"""累加数字,如果达到上限就保持最大值"""
self.total += value
print(f"[COUNTState] 当前值: {self.total}")
if self.total >= self.limit:
@@ -86,19 +84,21 @@ class COUNTState:
self.total = self.limit # 达到上限不再增加
def get_total(self) -> int:
"""
获取当前累加值
Returns:
当前累加值
"""
"""获取当前累加值"""
return self.total
def reset(self) -> None:
def reset(self):
"""手动重置累加值"""
self.total = 0
print("[COUNTState] 已重置为 0")
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def deduplicate_entries(entries):
seen = set()
deduped = []
@@ -109,37 +109,70 @@ def deduplicate_entries(entries):
deduped.append(entry)
return deduped
def merge_to_key_value_pairs(data, query_key, result_key):
grouped = defaultdict(list)
for item in data:
grouped[item[query_key]].append(item[result_key])
return [{key: values} for key, values in grouped.items()]
def convert_extended_question_to_question(data):
async def Picture_recognize(image_path, PROMPT_TICKET_EXTRACTION, picture_model_name: str) -> str:
"""
递归地将数据中的 extended_question 字段转换为 question 字段
Updated to eliminate global variables in favor of explicit parameters.
Args:
data: 要转换的数据(可能是字典、列表或其他类型)
Returns:
转换后的数据
image_path: Path to image file
PROMPT_TICKET_EXTRACTION: Extraction prompt
picture_model_name: Picture model name (required, no longer from global variables)
"""
if isinstance(data, dict):
# 创建新字典来存储转换后的数据
converted = {}
for key, value in data.items():
if key == 'extended_question':
# 将 extended_question 转换为 question
converted['question'] = convert_extended_question_to_question(value)
else:
# 递归处理其他字段
converted[key] = convert_extended_question_to_question(value)
return converted
elif isinstance(data, list):
# 递归处理列表中的每个元素
return [convert_extended_question_to_question(item) for item in data]
else:
# 其他类型直接返回
return data
try:
model_config = get_picture_config(picture_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base=model_config['api_base']
logger.info(f"model_name: {backend_model_name}")
logger.info(f"api_key set: {'yes' if api_key else 'no'}")
logger.info(f"base_url: {model_config['api_base']}")
client = OpenAI(
api_key=api_key, base_url=api_base,
)
completion = client.chat.completions.create(
model=backend_model_name,
messages=[
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url":image_path,
},
{"type": "text",
"text": PROMPT_TICKET_EXTRACTION}
]
}
])
picture_text = completion.choices[0].message.content
picture_text = picture_text.replace('```json', '').replace('```', '')
picture_text = json.loads(picture_text)
return (picture_text['statement'])
async def Voice_recognize(voice_model_name: str):
"""
Updated to eliminate global variables in favor of explicit parameters.
Args:
voice_model_name: Voice model name (required, no longer from global variables)
"""
try:
model_config = get_voice_config(voice_model_name)
except Exception as e:
err = f"LLM配置不可用{str(e)}。请检查 config.json 和 runtime.json。"
logger.error(err)
return err
api_key = os.getenv(model_config["api_key"]) # 从环境变量读取对应后端的 API key
backend_model_name = model_config["llm_name"].split("/")[-1]
api_base = model_config['api_base']
return api_key,backend_model_name,api_base

Some files were not shown because too many files have changed in this diff Show More