Compare commits

..

1 Commits

Author SHA1 Message Date
yingzhao
543be4d610 Merge pull request #141 from SuanmoSuanyangTechnology/fix/web_zy
Fix/web zy
2026-01-17 11:43:58 +08:00
1085 changed files with 36052 additions and 129590 deletions

6
.gitignore vendored
View File

@@ -21,17 +21,13 @@ examples/
# Temporary outputs
.DS_Store
.hypothesis/
time.log
celerybeat-schedule.db
search_results.json
redbear-mem-metrics/
pitch-deck/
api/migrations/versions
tmp
files
powers/
# Exclude dep files
huggingface.co/
@@ -39,5 +35,3 @@ nltk_data/
tika-server*.jar*
cl100k_base.tiktoken
libssl*.deb
sandbox/lib/seccomp_redbear/target

View File

@@ -226,8 +226,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (Using Redis as broker)
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
# JWT Secret Key (Formation method: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here
@@ -334,13 +334,7 @@ step6: Log In to the Frontend Interface.
## License
This project is licensed under the Apache License 2.0. For details, see the LICENSE file.
## Community & Support
Join our community to ask questions, share your work, and connect with fellow developers.
- **GitHub Issues**: Report bugs, request features, or track known issues via [GitHub Issues](https://github.com/SuanmoSuanyangTechnology/MemoryBear/issues).
- **GitHub Pull Requests**: Contribute code improvements or fixes through [Pull Requests](https://github.com/SuanmoSuanyangTechnology/MemoryBear/pulls).
- **GitHub Discussions**: Ask questions, share ideas, and engage with the community in [GitHub Discussions](https://github.com/SuanmoSuanyangTechnology/MemoryBear/discussions).
- **WeChat**: Scan the QR code below to join our WeChat community group.
- ![wecom-temp-114020-47fe87a75da439f09f5dc93a01593046](https://github.com/user-attachments/assets/8c81885c-4134-40d5-96e2-7f78cc082dc6)
- **Contact**: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com
## Acknowledgements & Community
- Feedback & Issues: Please submit an Issue in the repository for bug reports or discussions.
- Contributions Welcome: When submitting a Pull Request, please create a feature branch and follow conventional commit message guidelines.
- Contact: If you are interested in contributing or collaborating, feel free to reach out at tianyou_hubm@redbearai.com

View File

@@ -201,8 +201,8 @@ REDIS_PORT=6379
REDIS_DB=1
# Celery (使用Redis作为broker)
REDIS_DB_CELERY_BROKER=1
REDIS_DB_CELERY_BACKEND=2
BROKER_URL=redis://127.0.0.1:6379/0
RESULT_BACKEND=redis://127.0.0.1:6379/0
# JWT密钥 (生成方式: openssl rand -hex 32)
SECRET_KEY=your-secret-key-here

View File

@@ -45,8 +45,7 @@ RUN --mount=type=cache,id=mem_apt,target=/var/cache/apt,sharing=locked \
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
apt install -y libjemalloc-dev && \
apt install -y python3-pip pipx nginx unzip curl wget git vim less && \
apt install -y ghostscript && \
apt install -y libmagic1
apt install -y ghostscript
RUN if [ "$NEED_MIRROR" == "1" ]; then \
pip3 config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple && \

View File

@@ -60,12 +60,7 @@ version_path_separator = os # Use os.pathsep. Default configuration used for ne
# are written from script.py.mako
# output_encoding = utf-8
# Database connection URL - DO NOT hardcode credentials here!
# Connection string is set dynamically from environment variables in migrations/env.py
# Required env vars: DB_USER, DB_PASSWORD, DB_HOST, DB_PORT, DB_NAME
# Example: postgresql://user:password@localhost:5432/dbname
; sqlalchemy.url = postgresql://user:password@host:port/dbname
sqlalchemy.url = driver://user:password@host:port/dbname
sqlalchemy.url = postgresql://user:password@localhost/dbname
[post_write_hooks]

View File

@@ -1,16 +1,16 @@
import os
import asyncio
import json
import logging
from typing import Dict, Any, Optional
import redis.asyncio as redis
from redis.asyncio import ConnectionPool
from app.core.config import settings
# 设置日志记录器
logger = logging.getLogger(__name__)
# 创建连接池
pool = ConnectionPool.from_url(
f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}",
@@ -21,7 +21,6 @@ pool = ConnectionPool.from_url(
)
aio_redis = redis.StrictRedis(connection_pool=pool)
async def get_redis_connection():
"""获取Redis连接"""
try:
@@ -30,8 +29,7 @@ async def get_redis_connection():
logger.error(f"Redis连接失败: {str(e)}")
return None
async def aio_redis_set(key: str, val: str | dict, expire: int = None):
async def aio_redis_set(key: str, val: str|dict, expire: int = None):
"""设置Redis键值
Args:
@@ -42,7 +40,7 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
try:
if isinstance(val, dict):
val = json.dumps(val, ensure_ascii=False)
if expire is not None:
# 设置带过期时间的键值
await aio_redis.set(key, val, ex=expire)
@@ -52,7 +50,6 @@ async def aio_redis_set(key: str, val: str | dict, expire: int = None):
except Exception as e:
logger.error(f"Redis set错误: {str(e)}")
async def aio_redis_get(key: str):
"""获取Redis键值"""
try:
@@ -61,7 +58,6 @@ async def aio_redis_get(key: str):
logger.error(f"Redis get错误: {str(e)}")
return None
async def aio_redis_delete(key: str):
"""删除Redis键"""
try:
@@ -70,7 +66,6 @@ async def aio_redis_delete(key: str):
logger.error(f"Redis delete错误: {str(e)}")
return None
async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
"""发布消息到Redis频道"""
try:
@@ -83,10 +78,9 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool:
logger.error(f"Redis发布错误: {str(e)}")
return False
class RedisSubscriber:
"""Redis订阅器"""
def __init__(self, channel: str):
self.channel = channel
self.conn = None
@@ -94,25 +88,25 @@ class RedisSubscriber:
self.is_closed = False
self._queue = asyncio.Queue()
self._task = None
async def start(self):
"""开始订阅"""
if self.is_closed or self._task:
return
self._task = asyncio.create_task(self._receive_messages())
logger.info(f"开始订阅: {self.channel}")
async def _receive_messages(self):
"""接收消息"""
try:
self.conn = await get_redis_connection()
if not self.conn:
return
self.pubsub = self.conn.pubsub()
await self.pubsub.subscribe(self.channel)
while not self.is_closed:
try:
message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01)
@@ -133,7 +127,7 @@ class RedisSubscriber:
finally:
await self._queue.put(None)
await self._cleanup()
async def _cleanup(self):
"""清理资源"""
if self.pubsub:
@@ -147,7 +141,7 @@ class RedisSubscriber:
await self.conn.close()
except Exception:
pass
async def get_message(self) -> Optional[Dict[str, Any]]:
"""获取消息"""
if self.is_closed:
@@ -159,7 +153,7 @@ class RedisSubscriber:
except Exception as e:
logger.error(f"获取消息错误: {str(e)}")
return None
async def close(self):
"""关闭订阅器"""
if self.is_closed:
@@ -169,33 +163,32 @@ class RedisSubscriber:
self._task.cancel()
await self._cleanup()
class RedisPubSubManager:
"""Redis发布订阅管理器"""
def __init__(self):
self.subscribers = {}
async def publish(self, channel: str, message: Dict[str, Any]) -> bool:
return await aio_redis_publish(channel, message)
def get_subscriber(self, channel: str) -> RedisSubscriber:
if channel in self.subscribers:
subscriber = self.subscribers[channel]
if not subscriber.is_closed:
return subscriber
subscriber = RedisSubscriber(channel)
self.subscribers[channel] = subscriber
return subscriber
def cancel_subscription(self, channel: str) -> bool:
if channel in self.subscribers:
asyncio.create_task(self.subscribers[channel].close())
del self.subscribers[channel]
return True
return False
def cancel_all_subscriptions(self) -> int:
count = len(self.subscribers)
for subscriber in self.subscribers.values():
@@ -203,6 +196,6 @@ class RedisPubSubManager:
self.subscribers.clear()
return count
# 全局实例
pubsub_manager = RedisPubSubManager()

View File

@@ -1,10 +0,0 @@
"""
Cache 缓存模块
提供各种缓存功能的统一入口
"""
from .memory import InterestMemoryCache
__all__ = [
"InterestMemoryCache",
]

View File

@@ -1,12 +0,0 @@
"""
Memory 缓存模块
提供记忆系统相关的缓存功能
"""
from .interest_memory import InterestMemoryCache
from .activity_stats_cache import ActivityStatsCache
__all__ = [
"InterestMemoryCache",
"ActivityStatsCache",
]

View File

@@ -1,124 +0,0 @@
"""
Recent Activity Stats Cache
记忆提取活动统计缓存模块
用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储24小时后释放
查询命令cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843
"""
import json
import logging
from typing import Optional, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
ACTIVITY_STATS_CACHE_EXPIRE = 86400
class ActivityStatsCache:
"""记忆提取活动统计缓存类"""
PREFIX = "cache:memory:activity_stats"
@classmethod
def _get_key(cls, workspace_id: str) -> str:
"""生成 Redis key
Args:
workspace_id: 工作空间ID
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_workspace:{workspace_id}"
@classmethod
async def set_activity_stats(
cls,
workspace_id: str,
stats: Dict[str, Any],
expire: int = ACTIVITY_STATS_CACHE_EXPIRE,
) -> bool:
"""设置记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
stats: 统计数据,格式:
{
"chunk_count": int,
"statements_count": int,
"triplet_entities_count": int,
"triplet_relations_count": int,
"temporal_count": int,
}
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(workspace_id)
payload = {
"stats": stats,
"generated_at": datetime.now().isoformat(),
"workspace_id": workspace_id,
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置活动统计缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_activity_stats(
cls,
workspace_id: str,
) -> Optional[Dict[str, Any]]:
"""获取记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
统计数据字典,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(workspace_id)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中活动统计缓存: {key}")
return payload
logger.info(f"活动统计缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取活动统计缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_activity_stats(
cls,
workspace_id: str,
) -> bool:
"""删除记忆提取活动统计缓存
Args:
workspace_id: 工作空间ID
Returns:
是否删除成功
"""
try:
key = cls._get_key(workspace_id)
result = await aio_redis.delete(key)
logger.info(f"删除活动统计缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除活动统计缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,122 +0,0 @@
"""
Interest Distribution Cache
兴趣分布缓存模块
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
"""
import json
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
INTEREST_CACHE_EXPIRE = 86400
class InterestMemoryCache:
"""兴趣分布缓存类"""
PREFIX = "cache:memory:interest_distribution"
@classmethod
def _get_key(cls, end_user_id: str, language: str) -> str:
"""生成 Redis key
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
@classmethod
async def set_interest_distribution(
cls,
end_user_id: str,
language: str,
data: List[Dict[str, Any]],
expire: int = INTEREST_CACHE_EXPIRE,
) -> bool:
"""设置用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(end_user_id, language)
payload = {
"data": data,
"generated_at": datetime.now().isoformat(),
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> Optional[List[Dict[str, Any]]]:
"""获取用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
兴趣分布列表,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
return payload.get("data")
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> bool:
"""删除用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
是否删除成功
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,60 +1,40 @@
import os
import platform
from datetime import timedelta
from urllib.parse import quote
from celery import Celery
from celery.schedules import crontab
from app.core.config import settings
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# macOS fork() safety - must be set before any Celery initialization
if platform.system() == 'Darwin':
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
from celery import Celery
# 创建 Celery 应用实例
# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定
# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定
# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND
# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md
# Build canonical broker/backend URLs and force them into os.environ so that
# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first)
# cannot be overridden by stray env vars.
# See: https://github.com/celery/celery/issues/4284
_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}"
_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}"
os.environ["CELERY_BROKER_URL"] = _broker_url
os.environ["CELERY_RESULT_BACKEND"] = _backend_url
# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click
# integration and accidentally override our canonical URLs.
os.environ.pop("BROKER_URL", None)
os.environ.pop("RESULT_BACKEND", None)
os.environ.pop("CELERY_BROKER", None)
os.environ.pop("CELERY_BACKEND", None)
# broker: 任务队列(使用 Redis DB 0
# backend: 结果存储(使用 Redis DB 10
celery_app = Celery(
"redbear_tasks",
broker=_broker_url,
backend=_backend_url,
broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}",
backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}",
)
logger.info(
"Celery app initialized",
extra={
"broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"),
"backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"),
},
)
# Default queue for unrouted tasks
celery_app.conf.task_default_queue = 'memory_tasks'
# 配置使用本地队列,避免与远程 worker 冲突
celery_app.conf.task_default_queue = 'localhost_test_wyl'
celery_app.conf.task_default_exchange = 'localhost_test_wyl'
celery_app.conf.task_default_routing_key = 'localhost_test_wyl'
# macOS 兼容性配置
if platform.system() == 'Darwin':
import platform
if platform.system() == 'Darwin': # macOS
# 设置环境变量解决 fork 问题
os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES')
# 使用 solo 池避免多进程问题
celery_app.conf.worker_pool = 'solo'
# 设置唯一的节点名称
import socket
import time
hostname = socket.gethostname()
timestamp = int(time.time())
celery_app.conf.worker_name = f"celery@{hostname}-{timestamp}"
# Celery 配置
celery_app.conf.update(
@@ -63,78 +43,58 @@ celery_app.conf.update(
accept_content=['json'],
result_serializer='json',
# # 时区
# timezone='Asia/Shanghai',
# enable_utc=False,
# 时区
timezone='Asia/Shanghai',
enable_utc=True,
# 任务追踪
task_track_started=True,
task_ignore_result=False,
# 超时设置
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
task_time_limit=30 * 60, # 30 分钟硬超时
task_soft_time_limit=25 * 60, # 25 分钟软超时
# Worker 设置 - 针对 macOS 优化
worker_prefetch_multiplier=1, # 减少预取任务数,避免内存堆积
worker_max_tasks_per_child=10, # 大幅减少每个 worker 执行的任务数,频繁重启防止内存泄漏
worker_max_memory_per_child=200000, # 200MB 内存限制,超过后重启 worker
# 结果过期时间
result_expires=3600, # 结果保存1小时
result_expires=3600, # 结果保存 1 小时
# 任务确认设置
task_acks_late=True,
task_reject_on_worker_lost=True,
worker_disable_rate_limits=True,
# FLower setting
worker_send_task_events=True,
task_send_sent_event=True,
# task routing
task_routes={
# Memory tasks → memory_tasks queue (threads worker)
'app.core.memory.agent.read_message_priority': {'queue': 'memory_tasks'},
'app.core.memory.agent.read_message': {'queue': 'memory_tasks'},
'app.core.memory.agent.write_message': {'queue': 'memory_tasks'},
'app.tasks.write_perceptual_memory': {'queue': 'memory_tasks'},
# Long-term storage tasks → memory_tasks queue (batched write strategies)
'app.core.memory.agent.long_term_storage.window': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.time': {'queue': 'memory_tasks'},
'app.core.memory.agent.long_term_storage.aggregate': {'queue': 'memory_tasks'},
# Document tasks → document_tasks queue (prefork worker)
'app.core.rag.tasks.parse_document': {'queue': 'document_tasks'},
'app.core.rag.tasks.build_graphrag_for_kb': {'queue': 'document_tasks'},
'app.core.rag.tasks.sync_knowledge_for_kb': {'queue': 'document_tasks'},
# Beat/periodic tasks → periodic_tasks queue (dedicated periodic worker)
'app.tasks.workspace_reflection_task': {'queue': 'periodic_tasks'},
'app.tasks.regenerate_memory_cache': {'queue': 'periodic_tasks'},
'app.tasks.run_forgetting_cycle_task': {'queue': 'periodic_tasks'},
'app.tasks.write_all_workspaces_memory_task': {'queue': 'periodic_tasks'},
'app.tasks.update_implicit_emotions_storage': {'queue': 'periodic_tasks'},
'app.tasks.init_implicit_emotions_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_interest_distribution_for_users': {'queue': 'periodic_tasks'},
'app.tasks.init_community_clustering_for_users': {'queue': 'periodic_tasks'},
},
task_acks_late=True, # 任务完成后才确认,避免任务丢失
worker_disable_rate_limits=True, # 禁用速率限制
# 任务路由(可选,用于不同队列)
# task_routes={
# 'app.core.rag.tasks.parse_document': {'queue': 'document_processing'},
# 'app.core.memory.agent.read_message': {'queue': 'memory_processing'},
# 'app.core.memory.agent.write_message': {'queue': 'memory_processing'},
# 'tasks.process_item': {'queue': 'default'},
# },
)
# 自动发现任务模块
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
reflection_schedule = timedelta(seconds=settings.REFLECTION_INTERVAL_SECONDS)
health_schedule = timedelta(seconds=settings.HEALTH_CHECK_SECONDS)
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
implicit_emotions_update_schedule = crontab(
hour=settings.IMPLICIT_EMOTIONS_UPDATE_HOUR,
minute=settings.IMPLICIT_EMOTIONS_UPDATE_MINUTE,
)
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
# 构建定时任务配置
beat_schedule_config = {
# "check-read-service": {
# "task": "app.core.memory.agent.health.check_read_service",
# "schedule": health_schedule,
# "args": (),
# },
"run-workspace-reflection": {
"task": "app.tasks.workspace_reflection_task",
"schedule": workspace_reflection_schedule,
@@ -152,16 +112,16 @@ beat_schedule_config = {
"config_id": None, # 使用默认配置,可以通过环境变量配置
},
},
"write-all-workspaces-memory": {
"task": "app.tasks.write_all_workspaces_memory_task",
"schedule": memory_increment_schedule,
"args": (),
},
"update-implicit-emotions-storage": {
"task": "app.tasks.update_implicit_emotions_storage",
"schedule": implicit_emotions_update_schedule,
"args": (),
},
}
# 如果配置了默认工作空间ID则添加记忆总量统计任务
if settings.DEFAULT_WORKSPACE_ID:
beat_schedule_config["write-total-memory"] = {
"task": "app.controllers.memory_storage_controller.search_all",
"schedule": memory_increment_schedule,
"kwargs": {
"workspace_id": settings.DEFAULT_WORKSPACE_ID,
},
}
celery_app.conf.beat_schedule = beat_schedule_config

View File

@@ -3,14 +3,8 @@ Celery Worker 入口点
用于启动 Celery Worker: celery -A app.celery_worker worker --loglevel=info
"""
from app.celery_app import celery_app
from app.core.logging_config import LoggingConfig, get_logger
# Initialize logging system for Celery worker
LoggingConfig.setup_logging()
logger = get_logger(__name__)
logger.info("Celery worker logging initialized")
# 导入任务模块以注册任务
import app.tasks
__all__ = ['celery_app']
__all__ = ['celery_app']

View File

@@ -1 +0,0 @@
"""Configuration module for application settings."""

View File

@@ -1,239 +0,0 @@
"""默认本体场景配置
本模块定义系统预设的本体场景和实体类型配置。
这些配置用于在工作空间创建时自动初始化默认场景。
支持中英文双语配置,根据用户语言偏好创建对应语言的场景。
"""
# 在线教育场景配置
ONLINE_EDUCATION_SCENE = {
"name_chinese": "在线教育",
"name_english": "Online Education",
"description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型",
"description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses",
"types": [
{
"name_chinese": "学生",
"name_english": "Student",
"description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性",
"description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class"
},
{
"name_chinese": "教师",
"name_english": "Teacher",
"description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性",
"description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title"
},
{
"name_chinese": "课程",
"name_english": "Course",
"description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性",
"description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours"
},
{
"name_chinese": "作业",
"name_english": "Assignment",
"description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性",
"description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status"
},
{
"name_chinese": "成绩",
"name_english": "Grade",
"description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性",
"description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course"
},
{
"name_chinese": "考试",
"name_english": "Exam",
"description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性",
"description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject"
},
{
"name_chinese": "教室",
"name_english": "Classroom",
"description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性",
"description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment"
},
{
"name_chinese": "学科",
"name_english": "Subject",
"description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性",
"description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department"
},
{
"name_chinese": "教材",
"name_english": "Textbook",
"description_chinese": "教学使用的书籍或资料包含书名、作者、出版社、ISBN等属性",
"description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN"
},
{
"name_chinese": "班级",
"name_english": "Class",
"description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性",
"description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher"
},
{
"name_chinese": "学期",
"name_english": "Semester",
"description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性",
"description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time"
},
{
"name_chinese": "课时",
"name_english": "Class Hour",
"description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性",
"description_english": "Time units of courses, including attributes such as class time, location, teacher, and course"
},
{
"name_chinese": "教学计划",
"name_english": "Teaching Plan",
"description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性",
"description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan"
}
]
}
# 情感陪伴场景配置
EMOTIONAL_COMPANION_SCENE = {
"name_chinese": "情感陪伴",
"name_english": "Emotional Companion",
"description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型",
"description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities",
"types": [
{
"name_chinese": "用户",
"name_english": "User",
"description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性",
"description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences"
},
{
"name_chinese": "情绪",
"name_english": "Emotion",
"description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性",
"description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration"
},
{
"name_chinese": "活动",
"name_english": "Activity",
"description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性",
"description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location"
},
{
"name_chinese": "对话",
"name_english": "Conversation",
"description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性",
"description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content"
},
{
"name_chinese": "兴趣爱好",
"name_english": "Hobby",
"description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性",
"description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities"
},
{
"name_chinese": "日常事件",
"name_english": "Daily Event",
"description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性",
"description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people"
},
{
"name_chinese": "关系",
"name_english": "Relationship",
"description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性",
"description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time"
},
{
"name_chinese": "回忆",
"name_english": "Memory",
"description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性",
"description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people"
},
{
"name_chinese": "地点",
"name_english": "Location",
"description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性",
"description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events"
},
{
"name_chinese": "时间节点",
"name_english": "Time Point",
"description_chinese": "重要的时间标记,包含日期、事件、意义等属性",
"description_english": "Important time markers, including attributes such as date, event, and significance"
},
{
"name_chinese": "目标",
"name_english": "Goal",
"description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性",
"description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities"
},
{
"name_chinese": "成就",
"name_english": "Achievement",
"description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性",
"description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals"
}
]
}
# 导出默认场景列表
DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE]
def get_scene_name(scene_config: dict, language: str = "zh") -> str:
"""获取场景名称(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景名称
"""
if language == "en":
return scene_config.get("name_english", scene_config.get("name_chinese"))
return scene_config.get("name_chinese")
def get_scene_description(scene_config: dict, language: str = "zh") -> str:
"""获取场景描述(根据语言)
Args:
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的场景描述
"""
if language == "en":
return scene_config.get("description_english", scene_config.get("description_chinese"))
return scene_config.get("description_chinese")
def get_type_name(type_config: dict, language: str = "zh") -> str:
"""获取类型名称(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型名称
"""
if language == "en":
return type_config.get("name_english", type_config.get("name_chinese"))
return type_config.get("name_chinese")
def get_type_description(type_config: dict, language: str = "zh") -> str:
"""获取类型描述(根据语言)
Args:
type_config: 类型配置字典
language: 语言类型 ("zh""en")
Returns:
对应语言的类型描述
"""
if language == "en":
return type_config.get("description_english", type_config.get("description_chinese"))
return type_config.get("description_chinese")

View File

@@ -1,249 +0,0 @@
# -*- coding: utf-8 -*-
"""默认本体场景初始化器
本模块提供默认本体场景和类型的自动初始化功能。
在工作空间创建时,自动添加预设的本体场景和实体类型。
Classes:
DefaultOntologyInitializer: 默认本体场景初始化器
"""
import logging
from typing import List, Optional, Tuple
from uuid import UUID
from sqlalchemy.orm import Session
from app.config.default_ontology_config import (
DEFAULT_SCENES,
get_scene_name,
get_scene_description,
get_type_name,
get_type_description,
)
from app.core.logging_config import get_business_logger
from app.repositories.ontology_scene_repository import OntologySceneRepository
from app.repositories.ontology_class_repository import OntologyClassRepository
class DefaultOntologyInitializer:
"""默认本体场景初始化器
负责在工作空间创建时自动初始化默认的本体场景和类型。
遵循最小侵入原则,确保初始化失败不阻止工作空间创建。
Attributes:
db: 数据库会话
scene_repo: 场景Repository
class_repo: 类型Repository
logger: 业务日志记录器
"""
def __init__(self, db: Session):
"""初始化
Args:
db: 数据库会话
"""
self.db = db
self.scene_repo = OntologySceneRepository(db)
self.class_repo = OntologyClassRepository(db)
self.logger = get_business_logger()
def initialize_default_scenes(
self,
workspace_id: UUID,
language: str = "zh"
) -> Tuple[bool, str]:
"""为工作空间初始化默认场景
创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。
如果创建失败,记录错误日志但不抛出异常。
Args:
workspace_id: 工作空间ID
language: 语言类型 ("zh""en"),默认为 "zh"
Returns:
Tuple[bool, str]: (是否成功, 错误信息)
"""
try:
self.logger.info(
f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}"
)
scenes_created = 0
total_types_created = 0
# 遍历默认场景配置
for scene_config in DEFAULT_SCENES:
scene_name = get_scene_name(scene_config, language)
# 创建场景及其类型
scene_id = self._create_scene_with_types(workspace_id, scene_config, language)
if scene_id:
scenes_created += 1
# 统计类型数量
types_count = len(scene_config.get("types", []))
total_types_created += types_count
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene_id}, types_count={types_count}, language={language}"
)
else:
self.logger.warning(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}"
)
# 记录总体结果
self.logger.info(
f"默认场景初始化完成 - workspace_id={workspace_id}, "
f"language={language}, scenes_created={scenes_created}, "
f"total_types_created={total_types_created}"
)
# 如果至少创建了一个场景,视为成功
if scenes_created > 0:
return True, ""
else:
error_msg = "所有默认场景创建失败"
self.logger.error(
f"默认场景初始化失败 - workspace_id={workspace_id}, "
f"language={language}, error={error_msg}"
)
return False, error_msg
except Exception as e:
error_msg = f"默认场景初始化异常: {str(e)}"
self.logger.error(
f"默认场景初始化异常 - workspace_id={workspace_id}, "
f"language={language}, error={str(e)}",
exc_info=True
)
return False, error_msg
def _create_scene_with_types(
self,
workspace_id: UUID,
scene_config: dict,
language: str = "zh"
) -> Optional[UUID]:
"""创建场景及其类型
Args:
workspace_id: 工作空间ID
scene_config: 场景配置字典
language: 语言类型 ("zh""en")
Returns:
Optional[UUID]: 创建的场景ID失败返回None
"""
try:
scene_name = get_scene_name(scene_config, language)
scene_description = get_scene_description(scene_config, language)
# 检查是否已存在同名场景(支持向后兼容)
existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id)
if existing_scene:
self.logger.info(
f"场景已存在,跳过创建 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, "
f"language={language}"
)
return None
# 创建场景记录,设置 is_system_default=true
scene_data = {
"scene_name": scene_name,
"scene_description": scene_description
}
scene = self.scene_repo.create(scene_data, workspace_id)
# 设置系统默认标识
scene.is_system_default = True
self.db.flush()
self.logger.info(
f"场景创建成功 - scene_name={scene_name}, "
f"scene_id={scene.scene_id}, is_system_default=True, language={language}"
)
# 批量创建类型
types_config = scene_config.get("types", [])
types_created = self._batch_create_types(scene.scene_id, types_config, language)
self.logger.info(
f"场景类型创建完成 - scene_id={scene.scene_id}, "
f"types_created={types_created}/{len(types_config)}, language={language}"
)
return scene.scene_id
except Exception as e:
scene_name = get_scene_name(scene_config, language)
self.logger.error(
f"场景创建失败 - scene_name={scene_name}, "
f"workspace_id={workspace_id}, language={language}, error={str(e)}",
exc_info=True
)
return None
def _batch_create_types(
self,
scene_id: UUID,
types_config: List[dict],
language: str = "zh"
) -> int:
"""批量创建实体类型
Args:
scene_id: 场景ID
types_config: 类型配置列表
language: 语言类型 ("zh""en")
Returns:
int: 成功创建的类型数量
"""
created_count = 0
for type_config in types_config:
try:
type_name = get_type_name(type_config, language)
type_description = get_type_description(type_config, language)
# 创建类型数据
class_data = {
"class_name": type_name,
"class_description": type_description
}
# 创建类型
ontology_class = self.class_repo.create(class_data, scene_id)
# 设置系统默认标识
ontology_class.is_system_default = True
self.db.flush()
created_count += 1
self.logger.debug(
f"类型创建成功 - class_name={type_name}, "
f"class_id={ontology_class.class_id}, "
f"scene_id={scene_id}, is_system_default=True, language={language}"
)
except Exception as e:
type_name = get_type_name(type_config, language)
self.logger.warning(
f"单个类型创建失败,继续创建其他类型 - "
f"class_name={type_name}, scene_id={scene_id}, "
f"language={language}, error={str(e)}"
)
# 继续创建其他类型
continue
return created_count

View File

@@ -13,26 +13,19 @@ from . import (
document_controller,
emotion_config_controller,
emotion_controller,
end_user_controller,
file_controller,
file_storage_controller,
home_page_controller,
i18n_controller,
implicit_memory_controller,
knowledge_controller,
knowledgeshare_controller,
mcp_market_controller,
mcp_market_config_controller,
memory_agent_controller,
memory_dashboard_controller,
memory_episodic_controller,
memory_explicit_controller,
memory_forget_controller,
memory_perceptual_controller,
memory_reflection_controller,
memory_short_term_controller,
memory_storage_controller,
memory_working_controller,
model_controller,
multi_agent_controller,
prompt_optimizer_controller,
@@ -45,9 +38,12 @@ from . import (
upload_controller,
user_controller,
user_memory_controllers,
workflow_controller,
workspace_controller,
ontology_controller,
skill_controller
memory_forget_controller,
home_page_controller,
memory_perceptual_controller,
memory_working_controller,
)
# 创建管理端 API 路由器
@@ -64,8 +60,6 @@ manager_router.include_router(model_controller.router)
manager_router.include_router(file_controller.router)
manager_router.include_router(document_controller.router)
manager_router.include_router(knowledge_controller.router)
manager_router.include_router(mcp_market_controller.router)
manager_router.include_router(mcp_market_config_controller.router)
manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
@@ -82,6 +76,7 @@ manager_router.include_router(release_share_controller.router)
manager_router.include_router(public_share_controller.router) # 公开路由(无需认证)
manager_router.include_router(memory_dashboard_controller.router)
manager_router.include_router(multi_agent_controller.router)
manager_router.include_router(workflow_controller.router)
manager_router.include_router(emotion_controller.router)
manager_router.include_router(emotion_config_controller.router)
manager_router.include_router(prompt_optimizer_controller.router)
@@ -93,10 +88,5 @@ manager_router.include_router(home_page_controller.router)
manager_router.include_router(implicit_memory_controller.router)
manager_router.include_router(memory_perceptual_controller.router)
manager_router.include_router(memory_working_controller.router)
manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
manager_router.include_router(end_user_controller.router)
__all__ = ["manager_router"]

View File

@@ -1,16 +1,13 @@
import uuid
import io
from typing import Optional, Annotated
import yaml
from fastapi import APIRouter, Depends, Path, Form, UploadFile, File
from fastapi import APIRouter, Depends, Path
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from urllib.parse import quote
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
from app.core.response_utils import success, fail
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models import User
@@ -20,14 +17,11 @@ from app.repositories.end_user_repository import EndUserRepository
from app.schemas import app_schema
from app.schemas.response_schema import PageData, PageMeta
from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema
from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave
from app.schemas.workflow_schema import WorkflowConfigUpdate
from app.services import app_service, workspace_service
from app.services.agent_config_helper import enrich_agent_config
from app.services.app_service import AppService
from app.services.app_statistics_service import AppStatisticsService
from app.services.workflow_import_service import WorkflowImportService
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.services.app_dsl_service import AppDslService
router = APIRouter(prefix="/apps", tags=["Apps"])
logger = get_business_logger()
@@ -53,7 +47,6 @@ def list_apps(
status: str | None = None,
search: str | None = None,
include_shared: bool = True,
shared_only: bool = False,
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
@@ -71,7 +64,7 @@ def list_apps(
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]
app_ids = [id.strip() for id in ids.split(',') if id.strip()]
items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id)
items = [service._convert_to_schema(app, workspace_id) for app in items_orm]
return success(data=items)
@@ -85,7 +78,6 @@ def list_apps(
status=status,
search=search,
include_shared=include_shared,
shared_only=shared_only,
page=page,
pagesize=pagesize,
)
@@ -95,37 +87,6 @@ def list_apps(
return success(data=PageData(page=meta, items=items))
@router.get("/my-shared-out", summary="列出本工作空间主动分享出去的记录")
@cur_workspace_access_guard()
def list_my_shared_out(
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""列出本工作空间主动分享给其他工作空间的所有记录(我的共享)"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
shares = service.list_my_shared_out(workspace_id=workspace_id)
data = [app_schema.AppShare.model_validate(s) for s in shares]
return success(data=data)
@router.delete("/share/{target_workspace_id}", summary="取消对某工作空间的所有应用分享")
@cur_workspace_access_guard()
def unshare_all_apps_to_workspace(
target_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""Cancel all app shares from current workspace to a target workspace."""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
count = service.unshare_all_apps_to_workspace(
target_workspace_id=target_workspace_id,
workspace_id=workspace_id
)
return success(msg=f"已取消 {count} 个应用的分享", data={"count": count})
@router.get("/{app_id}", summary="获取应用详情")
@cur_workspace_access_guard()
def get_app(
@@ -194,7 +155,6 @@ def delete_app(
def copy_app(
app_id: uuid.UUID,
new_name: Optional[str] = None,
payload: app_schema.CopyAppRequest = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
@@ -206,8 +166,6 @@ def copy_app(
- 不影响原应用
"""
workspace_id = current_user.current_workspace_id
# body takes precedence over query param for backward compatibility
new_name = (payload.new_name if payload else None) or new_name
logger.info(
"用户请求复制应用",
extra={
@@ -257,27 +215,6 @@ def get_agent_config(
return success(data=app_schema.AgentConfig.model_validate(cfg))
@router.get("/{app_id}/opening", summary="获取应用开场白配置")
@cur_workspace_access_guard()
def get_opening(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""返回开场白文本和预设问题,供前端对话界面初始化时展示"""
workspace_id = current_user.current_workspace_id
cfg = app_service.get_agent_config(db, app_id=app_id, workspace_id=workspace_id)
features = cfg.features or {}
if hasattr(features, "model_dump"):
features = features.model_dump()
opening = features.get("opening_statement", {})
return success(data=app_schema.OpeningResponse(
enabled=opening.get("enabled", False),
statement=opening.get("statement"),
suggested_questions=opening.get("suggested_questions", []),
))
@router.post("/{app_id}/publish", summary="发布应用(生成不可变快照)")
@cur_workspace_access_guard()
def publish_app(
@@ -359,8 +296,7 @@ def share_app(
app_id=app_id,
target_workspace_ids=payload.target_workspace_ids,
user_id=current_user.id,
workspace_id=workspace_id,
permission=payload.permission
workspace_id=workspace_id
)
data = [app_schema.AppShare.model_validate(s) for s in shares]
@@ -391,32 +327,6 @@ def unshare_app(
return success(msg="应用分享已取消")
@router.patch("/{app_id}/share/{target_workspace_id}", summary="更新共享权限")
@cur_workspace_access_guard()
def update_share_permission(
app_id: uuid.UUID,
target_workspace_id: uuid.UUID,
payload: app_schema.UpdateSharePermissionRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""更新共享权限readonly <-> editable
- 只能修改自己工作空间应用的共享权限
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
share = service.update_share_permission(
app_id=app_id,
target_workspace_id=target_workspace_id,
permission=payload.permission,
workspace_id=workspace_id
)
return success(data=app_schema.AppShare.model_validate(share))
@router.get("/{app_id}/shares", summary="列出应用的分享记录")
@cur_workspace_access_guard()
def list_app_shares(
@@ -440,46 +350,6 @@ def list_app_shares(
return success(data=data)
@router.delete("/shared/{source_workspace_id}", summary="批量移除某来源工作空间的所有共享应用")
@cur_workspace_access_guard()
def remove_all_shared_apps_from_workspace(
source_workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""Remove all shared apps from a specific source workspace (recipient operation)."""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
count = service.remove_all_shared_apps_from_workspace(
source_workspace_id=source_workspace_id,
workspace_id=workspace_id
)
return success(msg=f"已移除 {count} 个共享应用", data={"count": count})
@router.delete("/{app_id}/shared", summary="移除共享给我的应用")
@cur_workspace_access_guard()
def remove_shared_app(
app_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""被共享者从自己的工作空间移除共享应用
- 不会删除源应用,只删除共享记录
- 只能移除共享给自己工作空间的应用
"""
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
service.remove_shared_app(
app_id=app_id,
workspace_id=workspace_id
)
return success(msg="已移除共享应用")
@router.post("/{app_id}/draft/run", summary="试运行 Agent使用当前草稿配置")
@cur_workspace_access_guard()
async def draft_run(
@@ -520,13 +390,13 @@ async def draft_run(
# 提前验证和准备(在流式响应开始前完成)
from app.services.app_service import AppService
from app.services.multi_agent_service import MultiAgentService
from app.models import AgentConfig, ModelConfig, AppRelease
from app.models import AgentConfig, ModelConfig
from sqlalchemy import select
from app.core.exceptions import BusinessException
from app.services.draft_run_service import AgentRunService
from app.services.draft_run_service import DraftRunService
service = AppService(db)
draft_service = AgentRunService(db)
draft_service = DraftRunService(db)
# 1. 验证应用
app = service._get_app_or_404(app_id)
@@ -537,12 +407,11 @@ async def draft_run(
service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
workspace_id=app.workspace_id,
other_id=str(current_user.id),
original_user_id=str(current_user.id) # Save original user_id to other_id
)
payload.user_id = str(new_end_user.id)
@@ -559,29 +428,18 @@ async def draft_run(
service._check_agent_config(app_id)
# 2. 获取 Agent 配置
# 共享应用:从最新发布版本读配置快照,而非草稿
is_shared = app.workspace_id != workspace_id
if is_shared:
if not app.current_release_id:
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
release = db.get(AppRelease, app.current_release_id)
if not release:
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
agent_cfg = service._agent_config_from_release(release)
model_config = db.get(ModelConfig, release.default_model_config_id) if release.default_model_config_id else None
else:
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id)
agent_cfg = db.scalars(stmt).first()
if not agent_cfg:
raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING)
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 3. 获取模型配置
model_config = None
if agent_cfg.default_model_config_id:
model_config = db.get(ModelConfig, agent_cfg.default_model_config_id)
if not model_config:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("模型配置", str(agent_cfg.default_model_config_id))
# 流式返回
if payload.stream:
@@ -596,8 +454,7 @@ async def draft_run(
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
files=payload.files # 传递多模态文件
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -618,13 +475,12 @@ async def draft_run(
"app_id": str(app_id),
"message_length": len(payload.message),
"has_conversation_id": bool(payload.conversation_id),
"has_variables": bool(payload.variables),
"has_files": bool(payload.files)
"has_variables": bool(payload.variables)
}
)
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run(
agent_config=agent_cfg,
model_config=model_config,
@@ -634,8 +490,7 @@ async def draft_run(
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
files=payload.files # 传递多模态文件
user_rag_memory_id=user_rag_memory_id
)
logger.debug(
@@ -737,17 +592,7 @@ async def draft_run(
msg="多 Agent 任务执行成功"
)
elif app.type == AppType.WORKFLOW: # 工作流
# 共享应用:从最新发布版本读配置快照,而非草稿
is_shared = app.workspace_id != workspace_id
if is_shared:
if not app.current_release_id:
raise BusinessException("该应用尚未发布,无法使用", BizCode.AGENT_CONFIG_MISSING)
release = db.get(AppRelease, app.current_release_id)
if not release:
raise BusinessException("发布版本不存在", BizCode.AGENT_CONFIG_MISSING)
config = service._workflow_config_from_release(release)
else:
config = workflow_service.check_config(app_id)
config = workflow_service.check_config(app_id)
# 3. 流式返回
if payload.stream:
logger.debug(
@@ -816,11 +661,6 @@ async def draft_run(
data=result,
msg="工作流任务执行成功"
)
else:
return fail(
msg="未知应用类型",
code=422
)
@router.post("/{app_id}/draft/run/compare", summary="多模型对比试运行")
@@ -890,16 +730,6 @@ async def draft_run_compare(
raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED)
service._validate_app_accessible(app, workspace_id)
if payload.user_id is None:
# 先获取 app 的 workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app_id,
workspace_id=app.workspace_id,
other_id=str(current_user.id),
)
payload.user_id = str(new_end_user.id)
# 2. 获取 Agent 配置
from sqlalchemy import select
from app.models import AgentConfig
@@ -945,33 +775,25 @@ async def draft_run_compare(
"conversation_id": model_item.conversation_id # 传递每个模型的 conversation_id
})
# 从 features 中读取功能开关(与 draft_run 保持一致)
features_config: dict = agent_cfg.features or {}
if hasattr(features_config, 'model_dump'):
features_config = features_config.model_dump()
web_search_feature = features_config.get("web_search", {})
web_search = isinstance(web_search_feature, dict) and web_search_feature.get("enabled", False)
# 流式返回
if payload.stream:
async def event_generator():
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
async for event in draft_service.run_compare_stream(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60,
files=payload.files
timeout=payload.timeout or 60
):
yield event
@@ -986,23 +808,22 @@ async def draft_run_compare(
)
# 非流式返回
from app.services.draft_run_service import AgentRunService
draft_service = AgentRunService(db)
from app.services.draft_run_service import DraftRunService
draft_service = DraftRunService(db)
result = await draft_service.run_compare(
agent_config=agent_cfg,
models=model_configs,
message=payload.message,
workspace_id=workspace_id,
conversation_id=payload.conversation_id,
user_id=payload.user_id,
user_id=payload.user_id or str(current_user.id),
variables=payload.variables,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
web_search=web_search,
web_search=True,
memory=True,
parallel=payload.parallel,
timeout=payload.timeout or 60,
files=payload.files
timeout=payload.timeout or 60
)
logger.info(
@@ -1046,187 +867,3 @@ async def update_workflow_config(
workspace_id = current_user.current_workspace_id
cfg = app_service.update_workflow_config(db, app_id=app_id, data=payload, workspace_id=workspace_id)
return success(data=WorkflowConfigSchema.model_validate(cfg))
@router.get("/{app_id}/workflow/export")
@cur_workspace_access_guard()
async def export_workflow_config(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)]
):
"""导出工作流配置为YAML文件"""
workflow_service = WorkflowService(db)
return success(data={
"content": workflow_service.export_workflow_dsl(app_id=app_id),
})
@router.post("/workflow/import")
@cur_workspace_access_guard()
async def import_workflow_config(
file: UploadFile = File(...),
platform: str = Form(...),
app_id: str = Form(None),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从YAML内容导入工作流配置"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST)
raw_text = (await file.read()).decode("utf-8")
import_service = WorkflowImportService(db)
config = yaml.safe_load(raw_text)
result = await import_service.upload_config(platform, config)
return success(data=result)
@router.post("/workflow/import/save")
@cur_workspace_access_guard()
async def save_workflow_import(
data: WorkflowImportSave,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
import_service = WorkflowImportService(db)
app = await import_service.save_workflow(
user_id=current_user.id,
workspace_id=current_user.current_workspace_id,
temp_id=data.temp_id,
name=data.name,
description=data.description,
)
return success(data=app_schema.App.model_validate(app))
@router.get("/{app_id}/statistics", summary="应用统计数据")
@cur_workspace_access_guard()
def get_app_statistics(
app_id: uuid.UUID,
start_date: int,
end_date: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取应用统计数据
Args:
app_id: 应用ID
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
- daily_conversations: 每日会话数统计
- total_conversations: 总会话数
- daily_new_users: 每日新增用户数
- total_new_users: 总新增用户数
- daily_api_calls: 每日API调用次数
- total_api_calls: 总API调用次数
- daily_tokens: 每日token消耗
- total_tokens: 总token消耗
"""
workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db)
result = stats_service.get_app_statistics(
app_id=app_id,
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
return success(data=result)
@router.get("/workspace/api-statistics", summary="工作空间API调用统计")
@cur_workspace_access_guard()
def get_workspace_api_statistics(
start_date: int,
end_date: int,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""获取工作空间API调用统计
Args:
start_date: 开始时间戳(毫秒)
end_date: 结束时间戳(毫秒)
db: 数据库连接
current_user: 当前用户
Returns:
每日统计数据列表,每项包含:
- date: 日期
- total_calls: 当日总调用次数
- app_calls: 当日应用调用次数
- service_calls: 当日服务调用次数
"""
workspace_id = current_user.current_workspace_id
stats_service = AppStatisticsService(db)
result = stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
return success(data=result)
@router.get("/{app_id}/export", summary="导出应用配置为 YAML 文件")
@cur_workspace_access_guard()
async def export_app(
app_id: uuid.UUID,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
release_id: Optional[uuid.UUID] = None
):
"""导出 agent / multi_agent / workflow 应用配置为 YAML 文件流。
release_id: 指定发布版本id不传则导出当前草稿配置。
"""
yaml_str, filename = AppDslService(db).export_dsl(app_id, release_id)
encoded = quote(filename, safe=".")
yaml_bytes = yaml_str.encode("utf-8")
file_stream = io.BytesIO(yaml_bytes)
file_stream.seek(0)
return StreamingResponse(
file_stream,
media_type="application/octet-stream; charset=utf-8",
headers={"Content-Disposition": f"attachment; filename={encoded}",
"Content-Length": str(len(yaml_bytes))}
)
@router.post("/import", summary="从 YAML 文件导入应用")
@cur_workspace_access_guard()
async def import_app(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从 YAML 文件导入 agent / multi_agent / workflow 应用。
跨空间/跨租户导入时,模型/工具/知识库会按名称匹配,匹配不到则置空并返回 warnings。
"""
if not file.filename.lower().endswith((".yaml", ".yml")):
return fail(msg="仅支持 YAML 文件", code=BizCode.BAD_REQUEST)
raw = (await file.read()).decode("utf-8")
dsl = yaml.safe_load(raw)
if not dsl or "app" not in dsl:
return fail(msg="YAML 格式无效,缺少 app 字段", code=BizCode.BAD_REQUEST)
new_app, warnings = AppDslService(db).import_dsl(
dsl=dsl,
workspace_id=current_user.current_workspace_id,
tenant_id=current_user.tenant_id,
user_id=current_user.id,
)
return success(
data={"app": app_schema.App.model_validate(new_app), "warnings": warnings},
msg="应用导入成功" + (",但部分资源需手动配置" if warnings else "")
)

View File

@@ -1,5 +1,4 @@
from datetime import datetime, timedelta, timezone
from typing import Callable
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -17,7 +16,6 @@ from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.dependencies import get_current_user, oauth2_scheme
from app.models.user_model import User
from app.i18n.dependencies import get_translator
# 获取专用日志器
auth_logger = get_auth_logger()
@@ -28,8 +26,7 @@ router = APIRouter(tags=["Authentication"])
@router.post("/token", response_model=ApiResponse)
async def login_for_access_token(
form_data: TokenRequest,
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
db: Session = Depends(get_db)
):
"""用户登录获取token"""
auth_logger.info(f"用户登录请求: {form_data.email}")
@@ -43,10 +40,10 @@ async def login_for_access_token(
invite_info = workspace_service.validate_invite_token(db, form_data.invite)
if not invite_info.is_valid:
raise BusinessException(t("auth.invite.invalid"), code=BizCode.BAD_REQUEST)
raise BusinessException("邀请码无效或已过期", code=BizCode.BAD_REQUEST)
if invite_info.email != form_data.email:
raise BusinessException(t("auth.invite.email_mismatch"), code=BizCode.BAD_REQUEST)
raise BusinessException("邀请邮箱与登录邮箱不匹配", code=BizCode.BAD_REQUEST)
auth_logger.info(f"邀请码验证成功: workspace={invite_info.workspace_name}")
try:
# 尝试认证用户
@@ -64,7 +61,6 @@ async def login_for_access_token(
user = auth_service.register_user_with_invite(
db=db,
email=form_data.email,
username=form_data.username,
password=form_data.password,
invite_token=form_data.invite,
workspace_id=invite_info.workspace_id
@@ -72,7 +68,7 @@ async def login_for_access_token(
elif e.code == BizCode.PASSWORD_ERROR:
# 用户存在但密码错误
auth_logger.warning(f"接受邀请失败,密码验证错误: {form_data.email}")
raise BusinessException(t("auth.invite.password_verification_failed"), BizCode.LOGIN_FAILED)
raise BusinessException("接受邀请失败,密码验证错误", BizCode.LOGIN_FAILED)
else:
# 其他认证失败情况,直接抛出
raise
@@ -85,7 +81,7 @@ async def login_for_access_token(
except BusinessException as e:
# 其他认证失败情况,直接抛出
raise BusinessException(e.message, BizCode.LOGIN_FAILED)
raise BusinessException(e.message,BizCode.LOGIN_FAILED)
# 创建 tokens
access_token, access_token_id = security.create_access_token(subject=user.id)
@@ -113,15 +109,14 @@ async def login_for_access_token(
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg=t("auth.login.success")
msg="登录成功"
)
@router.post("/refresh", response_model=ApiResponse)
async def refresh_token(
refresh_request: RefreshTokenRequest,
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
db: Session = Depends(get_db)
):
"""刷新token"""
auth_logger.info("收到token刷新请求")
@@ -129,18 +124,18 @@ async def refresh_token(
# 验证 refresh token
userId = security.verify_token(refresh_request.refresh_token, "refresh")
if not userId:
raise BusinessException(t("auth.token.invalid_refresh_token"), code=BizCode.TOKEN_INVALID)
raise BusinessException("无效的refresh token", code=BizCode.TOKEN_INVALID)
# 检查用户是否存在
user = auth_service.get_user_by_id(db, userId)
if not user:
raise BusinessException(t("auth.user.not_found"), code=BizCode.USER_NOT_FOUND)
raise BusinessException("用户不存在", code=BizCode.USER_NOT_FOUND)
# 检查 refresh token 黑名单
if settings.ENABLE_SINGLE_SESSION:
refresh_token_id = security.get_token_id(refresh_request.refresh_token)
if refresh_token_id and await SessionService.is_token_blacklisted(refresh_token_id):
raise BusinessException(t("auth.token.refresh_token_blacklisted"), code=BizCode.TOKEN_BLACKLISTED)
raise BusinessException("Refresh token已失效", code=BizCode.TOKEN_BLACKLISTED)
# 生成新 tokens
new_access_token, new_access_token_id = security.create_access_token(subject=user.id)
@@ -171,7 +166,7 @@ async def refresh_token(
expires_at=access_expires_at,
refresh_expires_at=refresh_expires_at
),
msg=t("auth.token.refresh_success")
msg="token刷新成功"
)
@@ -179,15 +174,14 @@ async def refresh_token(
async def logout(
token: str = Depends(oauth2_scheme),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
t: Callable = Depends(get_translator)
db: Session = Depends(get_db)
):
"""登出当前用户加入token黑名单并清理会话"""
auth_logger.info(f"用户 {current_user.username} 请求登出")
token_id = security.get_token_id(token)
if not token_id:
raise BusinessException(t("auth.token.invalid"), code=BizCode.TOKEN_INVALID)
raise BusinessException("无效的access token", code=BizCode.TOKEN_INVALID)
# 加入黑名单
await SessionService.blacklist_token(token_id)
@@ -197,5 +191,5 @@ async def logout(
await SessionService.clear_user_session(current_user.username)
auth_logger.info(f"用户 {current_user.username} 登出成功")
return success(msg=t("auth.logout.success"))
return success(msg="登出成功")

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -7,13 +7,11 @@ Routes:
GET /memory/config/emotion - 获取情绪引擎配置
POST /memory/config/emotion - 更新情绪引擎配置
"""
import uuid
from fastapi import APIRouter, Depends, Query, HTTPException, status
from pydantic import BaseModel, Field
from typing import Optional, Union
from typing import Optional
from sqlalchemy.orm import Session
from uuid import UUID
from app.core.response_utils import success
from app.dependencies import get_current_user
@@ -22,7 +20,6 @@ from app.schemas.response_schema import ApiResponse
from app.services.emotion_config_service import EmotionConfigService
from app.core.logging_config import get_api_logger
from app.db import get_db
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -35,11 +32,11 @@ router = APIRouter(
class EmotionConfigQuery(BaseModel):
"""情绪配置查询请求模型"""
config_id: UUID = Field(..., description="配置ID")
config_id: int = Field(..., description="配置ID")
class EmotionConfigUpdate(BaseModel):
"""情绪配置更新请求模型"""
config_id: Union[uuid.UUID, int, str]= Field(..., description="配置ID")
config_id: int = Field(..., description="配置ID")
emotion_enabled: bool = Field(..., description="是否启用情绪提取")
emotion_model_id: Optional[str] = Field(None, description="情绪分析专用模型ID")
emotion_extract_keywords: bool = Field(..., description="是否提取情绪关键词")
@@ -48,7 +45,7 @@ class EmotionConfigUpdate(BaseModel):
@router.get("/read_config", response_model=ApiResponse)
def get_emotion_config(
config_id: UUID|int = Query(..., description="配置ID"),
config_id: int = Query(..., description="配置ID"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -81,7 +78,7 @@ def get_emotion_config(
f"用户 {current_user.username} 请求获取情绪配置",
extra={"config_id": config_id}
)
config_id=resolve_config_id(config_id, db)
# 初始化服务
config_service = EmotionConfigService(db)
@@ -160,7 +157,6 @@ def update_emotion_config(
}
}
"""
config.config_id=resolve_config_id(config.config_id, db)
try:
api_logger.info(
f"用户 {current_user.username} 请求更新情绪配置",

View File

@@ -11,7 +11,6 @@ Routes:
"""
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user, get_db
@@ -25,7 +24,7 @@ from app.schemas.emotion_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.emotion_analytics_service import EmotionAnalyticsService
from fastapi import APIRouter, Depends, HTTPException, status,Header
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
# 获取API专用日志器
@@ -46,51 +45,45 @@ emotion_service = EmotionAnalyticsService()
@router.post("/tags", response_model=ApiResponse)
async def get_emotion_tags(
request: EmotionTagsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪标签统计",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"start_date": request.start_date,
"end_date": request.end_date,
"limit": request.limit,
"language_type": language
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_tags(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
emotion_type=request.emotion_type,
start_date=request.start_date,
end_date=request.end_date,
limit=request.limit,
language=language
limit=request.limit
)
api_logger.info(
"情绪标签统计获取成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"total_count": data.get("total_count", 0),
"tags_count": len(data.get("tags", []))
}
)
return success(data=data, msg="情绪标签获取成功")
except Exception as e:
api_logger.error(
f"获取情绪标签统计失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -103,44 +96,40 @@ async def get_emotion_tags(
@router.post("/wordcloud", response_model=ApiResponse)
async def get_emotion_wordcloud(
request: EmotionWordcloudRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪词云数据",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"emotion_type": request.emotion_type,
"limit": request.limit
}
)
# 调用服务层
data = await emotion_service.get_emotion_wordcloud(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
emotion_type=request.emotion_type,
limit=request.limit
)
api_logger.info(
"情绪词云数据获取成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"total_keywords": data.get("total_keywords", 0)
}
)
return success(data=data, msg="情绪词云获取成功")
except Exception as e:
api_logger.error(
f"获取情绪词云数据失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -153,52 +142,48 @@ async def get_emotion_wordcloud(
@router.post("/health", response_model=ApiResponse)
async def get_emotion_health(
request: EmotionHealthRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
):
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
# 验证时间范围参数
if request.time_range not in ["7d", "30d", "90d"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="时间范围参数无效,必须是 7d、30d 或 90d"
)
api_logger.info(
f"用户 {current_user.username} 请求获取情绪健康指数",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"time_range": request.time_range
}
)
# 调用服务层
data = await emotion_service.calculate_emotion_health_index(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
time_range=request.time_range
)
api_logger.info(
"情绪健康指数获取成功",
extra={
"end_user_id": request.end_user_id,
"health_score": data.get("health_score") or 0,
"group_id": request.group_id,
"health_score": data.get("health_score", 0),
"level": data.get("level", "未知")
}
)
return success(data=data, msg="情绪健康指数获取成功")
except HTTPException:
raise
except Exception as e:
api_logger.error(
f"获取情绪健康指数失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -208,112 +193,63 @@ async def get_emotion_health(
# @router.post("/check-data", response_model=ApiResponse)
# async def check_emotion_data_exists(
# request: EmotionSuggestionsRequest,
# db: Session = Depends(get_db),
# current_user: User = Depends(get_current_user),
# ):
# """检查用户情绪建议数据是否存在
# Args:
# request: 包含 end_user_id
# db: 数据库会话
# current_user: 当前用户
# Returns:
# 数据存在状态
# """
# try:
# api_logger.info(
# f"检查用户情绪建议数据是否存在: {request.end_user_id}",
# extra={"end_user_id": request.end_user_id}
# )
# # 从数据库获取建议
# data = await emotion_service.get_cached_suggestions(
# end_user_id=request.end_user_id,
# db=db
# )
# if data is None:
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据不存在")
# return fail(
# BizCode.NOT_FOUND,
# "情绪建议数据不存在,请点击右上角刷新进行初始化",
# {"exists": False}
# )
# api_logger.info(f"用户 {request.end_user_id} 的情绪建议数据存在")
# return success(data={"exists": True}, msg="情绪建议数据已存在")
# except Exception as e:
# api_logger.error(
# f"检查情绪建议数据失败: {str(e)}",
# extra={"end_user_id": request.end_user_id},
# exc_info=True
# )
# raise HTTPException(
# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
# detail=f"检查情绪建议数据失败: {str(e)}"
# )
@router.post("/suggestions", response_model=ApiResponse)
async def get_emotion_suggestions(
request: EmotionSuggestionsRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""获取个性化情绪建议(从数据库读取)
"""获取个性化情绪建议(从缓存读取)
Args:
request: 包含 end_user_id 和可选的 config_id
request: 包含 group_id 和可选的 config_id
db: 数据库会话
current_user: 当前用户
Returns:
的个性化情绪建议响应
存的个性化情绪建议响应
"""
try:
api_logger.info(
f"用户 {current_user.username} 请求获取个性化情绪建议",
f"用户 {current_user.username} 请求获取个性化情绪建议(缓存)",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"config_id": request.config_id
}
)
# 从数据库获取建议
# 从缓存获取建议
data = await emotion_service.get_cached_suggestions(
end_user_id=request.end_user_id,
end_user_id=request.group_id,
db=db
)
if data is None:
# 缓存不存在或已过期
api_logger.info(
f"用户 {request.end_user_id} 的建议数据不存在",
extra={"end_user_id": request.end_user_id}
f"用户 {request.group_id} 的建议缓存不存在或已过期",
extra={"group_id": request.group_id}
)
return success(
data={"exists": False},
msg="情绪建议数据不存在,请点击右上角刷新进行初始化"
return fail(
BizCode.RESOURCE_NOT_FOUND,
"建议缓存不存在或已过期,请调用 /generate_suggestions 接口生成新建议",
None
)
api_logger.info(
"个性化建议获取成功",
"个性化建议获取成功(缓存)",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议获取成功")
return success(data=data, msg="个性化建议获取成功(缓存)")
except Exception as e:
api_logger.error(
f"获取个性化建议失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
@@ -325,62 +261,83 @@ async def get_emotion_suggestions(
@router.post("/generate_suggestions", response_model=ApiResponse)
async def generate_emotion_suggestions(
request: EmotionGenerateSuggestionsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""生成个性化情绪建议调用LLM并保存到数据库
"""生成个性化情绪建议调用LLM并缓存
Args:
request: 包含 end_user_id
request: 包含 group_id、可选的 config_id 和 force_refresh
db: 数据库会话
current_user: 当前用户
Returns:
新生成的个性化情绪建议响应
"""
try:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
# 验证 config_id如果提供
# 获取终端用户关联的配置
config_id = request.config_id
if config_id is None:
# 如果没有提供 config_id尝试获取用户关联的配置
try:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(request.group_id, db)
config_id = connected_config.get("memory_config_id")
except ValueError as e:
return fail(BizCode.INVALID_PARAMETER, "无法获取用户关联的配置", str(e))
else:
# 如果提供了 config_id验证其有效性
from app.services.memory_config_service import MemoryConfigService
try:
config_service = MemoryConfigService(db)
config = config_service.get_config_by_id(config_id)
if not config:
return fail(BizCode.INVALID_PARAMETER, "配置ID无效", f"配置 {config_id} 不存在")
except Exception as e:
return fail(BizCode.INVALID_PARAMETER, "配置ID验证失败", str(e))
api_logger.info(
f"用户 {current_user.username} 请求生成个性化情绪建议",
extra={
"end_user_id": request.end_user_id
"group_id": request.group_id,
"config_id": config_id
}
)
# 调用服务层生成建议
data = await emotion_service.generate_emotion_suggestions(
end_user_id=request.end_user_id,
db=db,
language=language
)
# 保存到数据库
await emotion_service.save_suggestions_cache(
end_user_id=request.end_user_id,
suggestions_data=data,
end_user_id=request.group_id,
db=db
)
# 保存到缓存
await emotion_service.save_suggestions_cache(
end_user_id=request.group_id,
suggestions_data=data,
db=db,
expires_hours=24
)
api_logger.info(
"个性化建议生成成功",
extra={
"end_user_id": request.end_user_id,
"group_id": request.group_id,
"suggestions_count": len(data.get("suggestions", []))
}
)
return success(data=data, msg="个性化建议生成成功")
except Exception as e:
api_logger.error(
f"生成个性化建议失败: {str(e)}",
extra={"end_user_id": request.end_user_id},
extra={"group_id": request.group_id},
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"生成个性化建议失败: {str(e)}"
)
)

View File

@@ -1,48 +0,0 @@
"""End User 管理接口 - 无需认证"""
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
)
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
router = APIRouter(prefix="/end_users", tags=["End Users"])
logger = get_business_logger()
@router.post("")
async def create_end_user(
data: CreateEndUserRequest,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the given workspace.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
"""
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=data.workspace_id,
other_id=data.other_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -1,655 +0,0 @@
"""
File storage controller module.
This module provides API endpoints for file storage operations using the
configurable storage backend. It is a new controller that does not modify
the existing file_controller.py.
Routes:
POST /storage/files - Upload a file
GET /storage/files/{file_id} - Download a file
DELETE /storage/files/{file_id} - Delete a file
"""
import os
import uuid
from typing import Any
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse
from sqlalchemy.orm import Session
from app.core.config import settings
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.storage import LocalStorage
from app.core.storage.url_signer import generate_signed_url, verify_signed_url
from app.core.storage_exceptions import (
StorageDeleteError,
StorageUploadError,
)
from app.db import get_db
from app.dependencies import get_current_user, get_share_user_id, ShareTokenData
from app.models.file_metadata_model import FileMetadata
from app.models.user_model import User
from app.schemas.response_schema import ApiResponse
from app.services.file_storage_service import (
FileStorageService,
generate_file_key,
get_file_storage_service,
)
api_logger = get_api_logger()
router = APIRouter(
prefix="/storage",
tags=["storage"]
)
def _match_scheme(request: Request, url: str) -> str:
"""
将 presigned URL 的协议替换为与当前请求一致的协议http/https
解决反向代理场景下 presigned URL 协议与请求协议不匹配的问题。
"""
incoming_scheme = request.headers.get("x-forwarded-proto") or request.url.scheme
if url.startswith("http://") and incoming_scheme == "https":
return "https://" + url[7:]
if url.startswith("https://") and incoming_scheme == "http":
return "http://" + url[8:]
return url
@router.post("/files", response_model=ApiResponse)
async def upload_file(
file: UploadFile = File(...),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Upload a file to the configured storage backend.
"""
tenant_id = current_user.tenant_id
workspace_id = current_user.current_workspace_id
api_logger.info(
f"Storage upload request: tenant_id={tenant_id}, workspace_id={workspace_id}, "
f"filename={file.filename}, username={current_user.username}"
)
# Read file contents
contents = await file.read()
file_size = len(contents)
# Validate file size
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
)
# Extract file extension
_, file_extension = os.path.splitext(file.filename)
file_ext = file_extension.lower()
# Generate file_id and file_key
file_id = uuid.uuid4()
file_key = generate_file_key(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
)
# Create file metadata record with pending status
file_metadata = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=file.filename,
file_ext=file_ext,
file_size=file_size,
content_type=file.content_type,
status="pending",
)
db.add(file_metadata)
db.commit()
db.refresh(file_metadata)
# Upload file to storage backend
try:
await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=contents,
content_type=file.content_type,
)
# Update status to completed
file_metadata.status = "completed"
db.commit()
api_logger.info(f"File uploaded to storage: file_key={file_key}")
except StorageUploadError as e:
# Update status to failed
file_metadata.status = "failed"
db.commit()
api_logger.error(f"Storage upload failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File storage failed: {str(e)}"
)
api_logger.info(f"File upload successful: {file.filename} (file_id: {file_id})")
return success(
data={"file_id": str(file_id), "file_key": file_key},
msg="File upload successful"
)
@router.post("/share/files", response_model=ApiResponse)
async def upload_file_with_share_token(
file: UploadFile = File(...),
db: Session = Depends(get_db),
share_data: ShareTokenData = Depends(get_share_user_id),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Upload a file to the configured storage backend using share_token authentication.
"""
from app.services.release_share_service import ReleaseShareService
from app.models.app_model import App
from app.models.workspace_model import Workspace
# Get share and release info from share_token
service = ReleaseShareService(db)
share_info = service.get_shared_release_info(share_token=share_data.share_token)
# Get share object to access app_id
share = service.repo.get_by_share_token(share_data.share_token)
if not share:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Shared app not found"
)
# Get app to access workspace_id
app = db.query(App).filter(
App.id == share.app_id,
App.is_active.is_(True)
).first()
if not app:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="App not found"
)
# Get workspace to access tenant_id
workspace = db.query(Workspace).filter(
Workspace.id == app.workspace_id
).first()
if not workspace:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Workspace not found"
)
tenant_id = workspace.tenant_id
workspace_id = app.workspace_id
api_logger.info(
f"Storage upload request (share): tenant_id={tenant_id}, workspace_id={workspace_id}, "
f"filename={file.filename}, share_token={share_data.share_token}"
)
# Read file contents
contents = await file.read()
file_size = len(contents)
# Validate file size
if file_size == 0:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The file is empty."
)
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
)
# Extract file extension
_, file_extension = os.path.splitext(file.filename)
file_ext = file_extension.lower()
# Generate file_id and file_key
file_id = uuid.uuid4()
file_key = generate_file_key(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
)
# Create file metadata record with pending status
file_metadata = FileMetadata(
id=file_id,
tenant_id=tenant_id,
workspace_id=workspace_id,
file_key=file_key,
file_name=file.filename,
file_ext=file_ext,
file_size=file_size,
content_type=file.content_type,
status="pending",
)
db.add(file_metadata)
db.commit()
db.refresh(file_metadata)
# Upload file to storage backend
try:
await storage_service.upload_file(
tenant_id=tenant_id,
workspace_id=workspace_id,
file_id=file_id,
file_ext=file_ext,
content=contents,
content_type=file.content_type,
)
# Update status to completed
file_metadata.status = "completed"
db.commit()
api_logger.info(f"File uploaded to storage (share): file_key={file_key}")
except StorageUploadError as e:
# Update status to failed
file_metadata.status = "failed"
db.commit()
api_logger.error(f"Storage upload failed (share): {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"File storage failed: {str(e)}"
)
api_logger.info(f"File upload successful (share): {file.filename} (file_id: {file_id})")
return success(
data={"file_id": str(file_id), "file_key": file_key},
msg="File upload successful"
)
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
request: Request,
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Download a file from the configured storage backend.
"""
api_logger.info(f"Storage download request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found (possibly deleted)"
)
api_logger.info(f"Serving local file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
presigned_url = _match_scheme(request, presigned_url)
api_logger.info(f"Redirecting to presigned URL: file_key={file_key}")
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except FileNotFoundError:
api_logger.warning(f"File not found in remote storage: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found in storage"
)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
@router.delete("/files/{file_id}", response_model=ApiResponse)
async def delete_file(
file_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Delete a file from the configured storage backend.
"""
api_logger.info(
f"Storage delete request: file_id={file_id}, username={current_user.username}"
)
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
file_key = file_metadata.file_key
# Delete file from storage
try:
deleted = await storage_service.delete_file(file_key)
if deleted:
api_logger.info(f"File deleted from storage: file_key={file_key}")
else:
api_logger.info(f"File did not exist in storage: file_key={file_key}")
except StorageDeleteError as e:
api_logger.error(f"Storage delete failed: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete file from storage: {str(e)}"
)
# Delete database record
try:
db.delete(file_metadata)
db.commit()
api_logger.info(f"File record deleted from database: file_id={file_id}")
except Exception as e:
api_logger.error(f"Database delete failed: {e}")
db.rollback()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete file record: {str(e)}"
)
return success(msg="File deleted successfully")
@router.get("/files/{file_id}/url", response_model=ApiResponse)
async def get_file_url(
request: Request,
file_id: uuid.UUID,
expires: int = None,
permanent: bool = False,
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
Get an access URL for a file (no authentication required).
Args:
file_id: The UUID of the file.
expires: URL validity period in seconds (default from FILE_URL_EXPIRES env).
permanent: If True, return a permanent URL without expiration.
db: Database session.
storage_service: The file storage service.
Returns:
ApiResponse with the access URL.
"""
if expires is None:
expires = settings.FILE_URL_EXPIRES
api_logger.info(f"Get file URL request: file_id={file_id}, expires={expires}, permanent={permanent}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
try:
if permanent:
# Generate permanent URL (no expiration check)
server_url = settings.FILE_LOCAL_SERVER_URL
url = f"{server_url}/storage/permanent/{file_id}"
return success(
data={
"url": url,
"expires_in": None,
"permanent": True,
"file_name": file_metadata.file_name,
},
msg="Permanent file URL generated successfully"
)
if isinstance(storage, LocalStorage):
# For local storage, generate signed URL with expiration
url = generate_signed_url(str(file_id), expires)
else:
# For remote storage (OSS/S3), get presigned URL
url = await storage_service.get_file_url(file_key, expires=expires)
url = _match_scheme(request, url)
api_logger.info(f"Generated file URL: file_id={file_id}")
return success(
data={
"url": url,
"expires_in": expires,
"permanent": False,
"file_name": file_metadata.file_name,
},
msg="File URL generated successfully"
)
except Exception as e:
api_logger.error(f"Failed to generate file URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate file URL: {str(e)}"
)
@router.get("/public/{file_id}", response_model=Any)
async def public_download_file(
request: Request,
file_id: uuid.UUID,
expires: int = 0,
signature: str = "",
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Public file download endpoint with signature verification.
This endpoint allows downloading files without authentication,
but requires a valid signature and non-expired timestamp.
Args:
file_id: The UUID of the file.
expires: Expiration timestamp.
signature: HMAC signature for verification.
db: Database session.
storage_service: The file storage service.
Returns:
FileResponse for the requested file.
"""
api_logger.info(f"Public download request: file_id={file_id}")
# Verify signature
is_valid, error_msg = verify_signed_url(str(file_id), expires, signature)
if not is_valid:
api_logger.warning(f"Invalid signed URL: file_id={file_id}, error={error_msg}")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=error_msg
)
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
api_logger.info(f"Serving public file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
# For remote storage, redirect to presigned URL
try:
presigned_url = await storage_service.get_file_url(file_key, expires=3600)
presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
@router.get("/permanent/{file_id}", response_model=Any)
async def permanent_download_file(
request: Request,
file_id: uuid.UUID,
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
) -> Any:
"""
Permanent file download endpoint (no expiration, no signature required).
This endpoint allows downloading files without authentication or expiration.
Use with caution as URLs are permanently accessible.
Args:
file_id: The UUID of the file.
db: Database session.
storage_service: The file storage service.
Returns:
FileResponse for the requested file.
"""
api_logger.info(f"Permanent download request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
if file_metadata.status != "completed":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}"
)
file_key = file_metadata.file_key
storage = storage_service.storage
if isinstance(storage, LocalStorage):
full_path = storage._get_full_path(file_key)
if not full_path.exists():
api_logger.warning(f"File not found on disk: file_key={file_key}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="File not found"
)
api_logger.info(f"Serving permanent file: file_key={file_key}")
return FileResponse(
path=str(full_path),
filename=file_metadata.file_name,
media_type=file_metadata.content_type or "application/octet-stream"
)
else:
# For remote storage, redirect to presigned URL with long expiration
try:
# Use a very long expiration (7 days max for most cloud providers)
presigned_url = await storage_service.get_file_url(file_key, expires=604800)
presigned_url = _match_scheme(request, presigned_url)
return RedirectResponse(url=presigned_url, status_code=status.HTTP_302_FOUND)
except Exception as e:
api_logger.error(f"Failed to get presigned URL: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)

View File

@@ -1,833 +0,0 @@
"""
I18n Management API Controller
This module provides management APIs for:
- Language management (list, get, add, update languages)
- Translation management (get, update, reload translations)
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import Callable, Optional
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.i18n.dependencies import get_translator
from app.i18n.service import get_translation_service
from app.models.user_model import User
from app.schemas.i18n_schema import (
LanguageInfo,
LanguageListResponse,
LanguageCreateRequest,
LanguageUpdateRequest,
TranslationResponse,
TranslationUpdateRequest,
MissingTranslationsResponse,
ReloadResponse
)
from app.schemas.response_schema import ApiResponse
api_logger = get_api_logger()
router = APIRouter(
prefix="/i18n",
tags=["I18n Management"],
)
# ============================================================================
# Language Management APIs
# ============================================================================
@router.get("/languages", response_model=ApiResponse)
def get_languages(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of all supported languages.
Returns:
List of language information including code, name, and status
"""
api_logger.info(f"Get languages request from user: {current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Get available locales from translation service
available_locales = translation_service.get_available_locales()
# Build language info list
languages = []
for locale in available_locales:
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
# Get native names
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
languages.append(language_info)
response = LanguageListResponse(languages=languages)
api_logger.info(f"Returning {len(languages)} languages")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/languages/{locale}", response_model=ApiResponse)
def get_language(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get information about a specific language.
Args:
locale: Language code (e.g., 'zh', 'en')
Returns:
Language information
"""
api_logger.info(f"Get language info request: locale={locale}, user={current_user.username}")
from app.core.config import settings
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Build language info
is_default = locale == settings.I18N_DEFAULT_LANGUAGE
is_enabled = locale in settings.I18N_SUPPORTED_LANGUAGES
native_names = {
"zh": "中文(简体)",
"en": "English",
"ja": "日本語",
"ko": "한국어",
"fr": "Français",
"de": "Deutsch",
"es": "Español"
}
language_info = LanguageInfo(
code=locale,
name=f"{locale.upper()}",
native_name=native_names.get(locale, locale),
is_enabled=is_enabled,
is_default=is_default
)
api_logger.info(f"Returning language info for: {locale}")
return success(data=language_info.dict(), msg=t("common.success.retrieved"))
@router.post("/languages", response_model=ApiResponse)
def add_language(
request: LanguageCreateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Add a new language (admin only).
Note: This endpoint validates the request but actual language addition
requires creating translation files in the locales directory.
Args:
request: Language creation request
Returns:
Success message
"""
api_logger.info(
f"Add language request: code={request.code}, admin={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
# Check if language already exists
available_locales = translation_service.get_available_locales()
if request.code in available_locales:
api_logger.warning(f"Language already exists: {request.code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.language.already_exists", locale=request.code)
)
# Note: Actual language addition requires creating translation files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language addition validated: {request.code}. "
"Translation files need to be created manually."
)
return success(
msg=t(
"i18n.language.add_instructions",
locale=request.code,
dir=settings.I18N_CORE_LOCALES_DIR
)
)
@router.put("/languages/{locale}", response_model=ApiResponse)
def update_language(
locale: str,
request: LanguageUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update language configuration (admin only).
Note: This endpoint validates the request but actual configuration
changes require updating environment variables or config files.
Args:
locale: Language code
request: Language update request
Returns:
Success message
"""
api_logger.info(
f"Update language request: locale={locale}, admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if language exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Note: Actual configuration changes require updating settings
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Language update validated: {locale}. "
"Configuration changes require environment variable updates."
)
return success(msg=t("i18n.language.update_instructions", locale=locale))
# ============================================================================
# Translation Management APIs
# ============================================================================
@router.get("/translations", response_model=ApiResponse)
def get_all_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for all or specific locale.
Args:
locale: Optional locale filter
Returns:
All translations organized by locale and namespace
"""
api_logger.info(
f"Get all translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
if locale:
# Get translations for specific locale
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = {
locale: translation_service._cache.get(locale, {})
}
else:
# Get all translations
translations = translation_service._cache
response = TranslationResponse(translations=translations)
api_logger.info(f"Returning translations for: {locale or 'all locales'}")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.get("/translations/{locale}", response_model=ApiResponse)
def get_locale_translations(
locale: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get all translations for a specific locale.
Args:
locale: Language code
Returns:
All translations for the locale organized by namespace
"""
api_logger.info(
f"Get locale translations request: locale={locale}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
translations = translation_service._cache.get(locale, {})
api_logger.info(f"Returning {len(translations)} namespaces for locale: {locale}")
return success(data={"locale": locale, "translations": translations}, msg=t("common.success.retrieved"))
@router.get("/translations/{locale}/{namespace}", response_model=ApiResponse)
def get_namespace_translations(
locale: str,
namespace: str,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get translations for a specific namespace in a locale.
Args:
locale: Language code
namespace: Translation namespace (e.g., 'common', 'auth')
Returns:
Translations for the specified namespace
"""
api_logger.info(
f"Get namespace translations request: locale={locale}, "
f"namespace={namespace}, user={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Get namespace translations
locale_translations = translation_service._cache.get(locale, {})
namespace_translations = locale_translations.get(namespace, {})
if not namespace_translations:
api_logger.warning(f"Namespace not found: {namespace} in locale: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.namespace.not_found", namespace=namespace, locale=locale)
)
api_logger.info(
f"Returning translations for namespace: {namespace} in locale: {locale}"
)
return success(
data={
"locale": locale,
"namespace": namespace,
"translations": namespace_translations
},
msg=t("common.success.retrieved")
)
@router.put("/translations/{locale}/{key:path}", response_model=ApiResponse)
def update_translation(
locale: str,
key: str,
request: TranslationUpdateRequest,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Update a single translation (admin only).
Note: This endpoint validates the request but actual translation updates
require modifying translation files in the locales directory.
Args:
locale: Language code
key: Translation key (format: "namespace.key.subkey")
request: Translation update request
Returns:
Success message
"""
api_logger.info(
f"Update translation request: locale={locale}, key={key}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
# Check if locale exists
available_locales = translation_service.get_available_locales()
if locale not in available_locales:
api_logger.warning(f"Language not found: {locale}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("i18n.language.not_found", locale=locale)
)
# Validate key format
if "." not in key:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=t("i18n.translation.invalid_key_format", key=key)
)
# Note: Actual translation updates require modifying JSON files
# This endpoint serves as a validation and documentation point
api_logger.info(
f"Translation update validated: {locale}/{key}. "
"Translation files need to be updated manually."
)
return success(
msg=t("i18n.translation.update_instructions", locale=locale, key=key)
)
@router.get("/translations/missing", response_model=ApiResponse)
def get_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_user)
):
"""
Get list of missing translations.
Compares translations across locales to find missing keys.
Args:
locale: Optional locale to check (defaults to checking all non-default locales)
Returns:
List of missing translation keys
"""
api_logger.info(
f"Get missing translations request: locale={locale}, user={current_user.username}"
)
from app.core.config import settings
translation_service = get_translation_service()
default_locale = settings.I18N_DEFAULT_LANGUAGE
available_locales = translation_service.get_available_locales()
# Get default locale translations as reference
default_translations = translation_service._cache.get(default_locale, {})
# Collect all keys from default locale
def collect_keys(data, prefix=""):
keys = []
for key, value in data.items():
full_key = f"{prefix}.{key}" if prefix else key
if isinstance(value, dict):
keys.extend(collect_keys(value, full_key))
else:
keys.append(full_key)
return keys
default_keys = set()
for namespace, translations in default_translations.items():
namespace_keys = collect_keys(translations, namespace)
default_keys.update(namespace_keys)
# Find missing keys in target locale(s)
missing_by_locale = {}
target_locales = [locale] if locale else [
loc for loc in available_locales if loc != default_locale
]
for target_locale in target_locales:
if target_locale not in available_locales:
continue
target_translations = translation_service._cache.get(target_locale, {})
target_keys = set()
for namespace, translations in target_translations.items():
namespace_keys = collect_keys(translations, namespace)
target_keys.update(namespace_keys)
missing_keys = default_keys - target_keys
if missing_keys:
missing_by_locale[target_locale] = sorted(list(missing_keys))
response = MissingTranslationsResponse(missing_translations=missing_by_locale)
total_missing = sum(len(keys) for keys in missing_by_locale.values())
api_logger.info(f"Found {total_missing} missing translations across {len(missing_by_locale)} locales")
return success(data=response.dict(), msg=t("common.success.retrieved"))
@router.post("/reload", response_model=ApiResponse)
def reload_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Trigger hot reload of translation files (admin only).
Args:
locale: Optional locale to reload (defaults to reloading all locales)
Returns:
Reload status and statistics
"""
api_logger.info(
f"Reload translations request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
from app.core.config import settings
if not settings.I18N_ENABLE_HOT_RELOAD:
api_logger.warning("Hot reload is disabled in configuration")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=t("i18n.reload.disabled")
)
translation_service = get_translation_service()
try:
# Reload translations
translation_service.reload(locale)
# Get statistics
available_locales = translation_service.get_available_locales()
reloaded_locales = [locale] if locale else available_locales
response = ReloadResponse(
success=True,
reloaded_locales=reloaded_locales,
total_locales=len(available_locales)
)
api_logger.info(
f"Successfully reloaded translations for: {', '.join(reloaded_locales)}"
)
return success(data=response.dict(), msg=t("i18n.reload.success"))
except Exception as e:
api_logger.error(f"Failed to reload translations: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=t("i18n.reload.failed", error=str(e))
)
# ============================================================================
# Performance Monitoring APIs
# ============================================================================
@router.get("/metrics", response_model=ApiResponse)
def get_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get i18n performance metrics (admin only).
Returns:
Performance metrics including:
- Request counts
- Missing translations
- Timing statistics
- Locale usage
- Error counts
"""
api_logger.info(f"Get metrics request: admin={current_user.username}")
translation_service = get_translation_service()
metrics = translation_service.get_metrics_summary()
api_logger.info("Returning i18n metrics")
return success(data=metrics, msg=t("common.success.retrieved"))
@router.get("/metrics/cache", response_model=ApiResponse)
def get_cache_stats(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get cache statistics (admin only).
Returns:
Cache statistics including:
- Hit/miss rates
- LRU cache performance
- Loaded locales
- Memory usage
"""
api_logger.info(f"Get cache stats request: admin={current_user.username}")
translation_service = get_translation_service()
cache_stats = translation_service.get_cache_stats()
memory_usage = translation_service.get_memory_usage()
data = {
"cache": cache_stats,
"memory": memory_usage
}
api_logger.info("Returning cache statistics")
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/metrics/prometheus")
def get_prometheus_metrics(
current_user: User = Depends(get_current_superuser)
):
"""
Get metrics in Prometheus format (admin only).
Returns:
Prometheus-formatted metrics as plain text
"""
api_logger.info(f"Get Prometheus metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
prometheus_output = metrics.export_prometheus()
from fastapi.responses import PlainTextResponse
return PlainTextResponse(content=prometheus_output)
@router.post("/metrics/reset", response_model=ApiResponse)
def reset_metrics(
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Reset all metrics (admin only).
Returns:
Success message
"""
api_logger.info(f"Reset metrics request: admin={current_user.username}")
from app.i18n.metrics import get_metrics
metrics = get_metrics()
metrics.reset()
translation_service = get_translation_service()
translation_service.cache.reset_stats()
api_logger.info("Metrics reset completed")
return success(msg=t("i18n.metrics.reset_success"))
# ============================================================================
# Missing Translation Logging and Reporting APIs
# ============================================================================
@router.get("/logs/missing", response_model=ApiResponse)
def get_missing_translation_logs(
locale: Optional[str] = None,
limit: Optional[int] = 100,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Get missing translation logs (admin only).
Returns logged missing translations with context information.
Args:
locale: Optional locale filter
limit: Maximum number of entries to return (default: 100)
Returns:
Missing translation logs with context
"""
api_logger.info(
f"Get missing translation logs request: locale={locale}, "
f"limit={limit}, admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Get missing translations
missing_translations = translation_logger.get_missing_translations(locale)
# Get missing with context
missing_with_context = translation_logger.get_missing_with_context(locale, limit)
# Get statistics
statistics = translation_logger.get_statistics()
data = {
"missing_translations": missing_translations,
"recent_context": missing_with_context,
"statistics": statistics
}
api_logger.info(
f"Returning {statistics['total_missing']} missing translations"
)
return success(data=data, msg=t("common.success.retrieved"))
@router.get("/logs/missing/report", response_model=ApiResponse)
def generate_missing_translation_report(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Generate a comprehensive missing translation report (admin only).
Args:
locale: Optional locale filter
Returns:
Comprehensive report with missing translations and statistics
"""
api_logger.info(
f"Generate missing translation report request: locale={locale}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate report
report = translation_logger.generate_report(locale)
api_logger.info(
f"Generated report with {report['total_missing']} missing translations"
)
return success(data=report, msg=t("common.success.retrieved"))
@router.post("/logs/missing/export", response_model=ApiResponse)
def export_missing_translations(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Export missing translations to JSON file (admin only).
Args:
locale: Optional locale filter
Returns:
Export status and file path
"""
api_logger.info(
f"Export missing translations request: locale={locale}, "
f"admin={current_user.username}"
)
from datetime import datetime
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Generate filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
locale_suffix = f"_{locale}" if locale else "_all"
output_file = f"logs/i18n/missing_translations{locale_suffix}_{timestamp}.json"
# Export to file
translation_logger.export_to_json(output_file)
api_logger.info(f"Missing translations exported to: {output_file}")
return success(
data={"file_path": output_file},
msg=t("i18n.logs.export_success", file=output_file)
)
@router.delete("/logs/missing", response_model=ApiResponse)
def clear_missing_translation_logs(
locale: Optional[str] = None,
t: Callable = Depends(get_translator),
current_user: User = Depends(get_current_superuser)
):
"""
Clear missing translation logs (admin only).
Args:
locale: Optional locale to clear (clears all if not specified)
Returns:
Success message
"""
api_logger.info(
f"Clear missing translation logs request: locale={locale or 'all'}, "
f"admin={current_user.username}"
)
translation_service = get_translation_service()
translation_logger = translation_service.translation_logger
# Clear logs
translation_logger.clear(locale)
api_logger.info(f"Cleared missing translation logs for: {locale or 'all locales'}")
return success(msg=t("i18n.logs.clear_success"))

View File

@@ -122,52 +122,10 @@ def validate_confidence_threshold(threshold: float) -> None:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
@router.get("/check-data/{end_user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def check_user_data_exists(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
) -> ApiResponse:
"""
检查用户画像数据是否存在
Args:
end_user_id: 目标用户ID
Returns:
数据存在状态
"""
api_logger.info(f"检查用户画像数据是否存在: {end_user_id}")
try:
# Validate inputs
validate_user_id(end_user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return success(
data={"exists": False},
msg="画像数据不存在,请点击右上角刷新进行初始化"
)
api_logger.info(f"用户 {end_user_id} 的画像数据存在")
return success(data={"exists": True}, msg="画像数据已存在")
except Exception as e:
return handle_implicit_memory_error(e, "检查画像数据", end_user_id)
@router.get("/preferences/{end_user_id}", response_model=ApiResponse)
@router.get("/preferences/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_preference_tags(
end_user_id: str,
user_id: str,
confidence_threshold: float = Query(0.5, ge=0.0, le=1.0, description="Minimum confidence threshold"),
tag_category: Optional[str] = Query(None, description="Filter by tag category"),
start_date: Optional[datetime] = Query(None, description="Filter start date"),
@@ -179,7 +137,7 @@ async def get_preference_tags(
Get user preference tags from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
confidence_threshold: Minimum confidence score (0.0-1.0)
tag_category: Optional category filter
start_date: Optional start date filter
@@ -188,21 +146,25 @@ async def get_preference_tags(
Returns:
List of preference tags from cache
"""
api_logger.info(f"Preference tags requested for user: {end_user_id} (from cache)")
api_logger.info(f"Preference tags requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract preferences from cache
preferences = cached_profile.get("preferences", [])
@@ -230,17 +192,17 @@ async def get_preference_tags(
filtered_preferences.append(pref)
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {end_user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_preferences)} preference tags for user: {user_id} (from cache)")
return success(data=filtered_preferences, msg="偏好标签获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "偏好标签获取", end_user_id)
return handle_implicit_memory_error(e, "偏好标签获取", user_id)
@router.get("/portrait/{end_user_id}", response_model=ApiResponse)
@router.get("/portrait/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_dimension_portrait(
end_user_id: str,
user_id: str,
include_history: bool = Query(False, description="Include historical trends"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -249,42 +211,46 @@ async def get_dimension_portrait(
Get user's four-dimension personality portrait from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
include_history: Whether to include historical trend data (ignored for cached data)
Returns:
Four-dimension personality portrait from cache
"""
api_logger.info(f"Dimension portrait requested for user: {end_user_id} (from cache)")
api_logger.info(f"Dimension portrait requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract portrait from cache
portrait = cached_profile.get("portrait", {})
api_logger.info(f"Dimension portrait retrieved for user: {end_user_id} (from cache)")
api_logger.info(f"Dimension portrait retrieved for user: {user_id} (from cache)")
return success(data=portrait, msg="四维画像获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "四维画像获取", end_user_id)
return handle_implicit_memory_error(e, "四维画像获取", user_id)
@router.get("/interest-areas/{end_user_id}", response_model=ApiResponse)
@router.get("/interest-areas/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_interest_area_distribution(
end_user_id: str,
user_id: str,
include_trends: bool = Query(False, description="Include trend analysis"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
@@ -293,42 +259,46 @@ async def get_interest_area_distribution(
Get user's interest area distribution from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
include_trends: Whether to include trend analysis data (ignored for cached data)
Returns:
Interest area distribution from cache
"""
api_logger.info(f"Interest area distribution requested for user: {end_user_id} (from cache)")
api_logger.info(f"Interest area distribution requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract interest areas from cache
interest_areas = cached_profile.get("interest_areas", {})
api_logger.info(f"Interest area distribution retrieved for user: {end_user_id} (from cache)")
api_logger.info(f"Interest area distribution retrieved for user: {user_id} (from cache)")
return success(data=interest_areas, msg="兴趣领域分布获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "兴趣领域分布获取", end_user_id)
return handle_implicit_memory_error(e, "兴趣领域分布获取", user_id)
@router.get("/habits/{end_user_id}", response_model=ApiResponse)
@router.get("/habits/{user_id}", response_model=ApiResponse)
@cur_workspace_access_guard()
async def get_behavior_habits(
end_user_id: str,
user_id: str,
confidence_level: Optional[str] = Query(None, regex="^(high|medium|low)$", description="Filter by confidence level"),
frequency_pattern: Optional[str] = Query(None, regex="^(daily|weekly|monthly|seasonal|occasional|event_triggered)$", description="Filter by frequency pattern"),
time_period: Optional[str] = Query(None, regex="^(current|past)$", description="Filter by time period"),
@@ -339,7 +309,7 @@ async def get_behavior_habits(
Get user's behavioral habits from cache.
Args:
end_user_id: Target end user ID
user_id: Target user ID
confidence_level: Filter by confidence level (high, medium, low)
frequency_pattern: Filter by frequency pattern (daily, weekly, monthly, seasonal, occasional, event_triggered)
time_period: Filter by time period (current, past)
@@ -347,21 +317,25 @@ async def get_behavior_habits(
Returns:
List of behavioral habits from cache
"""
api_logger.info(f"Behavior habits requested for user: {end_user_id} (from cache)")
api_logger.info(f"Behavior habits requested for user: {user_id} (from cache)")
try:
# Validate inputs
validate_user_id(end_user_id)
validate_user_id(user_id)
# Create service with user-specific config
service = ImplicitMemoryService(db=db, end_user_id=end_user_id)
service = ImplicitMemoryService(db=db, end_user_id=user_id)
# Get cached profile
cached_profile = await service.get_cached_profile(end_user_id=end_user_id, db=db)
cached_profile = await service.get_cached_profile(end_user_id=user_id, db=db)
if cached_profile is None:
api_logger.info(f"用户 {end_user_id} 的画像数据不存在")
return fail(BizCode.NOT_FOUND, "", "")
api_logger.info(f"用户 {user_id} 的画像缓存不存在或已过期")
return fail(
BizCode.RESOURCE_NOT_FOUND,
"画像缓存不存在或已过期,请调用 /generate_profile 接口生成新画像",
None
)
# Extract habits from cache
habits = cached_profile.get("habits", [])
@@ -394,11 +368,11 @@ async def get_behavior_habits(
filtered_habits.append(habit)
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {end_user_id} (from cache)")
api_logger.info(f"Retrieved {len(filtered_habits)} behavior habits for user: {user_id} (from cache)")
return success(data=filtered_habits, msg="行为习惯获取成功(缓存)")
except Exception as e:
return handle_implicit_memory_error(e, "行为习惯获取", end_user_id)
return handle_implicit_memory_error(e, "行为习惯获取", user_id)

View File

@@ -9,16 +9,13 @@ from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger
from app.core.rag.common import settings
from app.core.rag.integrations.feishu.client import FeishuAPIClient
from app.core.rag.integrations.yuque.client import YuqueAPIClient
from app.core.rag.llm.chat_model import Base
from app.core.rag.nlp import rag_tokenizer, search
from app.core.rag.prompts.generator import graph_entity_types
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.response_utils import success, fail
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import knowledge_model
@@ -487,99 +484,3 @@ async def rebuild_knowledge_graph(
except Exception as e:
api_logger.error(f"Failed to rebuild knowledge graph: knowledge_id={knowledge_id} - {str(e)}")
raise
@router.get("/check/yuque/auth", response_model=ApiResponse)
async def check_yuque_auth(
yuque_user_id: str,
yuque_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check yuque auth info
"""
api_logger.info(f"check yuque auth info, username: {current_user.username}")
try:
api_client = YuqueAPIClient(
user_id=yuque_user_id,
token=yuque_token
)
async with api_client as client:
repos = await client.get_user_repos()
if repos:
return success(msg="Successfully auth yuque info")
return fail(BizCode.UNAUTHORIZED, msg="auth yuque info failed", error="user_id or token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth yuque info failed: {str(e)}")
raise
@router.get("/check/feishu/auth", response_model=ApiResponse)
async def check_feishu_auth(
feishu_app_id: str,
feishu_app_secret: str,
feishu_folder_token: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
check feishu auth info
"""
api_logger.info(f"check feishu auth info, username: {current_user.username}")
try:
api_client = FeishuAPIClient(
app_id=feishu_app_id,
app_secret=feishu_app_secret
)
async with api_client as client:
files = await client.list_all_folder_files(feishu_folder_token, recursive=True)
if files:
return success(msg="Successfully auth feishu info")
return fail(BizCode.UNAUTHORIZED, msg="auth feishu info failed", error="app_id or app_secret or feishu_folder_token is incorrect")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"auth feishu info failed: {str(e)}")
raise
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
async def sync_knowledge(
knowledge_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
sync knowledge base information based on knowledge_id
"""
api_logger.info(f"Obtain details of the knowledge base: knowledge_id={knowledge_id}, username: {current_user.username}")
try:
# 1. Query knowledge base information from the database
api_logger.debug(f"Query knowledge base: {knowledge_id}")
db_knowledge = knowledge_service.get_knowledge_by_id(db, knowledge_id=knowledge_id, current_user=current_user)
if not db_knowledge:
api_logger.warning(f"The knowledge base does not exist or access is denied: knowledge_id={knowledge_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The knowledge base does not exist or access is denied"
)
# 2. sync knowledge
# from app.tasks import sync_knowledge_for_kb
# sync_knowledge_for_kb(kb_id)
task = celery_app.send_task("app.core.rag.tasks.sync_knowledge_for_kb", args=[knowledge_id])
result = {
"task_id": task.id
}
return success(data=result, msg="Task accepted. sync knowledge is being processed in the background.")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Failed to sync knowledge: knowledge_id={knowledge_id} - {str(e)}")
raise

View File

@@ -1,465 +0,0 @@
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
import requests
from sqlalchemy import or_
from sqlalchemy.orm import Session
from modelscope.hub.errors import raise_for_http_status
from modelscope.hub.mcp_api import MCPApi
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import mcp_market_config_model
from app.models.user_model import User
from app.schemas import mcp_market_config_schema
from app.schemas.response_schema import ApiResponse
from app.services import mcp_market_config_service, mcp_market_service
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/mcp_market_configs",
tags=["mcp_market_configs"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/mcp_servers", response_model=ApiResponse)
async def get_mcp_servers(
mcp_market_config_id: uuid.UUID,
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
keywords: Optional[str] = Query(None, description="Search keywords (Optional search query string,e.g. Chinese service name, English service name, author/owner username)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + mcp server list
"""
api_logger.info(
f"Query mcp server list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
if page * pagesize > 100:
api_logger.warning(f"Paging parameters exceed ModelScope limit: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The maximum number of MCP services can view is 100. Please visit the ModelScope MCP Plaza."
)
# 2. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 3. Execute paged query
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
body = {
'filter': {},
'page_number': page,
'page_size': pagesize,
'search': keywords
}
try:
cookies = api.get_cookies(token)
r = api.session.put(
url=api.mcp_base_url,
headers=api.builder_headers(api.headers),
json=body,
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 4. Return structured response
result = {
"items": mcp_server_list,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
# 5. Update mck_market.mcp_count
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=db_mcp_market_config.mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={db_mcp_market_config.mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or access is denied"
)
db_mcp_market.mcp_count = total
db.commit()
db.refresh(db_mcp_market)
return success(data=result, msg="Query of mcp servers list successful")
@router.get("/operational_mcp_servers", response_model=ApiResponse)
async def get_operational_mcp_servers(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the operational mcp servers list in pages
- Support keyword search for name,author,owner
- Return paging metadata + operational mcp server list
"""
api_logger.info(
f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 2. Execute paged query
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
url = f'{api.mcp_base_url}/operational'
headers = api.builder_headers(api.headers)
try:
cookies = api.get_cookies(access_token=token, cookies_required=True)
r = api.session.get(url, headers=headers, cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get operational MCP servers: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get operational MCP servers: {str(e)}"
)
data = api._handle_response(r)
total = data.get('total_count', 0)
mcp_server_list = data.get('mcp_server_list', [])
# items = [{
# 'name': item.get('name', ''),
# 'id': item.get('id', ''),
# 'description': item.get('description', '')
# } for item in mcp_server_list]
# 3. Return structured response
return success(data=mcp_server_list, msg="Query of operational mcp servers list successful")
@router.get("/mcp_server", response_model=ApiResponse)
async def get_mcp_server(
mcp_market_config_id: uuid.UUID,
server_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Get detailed information for a specific MCP Server
"""
api_logger.info(
f"Query mcp server: tenant_id={current_user.tenant_id}, mcp_market_config_id={mcp_market_config_id}, server_id={server_id}, username: {current_user.username}")
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db,
mcp_market_config_id=mcp_market_config_id,
current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 2. Get detailed information for a specific MCP Server
token = db_mcp_market_config.token
if not token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="MCP market config token is not configured"
)
api = MCPApi()
api.login(token)
result = api.get_mcp_server(server_id=server_id)
return success(data=result, msg="Query of mcp servers list successful")
@router.post("/mcp_market_config", response_model=ApiResponse)
async def create_mcp_market_config(
create_data: mcp_market_config_schema.McpMarketConfigCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create mcp market config
"""
api_logger.info(
f"Request to create a mcp market config: mcp_market_id={create_data.mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the mcp market config: {create_data.mcp_market_id}")
# 1. Validate token can access ModelScope MCP market
if not create_data.token:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Token is required to access ModelScope MCP market"
)
try:
api = MCPApi()
api.login(create_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(create_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 2. Check if the mcp market name already exists
db_mcp_market_config_exist = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=create_data.mcp_market_id, current_user=current_user)
if db_mcp_market_config_exist:
api_logger.warning(f"The mcp market id already exists: {create_data.mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market id already exists: {create_data.mcp_market_id}"
)
# 2. verify token
create_data.status = 1
try:
api = MCPApi()
token = create_data.token
api.login(token)
body = {
'filter': {},
'page_number': 1,
'page_size': 20,
'search': ""
}
cookies = api.get_cookies(token)
r = api.session.put(
url=api.mcp_base_url,
headers=api.builder_headers(api.headers),
json=body,
cookies=cookies)
raise_for_http_status(r)
except requests.exceptions.RequestException as e:
api_logger.error(f"Failed to get MCP servers: {str(e)}")
create_data.status = 0
# 3. create mcp_market_config
db_mcp_market_config = mcp_market_config_service.create_mcp_market_config(db=db, mcp_market_config=create_data, current_user=current_user)
api_logger.info(
f"The mcp market config has been successfully created: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="The mcp market config has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the mcp market config failed: {create_data.mcp_market_id} - {str(e)}")
raise
@router.get("/{mcp_market_config_id}", response_model=ApiResponse)
async def get_mcp_market_config(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market config information based on mcp_market_config_id
"""
api_logger.info(
f"Obtain details of the mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
try:
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="Successfully obtained mcp market config information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market config query failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise
@router.get("/mcp_market_id/{mcp_market_id}", response_model=ApiResponse)
async def get_mcp_market_config_by_mcp_market_id(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market config information based on mcp_market_id
"""
api_logger.info(
f"Request to create a mcp market config: mcp_market_id={mcp_market_id}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
# 1. Query mcp market config information from the database
api_logger.debug(f"Query mcp market config: mcp_market_id={mcp_market_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_mcp_market_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(f"The mcp market config does not exist or access is denied: mcp_market_id={mcp_market_id}")
return success(msg='The mcp market config does not exist or access is denied')
api_logger.info(f"mcp market config query successful: (ID: {db_mcp_market_config.id})")
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="Successfully obtained mcp market config information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market config query failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise
@router.put("/{mcp_market_config_id}", response_model=ApiResponse)
async def update_mcp_market_config(
mcp_market_config_id: uuid.UUID,
update_data: mcp_market_config_schema.McpMarketConfigUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# 1. Check if the mcp market config exists
api_logger.debug(f"Query the mcp market config to be updated: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 2. Validate new token if provided
if update_data.token is not None:
try:
api = MCPApi()
api.login(update_data.token)
body = {'filter': {}, 'page_number': 1, 'page_size': 1, 'search': None}
cookies = api.get_cookies(update_data.token)
r = api.session.put(url=api.mcp_base_url, headers=api.builder_headers(api.headers), json=body, cookies=cookies)
raise_for_http_status(r)
except Exception as e:
api_logger.warning(f"Token validation failed for ModelScope MCP market: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access ModelScope MCP market with the provided token: {str(e)}"
)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market config fields: {mcp_market_config_id}")
update_dict = update_data.dict(exclude_unset=True)
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_mcp_market_config, field):
old_value = getattr(db_mcp_market_config, field)
if old_value != value:
# update value
setattr(db_mcp_market_config, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 4. Save to database
try:
db.commit()
db.refresh(db_mcp_market_config)
api_logger.info(f"The mcp market config has been successfully updated: (ID: {db_mcp_market_config.id})")
except Exception as e:
db.rollback()
api_logger.error(f"The mcp market config update failed: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The mcp market config update failed: {str(e)}"
)
# 5. Return the updated mcp market config
return success(data=jsonable_encoder(mcp_market_config_schema.McpMarketConfig.model_validate(db_mcp_market_config)),
msg="The mcp market config information updated successfully")
@router.delete("/{mcp_market_config_id}", response_model=ApiResponse)
async def delete_mcp_market_config(
mcp_market_config_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete mcp market config
"""
api_logger.info(f"Request to delete mcp market config: mcp_market_config_id={mcp_market_config_id}, username: {current_user.username}")
try:
# 1. Check whether the mcp market config exists
api_logger.debug(f"Check whether the mcp market config exists: {mcp_market_config_id}")
db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
if not db_mcp_market_config:
api_logger.warning(
f"The mcp market config does not exist or you do not have permission to access it: mcp_market_config_id={mcp_market_config_id}")
return success(msg='The mcp market config does not exist or access is denied')
# 2. Deleting mcp market config
mcp_market_config_service.delete_mcp_market_config_by_id(db, mcp_market_config_id=mcp_market_config_id, current_user=current_user)
api_logger.info(f"The mcp market config has been successfully deleted: (ID: {mcp_market_config_id})")
return success(msg="The mcp market config has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the mcp market config: mcp_market_config_id={mcp_market_config_id} - {str(e)}")
raise

View File

@@ -1,262 +0,0 @@
import datetime
import json
from typing import Optional
import uuid
from fastapi import APIRouter, Depends, HTTPException, status, Query
from fastapi.encoders import jsonable_encoder
from sqlalchemy import or_
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.db import get_db
from app.dependencies import get_current_user
from app.models import mcp_market_model
from app.models.user_model import User
from app.schemas import mcp_market_schema
from app.schemas.response_schema import ApiResponse
from app.services import mcp_market_service
# Obtain a dedicated API logger
api_logger = get_api_logger()
router = APIRouter(
prefix="/mcp_markets",
tags=["mcp_markets"],
dependencies=[Depends(get_current_user)] # Apply auth to all routes in this controller
)
@router.get("/mcp_markets", response_model=ApiResponse)
async def get_mcp_markets(
page: int = Query(1, gt=0), # Default: 1, which must be greater than 0
pagesize: int = Query(20, gt=0, le=100), # Default: 20 items per page, maximum: 100 items
orderby: Optional[str] = Query(None, description="Sort fields, such as: category, created_at"),
desc: Optional[bool] = Query(False, description="Is it descending order"),
keywords: Optional[str] = Query(None, description="Search keywords (mcp_market base name)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Query the mcp markets list in pages
- Support keyword search for name,description
- Support dynamic sorting
- Return paging metadata + mcp_market list
"""
api_logger.info(
f"Query mcp market list: tenant_id={current_user.tenant_id}, page={page}, pagesize={pagesize}, keywords={keywords}, username: {current_user.username}")
# 1. parameter validation
if page < 1 or pagesize < 1:
api_logger.warning(f"Error in paging parameters: page={page}, pagesize={pagesize}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="The paging parameter must be greater than 0"
)
# 2. Construct query conditions
filters = []
# Keyword search (fuzzy matching of mcp market name,description)
if keywords:
api_logger.debug(f"Add keyword search criteria: {keywords}")
filters.append(
or_(
mcp_market_model.McpMarket.name.ilike(f"%{keywords}%"),
mcp_market_model.McpMarket.description.ilike(f"%{keywords}%")
)
)
# 3. Execute paged query
try:
api_logger.debug("Start executing mcp market paging query")
total, items = mcp_market_service.get_mcp_markets_paginated(
db=db,
filters=filters,
page=page,
pagesize=pagesize,
orderby=orderby,
desc=desc,
current_user=current_user
)
api_logger.info(f"mcp market query successful: total={total}, returned={len(items)} records")
except Exception as e:
api_logger.error(f"mcp market query failed: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Query failed: {str(e)}"
)
# 4. Return structured response
result = {
"items": items,
"page": {
"page": page,
"pagesize": pagesize,
"total": total,
"has_next": True if page * pagesize < total else False
}
}
return success(data=jsonable_encoder(result), msg="Query of mcp market list successful")
@router.post("/mcp_market", response_model=ApiResponse)
async def create_mcp_market(
create_data: mcp_market_schema.McpMarketCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
create mcp market
"""
api_logger.info(
f"Request to create a mcp market: name={create_data.name}, tenant_id={current_user.tenant_id}, username: {current_user.username}")
try:
api_logger.debug(f"Start creating the mcp market: {create_data.name}")
# 1. Check if the mcp market name already exists
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=create_data.name, current_user=current_user)
if db_mcp_market_exist:
api_logger.warning(f"The mcp market name already exists: {create_data.name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market name already exists: {create_data.name}"
)
db_mcp_market = mcp_market_service.create_mcp_market(db=db, mcp_market=create_data, current_user=current_user)
api_logger.info(
f"The mcp market has been successfully created: {db_mcp_market.name} (ID: {db_mcp_market.id})")
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="The mcp market has been successfully created")
except Exception as e:
api_logger.error(f"The creation of the mcp market failed: {create_data.name} - {str(e)}")
raise
@router.get("/{mcp_market_id}", response_model=ApiResponse)
async def get_mcp_market(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Retrieve mcp market information based on mcp_market_id
"""
api_logger.info(
f"Obtain details of the mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
try:
# 1. Query mcp market information from the database
api_logger.debug(f"Query mcp market: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(f"The mcp market does not exist or access is denied: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or access is denied"
)
api_logger.info(f"mcp market query successful: {db_mcp_market.name} (ID: {db_mcp_market.id})")
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="Successfully obtained mcp market information")
except HTTPException:
raise
except Exception as e:
api_logger.error(f"mcp market query failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise
@router.put("/{mcp_market_id}", response_model=ApiResponse)
async def update_mcp_market(
mcp_market_id: uuid.UUID,
update_data: mcp_market_schema.McpMarketUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
# 1. Check if the mcp market exists
api_logger.debug(f"Query the mcp market to be updated: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or you do not have permission to access it"
)
# 2. not updating the name (name already exists)
update_dict = update_data.dict(exclude_unset=True)
if "name" in update_dict:
name = update_dict["name"]
if name != db_mcp_market.name:
# Check if the mcp market name already exists
db_mcp_market_exist = mcp_market_service.get_mcp_market_by_name(db, name=name, current_user=current_user)
if db_mcp_market_exist:
api_logger.warning(f"The mcp market name already exists: {name}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"The mcp market name already exists: {name}"
)
# 3. Update fields (only update non-null fields)
api_logger.debug(f"Start updating the mcp market fields: {mcp_market_id}")
updated_fields = []
for field, value in update_dict.items():
if hasattr(db_mcp_market, field):
old_value = getattr(db_mcp_market, field)
if old_value != value:
# update value
setattr(db_mcp_market, field, value)
updated_fields.append(f"{field}: {old_value} -> {value}")
if updated_fields:
api_logger.debug(f"updated fields: {', '.join(updated_fields)}")
# 4. Save to database
try:
db.commit()
db.refresh(db_mcp_market)
api_logger.info(f"The mcp market has been successfully updated: {db_mcp_market.name} (ID: {db_mcp_market.id})")
except Exception as e:
db.rollback()
api_logger.error(f"The mcp market update failed: mcp_market_id={mcp_market_id} - {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"The mcp market update failed: {str(e)}"
)
# 5. Return the updated mcp market
return success(data=jsonable_encoder(mcp_market_schema.McpMarket.model_validate(db_mcp_market)),
msg="The mcp market information updated successfully")
@router.delete("/{mcp_market_id}", response_model=ApiResponse)
async def delete_mcp_market(
mcp_market_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
delete mcp market
"""
api_logger.info(f"Request to delete mcp market: mcp_market_id={mcp_market_id}, username: {current_user.username}")
try:
# 1. Check whether the mcp market exists
api_logger.debug(f"Check whether the mcp market exists: {mcp_market_id}")
db_mcp_market = mcp_market_service.get_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
if not db_mcp_market:
api_logger.warning(
f"The mcp market does not exist or you do not have permission to access it: mcp_market_id={mcp_market_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The mcp market does not exist or you do not have permission to access it"
)
# 2. Deleting mcp market
mcp_market_service.delete_mcp_market_by_id(db, mcp_market_id=mcp_market_id, current_user=current_user)
api_logger.info(f"The mcp market has been successfully deleted: (ID: {mcp_market_id})")
return success(msg="The mcp market has been successfully deleted")
except Exception as e:
api_logger.error(f"Failed to delete from the mcp market: mcp_market_id={mcp_market_id} - {str(e)}")
raise

View File

@@ -1,17 +1,8 @@
from typing import List, Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.rag.llm.cv_model import QWenCV
from app.core.response_utils import fail, success
from app.db import get_db
@@ -24,6 +15,10 @@ from app.schemas.response_schema import ApiResponse
from app.services import task_service, workspace_service
from app.services.memory_agent_service import MemoryAgentService
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, File, Form, Query, UploadFile
from sqlalchemy.orm import Session
from starlette.responses import StreamingResponse
load_dotenv()
api_logger = get_api_logger()
@@ -38,7 +33,7 @@ router = APIRouter(
@router.get("/health/status", response_model=ApiResponse)
async def get_health_status(
current_user: User = Depends(get_current_user)
current_user: User = Depends(get_current_user)
):
"""
Get latest health status written by Celery periodic task
@@ -56,9 +51,8 @@ async def get_health_status(
@router.get("/download_log")
async def download_log(
log_type: str = Query("file", regex="^(file|transmission)$",
description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"),
current_user: User = Depends(get_current_user)
):
"""
Download or stream agent service log file
@@ -77,16 +71,16 @@ async def download_log(
- transmission mode: StreamingResponse with SSE
"""
api_logger.info(f"Log download requested with log_type={log_type}")
# Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity)
if log_type not in ["file", "transmission"]:
api_logger.warning(f"Invalid log_type parameter: {log_type}")
return fail(
BizCode.BAD_REQUEST,
"无效的log_type参数",
BizCode.BAD_REQUEST,
"无效的log_type参数",
"log_type必须是'file''transmission'"
)
# Route to appropriate mode
if log_type == "file":
# File mode: Return complete log file content
@@ -121,28 +115,23 @@ async def download_log(
@router.post("/writer_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Write service endpoint - processes write operations synchronously
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
user_input: Write request containing message and group_id
Returns:
Response with write operation status
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -151,7 +140,7 @@ async def write_server(
)
if storage_type is None: storage_type = 'neo4j'
user_rag_memory_id = ''
# 如果 storage_type 是 rag必须确保有有效的 user_rag_memory_id
if storage_type == 'rag':
if workspace_id:
@@ -163,27 +152,22 @@ async def write_server(
if knowledge:
user_rag_memory_id = str(knowledge.id)
else:
api_logger.warning(
f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库workspace_id: {workspace_id},将使用 neo4j 存储")
storage_type = 'neo4j'
else:
api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储")
storage_type = 'neo4j'
api_logger.info(
f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
api_logger.info(f"Write service requested for group {user_input.group_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}")
try:
messages_list = memory_agent_service.get_messages_list(user_input)
result = await memory_agent_service.write_memory(
user_input.end_user_id,
messages_list,
user_input.group_id,
user_input.message,
config_id,
db,
storage_type,
user_rag_memory_id,
language
storage_type,
user_rag_memory_id
)
return success(data=result, msg="写入成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -199,29 +183,23 @@ async def write_server(
@router.post("/writer_service_async", response_model=ApiResponse)
@cur_workspace_access_guard()
async def write_server_async(
user_input: Write_UserInput,
language_type: str = Header(default=None, alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Async write service endpoint - enqueues write processing to Celery
Args:
user_input: Write request containing message and end_user_id
language_type: 语言类型 ("zh" 中文, "en" 英文),通过 X-Language-Type Header 传递
user_input: Write request containing message and group_id
Returns:
Task ID for tracking async operation
Use GET /memory/write_result/{task_id} to check task status and get result
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(
f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}")
api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
@@ -241,15 +219,12 @@ async def write_server_async(
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Async write: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
task = celery_app.send_task(
"app.core.memory.agent.write_message",
args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language]
args=[user_input.group_id, user_input.message, config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Write task queued: {task.id}")
return success(data={"task_id": task.id}, msg="写入任务已提交")
except Exception as e:
api_logger.error(f"Async write operation failed: {str(e)}")
@@ -259,9 +234,9 @@ async def write_server_async(
@router.post("/read_service", response_model=ApiResponse)
@cur_workspace_access_guard()
async def read_server(
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Read service endpoint - processes read operations synchronously
@@ -272,14 +247,16 @@ async def read_server(
- "2": Direct answer based on context
Args:
user_input: Read request with message, history, search_switch, and end_user_id
user_input: Read request with message, history, search_switch, and group_id
Returns:
Response with query answer
"""
config_id = user_input.config_id
workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
workspace_id=workspace_id,
@@ -294,14 +271,12 @@ async def read_server(
name="USER_RAG_MERORY",
workspace_id=workspace_id
)
if knowledge:
user_rag_memory_id = str(knowledge.id)
api_logger.info(
f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
if knowledge: user_rag_memory_id = str(knowledge.id)
api_logger.info(f"Read service: group={user_input.group_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}")
try:
result = await memory_agent_service.read_memory(
user_input.end_user_id,
user_input.group_id,
user_input.message,
user_input.history,
user_input.search_switch,
@@ -310,23 +285,6 @@ async def read_server(
storage_type,
user_rag_memory_id
)
if str(user_input.search_switch) == "2":
retrieve_info = result['answer']
history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id,
user_input.end_user_id)
query = user_input.message
# 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
end_user_id=user_input.end_user_id,
retrieve_info=retrieve_info,
history=history,
query=query,
config_id=config_id,
db=db
)
if "信息不足,无法回答" in result['answer']:
result['answer'] = retrieve_info
return success(data=result, msg="回复对话消息成功")
except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup
@@ -342,10 +300,9 @@ async def read_server(
@router.post("/file", response_model=ApiResponse)
async def file_update(
files: List[UploadFile] = File(..., description="要上传的文件"),
model_id: str = Form(..., description="模型ID"),
model_id:str = Form(..., description="模型ID"),
metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
文件上传接口 - 支持图片识别
@@ -358,6 +315,9 @@ async def file_update(
Returns:
文件处理结果
"""
db_gen = get_db() # get_db 通常是一个生成器
db = next(db_gen)
api_logger.info(f"File upload requested, file count: {len(files)}")
config = ModelConfigService.get_model_by_id(db=db, model_id=model_id)
apiConfig: ModelApiKey = config.api_keys[0]
@@ -366,7 +326,7 @@ async def file_update(
for file in files:
api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}")
content = await file.read()
if file.content_type and file.content_type.startswith("image/"):
vision_model = QWenCV(
key=apiConfig.api_key,
@@ -380,12 +340,12 @@ async def file_update(
else:
api_logger.warning(f"Unsupported file type: {file.content_type}")
file_content.append(f"[不支持的文件类型: {file.content_type}]")
result_text = ';'.join(file_content)
api_logger.info(f"File processing completed, result length: {len(result_text)}")
return success(data=result_text, msg="转换文本成功")
except Exception as e:
api_logger.error(f"File processing failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e))
@@ -422,7 +382,7 @@ async def read_server_async(
try:
task = celery_app.send_task(
"app.core.memory.agent.read_message",
args=[user_input.end_user_id, user_input.message, user_input.history, user_input.search_switch,
args=[user_input.group_id, user_input.message, user_input.history, user_input.search_switch,
config_id, storage_type, user_rag_memory_id]
)
api_logger.info(f"Read task queued: {task.id}")
@@ -435,8 +395,8 @@ async def read_server_async(
@router.get("/read_result/", response_model=ApiResponse)
async def get_read_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async read task
@@ -457,7 +417,7 @@ async def get_read_task_result(
try:
result = task_service.get_task_memory_read_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -466,7 +426,7 @@ async def get_read_task_result(
return success(
data={
"result": task_result.get("result"),
"end_user_id": task_result.get("end_user_id"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -475,7 +435,7 @@ async def get_read_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="查询任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -484,7 +444,7 @@ async def get_read_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -504,7 +464,7 @@ async def get_read_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -512,8 +472,8 @@ async def get_read_task_result(
@router.get("/write_result/", response_model=ApiResponse)
async def get_write_task_result(
task_id: str,
current_user: User = Depends(get_current_user)
task_id: str,
current_user: User = Depends(get_current_user)
):
"""
Get the status and result of an async write task
@@ -534,7 +494,7 @@ async def get_write_task_result(
try:
result = task_service.get_task_memory_write_result(task_id)
status = result.get("status")
if status == "SUCCESS":
# 任务成功完成
task_result = result.get("result", {})
@@ -543,7 +503,7 @@ async def get_write_task_result(
return success(
data={
"result": task_result.get("result"),
"end_user_id": task_result.get("end_user_id"),
"group_id": task_result.get("group_id"),
"elapsed_time": task_result.get("elapsed_time"),
"task_id": task_id
},
@@ -552,7 +512,7 @@ async def get_write_task_result(
else:
# 旧格式:直接返回结果
return success(data=task_result, msg="写入任务已完成")
elif status == "FAILURE":
# 任务失败
error_info = result.get("result", "Unknown error")
@@ -561,7 +521,7 @@ async def get_write_task_result(
else:
error_msg = str(error_info)
return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg)
elif status in ["PENDING", "STARTED"]:
# 任务进行中
return success(
@@ -581,7 +541,7 @@ async def get_write_task_result(
},
msg=f"任务状态: {status}"
)
except Exception as e:
api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e))
@@ -589,38 +549,23 @@ async def get_write_task_result(
@router.post("/status_type", response_model=ApiResponse)
async def status_type(
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
user_input: Write_UserInput,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
Determine the type of user message (read or write)
Args:
user_input: Request containing user message and end_user_id
user_input: Request containing user message and group_id
Returns:
Type classification result
"""
api_logger.info(f"Status type check requested for group {user_input.end_user_id}")
api_logger.info(f"Status type check requested for group {user_input.group_id}")
try:
# 获取标准化的消息列表
messages_list = memory_agent_service.get_messages_list(user_input)
# 将消息列表转换为字符串用于分类
# 只取最后一条用户消息进行分类
last_user_message = ""
for msg in reversed(messages_list):
if msg.get('role') == 'user':
last_user_message = msg.get('content', '')
break
if not last_user_message:
# 如果没有用户消息,使用所有消息的内容
last_user_message = " ".join([msg.get('content', '') for msg in messages_list])
result = await memory_agent_service.classify_message_type(
last_user_message,
user_input.message,
user_input.config_id,
db
)
@@ -634,21 +579,26 @@ async def status_type(
@router.get("/stats/types", response_model=ApiResponse)
async def get_knowledge_type_stats_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
only_active: bool = Query(True, description="仅统计有效记录(status=1)"),
current_user: User = Depends(get_current_user)
):
"""
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。
统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder | memory
会对缺失类型补 0返回字典形式。
可选按状态过滤。
- 知识库类型根据当前用户的 current_workspace_id 过滤
- 如果用户没有当前工作空间,对应的统计返回 0
- memory 是 Neo4j 中 Chunk 的数量,根据 end_user_id (group_id) 过滤
- 如果用户没有当前工作空间或未提供 end_user_id对应的统计返回 0
"""
api_logger.info(
f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}")
try:
from app.db import get_db
# 获取数据库会话
db_gen = get_db()
db = next(db_gen)
# 调用service层函数
result = await memory_agent_service.get_knowledge_type_stats(
end_user_id=end_user_id,
@@ -656,70 +606,45 @@ async def get_knowledge_type_stats_api(
current_workspace_id=current_user.current_workspace_id,
db=db
)
return success(data=result, msg="获取知识库类型统计成功")
except Exception as e:
api_logger.error(f"Knowledge type stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
current_user: User = Depends(get_current_user)
):
"""
获取指定用户的兴趣分布标签
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
获取指定用户的热门记忆标签
返回格式:
[
{"name": "兴趣活动", "frequency": 频次},
{"name": "标签", "frequency": 频次},
...
]
"""
language = get_language_from_header(language_type)
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
try:
# 优先读取缓存
cached = await InterestMemoryCache.get_interest_distribution(
result = await memory_agent_service.get_hot_memory_tags_by_user(
end_user_id=end_user_id,
language=language,
limit=limit
)
if cached is not None:
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
return success(data=cached, msg="获取兴趣分布标签成功")
# 缓存未命中,调用模型生成
result = await memory_agent_service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=limit,
language=language
)
# 写入缓存24小时过期
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
)
return success(data=result, msg="获取兴趣分布标签成功")
return success(data=result, msg="获取热门记忆标签成功")
except Exception as e:
api_logger.error(f"Interest distribution by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)
async def get_user_profile_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取用户详情,包含:
@@ -757,17 +682,17 @@ async def get_user_profile_api(
# ):
# """
# Get parsed API documentation (Public endpoint - no authentication required)
# Args:
# file_path: Optional path to API docs file. If None, uses default path.
# Returns:
# Parsed API documentation including title, meta info, and sections
# """
# api_logger.info(f"API docs requested, file_path: {file_path or 'default'}")
# try:
# result = await memory_agent_service.get_api_docs(file_path)
# if result.get("success"):
# return success(msg=result["msg"], data=result["data"])
# else:
@@ -783,9 +708,9 @@ async def get_user_profile_api(
@router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse)
async def get_end_user_connected_config(
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
end_user_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取终端用户关联的记忆配置
@@ -804,9 +729,9 @@ async def get_end_user_connected_config(
from app.services.memory_agent_service import (
get_end_user_connected_config as get_config,
)
api_logger.info(f"Getting connected config for end_user: {end_user_id}")
try:
result = get_config(end_user_id, db)
return success(data=result, msg="获取终端用户关联配置成功")
@@ -815,4 +740,4 @@ async def get_end_user_connected_config(
return fail(BizCode.NOT_FOUND, str(e))
except Exception as e:
api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))
return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e))

View File

@@ -1,16 +1,15 @@
from fastapi import APIRouter, Depends, HTTPException, status, Query
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from typing import Optional
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.memory_agent_schema import End_User_Information
from app.schemas.response_schema import ApiResponse
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
from app.services.memory_agent_service import get_end_users_connected_configs_batch
from app.services.app_statistics_service import AppStatisticsService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -41,7 +40,54 @@ def get_workspace_total_end_users(
api_logger.info(f"成功获取最新用户总数: total_num={total_end_users.get('total_num', 0)}")
return success(data=total_end_users, msg="用户数量获取成功")
@router.post("/update/end_users", response_model=ApiResponse)
async def update_workspace_end_users(
user_input: End_User_Information,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
更新工作空间的宿主信息
"""
username = user_input.end_user_name # 要更新的用户名
end_user_input_id = user_input.id # 宿主ID
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的宿主信息")
api_logger.info(f"更新参数: username={username}, end_user_id={end_user_input_id}")
try:
# 导入更新函数
from app.repositories.end_user_repository import update_end_user_other_name
import uuid
# 转换 end_user_id 为 UUID 类型
end_user_uuid = uuid.UUID(end_user_input_id)
# 直接更新数据库中的 other_name 字段
updated_count = update_end_user_other_name(
db=db,
end_user_id=end_user_uuid,
other_name=username
)
api_logger.info(f"成功更新宿主 {end_user_input_id} 的 other_name 为: {username}")
return success(
data={
"updated_count": updated_count,
"end_user_id": end_user_input_id,
"updated_other_name": username
},
msg=f"成功更新 {updated_count} 个宿主的信息"
)
except Exception as e:
api_logger.error(f"更新宿主信息失败: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"更新宿主信息失败: {str(e)}"
)
@@ -51,158 +97,63 @@ async def get_workspace_end_users(
current_user: User = Depends(get_current_user),
):
"""
获取工作空间的宿主列表(高性能优化版本 v2
获取工作空间的宿主列表
优化策略:
1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
返回格式与原 memory_list 接口中的 end_users 字段相同,
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name
"""
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users已优化为批量查询
end_users = memory_dashboard_service.get_workspace_end_users(
db=db,
workspace_id=workspace_id,
current_user=current_user
)
if not end_users:
api_logger.info("工作空间下没有宿主")
# 缓存空结果,避免重复查询
try:
await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次)
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
memory_configs_map = {}
if end_user_ids:
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
return {}
# 失败时使用空字典,不影响其他数据返回
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
try:
from app.celery_app import celery_app as _celery_app
_celery_app.send_task(
"app.tasks.init_implicit_emotions_for_users",
kwargs={"end_user_ids": end_user_ids},
)
_celery_app.send_task(
"app.tasks.init_interest_distribution_for_users",
kwargs={"end_user_ids": end_user_ids},
)
api_logger.info(f"已触发按需初始化任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = []
for end_user in end_users:
user_id = str(end_user.id)
config_info = memory_configs_map.get(user_id, {})
result.append({
'end_user': {
'id': user_id,
'other_name': end_user.other_name
},
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
'memory_config': {
"memory_config_id": config_info.get("memory_config_id"),
"memory_config_name": config_info.get("memory_config_name")
memory_num = {}
if current_workspace_type == "neo4j":
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
elif current_workspace_type == "rag":
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, {
"memory_config_id": None,
"memory_config_name": None
})
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
try:
from app.tasks import init_community_clustering_for_users
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")
# 只保留需要的字段,移除 error 字段(如果有
memory_config = {
"memory_config_id": memory_config_info.get("memory_config_id"),
"memory_config_name": memory_config_info.get("memory_config_name")
}
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
}
)
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功")
@@ -412,15 +363,14 @@ def get_current_user_rag_total_num(
@router.get("/rag_content", response_model=ApiResponse)
def get_rag_content(
end_user_id: str = Query(..., description="宿主ID"),
page: int = Query(1, gt=0, description="页码从1开始"),
pagesize: int = Query(15, gt=0, le=100, description="每页返回记录数"),
limit: int = Query(15, description="返回记录数"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
获取当前宿主知识库中的chunk内容(分页)
获取当前宿主知识库中的chunk内容
"""
data = memory_dashboard_service.get_rag_content(end_user_id, page, pagesize, db, current_user)
data = memory_dashboard_service.get_rag_content(end_user_id, limit, db, current_user)
return success(data=data, msg="宿主RAGchunk数据获取成功")
@@ -433,18 +383,26 @@ async def get_chunk_summary_tag(
current_user: User = Depends(get_current_user),
):
"""
读取RAG摘要、标签和人物形象纯读库不触发生成
获取chunk总结、提取的标签和人物形象
返回格式:
{
"summary": "用户摘要",
"tags": [{"tag": "标签1", "frequency": 5}, ...],
"personas": ["产品设计师", ...],
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
"summary": "chunk内容的总结",
"tags": [
{"tag": "标签1", "frequency": 5},
{"tag": "标签2", "frequency": 3},
...
],
"personas": [
"产品设计师",
"旅行爱好者",
"摄影发烧友",
...
]
}
"""
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG摘要/标签/人物形象")
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk摘要标签人物形象")
data = await memory_dashboard_service.get_chunk_summary_and_tags(
end_user_id=end_user_id,
limit=limit,
@@ -452,8 +410,9 @@ async def get_chunk_summary_tag(
db=db,
current_user=current_user
)
return success(data=data, msg="获取成功")
api_logger.info(f"成功获取chunk摘要、{len(data.get('tags', []))} 个标签和 {len(data.get('personas', []))} 个人物形象")
return success(data=data, msg="chunk摘要、标签和人物形象获取成功")
@router.get("/chunk_insight", response_model=ApiResponse)
@@ -464,64 +423,29 @@ async def get_chunk_insight(
current_user: User = Depends(get_current_user),
):
"""
读取RAG洞察报告纯读库不触发生成
获取chunk的洞察内容
返回格式:
{
"insight": "总体概述",
"behavior_pattern": "行为模式",
"key_findings": "关键发现",
"growth_trajectory": "成长轨迹",
"generated": true/false // false表示尚未生产请调用 /generate_rag_profile
"insight": "对chunk内容的深度洞察分析"
}
"""
api_logger.info(f"用户 {current_user.username} 取宿主 {end_user_id}RAG洞察")
api_logger.info(f"用户 {current_user.username} 请求获取宿主 {end_user_id}chunk洞察")
data = await memory_dashboard_service.get_chunk_insight(
end_user_id=end_user_id,
limit=limit,
db=db,
current_user=current_user
)
return success(data=data, msg="获取成功")
class GenerateRagProfileRequest(BaseModel):
end_user_id: str = Field(..., description="宿主ID")
limit: int = Field(15, description="参与生成的chunk数量上限")
max_tags: int = Field(10, description="最大标签数量")
@router.post("/generate_rag_profile", response_model=ApiResponse)
async def generate_rag_profile(
body: GenerateRagProfileRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
"""
生产接口为RAG存储模式的宿主全量重新生成完整画像并持久化到end_user表。
每次请求都会重新生成,覆盖已有数据。
"""
api_logger.info(f"用户 {current_user.username} 触发RAG画像生产: end_user_id={body.end_user_id}")
data = await memory_dashboard_service.generate_rag_profile(
end_user_id=body.end_user_id,
limit=body.limit,
max_tags=body.max_tags,
db=db,
current_user=current_user,
)
api_logger.info(f"RAG画像生产完成: {data}")
return success(data=data, msg="RAG画像生产完成")
api_logger.info("成功获取chunk洞察")
return success(data=data, msg="chunk洞察获取成功")
@router.get("/dashboard_data", response_model=ApiResponse)
async def dashboard_data(
end_user_id: Optional[str] = Query(None, description="可选的用户ID"),
start_date: Optional[int] = Query(None, description="开始时间戳(毫秒)"),
end_date: Optional[int] = Query(None, description="结束时间戳(毫秒)"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
):
@@ -556,15 +480,6 @@ async def dashboard_data(
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的dashboard整合数据")
# 如果没有提供时间范围默认使用最近30天
if start_date is None or end_date is None:
from datetime import datetime, timedelta
end_dt = datetime.now()
start_dt = end_dt - timedelta(days=30)
end_date = int(end_dt.timestamp() * 1000)
start_date = int(start_dt.timestamp() * 1000)
api_logger.info(f"使用默认时间范围: {start_dt}{end_dt}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type(
db=db,
@@ -603,12 +518,9 @@ async def dashboard_data(
)
neo4j_data["total_memory"] = total_memory_data.get("total_memory_count", 0)
# total_app: 统计当前空间下的所有app数量
# 包含自有app + 被分享给本工作空间的app
from app.services import app_service as _app_svc
_, total_app = _app_svc.AppService(db).list_apps(
workspace_id=workspace_id, include_shared=True, pagesize=1
)
neo4j_data["total_app"] = total_app
from app.repositories import app_repository
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
neo4j_data["total_app"] = len(apps_orm)
api_logger.info(f"成功获取记忆总量: {neo4j_data['total_memory']}, 应用数量: {neo4j_data['total_app']}")
except Exception as e:
api_logger.warning(f"获取记忆总量失败: {str(e)}")
@@ -628,22 +540,17 @@ async def dashboard_data(
except Exception as e:
api_logger.warning(f"获取知识库类型统计失败: {str(e)}")
# 3. 获取API调用统计total_api_call
# 3. 获取API调用增量total_api_call,转换为整数
try:
# 使用 AppStatisticsService 获取真实的API调用统计
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
api_increment = memory_dashboard_service.get_workspace_api_increment(
db=db,
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
current_user=current_user
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
neo4j_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取API调用统计: {neo4j_data['total_api_call']}")
neo4j_data["total_api_call"] = api_increment
api_logger.info(f"成功获取API调用增量: {neo4j_data['total_api_call']}")
except Exception as e:
api_logger.error(f"获取API调用统计失败: {str(e)}")
neo4j_data["total_api_call"] = 0
api_logger.warning(f"获取API调用增量失败: {str(e)}")
result["neo4j_data"] = neo4j_data
api_logger.info("成功获取neo4j_data")
@@ -659,8 +566,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量
@@ -672,23 +579,10 @@ async def dashboard_data(
total_kb = memory_dashboard_service.get_rag_total_kb(db, current_user)
rag_data["total_knowledge"] = total_kb
# total_api_call: 使用 AppStatisticsService 获取真实的API调用统计
try:
app_stats_service = AppStatisticsService(db)
api_stats = app_stats_service.get_workspace_api_statistics(
workspace_id=workspace_id,
start_date=start_date,
end_date=end_date
)
# 计算总调用次数
total_api_calls = sum(item.get("total_calls", 0) for item in api_stats)
rag_data["total_api_call"] = total_api_calls
api_logger.info(f"成功获取RAG模式API调用统计: {rag_data['total_api_call']}")
except Exception as e:
api_logger.warning(f"获取RAG模式API调用统计失败使用默认值: {str(e)}")
rag_data["total_api_call"] = 0
# total_api_call: 固定值
rag_data["total_api_call"] = 1024
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}, api_calls={rag_data['total_api_call']}")
api_logger.info(f"成功获取RAG相关数据: memory={total_chunk}, app={len(apps_orm)}, knowledge={total_kb}")
except Exception as e:
api_logger.warning(f"获取RAG相关数据失败: {str(e)}")

View File

@@ -3,10 +3,9 @@
包含情景记忆总览和详情查询接口
"""
from fastapi import APIRouter, Depends, Header
from fastapi import APIRouter, Depends
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import fail, success
from app.dependencies import get_current_user
@@ -15,7 +14,6 @@ from app.schemas.response_schema import ApiResponse
from app.schemas.memory_episodic_schema import (
EpisodicMemoryOverviewRequest,
EpisodicMemoryDetailsRequest,
translate_episodic_type,
)
from app.services.memory_episodic_service import memory_episodic_service
@@ -86,7 +84,6 @@ async def get_episodic_memory_overview_api(
@router.post("/details", response_model=ApiResponse)
async def get_episodic_memory_details_api(
request: EpisodicMemoryDetailsRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
) -> dict:
"""
@@ -114,11 +111,6 @@ async def get_episodic_memory_details_api(
summary_id=request.summary_id
)
# 根据语言参数翻译 episodic_type
language = get_language_from_header(language_type)
if "episodic_type" in result:
result["episodic_type"] = translate_episodic_type(result["episodic_type"], language)
api_logger.info(
f"成功获取情景记忆详情: end_user_id={request.end_user_id}, summary_id={request.summary_id}"
)

View File

@@ -11,7 +11,6 @@
"""
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
@@ -34,7 +33,7 @@ from app.schemas.memory_storage_schema import (
)
from app.schemas.response_schema import ApiResponse
from app.services.memory_forget_service import MemoryForgetService
from app.utils.config_utils import resolve_config_id
# 获取API专用日志器
api_logger = get_api_logger()
@@ -84,8 +83,7 @@ async def trigger_forgetting_cycle(
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id((config_id), db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
@@ -108,7 +106,7 @@ async def trigger_forgetting_cycle(
# 调用服务层执行遗忘周期
report = await forget_service.trigger_forgetting_cycle(
db=db,
end_user_id=end_user_id, # 服务层方法的参数名是 end_user_id
group_id=end_user_id, # 服务层方法的参数名是 group_id
max_merge_batch_size=payload.max_merge_batch_size,
min_days_since_access=payload.min_days_since_access,
config_id=config_id
@@ -130,7 +128,7 @@ async def trigger_forgetting_cycle(
@router.get("/read_config", response_model=ApiResponse)
async def read_forgetting_config(
config_id: UUID|int,
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -159,7 +157,6 @@ async def read_forgetting_config(
)
try:
config_id=resolve_config_id(config_id, db)
# 调用服务层读取配置
config = forget_service.read_forgetting_config(db=db, config_id=config_id)
@@ -197,8 +194,6 @@ async def update_forgetting_config(
ApiResponse: 包含更新结果的响应
"""
workspace_id = current_user.current_workspace_id
payload.config_id=resolve_config_id((payload.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
@@ -241,7 +236,7 @@ async def update_forgetting_config(
@router.get("/stats", response_model=ApiResponse)
async def get_forgetting_stats(
end_user_id: Optional[str] = None,
group_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -251,7 +246,7 @@ async def get_forgetting_stats(
返回知识层节点统计、激活值分布等信息。
Args:
end_user_id: 组ID即 end_user_id可选
group_id: 组ID即 end_user_id可选
current_user: 当前用户
db: 数据库会话
@@ -259,25 +254,26 @@ async def get_forgetting_stats(
ApiResponse: 包含统计信息的响应
"""
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘引擎统计但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 如果提供了 end_user_id通过它获取 config_id
# 如果提供了 group_id通过它获取 config_id
config_id = None
if end_user_id:
if group_id:
try:
from app.services.memory_agent_service import get_end_user_connected_config
connected_config = get_end_user_connected_config(end_user_id, db)
connected_config = get_end_user_connected_config(group_id, db)
config_id = connected_config.get("memory_config_id")
config_id = resolve_config_id(config_id, db)
if config_id is None:
api_logger.warning(f"终端用户 {end_user_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {end_user_id} 未关联记忆配置", "memory_config_id is None")
api_logger.warning(f"终端用户 {group_id} 未关联记忆配置")
return fail(BizCode.INVALID_PARAMETER, f"终端用户 {group_id} 未关联记忆配置", "memory_config_id is None")
api_logger.debug(f"通过 end_user_id={end_user_id} 获取到 config_id={config_id}")
api_logger.debug(f"通过 group_id={group_id} 获取到 config_id={config_id}")
except ValueError as e:
api_logger.warning(f"获取终端用户配置失败: {str(e)}")
return fail(BizCode.INVALID_PARAMETER, str(e), "ValueError")
@@ -287,14 +283,14 @@ async def get_forgetting_stats(
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求获取遗忘引擎统计: "
f"end_user_id={end_user_id}, config_id={config_id}"
f"group_id={group_id}, config_id={config_id}"
)
try:
# 调用服务层获取统计信息
stats = await forget_service.get_forgetting_stats(
db=db,
end_user_id=end_user_id,
group_id=group_id,
config_id=config_id
)
@@ -328,7 +324,7 @@ async def get_forgetting_curve(
ApiResponse: 包含遗忘曲线数据的响应
"""
workspace_id = current_user.current_workspace_id
request.config_id = resolve_config_id((request.config_id), db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试获取遗忘曲线但未选择工作空间")

View File

@@ -27,27 +27,27 @@ router = APIRouter(
)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve perceptual memory statistics for a user group.
Args:
end_user_id: ID of the user group (usually end_user_id in this context)
group_id: ID of the user group (usually end_user_id in this context)
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Response containing memory count statistics
"""
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching perceptual memory statistics: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
count_stats = service.get_memory_count(end_user_id)
count_stats = service.get_memory_count(group_id)
api_logger.info(f"Memory statistics fetched successfully: total={count_stats.get('total', 0)}")
@@ -57,37 +57,37 @@ def get_memory_count(
)
except Exception as e:
api_logger.error(f"Failed to fetch memory statistics: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch memory statistics: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch memory statistics",
)
@router.get("/{end_user_id}/last_visual", response_model=ApiResponse)
@router.get("/{group_id}/last_visual", response_model=ApiResponse)
def get_last_visual_memory(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent VISION-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest visual memory
"""
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest visual memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
visual_memory = service.get_latest_visual_memory(end_user_id)
visual_memory = service.get_latest_visual_memory(group_id)
if visual_memory is None:
api_logger.info(f"No visual memory found: end_user_id={end_user_id}")
api_logger.info(f"No visual memory found: group_id={group_id}")
return success(
data=None,
msg="No visual memory available"
@@ -101,37 +101,37 @@ def get_last_visual_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest visual memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest visual memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest visual memory",
)
@router.get("/{end_user_id}/last_listen", response_model=ApiResponse)
@router.get("/{group_id}/last_listen", response_model=ApiResponse)
def get_last_memory_listen(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent AUDIO-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest audio memory
"""
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest audio memory: user={current_user.username}, group_id={group_id}")
try:
service = MemoryPerceptualService(db)
audio_memory = service.get_latest_audio_memory(end_user_id)
audio_memory = service.get_latest_audio_memory(group_id)
if audio_memory is None:
api_logger.info(f"No audio memory found: end_user_id={end_user_id}")
api_logger.info(f"No audio memory found: group_id={group_id}")
return success(
data=None,
msg="No audio memory available"
@@ -145,38 +145,38 @@ def get_last_memory_listen(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest audio memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest audio memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest audio memory",
)
@router.get("/{end_user_id}/last_text", response_model=ApiResponse)
@router.get("/{group_id}/last_text", response_model=ApiResponse)
def get_last_text_memory(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""Retrieve the most recent TEXT-type memory for a user.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
current_user: Current authenticated user
db: Database session
Returns:
ApiResponse: Metadata of the latest text memory
"""
api_logger.info(f"Fetching latest text memory: user={current_user.username}, end_user_id={end_user_id}")
api_logger.info(f"Fetching latest text memory: user={current_user.username}, group_id={group_id}")
try:
# 调用服务层获取最近的文本记忆
service = MemoryPerceptualService(db)
text_memory = service.get_latest_text_memory(end_user_id)
text_memory = service.get_latest_text_memory(group_id)
if text_memory is None:
api_logger.info(f"No text memory found: end_user_id={end_user_id}")
api_logger.info(f"No text memory found: group_id={group_id}")
return success(
data=None,
msg="No text memory available"
@@ -190,16 +190,16 @@ def get_last_text_memory(
)
except Exception as e:
api_logger.error(f"Failed to fetch latest text memory: end_user_id={end_user_id}, error={str(e)}")
api_logger.error(f"Failed to fetch latest text memory: group_id={group_id}, error={str(e)}")
return fail(
code=BizCode.INTERNAL_ERROR,
msg="Failed to fetch latest text memory",
)
@router.get("/{end_user_id}/timeline", response_model=ApiResponse)
@router.get("/{group_id}/timeline", response_model=ApiResponse)
def get_memory_time_line(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
perceptual_type: Optional[PerceptualType] = Query(None, description="感知类型过滤"),
page: int = Query(1, ge=1, description="页码"),
page_size: int = Query(10, ge=1, le=100, description="每页大小"),
@@ -209,7 +209,7 @@ def get_memory_time_line(
"""Retrieve a timeline of perceptual memories for a user group.
Args:
end_user_id: ID of the user group
group_id: ID of the user group
perceptual_type: Optional filter for perceptual type
page: Page number for pagination
page_size: Number of items per page
@@ -221,7 +221,7 @@ def get_memory_time_line(
"""
api_logger.info(
f"Fetching perceptual memory timeline: user={current_user.username}, "
f"end_user_id={end_user_id}, type={perceptual_type}, page={page}"
f"group_id={group_id}, type={perceptual_type}, page={page}"
)
try:
@@ -232,7 +232,7 @@ def get_memory_time_line(
)
service = MemoryPerceptualService(db)
timeline_data = service.get_time_line(end_user_id, query)
timeline_data = service.get_time_line(group_id, query)
api_logger.info(
f"Perceptual memory timeline retrieved successfully: total={timeline_data.total}, "
@@ -246,7 +246,7 @@ def get_memory_time_line(
except Exception as e:
api_logger.error(
f"Failed to fetch perceptual memory timeline: end_user_id={end_user_id}, "
f"Failed to fetch perceptual memory timeline: group_id={group_id}, "
f"error={str(e)}"
)
return fail(

View File

@@ -1,35 +1,16 @@
"""
Memory Reflection Controller
This module provides REST API endpoints for managing memory reflection configurations
and operations. It handles reflection engine setup, configuration management, and
execution of self-reflection processes across memory systems.
Key Features:
- Reflection configuration management (save, retrieve, update)
- Workspace-wide reflection execution across multiple applications
- Individual configuration-based reflection runs
- Multi-language support for reflection outputs
- Integration with Neo4j memory storage and LLM models
- Comprehensive error handling and logging
"""
import asyncio
import time
import uuid
from uuid import UUID
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.storage_services.reflection_engine.self_reflexion import (
ReflectionConfig,
ReflectionEngine, ReflectionRange, ReflectionBaseline,
ReflectionEngine,
)
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.repositories.memory_config_repository import MemoryConfigRepository
from app.repositories.data_config_repository import DataConfigRepository
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_reflection_schemas import Memory_Reflection
from app.services.memory_reflection_service import (
@@ -38,19 +19,13 @@ from app.services.memory_reflection_service import (
)
from app.services.model_service import ModelConfigService
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, status,Header
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import text
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# Load environment variables for configuration
load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter(
prefix="/memory",
tags=["Memory"],
@@ -63,74 +38,65 @@ async def save_reflection_config(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
Save reflection configuration to memory config table
"""Save reflection configuration to data_comfig table"""
Persists reflection engine configuration settings to the data_config table,
including reflection parameters, model settings, and evaluation criteria.
Validates configuration parameters and ensures data consistency.
Args:
request: Memory reflection configuration data including:
- config_id: Configuration identifier to update
- reflection_enabled: Whether reflection is enabled
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model for reflection operations
- memory_verify: Enable memory verification checks
- quality_assessment: Enable quality assessment evaluation
current_user: Authenticated user saving the configuration
db: Database session for data operations
Returns:
dict: Success response with saved reflection configuration data
Raises:
HTTPException 400: If config_id is missing or parameters are invalid
HTTPException 500: If configuration save operation fails
Database Operations:
- Updates memory_config table with reflection settings
- Commits transaction and refreshes entity
- Maintains configuration consistency
"""
try:
config_id = request.config_id
config_id = resolve_config_id(config_id, db)
if not config_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="缺少必需参数: config_id"
)
api_logger.info(f"用户 {current_user.username} 保存反思配置config_id: {config_id}")
# Update reflection configuration in database
memory_config = MemoryConfigRepository.update_reflection_config(
db,
config_id=config_id,
enable_self_reflexion=request.reflection_enabled,
iteration_period=request.reflection_period_in_hours,
reflexion_range=request.reflexion_range,
baseline=request.baseline,
reflection_model_id=request.reflection_model_id,
memory_verify=request.memory_verify,
quality_assessment=request.quality_assessment
)
update_params = {
"enable_self_reflexion": request.reflection_enabled,
"iteration_period": request.reflection_period_in_hours,
"reflexion_range": request.reflexion_range,
"baseline": request.baseline,
"reflection_model_id": request.reflection_model_id,
"memory_verify": request.memory_verify,
"quality_assessment": request.quality_assessment,
}
# Commit transaction and refresh entity
query, params = DataConfigRepository.build_update_reflection(config_id, **update_params)
result = db.execute(text(query), params)
if result.rowcount == 0:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
db.commit()
db.refresh(memory_config)
# 查询更新后的配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"更新后未找到config_id为 {config_id} 的配置"
)
api_logger.info(f"成功保存反思配置到数据库config_id: {config_id}")
reflection_result={
"config_id": memory_config.config_id,
"enable_self_reflexion": memory_config.enable_self_reflexion,
"iteration_period": memory_config.iteration_period,
"reflexion_range": memory_config.reflexion_range,
"baseline": memory_config.baseline,
"reflection_model_id": memory_config.reflection_model_id,
"memory_verify": memory_config.memory_verify,
"quality_assessment": memory_config.quality_assessment}
"config_id": result.config_id,
"enable_self_reflexion": result.enable_self_reflexion,
"iteration_period": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment,
"user_id": result.user_id}
return success(data=reflection_result, msg="反思配置成功")
@@ -150,119 +116,48 @@ async def save_reflection_config(
)
@router.get("/reflection")
@router.post("/reflection")
async def start_workspace_reflection(
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
Start reflection functionality for all matching applications in workspace
Initiates reflection processes across all applications within the user's current
workspace that have valid memory configurations. Processes each application's
configurations and associated end users, executing reflection operations
with proper error isolation and transaction management.
This endpoint serves as a workspace-wide reflection orchestrator, ensuring
that reflection failures for individual users don't affect other operations.
Args:
current_user: Authenticated user initiating workspace reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection results for all processed applications:
- app_id: Application identifier
- config_id: Memory configuration identifier
- end_user_id: End user identifier
- reflection_result: Individual reflection operation result
Processing Logic:
1. Retrieve all applications in the current workspace
2. Filter applications with valid memory configurations
3. For each configuration, find matching releases
4. Execute reflection for each end user with isolated transactions
5. Aggregate results with error handling per user
Error Handling:
- Individual user reflection failures are isolated
- Failed operations are logged and included in results
- Database transactions are isolated per user to prevent cascading failures
- Comprehensive error reporting for debugging
Raises:
HTTPException 500: If workspace reflection initialization fails
Performance Notes:
- Uses independent database sessions for each user operation
- Prevents transaction failures from affecting other users
- Comprehensive logging for operation tracking
"""
"""Activate the reflection function for all matching applications in the workspace"""
workspace_id = current_user.current_workspace_id
reflection_service = MemoryReflectionService(db)
try:
api_logger.info(f"用户 {current_user.username} 启动workspace反思workspace_id: {workspace_id}")
# Use independent database session to get workspace app details, avoiding transaction failures
from app.db import get_db_context
with get_db_context() as query_db:
service = WorkspaceAppService(query_db)
result = service.get_workspace_apps_detailed(workspace_id)
service = WorkspaceAppService(db)
result = service.get_workspace_apps_detailed(workspace_id)
reflection_results = []
# Process each application in the workspace
for data in result['apps_detailed_info']:
# Skip applications without configurations
if not data['memory_configs']:
api_logger.debug(f"应用 {data['id']} 没有memory_configs跳过")
if data['data_configs'] == []:
continue
releases = data['releases']
memory_configs = data['memory_configs']
data_configs = data['data_configs']
end_users = data['end_users']
# Execute reflection for each configuration and user combination
for config in memory_configs:
config_id_str = str(config['config_id'])
# Find all releases matching this configuration
matching_releases = [r for r in releases if str(r['config']) == config_id_str]
if not matching_releases:
api_logger.debug(f"配置 {config_id_str} 没有匹配的release")
continue
# Execute reflection for each user - using independent database sessions
for user in end_users:
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config_id_str}")
# Create independent database session for each user to avoid transaction failure impact
with get_db_context() as user_db:
try:
reflection_service = MemoryReflectionService(user_db)
reflection_result = await reflection_service.start_text_reflection(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": reflection_result
})
except Exception as e:
api_logger.error(f"用户 {user['id']} 反思失败: {str(e)}")
reflection_results.append({
"app_id": data['id'],
"config_id": config_id_str,
"end_user_id": user['id'],
"reflection_result": {
"status": "错误",
"message": f"反思失败: {str(e)}"
}
})
for base, config, user in zip(releases, data_configs, end_users):
if int(base['config']) == int(config['config_id']) and base['app_id'] == user['app_id']:
# 调用反思服务
api_logger.info(f"为用户 {user['id']} 启动反思config_id: {config['config_id']}")
reflection_result = await reflection_service.start_reflection_from_data(
config_data=config,
end_user_id=user['id']
)
reflection_results.append({
"app_id": base['app_id'],
"config_id": config['config_id'],
"end_user_id": user['id'],
"reflection_result": reflection_result
})
return success(data=reflection_results, msg="反思配置成功")
@@ -276,73 +171,42 @@ async def start_workspace_reflection(
@router.get("/reflection/configs")
async def start_reflection_configs(
config_id: uuid.UUID|int,
config_id: int,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
Query reflection configuration information by config_id
Retrieves detailed reflection configuration settings from the memory_config
table for a specific configuration ID. Provides comprehensive reflection
parameters including model settings, evaluation criteria, and operational flags.
Args:
config_id: Configuration identifier (UUID or integer) to query
current_user: Authenticated user making the request
db: Database session for data operations
Returns:
dict: Success response with detailed reflection configuration:
- config_id: Resolved configuration identifier
- reflection_enabled: Whether reflection is enabled for this config
- reflection_period_in_hours: Reflection execution interval
- reflexion_range: Scope of reflection operations (partial/all)
- baseline: Reflection strategy (time/fact/hybrid)
- reflection_model_id: LLM model identifier for reflection
- memory_verify: Memory verification flag
- quality_assessment: Quality assessment flag
Database Operations:
- Queries memory_config table by resolved config_id
- Retrieves all reflection-related configuration fields
- Resolves configuration ID for consistent formatting
Raises:
HTTPException 404: If configuration with specified ID is not found
HTTPException 500: If configuration query operation fails
ID Resolution:
- Supports both UUID and integer config_id formats
- Automatically resolves to appropriate internal format
- Maintains consistency across different ID representations
"""
config_id = resolve_config_id(config_id, db)
"""通过config_id查询data_config表中的反思配置信息"""
try:
config_id=resolve_config_id(config_id,db)
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
memory_config_id = resolve_config_id(result.config_id, db)
# Build response data with comprehensive configuration details
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"未找到config_id为 {config_id} 的配置"
)
# 构建返回数据
reflection_config = {
"config_id": memory_config_id,
"config_id": result.config_id,
"reflection_enabled": result.enable_self_reflexion,
"reflection_period_in_hours": result.iteration_period,
"reflexion_range": result.reflexion_range,
"baseline": result.baseline,
"reflection_model_id": result.reflection_model_id,
"memory_verify": result.memory_verify,
"quality_assessment": result.quality_assessment
"quality_assessment": result.quality_assessment,
"user_id": result.user_id
}
api_logger.info(f"成功查询反思配置config_id: {config_id}")
return success(data=reflection_config, msg="反思配置查询成功")
api_logger.info(f"Successfully queried reflection config, config_id: {config_id}")
return success(data=reflection_config, msg="Reflection configuration query successful")
except HTTPException:
# Re-raise HTTP exceptions without modification
# 重新抛出HTTP异常
raise
except Exception as e:
api_logger.error(f"查询反思配置失败: {str(e)}")
@@ -353,72 +217,19 @@ async def start_reflection_configs(
@router.get("/reflection/run")
async def reflection_run(
config_id: UUID|int,
language_type: str = Header(default=None, alias="X-Language-Type"),
config_id: int,
language_type: str = "zh",
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""
Execute reflection engine with specified configuration
Runs the reflection engine using configuration parameters from the database.
Validates model availability, sets up the reflection engine with proper
configuration, and executes the reflection process with multi-language support.
This endpoint provides a test run capability for reflection configurations,
allowing users to validate their reflection settings and see results before
deploying to production environments.
Args:
config_id: Configuration identifier (UUID or integer) for reflection settings
language_type: Language preference header for output localization (optional)
current_user: Authenticated user executing the reflection
db: Database session for configuration queries
Returns:
dict: Success response with reflection execution results including:
- baseline: Reflection strategy used
- source_data: Input data processed
- memory_verifies: Memory verification results (if enabled)
- quality_assessments: Quality assessment results (if enabled)
- reflexion_data: Generated reflection insights and solutions
Configuration Validation:
- Verifies configuration exists in database
- Validates LLM model availability
- Falls back to default model if specified model is unavailable
- Ensures all required parameters are properly set
Reflection Engine Setup:
- Creates ReflectionConfig with database parameters
- Initializes Neo4j connector for memory access
- Sets up ReflectionEngine with validated model
- Configures language preferences for output
Error Handling:
- Model validation with fallback to default
- Configuration validation and error reporting
- Comprehensive logging for debugging
- Graceful handling of missing configurations
Raises:
HTTPException 404: If configuration is not found
HTTPException 500: If reflection execution fails
Performance Notes:
- Direct database query for configuration retrieval
- Model validation to prevent runtime failures
- Efficient reflection engine initialization
- Language-aware output processing
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type)
"""Activate the reflection function for all matching applications in the workspace"""
api_logger.info(f"用户 {current_user.username} 查询反思配置config_id: {config_id}")
config_id = resolve_config_id(config_id, db)
# Query reflection configuration using MemoryConfigRepository
result = MemoryConfigRepository.query_reflection_config_by_id(db, config_id)
# 使用DataConfigRepository查询反思配置
select_query, select_params = DataConfigRepository.build_select_reflection(config_id)
result = db.execute(text(select_query), select_params).fetchone()
if not result:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
@@ -427,23 +238,22 @@ async def reflection_run(
api_logger.info(f"成功查询反思配置config_id: {config_id}")
# Validate model ID existence
# 验证模型ID是否存在
model_id = result.reflection_model_id
if model_id:
try:
ModelConfigService.get_model_by_id(db=db, model_id=uuid.UUID(model_id))
ModelConfigService.get_model_by_id(db=db, model_id=model_id)
api_logger.info(f"模型ID验证成功: {model_id}")
except Exception as e:
api_logger.warning(f"模型ID '{model_id}' 不存在,将使用默认模型: {str(e)}")
# 可以设置为None让反思引擎使用默认模型
model_id = None
# Create reflection configuration with database parameters
config = ReflectionConfig(
enabled=result.enable_self_reflexion,
iteration_period=result.iteration_period,
reflexion_range=ReflectionRange(result.reflexion_range),
baseline=ReflectionBaseline(result.baseline),
reflexion_range=result.reflexion_range,
baseline=result.baseline,
output_example='',
memory_verify=result.memory_verify,
quality_assessment=result.quality_assessment,
@@ -451,13 +261,11 @@ async def reflection_run(
model_id=model_id,
language_type=language_type
)
# Initialize Neo4j connector and reflection engine
connector = Neo4jConnector()
engine = ReflectionEngine(
config=config,
neo4j_connector=connector,
llm_client=model_id # Pass validated model_id
llm_client=model_id # 传入验证后的 model_id
)
result=await (engine.reflection_run())

View File

@@ -1,40 +1,18 @@
"""
Memory Short Term Controller
This module provides REST API endpoints for managing short-term and long-term memory
data retrieval and analysis. It handles memory system statistics, data aggregation,
and provides comprehensive memory insights for end users.
Key Features:
- Short-term memory data retrieval and statistics
- Long-term memory data aggregation
- Entity count integration
- Multi-language response support
- Memory system analytics and reporting
"""
from typing import Optional
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy.orm import Session
from app.core.language_utils import get_language_from_header
from fastapi import APIRouter, Depends, HTTPException, status
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.services.memory_short_service import LongService, ShortService
from app.services.memory_storage_service import search_entity
# Load environment variables for configuration
from app.services.memory_short_service import ShortService,LongService
from dotenv import load_dotenv
from sqlalchemy.orm import Session
from typing import Optional
load_dotenv()
# Initialize API logger for request tracking and debugging
api_logger = get_api_logger()
# Configure router with prefix and tags for API organization
router = APIRouter(
prefix="/memory/short",
tags=["Memory"],
@@ -42,77 +20,24 @@ router = APIRouter(
@router.get("/short_term")
async def short_term_configs(
end_user_id: str,
language_type:str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Retrieve comprehensive short-term and long-term memory statistics
Provides a comprehensive overview of memory system data for a specific end user,
including short-term memory entries, long-term memory aggregations, entity counts,
and retrieval statistics. Supports multi-language responses based on request headers.
This endpoint serves as a central dashboard for memory system analytics, combining
data from multiple memory subsystems to provide a holistic view of user memory state.
Args:
end_user_id: Unique identifier for the end user whose memory data to retrieve
language_type: Language preference header for response localization (optional)
current_user: Authenticated user making the request (injected by dependency)
db: Database session for data operations (injected by dependency)
Returns:
dict: Success response containing comprehensive memory statistics:
- short_term: List of short-term memory entries with detailed data
- long_term: List of long-term memory aggregations and summaries
- entity: Count of entities associated with the end user
- retrieval_number: Total count of short-term memory retrievals
- long_term_number: Total count of long-term memory entries
Response Structure:
{
"code": 200,
"msg": "Short-term memory system data retrieved successfully",
"data": {
"short_term": [...], # Short-term memory entries
"long_term": [...], # Long-term memory data
"entity": 42, # Entity count
"retrieval_number": 156, # Short-term retrieval count
"long_term_number": 23 # Long-term memory count
}
}
Raises:
HTTPException: If end_user_id is invalid or data retrieval fails
Performance Notes:
- Combines multiple service calls for comprehensive data
- Entity search is performed asynchronously for better performance
- Response time depends on memory data volume for the specified user
"""
# Use centralized language validation for consistent localization
language = get_language_from_header(language_type)
# Retrieve short-term memory data and statistics
short_term = ShortService(end_user_id, db)
short_result = short_term.get_short_databasets() # Get short-term memory entries
short_count = short_term.get_short_count() # Get short-term retrieval count
# 获取短期记忆数据
short_term=ShortService(end_user_id)
short_result=short_term.get_short_databasets()
short_count=short_term.get_short_count()
# Retrieve long-term memory data and aggregations
long_term = LongService(end_user_id, db)
long_result = long_term.get_long_databasets() # Get long-term memory entries
long_term=LongService(end_user_id)
long_result=long_term.get_long_databasets()
# Get entity count for the specified end user
entity_result = await search_entity(end_user_id)
# Compile comprehensive memory statistics response
result = {
'short_term': short_result, # Short-term memory entries
'long_term': long_result, # Long-term memory data
'entity': entity_result.get('num', 0), # Entity count (default to 0 if not found)
"retrieval_number": short_count, # Short-term retrieval statistics
"long_term_number": len(long_result) # Long-term memory entry count
'short_term': short_result,
'long_term': long_result,
'entity': entity_result.get('num', 0),
"retrieval_number":short_count,
"long_term_number":len(long_result)
}
return success(data=result, msg="短期记忆系统数据获取成功")

View File

@@ -1,13 +1,10 @@
import os
import uuid
from typing import Optional
from uuid import UUID
from fastapi import APIRouter, Depends, Query
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.memory.utils.self_reflexion_utils import self_reflexion
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
@@ -15,6 +12,7 @@ from app.models.user_model import User
from app.schemas.memory_storage_schema import (
ConfigKey,
ConfigParamsCreate,
ConfigParamsDelete,
ConfigPilotRun,
ConfigUpdate,
ConfigUpdateExtracted,
@@ -32,14 +30,13 @@ from app.services.memory_storage_service import (
search_dialogue,
search_edges,
search_entity,
search_entity_graph,
search_statement,
)
from fastapi import APIRouter, Depends, Header
from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.utils.config_utils import resolve_config_id
# Get API logger
api_logger = get_api_logger()
@@ -75,9 +72,68 @@ async def get_storage_info(
return fail(BizCode.INTERNAL_ERROR, "存储信息获取失败", str(e))
# --- DB connection dependency ---
_CONN: Optional[object] = None
"""PostgreSQL 连接生成与管理(使用 psycopg2"""
# 这个可以转移,可能是已经有的
# PostgreSQL 数据库连接
def _make_pgsql_conn() -> Optional[object]: # 创建 PostgreSQL 数据库连接
host = os.getenv("DB_HOST")
user = os.getenv("DB_USER")
password = os.getenv("DB_PASSWORD")
database = os.getenv("DB_NAME")
port_str = os.getenv("DB_PORT")
try:
import psycopg2 # type: ignore
port = int(port_str) if port_str else 5432
conn = psycopg2.connect(
host=host or "localhost",
port=port,
user=user,
password=password,
dbname=database,
)
# 设置自动提交,避免显式事务管理
conn.autocommit = True
# 设置会话时区为中国标准时间Asia/Shanghai便于直接以本地时区展示
try:
cur = conn.cursor()
cur.execute("SET TIME ZONE 'Asia/Shanghai'")
cur.close()
except Exception:
# 时区设置失败不影响连接,仅记录但不抛出
pass
return conn
except Exception as e:
try:
print(f"[PostgreSQL] 连接失败: {e}")
except Exception:
pass
return None
def get_db_conn() -> Optional[object]: # 获取 PostgreSQL 数据库连接
global _CONN
if _CONN is None:
_CONN = _make_pgsql_conn()
return _CONN
def reset_db_conn() -> bool: # 重置 PostgreSQL 数据库连接
"""Close and recreate the global DB connection."""
global _CONN
try:
if _CONN:
try:
_CONN.close()
except Exception:
pass
_CONN = _make_pgsql_conn()
return _CONN is not None
except Exception:
_CONN = None
return False
@router.post("/create_config", response_model=ApiResponse) # 创建配置文件,其他参数默认
@@ -85,9 +141,9 @@ def create_config(
payload: ConfigParamsCreate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
x_language_type: Optional[str] = Header(None, alias="X-Language-Type"),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试创建配置但未选择工作空间")
@@ -100,125 +156,46 @@ def create_config(
svc = DataConfigService(db)
result = svc.create(payload)
return success(data=result, msg="创建成功")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CONFIG_NAME:"):
config_name = err_str.split(":", 1)[1]
api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {err_str}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str)
except Exception as e:
from sqlalchemy.exc import IntegrityError
if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')):
api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}")
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Create config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e))
@router.delete("/delete_config", response_model=ApiResponse) # 删除数据库中的内容(按配置名称)
def delete_config(
config_id: UUID|int,
force: bool = Query(False, description="是否强制删除(即使有终端用户正在使用)"),
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
"""删除记忆配置(带终端用户保护)
- 检查是否为默认配置,默认配置不允许删除
- 检查是否有终端用户连接到该配置
- 如果有连接且 force=False返回警告
- 如果 force=True清除终端用户引用后删除配置
Query Parameters:
force: 设置为 true 可强制删除(即使有终端用户正在使用)
"""
) -> dict:
workspace_id = current_user.current_workspace_id
config_id=resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试删除配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: "
f"config_id={config_id}, force={force}"
)
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求删除配置: {config_id}")
try:
# 使用带保护的删除服务
from app.services.memory_config_service import MemoryConfigService
config_service = MemoryConfigService(db)
result = config_service.delete_config(config_id=config_id, force=force)
if result["status"] == "error":
api_logger.warning(
f"记忆配置删除被拒绝: config_id={config_id}, reason={result['message']}"
)
return fail(
code=BizCode.FORBIDDEN,
msg=result["message"],
data={"config_id": str(config_id), "is_default": result.get("is_default", False)}
)
if result["status"] == "warning":
api_logger.warning(
f"记忆配置正在使用,无法删除: config_id={config_id}, "
f"connected_count={result['connected_count']}"
)
return fail(
code=BizCode.RESOURCE_IN_USE,
msg=result["message"],
data={
"connected_count": result["connected_count"],
"force_required": result["force_required"]
}
)
api_logger.info(
f"记忆配置删除成功: config_id={config_id}, "
f"affected_users={result['affected_users']}"
)
return success(
msg=result["message"],
data={"affected_users": result["affected_users"]}
)
svc = DataConfigService(db)
result = svc.delete(ConfigParamsDelete(config_id=config_id))
return success(data=result, msg="删除成功")
except Exception as e:
api_logger.error(f"Delete config failed: {str(e)}", exc_info=True)
api_logger.error(f"Delete config failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "删除配置失败", str(e))
@router.post("/update_config", response_model=ApiResponse) # 更新配置文件中name和desc
def update_config(
payload: ConfigUpdate,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
# 校验至少有一个字段需要更新
if payload.config_name is None and payload.config_desc is None and payload.scene_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新配置但未提供任何更新字段")
return fail(BizCode.INVALID_PARAMETER, "请至少提供一个需要更新的字段", "config_name, config_desc, scene_id 均为空")
api_logger.info(f"用户 {current_user.username} 在工作空间 {workspace_id} 请求更新配置: {payload.config_id}")
try:
svc = DataConfigService(db)
@@ -234,9 +211,9 @@ def update_config_extracted(
payload: ConfigUpdateExtracted,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
payload.config_id = resolve_config_id(payload.config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新提取配置但未选择工作空间")
@@ -258,12 +235,12 @@ def update_config_extracted(
@router.get("/read_config_extracted", response_model=ApiResponse) # 通过查询参数读取某条配置(固定路径) 没有意义的话就删除
def read_config_extracted(
config_id: UUID | int,
config_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
config_id = resolve_config_id(config_id, db)
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试读取提取配置但未选择工作空间")
@@ -282,7 +259,7 @@ def read_config_extracted(
def read_all_config(
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
) -> dict:
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -304,22 +281,16 @@ def read_all_config(
@router.post("/pilot_run", response_model=None)
async def pilot_run(
payload: ConfigPilotRun,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StreamingResponse:
# 使用集中化的语言校验
language = get_language_from_header(language_type)
api_logger.info(
f"Pilot run requested: config_id={payload.config_id}, "
f"dialogue_text_length={len(payload.dialogue_text)}, "
f"custom_text_length={len(payload.custom_text) if payload.custom_text else 0}"
f"dialogue_text_length={len(payload.dialogue_text)}"
)
payload.config_id = resolve_config_id(payload.config_id, db)
svc = DataConfigService(db)
return StreamingResponse(
svc.pilot_run_stream(payload, language=language),
svc.pilot_run_stream(payload),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
@@ -328,8 +299,9 @@ async def pilot_run(
},
)
# ==================== Search & Analytics ====================
"""
以下为搜索与分析接口,直接挂载到同一 router统一响应为 ApiResponse。
"""
@router.get("/search/kb_type_distribution", response_model=ApiResponse)
async def get_kb_type_distribution(
@@ -442,7 +414,21 @@ async def search_entity_edges(
api_logger.error(f"Search edges failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "边查询失败", str(e))
@router.get("/search/entity_graph", response_model=ApiResponse)
async def search_for_entity_graph(
end_user_id: Optional[str] = None,
current_user: User = Depends(get_current_user),
) -> dict:
"""
搜索所有实体之间的关系网络
"""
api_logger.info(f"Search entity graph requested for end_user_id: {end_user_id}")
try:
result = await search_entity_graph(end_user_id)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Search entity graph failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "实体图查询失败", str(e))
@router.get("/analytics/hot_memory_tags", response_model=ApiResponse)
@@ -451,106 +437,39 @@ async def get_hot_memory_tags_api(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
) -> dict:
"""
获取热门记忆标签带Redis缓存
缓存策略:
- 缓存键workspace_id + limit
- 过期时间5分钟300秒
- 缓存命中:~50ms
- 缓存未命中:~600-800ms取决于LLM速度
"""
workspace_id = current_user.current_workspace_id
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}")
try:
# 尝试从Redis缓存获取
import json
from app.aioRedis import aio_redis_get, aio_redis_set
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
try:
data = json.loads(cached_result)
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串
try:
cache_data = json.dumps(result, ensure_ascii=False)
await aio_redis_set(cache_key, cache_data, expire=300)
api_logger.info(f"Cached result for key: {cache_key}")
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user),
) -> dict:
"""
清除热门标签缓存
用于:
- 手动刷新数据
- 调试和测试
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try:
from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
result = await aio_redis_delete(cache_key)
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
return success(
data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user),
) -> dict:
workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None
api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}")
) -> dict:
api_logger.info("Recent activity stats requested")
try:
result = await analytics_recent_activity_stats(workspace_id=workspace_id)
result = await analytics_recent_activity_stats()
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"Recent activity stats failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "最近活动统计失败", str(e))
@router.get("/self_reflexion")
async def self_reflexion_endpoint(host_id: uuid.UUID) -> str:
"""
自我反思接口,自动对检索出的信息进行自我反思并返回自我反思结果。
Args:
None
Returns:
自我反思结果。
"""
return await self_reflexion(host_id)

View File

@@ -8,7 +8,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.schemas import conversation_schema
from app.schemas.response_schema import ApiResponse
from app.services.conversation_service import ConversationService
@@ -21,18 +20,18 @@ router = APIRouter(
)
@router.get("/{end_user_id}/count", response_model=ApiResponse)
@router.get("/{group_id}/count", response_model=ApiResponse)
def get_memory_count(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
pass
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
@router.get("/{group_id}/conversations", response_model=ApiResponse)
def get_conversations(
end_user_id: uuid.UUID,
group_id: uuid.UUID,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
@@ -40,7 +39,7 @@ def get_conversations(
Retrieve all conversations for the current user in a specific group.
Args:
end_user_id (UUID): The group identifier.
group_id (UUID): The group identifier.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
@@ -54,7 +53,7 @@ def get_conversations(
"""
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
end_user_id
group_id
)
return success(data=[
{
@@ -64,7 +63,7 @@ def get_conversations(
], msg="get conversations success")
@router.get("/{end_user_id}/messages", response_model=ApiResponse)
@router.get("/{group_id}/messages", response_model=ApiResponse)
def get_messages(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),
@@ -91,13 +90,17 @@ def get_messages(
conversation_id,
)
messages = [
conversation_schema.Message.model_validate(message)
{
"role": message.role,
"content": message.content,
"created_at": int(message.created_at.timestamp() * 1000),
}
for message in messages_obj
]
return success(data=messages, msg="get conversation history success")
@router.get("/{end_user_id}/detail", response_model=ApiResponse)
@router.get("/{group_id}/detail", response_model=ApiResponse)
async def get_conversation_detail(
conversation_id: uuid.UUID,
current_user: User = Depends(get_current_user),

View File

@@ -3,17 +3,15 @@ from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user
from app.models.models_model import ModelProvider, ModelType, LoadBalanceStrategy
from app.models.models_model import ModelProvider, ModelType
from app.models.user_model import User
from app.repositories.model_repository import ModelConfigRepository
from app.schemas import model_schema
from app.core.response_utils import success
from app.schemas.response_schema import ApiResponse, PageData
from app.services.model_service import ModelConfigService, ModelApiKeyService, ModelBaseService
from app.services.model_service import ModelConfigService, ModelApiKeyService
from app.core.logging_config import get_api_logger
# 获取API专用日志器
@@ -26,54 +24,44 @@ router = APIRouter(
@router.get("/type", response_model=ApiResponse)
def get_model_types():
return success(msg="获取模型类型成功", data=list(ModelType))
@router.get("/provider", response_model=ApiResponse)
def get_model_providers():
providers = [p for p in ModelProvider if p != ModelProvider.COMPOSITE]
return success(msg="获取模型提供商成功", data=providers)
@router.get("/strategy", response_model=ApiResponse)
def get_model_strategies():
return success(msg="获取模型策略成功", data=list(LoadBalanceStrategy))
return success(msg="获取模型提供商成功", data=list(ModelProvider))
@router.get("", response_model=ApiResponse)
def get_model_list(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
type: Optional[str] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于API Key)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个(逗号分隔):?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(
f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, page={page}, pagesize={pagesize}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = []
if type is not None:
flat_type = []
for item in type:
split_items = [t.strip() for t in item.split(',') if t.strip()]
flat_type.extend(split_items)
unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type]
type_list = None
if type:
type_values = [t.strip() for t in type.split(',')]
type_list = [model_schema.ModelType(t.lower()) for t in type_values if t]
api_logger.error(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQuery(
type=type_list,
@@ -84,7 +72,7 @@ def get_model_list(
page=page,
pagesize=pagesize
)
api_logger.debug(f"开始获取模型配置列表: {query.dict()}")
result_orm = ModelConfigService.get_model_list(db=db, query=query, tenant_id=current_user.tenant_id)
result = PageData.model_validate(result_orm)
@@ -95,146 +83,6 @@ def get_model_list(
raise
@router.get("/new", response_model=ApiResponse)
def get_model_list_new(
type: Optional[list[str]] = Query(None, description="模型类型筛选(支持多个,如 ?type=LLM 或 ?type=LLM,EMBEDDING"),
provider: Optional[model_schema.ModelProvider] = Query(None, description="提供商筛选(基于ModelConfig)"),
is_active: Optional[bool] = Query(None, description="激活状态筛选"),
is_public: Optional[bool] = Query(None, description="公开状态筛选"),
search: Optional[str] = Query(None, description="搜索关键词"),
is_composite: Optional[bool] = Query(None, description="组合模型筛选"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
获取模型配置列表
支持多个 type 参数:
- 单个:?type=LLM
- 多个(逗号分隔):?type=LLM,EMBEDDING
- 多个(重复参数):?type=LLM&type=EMBEDDING
"""
api_logger.info(f"获取模型配置列表请求: type={type}, provider={provider}, tenant_id={current_user.tenant_id}")
try:
# 解析 type 参数(支持逗号分隔)
type_list = []
if type is not None:
flat_type = []
for item in type:
split_items = [t.strip() for t in item.split(',') if t.strip()]
flat_type.extend(split_items)
unique_flat_type = list(dict.fromkeys(flat_type))
type_list = [ModelType(t.lower()) for t in unique_flat_type]
api_logger.info(f"获取模型type_list: {type_list}")
query = model_schema.ModelConfigQueryNew(
type=type_list,
provider=provider,
is_active=is_active,
is_public=is_public,
is_composite=is_composite,
search=search
)
api_logger.debug(f"开始获取模型配置列表: {query.model_dump()}")
result = ModelConfigService.get_model_list_new(db=db, query=query, tenant_id=current_user.tenant_id)
api_logger.info(f"模型配置列表获取成功: 分组数={len(result)}, 总模型数={sum(len(item['models']) for item in result)}")
return success(data=result, msg="模型配置列表获取成功")
except Exception as e:
api_logger.error(f"获取模型配置列表失败: {str(e)}")
raise
@router.get("/model_plaza", response_model=ApiResponse)
def get_model_plaza_list(
type: Optional[ModelType] = Query(None, description="模型类型"),
provider: Optional[ModelProvider] = Query(None, description="供应商"),
is_official: Optional[bool] = Query(None, description="是否官方模型"),
is_deprecated: Optional[bool] = Query(None, description="是否弃用"),
search: Optional[str] = Query(None, description="搜索关键词"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""模型广场查询接口(按供应商分组)"""
query = model_schema.ModelBaseQuery(
type=type,
provider=provider,
is_official=is_official,
is_deprecated=is_deprecated,
search=search
)
result = ModelBaseService.get_model_base_list(db=db, query=query, tenant_id=current_user.tenant_id)
return success(data=result, msg="模型广场列表获取成功")
@router.get("/model_plaza/{model_base_id}", response_model=ApiResponse)
def get_model_base_by_id(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取基础模型详情"""
result = ModelBaseService.get_model_base_by_id(db=db, model_base_id=model_base_id)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型获取成功")
@router.post("/model_plaza", response_model=ApiResponse)
def create_model_base(
data: model_schema.ModelBaseCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建基础模型"""
result = ModelBaseService.create_model_base(db=db, data=data)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型创建成功")
@router.put("/model_plaza/{model_base_id}", response_model=ApiResponse)
def update_model_base(
model_base_id: uuid.UUID,
data: model_schema.ModelBaseUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新基础模型"""
# 不允许更改type类型
if data.type is not None or data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
result = ModelBaseService.update_model_base(db=db, model_base_id=model_base_id, data=data)
return success(data=model_schema.ModelBase.model_validate(result), msg="基础模型更新成功")
@router.delete("/model_plaza/{model_base_id}", response_model=ApiResponse)
def delete_model_base(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除基础模型"""
ModelBaseService.delete_model_base(db=db, model_base_id=model_base_id)
return success(msg="基础模型删除成功")
@router.post("/model_plaza/{model_base_id}/add", response_model=ApiResponse)
def add_model_from_plaza(
model_base_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""从模型广场添加模型到模型列表"""
result = ModelBaseService.add_model_from_plaza(db=db, model_base_id=model_base_id, tenant_id=current_user.tenant_id)
return success(data=model_schema.ModelConfig.model_validate(result), msg="模型添加成功")
@router.get("/{model_id}", response_model=ApiResponse)
def get_model_by_id(
model_id: uuid.UUID,
@@ -290,73 +138,6 @@ async def create_model(
raise
@router.post("/composite", response_model=ApiResponse)
async def create_composite_model(
model_data: model_schema.CompositeModelCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
创建组合模型
- 绑定一个或多个现有的 API Key
- 所有 API Key 必须来自非组合模型
- 所有 API Key 关联的模型类型必须与组合模型类型一致
"""
api_logger.info(f"创建组合模型请求: {model_data.name}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
try:
result_orm = await ModelConfigService.create_composite_model(db=db, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型创建成功: {result_orm.name} (ID: {result_orm.id})")
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="组合模型创建成功")
except Exception as e:
api_logger.error(f"创建组合模型失败: {model_data.name} - {str(e)}")
raise
@router.put("/composite/{model_id}", response_model=ApiResponse)
async def update_composite_model(
model_id: uuid.UUID,
model_data: model_schema.CompositeModelCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新组合模型"""
api_logger.info(f"更新组合模型请求: model_id={model_id}, 用户: {current_user.username}")
try:
if model_data.type is not None:
raise BusinessException("不允许更改模型类型", BizCode.INVALID_PARAMETER)
result_orm = await ModelConfigService.update_composite_model(db=db, model_id=model_id, model_data=model_data, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型更新成功: {result_orm.name} (ID: {model_id})")
result = model_schema.ModelConfig.model_validate(result_orm)
return success(data=result, msg="组合模型更新成功")
except Exception as e:
api_logger.error(f"更新组合模型失败: model_id={model_id} - {str(e)}")
raise
@router.delete("/composite/{model_id}", response_model=ApiResponse)
def delete_composite_model(
model_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除组合模型"""
api_logger.info(f"删除组合模型请求: model_id={model_id}, 用户: {current_user.username}")
try:
ModelConfigService.delete_model(db=db, model_id=model_id, tenant_id=current_user.tenant_id)
api_logger.info(f"组合模型删除成功: model_id={model_id}")
return success(msg="组合模型删除成功")
except Exception as e:
api_logger.error(f"删除组合模型失败: model_id={model_id} - {str(e)}")
raise
@router.put("/{model_id}", response_model=ApiResponse)
def update_model(
model_id: uuid.UUID,
@@ -368,14 +149,6 @@ def update_model(
更新模型配置
"""
api_logger.info(f"更新模型配置请求: model_id={model_id}, 用户: {current_user.username}, tenant_id={current_user.tenant_id}")
if model_data.type is not None or model_data.provider is not None:
raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER)
if model_data.is_active:
active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active)
if not active_keys:
raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER)
try:
api_logger.debug(f"开始更新模型配置: model_id={model_id}")
@@ -441,55 +214,6 @@ def get_model_api_keys(
raise
@router.post("/provider/apikeys", response_model=ApiResponse)
async def create_model_api_key_by_provider(
api_key_data: model_schema.ModelApiKeyCreateByProvider,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""
根据供应商为所有匹配的模型创建API Key
"""
api_logger.info(f"创建API Key请求: provider={api_key_data.provider}, 用户: {current_user.username}")
try:
# 根据tenant_id和provider筛选model_config_id列表
model_config_ids = api_key_data.model_config_ids
if not model_config_ids:
model_config_ids = ModelConfigRepository.get_model_config_ids_by_provider(
db=db,
tenant_id=current_user.tenant_id,
provider=api_key_data.provider
)
if not model_config_ids:
raise BusinessException(f"未找到供应商 {api_key_data.provider} 的模型配置", BizCode.MODEL_NOT_FOUND)
# 构造schema并调用service
create_data = model_schema.ModelApiKeyCreateByProvider(
provider=api_key_data.provider,
api_key=api_key_data.api_key,
api_base=api_key_data.api_base,
description=api_key_data.description,
config=api_key_data.config,
is_active=api_key_data.is_active,
priority=api_key_data.priority,
model_config_ids=model_config_ids,
capability=api_key_data.capability,
is_omni=api_key_data.is_omni
)
created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data)
api_logger.info(f"API Key创建成功: 关联{len(created_keys)}个模型")
# result_list = [model_schema.ModelApiKey.model_validate(key) for key in created_keys]
result = "API Key已存在" if len(created_keys) == 0 and len(failed_models) == 0 else \
f"成功为 {len(created_keys)} 个模型创建API Key, 失败模型列表{failed_models}"
return success(data=result, msg=f"成功为 {len(created_keys)} 个模型创建API Key")
except Exception as e:
api_logger.error(f"创建API Key失败: {str(e)}")
raise
@router.post("/{model_id}/apikeys", response_model=ApiResponse, status_code=status.HTTP_201_CREATED)
async def create_model_api_key(
model_id: uuid.UUID,
@@ -504,12 +228,11 @@ async def create_model_api_key(
try:
# 设置模型配置ID
api_key_data.model_config_ids = [model_id]
api_key_data.model_config_id = model_id
api_logger.debug(f"开始创建模型API Key: {api_key_data.model_name}")
result_orm = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result_orm.model_name} (ID: {result_orm.id})")
result = model_schema.ModelApiKey.model_validate(result_orm)
result = await ModelApiKeyService.create_api_key(db=db, api_key_data=api_key_data)
api_logger.info(f"模型API Key创建成功: {result.model_name} (ID: {result.id})")
return success(data=result, msg="模型API Key创建成功")
except Exception as e:
api_logger.error(f"创建模型API Key失败: {api_key_data.model_name} - {str(e)}")
@@ -611,3 +334,5 @@ async def validate_model_config(
return success(data=model_schema.ModelValidateResponse(**result), msg="验证完成")

File diff suppressed because it is too large Load Diff

View File

@@ -1,663 +0,0 @@
# -*- coding: utf-8 -*-
"""本体场景和类型路由(续)
由于主Controller文件较大将剩余路由放在此文件中。
"""
from uuid import UUID
from typing import Optional
from fastapi import Depends, Header
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.logging_config import get_api_logger, get_business_logger
from app.core.response_utils import fail, success
from app.db import get_db
from app.dependencies import get_current_user
from app.models.user_model import User
from app.schemas.ontology_schemas import (
SceneResponse,
SceneListResponse,
PaginationInfo,
ClassCreateRequest,
ClassUpdateRequest,
ClassResponse,
ClassListResponse,
ClassBatchCreateResponse,
)
from app.schemas.response_schema import ApiResponse
from app.services.ontology_service import OntologyService
from app.core.memory.llm_tools.openai_client import OpenAIClient
from app.core.models.base import RedBearModelConfig
from app.repositories.ontology_class_repository import OntologyClassRepository
api_logger = get_api_logger()
business_logger = get_business_logger()
def _get_dummy_ontology_service(db: Session) -> OntologyService:
"""获取OntologyService实例不需要LLM
场景和类型管理不需要LLM创建一个dummy配置。
"""
dummy_config = RedBearModelConfig(
model_name="dummy",
provider="openai",
api_key="dummy",
base_url="https://api.openai.com/v1"
)
llm_client = OpenAIClient(model_config=dummy_config)
return OntologyService(llm_client=llm_client, db=db)
# 这些函数将被导入到主Controller中
async def scenes_handler(
workspace_id: Optional[str] = None,
scene_name: Optional[str] = None,
page: Optional[int] = None,
pagesize: Optional[int] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取场景列表(支持模糊搜索和全量查询,全量查询支持分页)
当提供 scene_name 参数时,进行模糊搜索(不分页);
当不提供 scene_name 参数时,返回所有场景(支持分页)。
Args:
workspace_id: 工作空间ID可选默认当前用户工作空间
scene_name: 场景名称关键词(可选,支持模糊匹配)
page: 页码可选从1开始仅在全量查询时有效
pagesize: 每页数量(可选,仅在全量查询时有效)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if scene_name else "list"
api_logger.info(
f"Scene {operation} requested by user {current_user.id}, "
f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}"
)
try:
# 确定工作空间ID
if workspace_id:
try:
ws_uuid = UUID(workspace_id)
except ValueError:
api_logger.warning(f"Invalid workspace_id format: {workspace_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的工作空间ID格式")
else:
ws_uuid = current_user.current_workspace_id
if not ws_uuid:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 根据是否提供 scene_name 决定查询方式
if scene_name and scene_name.strip():
# 验证分页参数(模糊搜索也支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
# 模糊搜索场景(支持分页)
scenes = service.search_scenes_by_name(scene_name.strip(), ws_uuid)
total = len(scenes)
# 如果提供了分页参数,进行分页处理
if page is not None and pagesize is not None:
start_idx = (page - 1) * pagesize
end_idx = start_idx + pagesize
scenes = scenes[start_idx:end_idx]
# 构建响应
items = []
for scene in scenes:
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(
f"Scene search completed: found {len(items)} scenes matching '{scene_name}' "
f"in workspace {ws_uuid}, total={total}"
)
else:
# 获取所有场景(支持分页)
if page is not None and page < 1:
api_logger.warning(f"Invalid page number: {page}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0")
if pagesize is not None and pagesize < 1:
api_logger.warning(f"Invalid pagesize: {pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0")
# 如果只提供了page或pagesize中的一个返回错误
if (page is not None and pagesize is None) or (page is None and pagesize is not None):
api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供")
scenes, total = service.list_scenes(ws_uuid, page, pagesize)
# 构建响应
items = []
for scene in scenes:
entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None
type_num = len(scene.classes) if scene.classes else 0
items.append(SceneResponse(
scene_id=scene.scene_id,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
type_num=type_num,
entity_type=entity_type,
workspace_id=scene.workspace_id,
created_at=scene.created_at,
updated_at=scene.updated_at,
classes_count=type_num,
is_system_default=scene.is_system_default
))
# 构建响应(包含分页信息)
if page is not None and pagesize is not None:
hasnext = (page * pagesize) < total
pagination_info = PaginationInfo(
page=page,
pagesize=pagesize,
total=total,
hasnext=hasnext
)
response = SceneListResponse(items=items, page=pagination_info)
else:
response = SceneListResponse(items=items)
api_logger.info(f"Scene list retrieved successfully, count={len(items)}, total={total}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in scene {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in scene {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
# ==================== 本体类型管理接口 ====================
async def create_class_handler(
request: ClassCreateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
x_language_type: Optional[str] = None
):
"""创建本体类型(统一使用列表形式,支持单个或批量)"""
# 根据列表长度判断是单个还是批量
count = len(request.classes)
mode = "single" if count == 1 else "batch"
api_logger.info(
f"Class creation ({mode}) requested by user {current_user.id}, "
f"scene_id={request.scene_id}, count={count}"
)
try:
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 准备类型数据
classes_data = [
{
"class_name": item.class_name,
"class_description": item.class_description
}
for item in request.classes
]
if count == 1:
# 单个创建 - 先检查重名
class_data = classes_data[0]
existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id)
if existing:
raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}")
ontology_class = service.create_class(
scene_id=request.scene_id,
class_name=class_data["class_name"],
class_description=class_data["class_description"],
workspace_id=workspace_id
)
# 构建单个响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class created successfully: {ontology_class.class_id}")
return success(data=response.model_dump(mode='json'), msg="类型创建成功")
else:
# 批量创建
created_classes, errors = service.create_classes_batch(
scene_id=request.scene_id,
classes=classes_data,
workspace_id=workspace_id
)
# 构建批量响应
items = []
for ontology_class in created_classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassBatchCreateResponse(
total=len(classes_data),
success_count=len(created_classes),
failed_count=len(errors),
items=items,
errors=errors if errors else None
)
api_logger.info(
f"Batch class creation completed: "
f"success={len(created_classes)}, failed={len(errors)}"
)
return success(data=response.model_dump(mode='json'), msg="批量创建完成")
except ValueError as e:
err_str = str(e)
if err_str.startswith("DUPLICATE_CLASS_NAME:"):
class_name = err_str.split(":", 1)[1]
api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.warning(f"Validation error in class creation: {err_str}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str)
except RuntimeError as e:
err_str = str(e)
if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str:
api_logger.warning(f"Duplicate class name in scene {request.scene_id}")
from app.core.language_utils import get_language_from_header
from fastapi.responses import JSONResponse
lang = get_language_from_header(x_language_type)
class_name = request.classes[0].class_name if request.classes else ""
if lang == "en":
msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.")
else:
msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称")
return JSONResponse(status_code=400, content=msg)
api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str)
except Exception as e:
api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e))
async def update_class_handler(
class_id: str,
request: ClassUpdateRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新本体类型"""
api_logger.info(
f"Class update requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试修改系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可修改",
"该类型为系统预设类型,不允许修改"
)
# 创建Service
service = _get_dummy_ontology_service(db)
# 更新类型
ontology_class = service.update_class(
class_id=class_uuid,
class_name=request.class_name,
class_description=request.class_description,
workspace_id=workspace_id
)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class updated successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="类型更新成功")
except ValueError as e:
api_logger.warning(f"Validation error in class update: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class update: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型更新失败", str(e))
async def delete_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除本体类型"""
api_logger.info(
f"Class deletion requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 检查是否为系统默认类型
class_repo = OntologyClassRepository(db)
ontology_class = class_repo.get_by_id(class_uuid)
if ontology_class and ontology_class.is_system_default:
business_logger.warning(
f"尝试删除系统默认类型: user_id={current_user.id}, "
f"class_id={class_id}, class_name={ontology_class.class_name}"
)
return fail(
BizCode.BAD_REQUEST,
"系统默认类型不可删除",
"该类型为系统预设类型,不允许删除"
)
# 创建Service
service = _get_dummy_ontology_service(db)
# 删除类型
success_flag = service.delete_class(
class_id=class_uuid,
workspace_id=workspace_id
)
api_logger.info(f"Class deleted successfully: {class_id}")
return success(data={"deleted": success_flag}, msg="类型删除成功")
except ValueError as e:
api_logger.warning(f"Validation error in class deletion: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class deletion: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "类型删除失败", str(e))
async def get_class_handler(
class_id: str,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取单个本体类型"""
api_logger.info(
f"Get class requested by user {current_user.id}, "
f"class_id={class_id}"
)
try:
# 验证UUID格式
try:
class_uuid = UUID(class_id)
except ValueError:
api_logger.warning(f"Invalid class_id format: {class_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的类型ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取类型会抛出ValueError如果不存在
ontology_class = service.get_class_by_id(class_uuid, workspace_id)
# 构建响应
response = ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
)
api_logger.info(f"Class retrieved successfully: {class_id}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
# 类型不存在或无权限访问
api_logger.warning(f"Validation error in get class: {str(e)}")
return fail(BizCode.NOT_FOUND, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in get class: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
async def classes_handler(
scene_id: str,
class_name: Optional[str] = None,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取类型列表(支持模糊搜索和全量查询)
当提供 class_name 参数时,进行模糊搜索;
当不提供 class_name 参数时,返回场景下的所有类型。
Args:
scene_id: 场景ID必填
class_name: 类型名称关键词(可选,支持模糊匹配)
db: 数据库会话
current_user: 当前用户
"""
operation = "search" if class_name else "list"
api_logger.info(
f"Class {operation} requested by user {current_user.id}, "
f"keyword={class_name}, scene_id={scene_id}"
)
try:
# 验证UUID格式
try:
scene_uuid = UUID(scene_id)
except ValueError:
api_logger.warning(f"Invalid scene_id format: {scene_id}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "无效的场景ID格式")
# 获取当前工作空间ID
workspace_id = current_user.current_workspace_id
if not workspace_id:
api_logger.warning(f"User {current_user.id} has no current workspace")
return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间")
# 创建Service
service = _get_dummy_ontology_service(db)
# 获取场景信息
scene = service.get_scene_by_id(scene_uuid, workspace_id)
if not scene:
api_logger.warning(f"Scene not found: {scene_id}")
return fail(BizCode.NOT_FOUND, "场景不存在", f"未找到ID为 {scene_id} 的场景")
# 根据是否提供 class_name 决定查询方式
if class_name and class_name.strip():
# 模糊搜索类型
classes = service.search_classes_by_name(class_name.strip(), scene_uuid, workspace_id)
else:
# 获取所有类型
classes = service.list_classes_by_scene(scene_uuid, workspace_id)
# 构建响应
items = []
for ontology_class in classes:
items.append(ClassResponse(
class_id=ontology_class.class_id,
class_name=ontology_class.class_name,
class_description=ontology_class.class_description,
scene_id=ontology_class.scene_id,
created_at=ontology_class.created_at,
updated_at=ontology_class.updated_at
))
response = ClassListResponse(
total=len(items),
scene_id=scene_uuid,
scene_name=scene.scene_name,
scene_description=scene.scene_description,
is_system_default=scene.is_system_default,
items=items
)
if class_name:
api_logger.info(
f"Class search completed: found {len(items)} classes matching '{class_name}' "
f"in scene {scene_id}"
)
else:
api_logger.info(f"Class list retrieved successfully, count={len(items)}")
return success(data=response.model_dump(mode='json'), msg="查询成功")
except ValueError as e:
api_logger.warning(f"Validation error in class {operation}: {str(e)}")
return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e))
except RuntimeError as e:
api_logger.error(f"Runtime error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))
except Exception as e:
api_logger.error(f"Unexpected error in class {operation}: {str(e)}", exc_info=True)
return fail(BizCode.INTERNAL_ERROR, "查询失败", str(e))

View File

@@ -1,5 +1,5 @@
import json
import uuid
import json
from fastapi import APIRouter, Depends, Path
from sqlalchemy.orm import Session
@@ -8,13 +8,9 @@ from starlette.responses import StreamingResponse
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.dependencies import get_current_user, get_db
from app.schemas.prompt_optimizer_schema import (
PromptOptMessage,
CreateSessionResponse,
SessionHistoryResponse,
SessionMessage,
PromptSaveRequest
)
from app.models.prompt_optimizer_model import RoleType
from app.schemas.prompt_optimizer_schema import PromptOptMessage, PromptOptModelSet, CreateSessionResponse, \
OptimizePromptResponse, SessionHistoryResponse, SessionMessage
from app.schemas.response_schema import ApiResponse
from app.services.prompt_optimizer_service import PromptOptimizerService
@@ -120,8 +116,7 @@ async def get_prompt_opt(
session_id=session_id,
user_id=current_user.id,
current_prompt=data.current_prompt,
user_require=data.message,
skill=data.skill
user_require=data.message
):
# chunk 是 prompt 的增量内容
yield f"event:message\ndata: {json.dumps(chunk)}\n\n"
@@ -140,109 +135,3 @@ async def get_prompt_opt(
"X-Accel-Buffering": "no"
}
)
@router.post(
"/releases",
summary="Get prompt optimization",
response_model=ApiResponse
)
def save_prompt(
data: PromptSaveRequest,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Save a prompt release for the current tenant.
Args:
data (PromptSaveRequest): Request body containing session_id, title, and prompt.
db (Session): SQLAlchemy database session, injected via dependency.
current_user: Currently authenticated user object, injected via dependency.
Returns:
ApiResponse: Standard API response containing the saved prompt release info:
- id: UUID of the prompt release
- session_id: associated session
- title: prompt title
- prompt: prompt content
- created_at: timestamp of creation
Raises:
Any database or service exceptions are propagated to the global exception handler.
"""
service = PromptOptimizerService(db)
prompt_info = service.save_prompt(
tenant_id=current_user.tenant_id,
session_id=data.session_id,
title=data.title,
prompt=data.prompt
)
return success(data=prompt_info)
@router.delete(
"/releases/{prompt_id}",
summary="Delete prompt (soft delete)",
response_model=ApiResponse
)
def delete_prompt(
prompt_id: uuid.UUID = Path(..., description="Prompt ID"),
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Soft delete a prompt release.
Args:
prompt_id
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Success message confirming deletion
"""
service = PromptOptimizerService(db)
service.delete_prompt(
tenant_id=current_user.tenant_id,
prompt_id=prompt_id
)
return success(msg="Prompt deleted successfully")
@router.get(
"/releases/list",
summary="Get paginated list of released prompts with optional filter",
response_model=ApiResponse
)
def get_release_list(
page: int = 1,
page_size: int = 20,
keyword: str | None = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""
Retrieve paginated list of released prompts for the current tenant.
Optionally filter by keyword in title.
Args:
page (int): Page number (starting from 1)
page_size (int): Number of items per page (max 100)
keyword (str | None): Optional keyword to filter prompt titles
db (Session): Database session
current_user: Current logged-in user
Returns:
ApiResponse: Contains paginated list of prompt releases with metadata
"""
service = PromptOptimizerService(db)
result = service.get_release_list(
tenant_id=current_user.tenant_id,
page=max(1, page),
page_size=min(max(1, page_size), 100),
filter_keyword=keyword
)
return success(data=result)

View File

@@ -2,33 +2,24 @@ import hashlib
import json
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Query, Request
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success, fail
from app.db import get_db, get_db_read
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_share_user_id, ShareTokenData
from app.models.app_model import AppType
from app.repositories import knowledge_repository
from app.repositories.end_user_repository import EndUserRepository
from app.repositories.workflow_repository import WorkflowConfigRepository
from app.schemas import release_share_schema, conversation_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import AppService
from app.services.auth_service import create_access_token
from app.services.conversation_service import ConversationService
from app.services.release_share_service import ReleaseShareService
from app.services.shared_chat_service import SharedChatService
from app.services.workflow_service import WorkflowService
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
router = APIRouter(prefix="/public/share", tags=["Public Share"])
logger = get_business_logger()
@@ -74,10 +65,10 @@ def get_or_generate_user_id(payload_user_id: str, request: Request) -> str:
summary="获取访问 token"
)
def get_access_token(
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
share_token: str,
payload: release_share_schema.TokenRequest,
request: Request,
db: Session = Depends(get_db),
):
"""获取访问 token
@@ -122,9 +113,9 @@ def get_access_token(
response_model=None
)
def get_shared_release(
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
password: str = Query(None, description="访问密码(如果需要)"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取公开分享的发布版本信息
@@ -146,9 +137,9 @@ def get_shared_release(
summary="验证访问密码"
)
def verify_password(
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
payload: release_share_schema.PasswordVerifyRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""验证分享的访问密码
@@ -168,11 +159,11 @@ def verify_password(
summary="获取嵌入代码"
)
def get_embed_code(
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
width: str = Query("100%", description="iframe 宽度"),
height: str = Query("600px", description="iframe 高度"),
request: Request = None,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取嵌入代码
@@ -192,6 +183,7 @@ def get_embed_code(
return success(data=embed_code)
# ---------- 会话管理接口 ----------
@router.get(
@@ -199,11 +191,11 @@ def get_embed_code(
summary="获取会话列表"
)
def list_conversations(
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
password: str = Query(None, description="访问密码"),
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取分享应用的会话列表
@@ -213,16 +205,15 @@ def list_conversations(
logger.debug(f"share_data:{share_data.user_id}")
other_id = share_data.user_id
service = SharedChatService(db)
share, release = service.get_release_by_share_token(share_data.share_token, password)
share, release = service._get_release_by_share_token(share_data.share_token, password)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=app.workspace_id,
other_id=other_id
)
app_id=share.app_id,
other_id=other_id
)
logger.debug(new_end_user.id)
service = SharedChatService(db)
conversations, total = service.list_conversations(
share_token=share_data.share_token,
user_id=str(new_end_user.id),
@@ -242,10 +233,10 @@ def list_conversations(
summary="获取会话详情(含消息)"
)
def get_conversation(
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
conversation_id: uuid.UUID,
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
"""获取会话详情和消息历史"""
chat_service = SharedChatService(db)
@@ -275,10 +266,10 @@ def get_conversation(
summary="发送消息(支持流式和非流式)"
)
async def chat(
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
payload: conversation_schema.ChatRequest,
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
):
"""发送消息并获取回复
@@ -301,39 +292,37 @@ async def chat(
# 提前验证和准备(在流式响应开始前完成)
# 这样可以确保错误能正确返回,而不是在流式响应中间出错
from app.models.app_model import AppType
try:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
from app.services.app_service import AppService
# 验证分享链接和密码
share, release = service.get_release_by_share_token(share_token, password)
share, release = service._get_release_by_share_token(share_token, password)
# # Create end_user_id by concatenating app_id with user_id
# end_user_id = f"{share.app_id}_{user_id}"
# Store end_user_id in database with original user_id
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(db)
app_service = AppService(db)
app = app_service._get_app_or_404(share.app_id)
workspace_id = app.workspace_id
new_end_user = end_user_repo.get_or_create_end_user(
app_id=share.app_id,
workspace_id=workspace_id,
other_id=other_id,
original_user_id=user_id
original_user_id=user_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
# appid = share.app_id
appid=share.app_id
"""获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
# app = db.query(App).filter(
# App.id == appid,
# App.is_active.is_(True)
# ).first()
# if not app:
# raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
# 直接通过 SQLAlchemy 查询 app
from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first()
if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)
# workspace_id = app.workspace_id
workspace_id = app.workspace_id
# 直接从 workspace 获取 storage_type公开分享场景无需权限检查
storage_type = workspace_service.get_workspace_storage_type_without_auth(
@@ -366,12 +355,12 @@ async def chat(
app_type = release.app.type if release.app else None
# 根据应用类型验证配置
if app_type == AppType.AGENT:
if app_type == "agent":
# Agent 类型:验证模型配置
model_config_id = release.default_model_config_id
if not model_config_id:
raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING)
elif app_type == AppType.MULTI_AGENT:
elif app_type == "multi_agent":
# Multi-Agent 类型:验证多 Agent 配置
config = release.config or {}
if not config.get("sub_agents"):
@@ -436,17 +425,16 @@ async def chat(
# )
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= str(new_end_user.id), # 转换为字符串
variables=payload.variables,
web_search=payload.web_search,
config=agent_config,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -483,8 +471,7 @@ async def chat(
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -494,15 +481,15 @@ async def chat(
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=str(new_end_user.id), # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -574,29 +561,24 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result))
elif app_type == AppType.WORKFLOW:
config = workflow_config_4_app_release(release)
if not config.id:
with get_db_read() as db:
source_config = WorkflowConfigRepository(db).get_by_app_id(release.app_id)
config.id = source_config.id
config.id = uuid.UUID(config.id)
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
files=payload.files,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id,
public=True
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
@@ -617,19 +599,18 @@ async def chat(
# 多 Agent 非流式返回
result = await app_chat_service.workflow_chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
files=payload.files,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=release.app_id,
workspace_id=workspace_id,
release_id=release.id
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
@@ -645,38 +626,6 @@ async def chat(
# return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
else:
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)
@router.get("/config", summary="获取应用启动配置")
async def config_query(
password: str = Query(None, description="访问密码"),
share_data: ShareTokenData = Depends(get_share_user_id),
db: Session = Depends(get_db),
):
share_service = SharedChatService(db)
share_token = share_data.share_token
share, release = share_service.get_release_by_share_token(share_token, password)
if release.app.type == AppType.WORKFLOW:
workflow_service = WorkflowService(db)
content = {
"app_type": release.app.type,
"variables": workflow_service.get_start_node_variables(release.config),
"memory": workflow_service.is_memory_enable(release.config),
"features": release.config.get("features")
}
elif release.app.type == AppType.AGENT:
content = {
"app_type": release.app.type,
"variables": release.config.get("variables"),
"features": release.config.get("features")
}
elif release.app.type == AppType.MULTI_AGENT:
content = {
"app_type": release.app.type,
"variables": [],
"features": release.config.get("features")
}
else:
return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED)
return success(data=content)

View File

@@ -12,6 +12,7 @@ from app.core.exceptions import BusinessException
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_app_or_workspace
from app.models.app_model import App
from app.models.app_model import AppType
from app.repositories import knowledge_repository
@@ -20,10 +21,9 @@ from app.schemas import AppChatRequest, conversation_schema
from app.schemas.api_key_schema import ApiKeyAuth
from app.services import workspace_service
from app.services.app_chat_service import AppChatService, get_app_chat_service
from app.services.app_service import get_app_service, AppService
from app.services.conversation_service import ConversationService, get_conversation_service
from app.utils.app_config_utils import workflow_config_4_app_release, \
agent_config_4_app_release, multi_agent_config_4_app_release
from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, agent_config_4_app_release, multi_agent_config_4_app_release
from app.services.app_service import get_app_service, AppService
router = APIRouter(prefix="/app", tags=["V1 - App API"])
logger = get_business_logger()
@@ -34,7 +34,6 @@ async def list_apps():
"""列出可访问的应用(占位)"""
return success(data=[], msg="App API - Coming Soon")
# /v1/app/chat
# @router.post("/chat")
@@ -74,33 +73,33 @@ def _checkAppConfig(app: App):
else:
raise BusinessException("不支持的应用类型", BizCode.AGENT_CONFIG_MISSING)
@router.post("/chat")
@require_api_key(scopes=["app"])
async def chat(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
app_service: Annotated[AppService, Depends(get_app_service)] = None,
message: str = Body(..., description="聊天消息内容"),
request:Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
conversation_service: Annotated[ConversationService, Depends(get_conversation_service)] = None,
app_chat_service: Annotated[AppChatService, Depends(get_app_chat_service)] = None,
app_service: Annotated[AppService, Depends(get_app_service)] = None,
message: str = Body(..., description="聊天消息内容"),
):
body = await request.json()
payload = AppChatRequest(**body)
other_id = payload.user_id
app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id)
other_id = payload.user_id
workspace_id = app.workspace_id
end_user_repo = EndUserRepository(db)
new_end_user = end_user_repo.get_or_create_end_user(
app_id=app.id,
workspace_id=workspace_id,
other_id=other_id,
original_user_id=other_id # Save original user_id to other_id
)
end_user_id = str(new_end_user.id)
web_search = True
memory = True
web_search=True
memory=True
# 提前验证和准备(在流式响应开始前完成)
storage_type = workspace_service.get_workspace_storage_type_without_auth(
db=db,
@@ -134,8 +133,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
user_id=end_user_id,
is_draft=False,
conversation_id=payload.conversation_id
is_draft=False
)
if app_type == AppType.AGENT:
@@ -148,17 +146,16 @@ async def chat(
if payload.stream:
async def event_generator():
async for event in app_chat_service.agnet_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
web_search=web_search,
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id= end_user_id, # 转换为字符串
variables=payload.variables,
web_search=web_search,
config=agent_config,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id
):
yield event
@@ -178,13 +175,12 @@ async def chat(
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=agent_config,
config= agent_config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
workspace_id=workspace_id,
files=payload.files # 传递多模态文件
workspace_id=workspace_id
)
return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json"))
elif app_type == AppType.MULTI_AGENT:
@@ -194,15 +190,15 @@ async def chat(
async def event_generator():
async for event in app_chat_service.multi_agent_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id
):
yield event
@@ -236,20 +232,18 @@ async def chat(
if payload.stream:
async def event_generator():
async for event in app_chat_service.workflow_chat_stream(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
variables=payload.variables,
files=payload.files,
config=config,
web_search=web_search,
memory=memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
app_id=app.app_id,
workspace_id=workspace_id
):
event_type = event.get("event", "message")
event_data = event.get("data", {})
@@ -273,17 +267,15 @@ async def chat(
message=payload.message,
conversation_id=conversation.id, # 使用已创建的会话 ID
user_id=end_user_id, # 转换为字符串
user_id=new_end_user.id, # 转换为字符串
variables=payload.variables,
config=config,
web_search=web_search,
memory=memory,
web_search=payload.web_search,
memory=payload.memory,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
files=payload.files,
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id
app_id=app.app_id,
workspace_id=workspace_id
)
logger.debug(
"工作流试运行返回结果",
@@ -300,3 +292,4 @@ async def chat(
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED)

View File

@@ -6,7 +6,6 @@ from app.core.response_utils import success
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import (
ListConfigsResponse,
MemoryReadRequest,
MemoryReadResponse,
MemoryWriteRequest,
@@ -32,16 +31,15 @@ async def write_memory_api_service(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Message content"),
payload: MemoryWriteRequest = Body(..., embed=False),
):
"""
Write memory to storage.
Stores memory content for the specified end user using the Memory API Service.
"""
body = await request.json()
payload = MemoryWriteRequest(**body)
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}")
memory_api_service = MemoryAPIService(db)
@@ -64,15 +62,13 @@ async def read_memory_api_service(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
message: str = Body(..., description="Query message"),
payload: MemoryReadRequest = Body(..., embed=False),
):
"""
Read memory from storage.
Queries and retrieves memories for the specified end user with context-aware responses.
"""
body = await request.json()
payload = MemoryReadRequest(**body)
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
memory_api_service = MemoryAPIService(db)
@@ -89,27 +85,3 @@ async def read_memory_api_service(
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
@router.get("/configs")
@require_api_key(scopes=["memory"])
async def list_memory_configs(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
List all memory configs for the workspace.
Returns all available memory configurations associated with the authorized workspace.
"""
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.list_memory_configs(
workspace_id=api_key_auth.workspace_id,
)
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")

View File

@@ -246,73 +246,3 @@ async def rebuild_knowledge_graph(
db=db,
current_user=current_user)
@router.get("/check/yuque/auth", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def check_yuque_auth(
yuque_user_id: str,
yuque_token: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
check yuque auth info
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
api_logger.info(f"check yuque auth info, username: {current_user.username}")
return await knowledge_controller.check_yuque_auth(yuque_user_id=yuque_user_id,
yuque_token=yuque_token,
db=db,
current_user=current_user)
@router.get("/check/feishu/auth", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def check_feishu_auth(
feishu_app_id: str,
feishu_app_secret: str,
feishu_folder_token: str,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
check feishu auth info
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
api_logger.info(f"check feishu auth info, username: {current_user.username}")
return await knowledge_controller.check_feishu_auth(feishu_app_id=feishu_app_id,
feishu_app_secret=feishu_app_secret,
feishu_folder_token=feishu_folder_token,
db=db,
current_user=current_user)
@router.post("/{knowledge_id}/sync", response_model=ApiResponse)
@require_api_key(scopes=["rag"])
async def sync_knowledge(
knowledge_id: uuid.UUID,
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
sync knowledge base information based on knowledge_id
"""
api_key = api_key_service.ApiKeyService.get_api_key(db, api_key_auth.api_key_id, api_key_auth.workspace_id)
current_user = api_key.creator
current_user.current_workspace_id = api_key_auth.workspace_id
return await knowledge_controller.sync_knowledge(knowledge_id=knowledge_id,
db=db,
current_user=current_user)

View File

@@ -1,85 +0,0 @@
"""Skill Controller - 技能市场管理"""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
import uuid
from app.db import get_db
from app.dependencies import get_current_user
from app.models import User
from app.schemas import skill_schema
from app.schemas.response_schema import PageData, PageMeta
from app.services.skill_service import SkillService
from app.core.response_utils import success
router = APIRouter(prefix="/skills", tags=["Skills"])
@router.post("", summary="创建技能")
def create_skill(
data: skill_schema.SkillCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""创建技能 - 可以关联现有工具内置、MCP、自定义"""
tenant_id = current_user.tenant_id
skill = SkillService.create_skill(db, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能创建成功")
@router.get("", summary="技能列表")
def list_skills(
search: Optional[str] = Query(None, description="搜索关键词"),
is_active: Optional[bool] = Query(None, description="是否激活"),
is_public: Optional[bool] = Query(None, description="是否公开"),
page: int = Query(1, ge=1, description="页码"),
pagesize: int = Query(10, ge=1, le=100, description="每页数量"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""技能市场列表 - 包含本工作空间和公开的技能"""
tenant_id = current_user.tenant_id
skills, total = SkillService.list_skills(
db, tenant_id, search, is_active, is_public, page, pagesize
)
items = [skill_schema.Skill.model_validate(s) for s in skills]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
return success(data=PageData(page=meta, items=items), msg="技能市场列表获取成功")
@router.get("/{skill_id}", summary="获取技能详情")
def get_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""获取技能详情"""
tenant_id = current_user.tenant_id
skill = SkillService.get_skill(db, skill_id, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="获取技能详情成功")
@router.put("/{skill_id}", summary="更新技能")
def update_skill(
skill_id: uuid.UUID,
data: skill_schema.SkillUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""更新技能"""
tenant_id = current_user.tenant_id
skill = SkillService.update_skill(db, skill_id, data, tenant_id)
return success(data=skill_schema.Skill.model_validate(skill), msg="技能更新成功")
@router.delete("/{skill_id}", summary="删除技能")
def delete_skill(
skill_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user)
):
"""删除技能"""
tenant_id = current_user.tenant_id
SkillService.delete_skill(db, skill_id, tenant_id)
return success(msg="技能删除成功")

View File

@@ -3,11 +3,8 @@ from typing import Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.schemas.tool_schema import (
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest,
CustomToolTestRequest, ToolActiveUpdate
ToolCreateRequest, ToolUpdateRequest, ToolExecuteRequest, ParseSchemaRequest, CustomToolTestRequest
)
from app.core.response_utils import success
@@ -17,7 +14,6 @@ from app.models import User
from app.models.tool_model import ToolType, ToolStatus, AuthType
from app.services.tool_service import ToolService
from app.schemas.response_schema import ApiResponse
from app.core.exceptions import BusinessException
router = APIRouter(prefix="/tools", tags=["Tool System"])
@@ -101,13 +97,7 @@ async def create_tool(
):
"""创建工具"""
try:
# 将 MCP 来源字段合并进 config
if request.tool_type == ToolType.MCP:
for key in ("source_channel", "market_id", "market_config_id", "mcp_service_id"):
val = getattr(request, key, None)
if val is not None:
request.config[key] = val
tool_id = await service.create_tool(
tool_id = service.create_tool(
name=request.name,
tool_type=request.tool_type,
tenant_id=current_user.tenant_id,
@@ -117,8 +107,6 @@ async def create_tool(
tags=request.tags
)
return success(data={"tool_id": tool_id}, msg="工具创建成功")
except BusinessException as e:
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
@@ -159,7 +147,7 @@ async def delete_tool(
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""删除工具逻辑删除is_active=False"""
"""删除工具"""
try:
success_flag = service.delete_tool(tool_id, current_user.tenant_id)
if not success_flag:
@@ -171,30 +159,6 @@ async def delete_tool(
raise HTTPException(status_code=500, detail=str(e))
@router.patch("/{tool_id}/active", response_model=ApiResponse)
async def set_tool_active(
tool_id: str,
request: ToolActiveUpdate,
current_user: User = Depends(get_current_user),
service: ToolService = Depends(get_tool_service)
):
"""设置工具可用状态(启用/禁用)
- is_active=true: 启用工具
- is_active=false: 禁用工具(等同于删除,但可恢复)
"""
try:
success_flag = service.set_tool_active(tool_id, current_user.tenant_id, request.is_active)
if not success_flag:
raise HTTPException(status_code=404, detail="工具不存在")
action = "启用" if request.is_active else "禁用"
return success(msg=f"工具已{action}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/execution/execute", response_model=ApiResponse)
async def execute_tool(
request: ToolExecuteRequest,
@@ -252,10 +216,8 @@ async def sync_mcp_tools(
try:
result = await service.sync_mcp_tools(tool_id, current_user.tenant_id)
if not result.get("success", False):
raise BusinessException(result.get("message", "工具列表同步失败"), BizCode.BAD_REQUEST)
raise HTTPException(status_code=400, detail=result.get("message", "同步失败"))
return success(data=result, msg="MCP工具列表同步完成")
except BusinessException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -278,10 +240,8 @@ async def test_tool_connection(
# 普通连接测试
result = await service.test_connection(tool_id, current_user.tenant_id)
if result["success"] is False:
raise BusinessException(result["message"], BizCode.SERVICE_UNAVAILABLE)
raise HTTPException(status_code=400, detail=result["message"])
return success(data=result, msg="连接测试完成")
except BusinessException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,26 +1,16 @@
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
import uuid
from typing import Callable
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.db import get_db
from app.dependencies import get_current_user, get_current_superuser
from app.models.user_model import User
from app.schemas import user_schema
from app.schemas.user_schema import (
ChangePasswordRequest,
AdminChangePasswordRequest,
SendEmailCodeRequest,
VerifyEmailCodeRequest,
VerifyPasswordRequest)
from app.schemas.user_schema import ChangePasswordRequest, AdminChangePasswordRequest
from app.schemas.response_schema import ApiResponse
from app.services import user_service
from app.core.logging_config import get_api_logger
from app.core.response_utils import success
from app.core.security import verify_password
from app.i18n.dependencies import get_translator
# 获取API专用日志器
api_logger = get_api_logger()
@@ -35,8 +25,7 @@ router = APIRouter(
def create_superuser(
user: user_schema.UserCreate,
db: Session = Depends(get_db),
current_superuser: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
current_superuser: User = Depends(get_current_superuser)
):
"""创建超级管理员(仅超级管理员可访问)"""
api_logger.info(f"超级管理员创建请求: {user.username}, email: {user.email}")
@@ -45,7 +34,7 @@ def create_superuser(
api_logger.info(f"超级管理员创建成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg=t("users.create.superuser_success"))
return success(data=result_schema, msg="超级管理员创建成功")
@router.delete("/{user_id}", response_model=ApiResponse)
@@ -53,7 +42,6 @@ def delete_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""停用用户(软删除)"""
api_logger.info(f"用户停用请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -61,14 +49,13 @@ def delete_user(
db=db, user_id_to_deactivate=user_id, current_user=current_user
)
api_logger.info(f"用户停用成功: {result.username} (ID: {result.id})")
return success(msg=t("users.delete.deactivate_success"))
return success(msg="用户停用成功")
@router.post("/{user_id}/activate", response_model=ApiResponse)
def activate_user(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""激活用户"""
api_logger.info(f"用户激活请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -79,14 +66,13 @@ def activate_user(
api_logger.info(f"用户激活成功: {result.username} (ID: {result.id})")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg=t("users.activate.success"))
return success(data=result_schema, msg="用户激活成功")
@router.get("", response_model=ApiResponse)
def get_current_user_info(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""获取当前用户信息"""
api_logger.info(f"当前用户信息请求: {current_user.username}")
@@ -106,12 +92,12 @@ def get_current_user_info(
result_schema.current_workspace_name = current_workspace.name
for ws in result.workspaces:
if ws.workspace_id == current_user.current_workspace_id and ws.is_active:
if ws.workspace_id == current_user.current_workspace_id:
result_schema.role = ws.role
break
api_logger.info(f"当前用户信息获取成功: {result.username}, 角色: {result_schema.role}, 工作空间: {result_schema.current_workspace_name}")
return success(data=result_schema, msg=t("users.info.get_success"))
return success(data=result_schema, msg="用户信息获取成功")
@router.get("/superusers", response_model=ApiResponse)
@@ -119,7 +105,6 @@ def get_tenant_superusers(
include_inactive: bool = False,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
):
"""获取当前租户下的超管账号列表(仅超级管理员可访问)"""
api_logger.info(f"获取租户超管列表请求: {current_user.username}")
@@ -132,8 +117,7 @@ def get_tenant_superusers(
api_logger.info(f"租户超管列表获取成功: count={len(superusers)}")
superusers_schema = [user_schema.User.model_validate(u) for u in superusers]
return success(data=superusers_schema, msg=t("users.list.superusers_success"))
return success(data=superusers_schema, msg="租户超管列表获取成功")
@router.get("/{user_id}", response_model=ApiResponse)
@@ -141,7 +125,6 @@ def get_user_info_by_id(
user_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""根据用户ID获取用户信息"""
api_logger.info(f"获取用户信息请求: user_id={user_id}, 操作者: {current_user.username}")
@@ -152,7 +135,7 @@ def get_user_info_by_id(
api_logger.info(f"用户信息获取成功: {result.username}")
result_schema = user_schema.User.model_validate(result)
return success(data=result_schema, msg=t("users.info.get_success"))
return success(data=result_schema, msg="用户信息获取成功")
@router.put("/change-password", response_model=ApiResponse)
@@ -160,7 +143,6 @@ async def change_password(
request: ChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""修改当前用户密码"""
api_logger.info(f"用户密码修改请求: {current_user.username}")
@@ -173,7 +155,7 @@ async def change_password(
current_user=current_user
)
api_logger.info(f"用户密码修改成功: {current_user.username}")
return success(msg=t("auth.password.change_success"))
return success(msg="密码修改成功")
@router.put("/admin/change-password", response_model=ApiResponse)
@@ -181,7 +163,6 @@ async def admin_change_password(
request: AdminChangePasswordRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
t: Callable = Depends(get_translator)
):
"""超级管理员修改指定用户的密码"""
api_logger.info(f"管理员密码修改请求: 管理员 {current_user.username} 修改用户 {request.user_id}")
@@ -196,107 +177,7 @@ async def admin_change_password(
# 根据是否生成了随机密码来构造响应
if request.new_password:
api_logger.info(f"管理员密码修改成功: 用户 {request.user_id}")
return success(msg=t("auth.password.change_success"))
return success(msg="密码修改成功")
else:
api_logger.info(f"管理员密码重置成功: 用户 {request.user_id}, 随机密码已生成")
return success(data=generated_password, msg=t("auth.password.reset_success"))
@router.post("/verify_pwd", response_model=ApiResponse)
def verify_pwd(
request: VerifyPasswordRequest,
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""验证当前用户密码"""
api_logger.info(f"用户验证密码请求: {current_user.username}")
is_valid = verify_password(request.password, current_user.hashed_password)
api_logger.info(f"用户密码验证结果: {current_user.username}, valid={is_valid}")
if not is_valid:
raise BusinessException(t("users.errors.password_verification_failed"), code=BizCode.VALIDATION_FAILED)
return success(data={"valid": is_valid}, msg=t("common.success.retrieved"))
@router.post("/send-email-code", response_model=ApiResponse)
async def send_email_code(
request: SendEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""发送邮箱验证码"""
api_logger.info(f"用户请求发送邮箱验证码: {current_user.username}, email={request.email}")
await user_service.send_email_code_method(db=db, email=request.email, user_id=current_user.id)
api_logger.info(f"邮箱验证码已发送: {current_user.username}")
return success(msg=t("users.email.code_sent"))
@router.put("/change-email", response_model=ApiResponse)
async def change_email(
request: VerifyEmailCodeRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""验证验证码并修改邮箱"""
api_logger.info(f"用户修改邮箱: {current_user.username}, new_email={request.new_email}")
await user_service.verify_and_change_email(
db=db,
user_id=current_user.id,
new_email=request.new_email,
code=request.code
)
api_logger.info(f"用户邮箱修改成功: {current_user.username}")
return success(msg=t("users.email.change_success"))
@router.get("/me/language", response_model=ApiResponse)
def get_current_user_language(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""获取当前用户的语言偏好"""
api_logger.info(f"获取用户语言偏好: {current_user.username}")
language = user_service.get_user_language_preference(
db=db,
user_id=current_user.id,
current_user=current_user
)
api_logger.info(f"用户语言偏好获取成功: {current_user.username}, language={language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=language),
msg=t("users.language.get_success")
)
@router.put("/me/language", response_model=ApiResponse)
def update_current_user_language(
request: user_schema.LanguagePreferenceRequest,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: Callable = Depends(get_translator)
):
"""设置当前用户的语言偏好"""
api_logger.info(f"更新用户语言偏好: {current_user.username}, language={request.language}")
updated_user = user_service.update_user_language_preference(
db=db,
user_id=current_user.id,
language=request.language,
current_user=current_user
)
api_logger.info(f"用户语言偏好更新成功: {current_user.username}, language={request.language}")
return success(
data=user_schema.LanguagePreferenceResponse(language=updated_user.preferred_language),
msg=t("users.language.update_success")
)
return success(data=generated_password, msg="密码重置成功")

View File

@@ -5,10 +5,9 @@
from typing import Optional
import datetime
from sqlalchemy.orm import Session
from fastapi import APIRouter, Depends,Header
from fastapi import APIRouter, Depends
from app.db import get_db
from app.core.language_utils import get_language_from_header
from app.core.logging_config import get_api_logger
from app.core.response_utils import success, fail
from app.core.error_codes import BizCode
@@ -17,12 +16,11 @@ from app.services.user_memory_service import (
UserMemoryService,
analytics_memory_types,
analytics_graph_data,
analytics_community_graph_data,
)
from app.services.memory_entity_relationship_service import MemoryEntityService,MemoryEmotion,MemoryInteraction
from app.schemas.response_schema import ApiResponse
from app.schemas.memory_storage_schema import GenerateCacheRequest
from app.repositories.workspace_repository import WorkspaceRepository
from app.schemas.end_user_schema import (
EndUserProfileResponse,
EndUserProfileUpdate,
@@ -74,7 +72,6 @@ async def get_memory_insight_report_api(
@router.get("/analytics/user_summary", response_model=ApiResponse)
async def get_user_summary_api(
end_user_id: str,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -83,26 +80,11 @@ async def get_user_summary_api(
此接口仅查询数据库中已缓存的用户摘要数据,不执行生成操作。
如需生成新的用户摘要,请使用专门的生成接口。
语言控制:
- 使用 X-Language-Type Header 指定语言
- 如果未传 Header默认使用中文 (zh)
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
api_logger.info(f"用户摘要查询请求: end_user_id={end_user_id}, user={current_user.username}")
try:
# 调用服务层获取缓存数据
result = await user_memory_service.get_cached_user_summary(db, end_user_id,model_id,language)
result = await user_memory_service.get_cached_user_summary(db, end_user_id)
if result["is_cached"]:
api_logger.info(f"成功返回缓存的用户摘要: end_user_id={end_user_id}")
@@ -118,7 +100,6 @@ async def get_user_summary_api(
@router.post("/analytics/generate_cache", response_model=ApiResponse)
async def generate_cache_api(
request: GenerateCacheRequest,
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
@@ -127,14 +108,7 @@ async def generate_cache_api(
- 如果提供 end_user_id只为该用户生成
- 如果不提供,为当前工作空间的所有用户生成
语言控制:
- 使用 X-Language-Type Header 指定语言 ("zh" 中文, "en" 英文)
- 如果未传 Header默认使用中文 (zh)
"""
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id = current_user.current_workspace_id
# 检查用户是否已选择工作空间
@@ -142,27 +116,27 @@ async def generate_cache_api(
api_logger.warning(f"用户 {current_user.username} 尝试生成缓存但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
end_user_id = request.end_user_id
group_id = request.end_user_id
api_logger.info(
f"缓存生成请求: user={current_user.username}, workspace={workspace_id}, "
f"end_user_id={end_user_id if end_user_id else '全部用户'}, language={language}"
f"end_user_id={group_id if group_id else '全部用户'}"
)
try:
if end_user_id:
if group_id:
# 为单个用户生成
api_logger.info(f"开始为单个用户生成缓存: end_user_id={end_user_id}")
api_logger.info(f"开始为单个用户生成缓存: end_user_id={group_id}")
# 生成记忆洞察
insight_result = await user_memory_service.generate_and_cache_insight(db, end_user_id, workspace_id, language=language)
insight_result = await user_memory_service.generate_and_cache_insight(db, group_id, workspace_id)
# 生成用户摘要
summary_result = await user_memory_service.generate_and_cache_summary(db, end_user_id, workspace_id, language=language)
summary_result = await user_memory_service.generate_and_cache_summary(db, group_id, workspace_id)
# 构建响应
result = {
"end_user_id": end_user_id,
"end_user_id": group_id,
"insight_success": insight_result["success"],
"summary_success": summary_result["success"],
"errors": []
@@ -182,9 +156,9 @@ async def generate_cache_api(
# 记录结果
if result["insight_success"] and result["summary_success"]:
api_logger.info(f"成功为用户 {end_user_id} 生成缓存")
api_logger.info(f"成功为用户 {group_id} 生成缓存")
else:
api_logger.warning(f"用户 {end_user_id} 的缓存生成部分失败: {result['errors']}")
api_logger.warning(f"用户 {group_id} 的缓存生成部分失败: {result['errors']}")
return success(data=result, msg="生成完成")
@@ -192,7 +166,7 @@ async def generate_cache_api(
# 为整个工作空间生成
api_logger.info(f"开始为工作空间 {workspace_id} 批量生成缓存")
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id, language=language)
result = await user_memory_service.generate_cache_for_workspace(db, workspace_id)
# 记录统计信息
api_logger.info(
@@ -279,6 +253,7 @@ async def get_graph_data_api(
depth=depth,
center_node_id=center_node_id
)
# 检查是否有错误消息
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"图数据查询返回空结果: {result.get('message')}")
@@ -296,42 +271,6 @@ async def get_graph_data_api(
return fail(BizCode.INTERNAL_ERROR, "图数据查询失败", str(e))
@router.get("/analytics/community_graph", response_model=ApiResponse)
async def get_community_graph_data_api(
end_user_id: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询社区图谱但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
api_logger.info(
f"社区图谱查询请求: end_user_id={end_user_id}, user={current_user.username}, "
f"workspace={workspace_id}"
)
try:
result = await analytics_community_graph_data(db=db, end_user_id=end_user_id)
if "message" in result and result["statistics"]["total_nodes"] == 0:
api_logger.warning(f"社区图谱查询返回空结果: {result.get('message')}")
return success(data=result, msg=result.get("message", "查询成功"))
api_logger.info(
f"成功获取社区图谱: end_user_id={end_user_id}, "
f"nodes={result['statistics']['total_nodes']}, "
f"edges={result['statistics']['total_edges']}"
)
return success(data=result, msg="查询成功")
except Exception as e:
api_logger.error(f"社区图谱查询失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "社区图谱查询失败", str(e))
@router.get("/read_end_user/profile", response_model=ApiResponse)
async def get_end_user_profile(
end_user_id: str,
@@ -339,13 +278,7 @@ async def get_end_user_profile(
db: Session = Depends(get_db),
) -> dict:
workspace_id = current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试查询用户信息但未选择工作空间")
@@ -363,6 +296,7 @@ async def get_end_user_profile(
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
@@ -394,11 +328,12 @@ async def update_end_user_profile(
该接口可以更新用户的姓名、职位、部门、联系方式、电话和入职日期等信息。
所有字段都是可选的,只更新提供的字段。
"""
workspace_id = current_user.current_workspace_id
end_user_id = profile_update.end_user_id
# 验证工作空间
# 检查用户是否已选择工作空间
if workspace_id is None:
api_logger.warning(f"用户 {current_user.username} 尝试更新用户信息但未选择工作空间")
return fail(BizCode.INVALID_PARAMETER, "请先切换到一个工作空间", "current_workspace_id is None")
@@ -408,44 +343,65 @@ async def update_end_user_profile(
f"workspace={workspace_id}"
)
# 调用 Service 层处理业务逻辑
result = user_memory_service.update_end_user_profile(db, end_user_id, profile_update)
try:
# 查询终端用户
end_user = db.query(EndUser).filter(EndUser.id == end_user_id).first()
if result["success"]:
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}")
return success(data=result["data"], msg="更新成功")
else:
error_msg = result["error"]
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={error_msg}")
# 根据错误类型映射到合适的业务错误码
if error_msg == "终端用户不存在":
return fail(BizCode.USER_NOT_FOUND, "终端用户不存在", error_msg)
elif error_msg == "无效的用户ID格式":
return fail(BizCode.INVALID_USER_ID, "无效的用户ID格式", error_msg)
else:
# 只有未预期的错误才使用 INTERNAL_ERROR
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", error_msg)
if not end_user:
api_logger.warning(f"终端用户不存在: end_user_id={end_user_id}")
return fail(BizCode.INVALID_PARAMETER, "终端用户不存在", f"end_user_id={end_user_id}")
# 更新字段(只更新提供的字段,排除 end_user_id
# 允许 None 值来重置字段(如 hire_date
update_data = profile_update.model_dump(exclude_unset=True, exclude={'end_user_id'})
# 特殊处理 hire_date如果提供了时间戳转换为 DateTime
if 'hire_date' in update_data:
hire_date_timestamp = update_data['hire_date']
if hire_date_timestamp is not None:
update_data['hire_date'] = timestamp_to_datetime(hire_date_timestamp)
# 如果是 None保持 None允许清空
for field, value in update_data.items():
setattr(end_user, field, value)
# 更新 updated_at 时间戳
end_user.updated_at = datetime.datetime.now()
# 更新 updatetime_profile 为当前时间
end_user.updatetime_profile = datetime.datetime.now()
# 提交更改
db.commit()
db.refresh(end_user)
# 构建响应数据
profile_data = EndUserProfileResponse(
id=end_user.id,
other_name=end_user.other_name,
position=end_user.position,
department=end_user.department,
contact=end_user.contact,
phone=end_user.phone,
hire_date=end_user.hire_date,
updatetime_profile=end_user.updatetime_profile
)
api_logger.info(f"成功更新用户信息: end_user_id={end_user_id}, updated_fields={list(update_data.keys())}")
return success(data=UserMemoryService.convert_profile_to_dict_with_timestamp(profile_data), msg="更新成功")
except Exception as e:
db.rollback()
api_logger.error(f"用户信息更新失败: end_user_id={end_user_id}, error={str(e)}")
return fail(BizCode.INTERNAL_ERROR, "用户信息更新失败", str(e))
@router.get("/memory_space/timeline_memories", response_model=ApiResponse)
async def memory_space_timeline_of_shared_memories(id: str, label: str,language_type: str = Header(default=None, alias="X-Language-Type"),
async def memory_space_timeline_of_shared_memories(id: str, label: str,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db),
):
# 使用集中化的语言校验
language = get_language_from_header(language_type)
workspace_id=current_user.current_workspace_id
workspace_repo = WorkspaceRepository(db)
workspace_models = workspace_repo.get_workspace_models_configs(workspace_id)
if workspace_models:
model_id = workspace_models.get("llm", None)
else:
model_id = None
MemoryEntity = MemoryEntityService(id, label)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server(model_id, language)
timeline_memories_result = await MemoryEntity.get_timeline_memories_server()
return success(data=timeline_memories_result, msg="共同记忆时间线")
@router.get("/memory_space/relationship_evolution", response_model=ApiResponse)
async def memory_space_relationship_evolution(id: str, label: str,

View File

@@ -0,0 +1,610 @@
"""
工作流 API 控制器
"""
import logging
import uuid
from typing import Annotated
from fastapi import APIRouter, Depends, Path, Query
from sqlalchemy.orm import Session
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.user_model import User
from app.models.app_model import App
from app.services.workflow_service import WorkflowService, get_workflow_service
from app.schemas.workflow_schema import (
WorkflowConfigCreate,
WorkflowConfigUpdate,
WorkflowConfig,
WorkflowValidationResponse,
WorkflowExecution,
WorkflowNodeExecution,
WorkflowExecutionRequest,
WorkflowExecutionResponse
)
from app.core.response_utils import success, fail
from app.core.exceptions import BusinessException
from app.core.error_codes import BizCode
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/apps", tags=["workflow"])
# ==================== 工作流配置管理 ====================
@router.post("/{app_id}/workflow")
@cur_workspace_access_guard()
async def create_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
config: WorkflowConfigCreate,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""创建工作流配置
创建或更新应用的工作流配置。配置会进行基础验证,但允许保存不完整的配置(草稿)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 创建工作流配置
workflow_config = service.create_workflow_config(
app_id=app_id,
nodes=[node.model_dump() for node in config.nodes],
edges=[edge.model_dump() for edge in config.edges],
variables=[var.model_dump() for var in config.variables],
execution_config=config.execution_config.model_dump(),
triggers=[trigger.model_dump() for trigger in config.triggers],
validate=True # 进行基础验证
)
return success(
data=WorkflowConfig.model_validate(workflow_config),
msg="工作流配置创建成功"
)
except BusinessException as e:
logger.warning(f"创建工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"创建工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"创建工作流配置失败: {str(e)}"
)
#
# @router.get("/{app_id}/workflow")
# async def get_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)]
#
# ):
# """获取工作流配置
#
# 获取应用的工作流配置详情。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
#
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
#
# # 获取工作流配置
# service = WorkflowService(db)
# workflow_config = service.get_workflow_config(app_id)
#
# if not workflow_config:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="工作流配置不存在"
# )
#
# return success(
# data=WorkflowConfig.model_validate(workflow_config)
# )
#
# except Exception as e:
# logger.error(f"获取工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"获取工作流配置失败: {str(e)}"
# )
# @router.put("/{app_id}/workflow")
# async def update_workflow_config(
# app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
# config: WorkflowConfigUpdate,
# db: Annotated[Session, Depends(get_db)],
# current_user: Annotated[User, Depends(get_current_user)],
# service: Annotated[WorkflowService, Depends(get_workflow_service)]
# ):
# """更新工作流配置
# 更新应用的工作流配置。可以部分更新,未提供的字段保持不变。
# """
# try:
# # 验证应用是否存在且属于当前工作空间
# app = db.query(App).filter(
# App.id == app_id,
# App.workspace_id == current_user.current_workspace_id,
# App.is_active == True
# ).first()
# if not app:
# return fail(
# code=BizCode.NOT_FOUND,
# msg="应用不存在或无权访问"
# )
# # 更新工作流配置
# workflow_config = service.update_workflow_config(
# app_id=app_id,
# nodes=[node.model_dump() for node in config.nodes] if config.nodes else None,
# edges=[edge.model_dump() for edge in config.edges] if config.edges else None,
# variables=[var.model_dump() for var in config.variables] if config.variables else None,
# execution_config=config.execution_config.model_dump() if config.execution_config else None,
# triggers=[trigger.model_dump() for trigger in config.triggers] if config.triggers else None,
# validate=True
# )
# return success(
# data=WorkflowConfig.model_validate(workflow_config),
# msg="工作流配置更新成功"
# )
# except BusinessException as e:
# logger.warning(f"更新工作流配置失败: {e.message}")
# return fail(code=e.error_code, msg=e.message)
# except Exception as e:
# logger.error(f"更新工作流配置异常: {e}", exc_info=True)
# return fail(
# code=BizCode.INTERNAL_ERROR,
# msg=f"更新工作流配置失败: {str(e)}"
# )
@router.delete("/{app_id}/workflow")
async def delete_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""删除工作流配置
删除应用的工作流配置。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 删除工作流配置
deleted = service.delete_workflow_config(app_id)
if not deleted:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
return success(msg="工作流配置删除成功")
except Exception as e:
logger.error(f"删除工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"删除工作流配置失败: {str(e)}"
)
@router.post("/{app_id}/workflow/validate")
async def validate_workflow_config(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
for_publish: Annotated[bool, Query(description="是否为发布验证")] = False
):
"""验证工作流配置
验证工作流配置是否有效。可以选择是否进行发布级别的严格验证。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证工作流配置
if for_publish:
is_valid, errors = service.validate_workflow_config_for_publish(app_id)
else:
workflow_config = service.get_workflow_config(app_id)
if not workflow_config:
return fail(
code=BizCode.NOT_FOUND,
msg="工作流配置不存在"
)
from app.core.workflow.validator import validate_workflow_config as validate_config
config_dict = {
"nodes": workflow_config.nodes,
"edges": workflow_config.edges,
"variables": workflow_config.variables,
"execution_config": workflow_config.execution_config,
"triggers": workflow_config.triggers
}
is_valid, errors = validate_config(config_dict, for_publish=False)
return success(
data=WorkflowValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=[]
)
)
except BusinessException as e:
logger.warning(f"验证工作流配置失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"验证工作流配置异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"验证工作流配置失败: {str(e)}"
)
# ==================== 工作流执行管理 ====================
@router.get("/{app_id}/workflow/executions")
async def get_workflow_executions(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)],
limit: Annotated[int, Query(ge=1, le=100)] = 50,
offset: Annotated[int, Query(ge=0)] = 0
):
"""获取工作流执行记录列表
获取应用的工作流执行历史记录。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 获取执行记录
executions = service.get_executions_by_app(app_id, limit, offset)
# 获取统计信息
statistics = service.get_execution_statistics(app_id)
return success(
data={
"executions": [WorkflowExecution.model_validate(e) for e in executions],
"statistics": statistics,
"pagination": {
"limit": limit,
"offset": offset,
"total": statistics["total"]
}
}
)
except Exception as e:
logger.error(f"获取工作流执行记录异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行记录失败: {str(e)}"
)
@router.get("/workflow/executions/{execution_id}")
async def get_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""获取工作流执行详情
获取单个工作流执行的详细信息,包括所有节点的执行记录。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 获取节点执行记录
node_executions = service.node_execution_repo.get_by_execution_id(execution.id)
return success(
data={
"execution": WorkflowExecution.model_validate(execution),
"node_executions": [
WorkflowNodeExecution.model_validate(ne) for ne in node_executions
]
}
)
except Exception as e:
logger.error(f"获取工作流执行详情异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"获取工作流执行详情失败: {str(e)}"
)
# ==================== 工作流执行 ====================
@router.post("/{app_id}/workflow/run")
async def run_workflow(
app_id: Annotated[uuid.UUID, Path(description="应用 ID")],
request: WorkflowExecutionRequest,
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""执行工作流
执行工作流并返回结果。支持流式和非流式两种模式。
**非流式模式**:等待工作流执行完成后返回完整结果。
**流式模式**:实时返回执行过程中的事件(节点开始、节点完成、工作流完成等)。
"""
try:
# 验证应用是否存在且属于当前工作空间
app = db.query(App).filter(
App.id == app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="应用不存在或无权访问"
)
# 验证应用类型
if app.type != "workflow":
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"应用类型必须为 workflow当前为 {app.type}"
)
# 准备输入数据
input_data = {
"message": request.message or "",
"variables": request.variables
}
# 执行工作流
if request.stream:
# 流式执行
from fastapi.responses import StreamingResponse
import json
async def event_generator():
"""生成 SSE 事件
SSE 格式:
event: <event_type>
data: <json_data>
支持的事件类型:
- workflow_start: 工作流开始
- workflow_end: 工作流结束
- node_start: 节点开始执行
- node_end: 节点执行完成
- node_chunk: 中间节点的流式输出
- message: 最终消息的流式输出End 节点及其相邻节点)
"""
try:
async for event in await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=True
):
# 提取事件类型和数据
event_type = event.get("event", "message")
event_data = event.get("data", {})
# 转换为标准 SSE 格式(字符串)
# event: <type>
# data: <json>
sse_message = f"event: {event_type}\ndata: {json.dumps(event_data)}\n\n"
yield sse_message
except Exception as e:
logger.error(f"流式执行异常: {e}", exc_info=True)
# 发送错误事件
sse_error = f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
yield sse_error
return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no" # 禁用 nginx 缓冲
}
)
else:
# 非流式执行
result = await service.run_workflow(
app_id=app_id,
input_data=input_data,
triggered_by=current_user.id,
conversation_id=uuid.UUID(request.conversation_id) if request.conversation_id else None,
stream=False
)
return success(
data=WorkflowExecutionResponse(
execution_id=result["execution_id"],
status=result["status"],
output=result.get("output"),
output_data=result.get("output_data"),
error_message=result.get("error_message"),
elapsed_time=result.get("elapsed_time"),
token_usage=result.get("token_usage")
),
msg="工作流执行完成"
)
except BusinessException as e:
logger.warning(f"执行工作流失败: {e.message}")
return fail(code=e.error_code, msg=e.message)
except Exception as e:
logger.error(f"执行工作流异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"执行工作流失败: {str(e)}"
)
@router.post("/workflow/executions/{execution_id}/cancel")
async def cancel_workflow_execution(
execution_id: Annotated[str, Path(description="执行 ID")],
db: Annotated[Session, Depends(get_db)],
current_user: Annotated[User, Depends(get_current_user)],
service: Annotated[WorkflowService, Depends(get_workflow_service)]
):
"""取消工作流执行
取消正在运行的工作流执行。
**注意**:当前版本仅更新状态为 cancelled实际的执行取消功能待实现。
"""
try:
# 获取执行记录
execution = service.get_execution(execution_id)
if not execution:
return fail(
code=BizCode.NOT_FOUND,
msg="执行记录不存在"
)
# 验证应用是否属于当前工作空间
app = db.query(App).filter(
App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id,
App.is_active == True
).first()
if not app:
return fail(
code=BizCode.NOT_FOUND,
msg="无权访问该执行记录"
)
# 检查执行状态
if execution.status not in ["pending", "running"]:
return fail(
code=BizCode.INVALID_PARAMETER,
msg=f"无法取消状态为 {execution.status} 的执行"
)
# 更新状态为 cancelled
service.update_execution_status(execution_id, "cancelled")
return success(msg="工作流执行已取消")
except BusinessException as e:
logger.warning(f"取消工作流执行失败: {e.message}")
return fail(code=e.code, msg=e.message)
except Exception as e:
logger.error(f"取消工作流执行异常: {e}", exc_info=True)
return fail(
code=BizCode.INTERNAL_ERROR,
msg=f"取消工作流执行失败: {str(e)}"
)

View File

@@ -1,7 +1,7 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query, status
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from app.core.logging_config import get_api_logger
@@ -14,12 +14,6 @@ from app.dependencies import (
get_current_user,
workspace_access_guard,
)
from app.i18n.dependencies import get_current_language, get_translator
from app.i18n.serializers import (
WorkspaceSerializer,
WorkspaceMemberSerializer,
WorkspaceInviteSerializer
)
from app.models.tenant_model import Tenants
from app.models.user_model import User
from app.models.workspace_model import InviteStatus
@@ -71,9 +65,7 @@ def get_workspaces(
include_current: bool = Query(True, description="是否包含当前工作空间"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
current_tenant: Tenants = Depends(get_current_tenant),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
current_tenant: Tenants = Depends(get_current_tenant)
):
"""获取当前租户下用户参与的所有工作空间
@@ -96,50 +88,25 @@ def get_workspaces(
)
api_logger.info(f"成功获取 {len(workspaces)} 个工作空间")
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
workspaces_data = [WorkspaceResponse.model_validate(w).model_dump() for w in workspaces]
workspaces_i18n = serializer.serialize_list(workspaces_data, language)
return success(data=workspaces_i18n, msg=t("workspace.list_retrieved"))
workspaces_schema = [WorkspaceResponse.model_validate(w) for w in workspaces]
return success(data=workspaces_schema, msg="工作空间列表获取成功")
@router.post("", response_model=ApiResponse)
def create_workspace(
workspace: WorkspaceCreate,
language_type: str = Header(default="zh", alias="X-Language-Type"),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_superuser),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""创建新的工作空间"""
from app.core.language_utils import get_language_from_header
# 验证并获取语言参数
language = get_language_from_header(language_type)
api_logger.info(
f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, "
f"language={language}"
)
api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}")
result = workspace_service.create_workspace(
db=db, workspace=workspace, user=current_user, language=language
)
db=db, workspace=workspace, user=current_user)
api_logger.info(
f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, "
f"创建者: {current_user.username}, language={language}"
)
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.created"))
api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}")
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间创建成功")
@router.put("", response_model=ApiResponse)
@cur_workspace_access_guard()
@@ -147,8 +114,6 @@ def update_workspace(
workspace: WorkspaceUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""更新工作空间"""
workspace_id = current_user.current_workspace_id
@@ -161,21 +126,14 @@ def update_workspace(
user=current_user,
)
api_logger.info(f"工作空间更新成功 - ID: {workspace_id}, 用户: {current_user.username}")
# 使用序列化器添加国际化字段
serializer = WorkspaceSerializer()
result_data = WorkspaceResponse.model_validate(result).model_dump()
result_i18n = serializer.serialize(result_data, language)
return success(data=result_i18n, msg=t("workspace.updated"))
result_schema = WorkspaceResponse.model_validate(result)
return success(data=result_schema, msg="工作空间更新成功")
@router.get("/members", response_model=ApiResponse)
@cur_workspace_access_guard()
def get_cur_workspace_members(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间成员列表(关系序列化)"""
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {current_user.current_workspace_id} 的成员列表")
@@ -186,14 +144,8 @@ def get_cur_workspace_members(
user=current_user,
)
api_logger.info(f"工作空间成员列表获取成功 - ID: {current_user.current_workspace_id}, 数量: {len(members)}")
# 转换为表格项并使用序列化器添加国际化字段
table_items = _convert_members_to_table_items(members)
serializer = WorkspaceMemberSerializer()
members_data = [item.model_dump() for item in table_items]
members_i18n = serializer.serialize_list(members_data, language)
return success(data=members_i18n, msg=t("workspace.members.list_retrieved"))
return success(data=table_items, msg="工作空间成员列表获取成功")
@router.put("/members", response_model=ApiResponse)
@@ -203,7 +155,6 @@ def update_workspace_members(
updates: List[WorkspaceMemberUpdate],
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求更新工作空间 {workspace_id} 的成员角色")
@@ -214,7 +165,7 @@ def update_workspace_members(
user=current_user,
)
api_logger.info(f"工作空间成员角色更新成功 - ID: {workspace_id}, 数量: {len(members)}")
return success(msg=t("workspace.members.role_updated"))
return success(msg="成员角色更新成功")
@router.delete("/members/{member_id}", response_model=ApiResponse)
@@ -223,7 +174,6 @@ def delete_workspace_member(
member_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
workspace_id = current_user.current_workspace_id
api_logger.info(f"用户 {current_user.username} 请求删除工作空间 {workspace_id} 的成员 {member_id}")
@@ -235,7 +185,7 @@ def delete_workspace_member(
user=current_user,
)
api_logger.info(f"工作空间成员删除成功 - ID: {workspace_id}, 成员: {member_id}")
return success(msg=t("workspace.members.deleted"))
return success(msg="成员删除成功")
# 创建空间协作邀请
@@ -245,8 +195,6 @@ def create_workspace_invite(
invite_data: WorkspaceInviteCreate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""创建工作空间邀请"""
workspace_id = current_user.current_workspace_id
@@ -259,12 +207,7 @@ def create_workspace_invite(
user=current_user
)
api_logger.info(f"工作空间邀请创建成功 - 工作空间: {workspace_id}, 邮箱: {invite_data.email}")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.created"))
return success(data=result, msg="邀请创建成功")
@router.get("/invites", response_model=ApiResponse)
@@ -276,8 +219,6 @@ def get_workspace_invites(
offset: int = Query(0, ge=0),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间邀请列表"""
workspace_id = current_user.current_workspace_id
@@ -292,30 +233,18 @@ def get_workspace_invites(
offset=offset
)
api_logger.info(f"成功获取 {len(invites)} 个邀请记录")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
invites_i18n = serializer.serialize_list(invites, language)
return success(data=invites_i18n, msg=t("workspace.invites.list_retrieved"))
return success(data=invites, msg="邀请列表获取成功")
@public_router.get("/invites/validate/{token}", response_model=ApiResponse)
def get_workspace_invite_info(
token: str,
db: Session = Depends(get_db),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取工作空间邀请用户信息(无需认证)"""
result = workspace_service.validate_invite_token(db=db, token=token)
api_logger.info(f"工作空间邀请验证成功 - 邀请: {token}")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.validated"))
return success(data=result, msg="邀请验证成功")
@router.delete("/invites/{invite_id}", response_model=ApiResponse)
@@ -325,8 +254,6 @@ def revoke_workspace_invite(
invite_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""撤销工作空间邀请"""
workspace_id = current_user.current_workspace_id
@@ -339,12 +266,7 @@ def revoke_workspace_invite(
user=current_user
)
api_logger.info(f"工作空间邀请撤销成功 - 邀请: {invite_id}")
# 使用序列化器添加国际化字段
serializer = WorkspaceInviteSerializer()
result_i18n = serializer.serialize(result, language)
return success(data=result_i18n, msg=t("workspace.invites.revoked"))
return success(data=result, msg="邀请撤销成功")
# ==================== 公开邀请接口(无需认证) ====================
@@ -367,7 +289,6 @@ def switch_workspace(
workspace_id: uuid.UUID,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""切换工作空间"""
api_logger.info(f"用户 {current_user.username} 请求切换工作空间为 {workspace_id}")
@@ -378,7 +299,7 @@ def switch_workspace(
user=current_user,
)
api_logger.info(f"成功切换工作空间为 {workspace_id}")
return success(msg=t("workspace.switched"))
return success(msg="工作空间切换成功")
@router.get("/storage", response_model=ApiResponse)
@@ -386,7 +307,6 @@ def switch_workspace(
def get_workspace_storage_type(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""获取当前工作空间的存储类型"""
workspace_id = current_user.current_workspace_id
@@ -398,7 +318,7 @@ def get_workspace_storage_type(
user=current_user
)
api_logger.info(f"成功获取工作空间 {workspace_id} 的存储类型: {storage_type}")
return success(data={"storage_type": storage_type}, msg=t("workspace.storage.type_retrieved"))
return success(data={"storage_type": storage_type}, msg="存储类型获取成功")
@router.get("/workspace_models", response_model=ApiResponse)
@@ -406,8 +326,6 @@ def get_workspace_storage_type(
def workspace_models_configs(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
language: str = Depends(get_current_language),
t: callable = Depends(get_translator)
):
"""获取当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
@@ -423,14 +341,14 @@ def workspace_models_configs(
api_logger.warning(f"工作空间 {workspace_id} 不存在或无权访问")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=t("workspace.not_found")
detail="工作空间不存在或无权访问"
)
api_logger.info(
f"成功获取工作空间 {workspace_id} 的模型配置: "
f"llm={configs.get('llm')}, embedding={configs.get('embedding')}, rerank={configs.get('rerank')}"
)
return success(data=WorkspaceModelsConfig.model_validate(configs), msg=t("workspace.models.config_retrieved"))
return success(data=WorkspaceModelsConfig.model_validate(configs), msg="模型配置获取成功")
@router.put("/workspace_models", response_model=ApiResponse)
@@ -439,7 +357,6 @@ def update_workspace_models_configs(
models_update: WorkspaceModelsUpdate,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_user),
t: callable = Depends(get_translator)
):
"""更新当前工作空间的模型配置llm, embedding, rerank"""
workspace_id = current_user.current_workspace_id
@@ -456,5 +373,5 @@ def update_workspace_models_configs(
f"成功更新工作空间 {workspace_id} 的模型配置: "
f"llm={updated_workspace.llm}, embedding={updated_workspace.embedding}, rerank={updated_workspace.rerank}"
)
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg=t("workspace.models.config_updated"))
return success(data=WorkspaceModelsConfig.model_validate(updated_workspace), msg="模型配置更新成功")

View File

@@ -1,4 +0,0 @@
# -*- coding: UTF-8 -*-
# Author: Eternity
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 16:24

View File

@@ -1,162 +0,0 @@
"""Agent Middleware - 动态技能过滤"""
import uuid
from typing import List, Dict, Any, Optional
from langchain_core.runnables import RunnablePassthrough
from app.services.skill_service import SkillService
from app.repositories.skill_repository import SkillRepository
class AgentMiddleware:
"""Agent 中间件 - 用于动态过滤和加载技能"""
def __init__(self, skills: Optional[dict] = None):
"""
初始化中间件
Args:
skills: 技能配置字典 {"enabled": bool, "all_skills": bool, "skill_ids": [...]}
"""
self.skills = skills or {}
self.enabled = self.skills.get('enabled', False)
self.all_skills = self.skills.get('all_skills', False)
self.skill_ids = self.skills.get('skill_ids', [])
@staticmethod
def filter_tools(
tools: List,
message: str = "",
skill_configs: Dict[str, Any] = None,
tool_to_skill_map: Dict[str, str] = None
) -> tuple[List, List[str]]:
"""
根据消息内容和技能配置动态过滤工具
Args:
tools: 所有可用工具列表
message: 用户消息(可用于智能过滤)
skill_configs: 技能配置字典 {skill_id: {"keywords": [...], "enabled": True, "prompt": "..."}}
tool_to_skill_map: 工具到技能的映射 {tool_name: skill_id}
Returns:
(过滤后的工具列表, 激活的技能ID列表)
"""
if not tools:
return [], []
# 如果没有技能配置,返回所有工具
if not skill_configs:
return tools, []
# 基于关键词匹配激活技能
activated_skill_ids = []
message_lower = message.lower()
for skill_id, config in skill_configs.items():
if not config.get('enabled', True):
continue
keywords = config.get('keywords', [])
# 如果没有关键词限制,或消息包含关键词,则激活该技能
if not keywords or any(kw.lower() in message_lower for kw in keywords):
activated_skill_ids.append(skill_id)
# 如果没有工具映射关系,返回所有工具
if not tool_to_skill_map:
return tools, activated_skill_ids
# 根据激活的技能过滤工具
filtered_tools = []
for tool in tools:
tool_name = getattr(tool, 'name', str(id(tool)))
# 如果工具不属于任何skillbase_tools或者工具所属的skill被激活则保留
if tool_name not in tool_to_skill_map or tool_to_skill_map[tool_name] in activated_skill_ids:
filtered_tools.append(tool)
return filtered_tools, activated_skill_ids
def load_skill_tools(self, db, tenant_id: uuid.UUID, base_tools: List = None) -> tuple[List, Dict[str, Any], Dict[str, str]]:
"""
加载技能关联的工具
Args:
db: 数据库会话
tenant_id: 租户id
base_tools: 基础工具列表
Returns:
(工具列表, 技能配置字典, 工具到技能的映射 {tool_name: skill_id})
"""
tools_dict = {}
tool_to_skill_map = {} # 工具名称到技能ID的映射
if base_tools:
for tool in base_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
tools_dict[tool_name] = tool
# base_tools 不属于任何 skill不加入映射
skill_configs = {}
skill_ids_to_load = []
# 如果启用技能且 all_skills 为 True加载租户下所有激活的技能
if self.enabled and self.all_skills:
skills, _ = SkillRepository.list_skills(db, tenant_id, is_active=True, page=1, pagesize=1000)
skill_ids_to_load = [str(skill.id) for skill in skills]
elif self.enabled and self.skill_ids:
skill_ids_to_load = self.skill_ids
if skill_ids_to_load:
for skill_id in skill_ids_to_load:
try:
skill = SkillRepository.get_by_id(db, uuid.UUID(skill_id), tenant_id)
if skill and skill.is_active:
# 保存技能配置包含prompt
config = skill.config or {}
config['prompt'] = skill.prompt
config['name'] = skill.name
skill_configs[skill_id] = config
except Exception:
continue
# 加载技能工具并获取映射关系
skill_tools, skill_tool_map = SkillService.load_skill_tools(db, skill_ids_to_load, tenant_id)
# 只添加不冲突的 skill_tools
for tool in skill_tools:
tool_name = getattr(tool, 'name', str(id(tool)))
if tool_name not in tools_dict:
tools_dict[tool_name] = tool
# 复制映射关系
if tool_name in skill_tool_map:
tool_to_skill_map[tool_name] = skill_tool_map[tool_name]
return list(tools_dict.values()), skill_configs, tool_to_skill_map
@staticmethod
def get_active_prompts(activated_skill_ids: List[str], skill_configs: Dict[str, Any]) -> str:
"""
根据激活的技能ID获取对应的提示词
Args:
activated_skill_ids: 被激活的技能ID列表
skill_configs: 技能配置字典
Returns:
合并后的提示词
"""
prompts = []
for skill_id in activated_skill_ids:
config = skill_configs.get(skill_id, {})
prompt = config.get('prompt')
name = config.get('name', 'Skill')
if prompt:
prompts.append(f"# {name}\n{prompt}")
return "\n\n".join(prompts) if prompts else ""
@staticmethod
def create_runnable():
"""创建可运行的中间件"""
return RunnablePassthrough()

View File

@@ -7,18 +7,23 @@ LangChain Agent 封装
- 支持流式输出
- 使用 RedBearLLM 支持多提供商
"""
import os
import time
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence
from app.core.memory.agent.langgraph_graph.write_graph import write_long_term
from app.db import get_db
from app.core.logging_config import get_business_logger
from app.core.memory.agent.utils.redis_tool import store
from app.core.models import RedBearLLM, RedBearModelConfig
from app.models.models_model import ModelType, ModelProvider
from app.models.models_model import ModelType
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from langchain.agents import create_agent
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.tools import BaseTool
@@ -29,19 +34,16 @@ logger = get_business_logger()
class LangChainAgent:
def __init__(
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
is_omni: bool = False,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False,
max_iterations: Optional[int] = None, # 最大迭代次数None 表示自动计算)
max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数
self,
model_name: str,
api_key: str,
provider: str = "openai",
api_base: Optional[str] = None,
temperature: float = 0.7,
max_tokens: int = 2000,
system_prompt: Optional[str] = None,
tools: Optional[Sequence[BaseTool]] = None,
streaming: bool = False
):
"""初始化 LangChain Agent
@@ -54,37 +56,13 @@ class LangChainAgent:
max_tokens: 最大 token 数
system_prompt: 系统提示词
tools: 工具列表(可选,框架自动走 ReAct 循环)
streaming: 是否启用流式输出
max_iterations: 最大迭代次数None 表示自动计算:基础 5 次 + 每个工具 2 次)
max_tool_consecutive_calls: 单个工具最大连续调用次数(默认 3 次)
streaming: 是否启用流式输出(默认 True
"""
self.model_name = model_name
self.provider = provider
self.system_prompt = system_prompt or "你是一个专业的AI助手"
self.tools = tools or []
self.streaming = streaming
self.is_omni = is_omni
self.max_tool_consecutive_calls = max_tool_consecutive_calls
# 工具调用计数器:记录每个工具的连续调用次数
self.tool_call_counter: Dict[str, int] = {}
self.last_tool_called: Optional[str] = None
# 根据工具数量动态调整最大迭代次数
# 基础值 + 每个工具额外的调用机会
if max_iterations is None:
# 自动计算:基础 5 次 + 每个工具 2 次额外机会
self.max_iterations = 5 + len(self.tools) * 2
else:
self.max_iterations = max_iterations
self.system_prompt = system_prompt or "你是一个专业的AI助手"
logger.debug(
f"Agent 迭代次数配置: max_iterations={self.max_iterations}, "
f"tool_count={len(self.tools)}, "
f"max_tool_consecutive_calls={self.max_tool_consecutive_calls}, "
f"auto_calculated={max_iterations is None}"
)
# 创建 RedBearLLM支持多提供商
model_config = RedBearModelConfig(
@@ -92,7 +70,6 @@ class LangChainAgent:
provider=provider,
api_key=api_key,
base_url=api_base,
is_omni=is_omni,
extra_params={
"temperature": temperature,
"max_tokens": max_tokens,
@@ -109,14 +86,11 @@ class LangChainAgent:
if streaming and hasattr(self._underlying_llm, 'streaming'):
self._underlying_llm.streaming = True
# 包装工具以跟踪连续调用次数
wrapped_tools = self._wrap_tools_with_tracking(self.tools) if self.tools else None
# 使用 create_agent 创建 agent graphLangChain 1.x 标准方式)
# 无论是否有工具,都使用 agent 统一处理
self.agent = create_agent(
model=self.llm,
tools=wrapped_tools,
tools=self.tools if self.tools else None,
system_prompt=self.system_prompt
)
@@ -128,92 +102,17 @@ class LangChainAgent:
"has_api_base": bool(api_base),
"temperature": temperature,
"streaming": streaming,
"max_iterations": self.max_iterations,
"max_tool_consecutive_calls": self.max_tool_consecutive_calls,
"tool_count": len(self.tools),
"tool_names": [tool.name for tool in self.tools] if self.tools else [],
# "tool_count": len(self.tools)
"tool_count": len(self.tools)
}
)
def _wrap_tools_with_tracking(self, tools: Sequence[BaseTool]) -> List[BaseTool]:
"""包装工具以跟踪连续调用次数
Args:
tools: 原始工具列表
Returns:
List[BaseTool]: 包装后的工具列表
"""
from langchain_core.tools import StructuredTool
from functools import wraps
wrapped_tools = []
for original_tool in tools:
tool_name = original_tool.name
original_func = original_tool.func if hasattr(original_tool, 'func') else None
if not original_func:
# 如果无法获取原始函数,直接使用原工具
wrapped_tools.append(original_tool)
continue
# 创建包装函数
def make_wrapped_func(tool_name, original_func):
"""创建包装函数的工厂函数,避免闭包问题"""
@wraps(original_func)
def wrapped_func(*args, **kwargs):
"""包装后的工具函数,跟踪连续调用次数"""
# 检查是否是连续调用同一个工具
if self.last_tool_called == tool_name:
self.tool_call_counter[tool_name] = self.tool_call_counter.get(tool_name, 0) + 1
else:
# 切换到新工具,重置计数器
self.tool_call_counter[tool_name] = 1
self.last_tool_called = tool_name
current_count = self.tool_call_counter[tool_name]
logger.debug(
f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}"
)
# 检查是否超过最大连续调用次数
if current_count > self.max_tool_consecutive_calls:
logger.warning(
f"工具 '{tool_name}' 连续调用次数已达上限 ({self.max_tool_consecutive_calls})"
f"返回提示信息"
)
return (
f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次,"
f"未找到有效结果。请尝试其他方法或直接回答用户的问题。"
)
# 调用原始工具函数
return original_func(*args, **kwargs)
return wrapped_func
# 使用 StructuredTool 创建新工具
wrapped_tool = StructuredTool(
name=original_tool.name,
description=original_tool.description,
func=make_wrapped_func(tool_name, original_func),
args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None
)
wrapped_tools.append(wrapped_tool)
return wrapped_tools
def _prepare_messages(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
files: Optional[List[Dict[str, Any]]] = None
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None
) -> List[BaseMessage]:
"""准备消息列表
@@ -221,7 +120,6 @@ class LangChainAgent:
message: 用户消息
history: 历史消息列表
context: 上下文信息
files: 多模态文件内容列表(已处理)
Returns:
List[BaseMessage]: 消息列表
@@ -244,49 +142,47 @@ class LangChainAgent:
if context:
user_content = f"参考信息:\n{context}\n\n用户问题:\n{user_content}"
# 构建用户消息(支持多模态)
if files and len(files) > 0:
content_parts = self._build_multimodal_content(user_content, files)
messages.append(HumanMessage(content=content_parts))
else:
# 纯文本消息
messages.append(HumanMessage(content=user_content))
messages.append(HumanMessage(content=user_content))
return messages
async def term_memory_save(self,messages,end_user_end,aimessages):
'''短长期存储redis为不影响正常使用6句一段话存储用户名加一个前缀当数据存够6条返回给neo4j'''
end_user_end=f"Term_{end_user_end}"
print(messages)
print(aimessages)
session_id = store.save_session(
userid=end_user_end,
messages=messages,
apply_id=end_user_end,
group_id=end_user_end,
aimessages=aimessages
)
store.delete_duplicate_sessions()
# logger.info(f'Redis_Agent:{end_user_end};{session_id}')
return session_id
async def term_memory_redis_read(self,end_user_end):
end_user_end = f"Term_{end_user_end}"
history = store.find_user_apply_group(end_user_end, end_user_end, end_user_end)
# logger.info(f'Redis_Agent:{end_user_end};{history}')
messagss_list=[]
retrieved_content=[]
for messages in history:
query = messages.get("Query")
aimessages = messages.get("Answer")
messagss_list.append(f'用户:{query}。AI回复:{aimessages}')
retrieved_content.append({query: aimessages})
return messagss_list,retrieved_content
def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
构建多模态消息内容
Args:
text: 文本内容
files: 文件列表(已由 MultimodalService 处理为对应 provider 的格式)
Returns:
List[Dict]: 消息内容列表
"""
# 根据 provider 使用不同的文本格式
# if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE,
# ModelProvider.GPUSTACK] or (
# self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)):
# # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."}
# content_parts = [{"type": "text", "text": text}]
# else:
# # 通义千问等: {"text": "..."}
# content_parts = [{"type": "text", "text": text}]
content_parts = [{"type": "text", "text": text}]
# 添加文件内容
# MultimodalService 已经根据 provider 返回了正确格式,直接使用
content_parts.extend(files)
logger.debug(
f"构建多模态消息: provider={self.provider}, "
f"parts={len(content_parts)}, "
f"files={len(files)}"
)
return content_parts
async def write(self,storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,content,actual_config_id):
if storage_type == "rag":
await write_rag(end_user_id, message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
else:
write_id = write_message_task.delay(actual_end_user_id, content, actual_config_id, storage_type,
user_rag_memory_id)
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'Agent:{actual_end_user_id};{write_status}')
async def chat(
self,
@@ -297,8 +193,7 @@ class LangChainAgent:
config_id: Optional[str] = None, # 添加这个参数
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
memory_flag: Optional[bool] = True
) -> Dict[str, Any]:
"""执行对话
@@ -310,7 +205,7 @@ class LangChainAgent:
Returns:
Dict: 包含 content 和元数据的字典
"""
message_chat = message
message_chat= message
start_time = time.time()
actual_config_id = config_id
# If config_id is None, try to get from end_user's connected config
@@ -330,11 +225,34 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
actual_end_user_id = end_user_id if end_user_id is not None else "unknown"
logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}')
logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}')
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
db_for_memory = next(get_db())
if memory_flag:
if len(history_term_memory)>=4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
print(retrieved_content)
# 为长期记忆操作获取新的数据库连接
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
except Exception as e:
logger.error(f"Failed to write to LongTermMemory: {e}")
raise
finally:
db_for_memory.close()
await self.write(storage_type,end_user_id,history_term_memory,user_rag_memory_id,actual_end_user_id,history_term_memory,actual_config_id)
await self.write(storage_type,end_user_id,message,user_rag_memory_id,actual_end_user_id,message,actual_config_id)
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
"准备调用 LangChain Agent",
@@ -342,86 +260,25 @@ class LangChainAgent:
"has_context": bool(context),
"has_history": bool(history),
"has_tools": bool(self.tools),
"has_files": bool(files),
"message_count": len(messages),
"max_iterations": self.max_iterations
"message_count": len(messages)
}
)
# 统一使用 agent.invoke 调用
# 通过 recursion_limit 限制最大迭代次数,防止工具调用死循环
try:
result = await self.agent.ainvoke(
{"messages": messages},
config={"recursion_limit": self.max_iterations}
)
except RecursionError as e:
logger.warning(
f"Agent 达到最大迭代次数限制 ({self.max_iterations}),可能存在工具调用循环",
extra={"error": str(e)}
)
# 返回一个友好的错误提示
return {
"content": f"抱歉,我在处理您的请求时遇到了问题。已达到最大处理步骤限制({self.max_iterations}次)。请尝试简化您的问题或稍后再试。",
"model": self.model_name,
"elapsed_time": time.time() - start_time,
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
}
result = await self.agent.ainvoke({"messages": messages})
# 获取最后的 AI 消息
output_messages = result.get("messages", [])
content = ""
logger.debug(f"输出消息数量: {len(output_messages)}")
total_tokens = 0
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
logger.debug(f"找到 AI 消息content 类型: {type(msg.content)}")
logger.debug(f"AI 消息内容: {msg.content}")
# 处理多模态响应content 可能是字符串或列表
if isinstance(msg.content, str):
content = msg.content
logger.debug(f"提取字符串内容,长度: {len(content)}")
elif isinstance(msg.content, list):
# 多模态响应:提取文本部分
logger.debug(f"多模态响应,列表长度: {len(msg.content)}")
text_parts = []
for item in msg.content:
logger.debug(f"处理项: {item}")
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
text_parts.append(text)
logger.debug(f"提取文本: {text[:100]}...")
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
text_parts.append(text)
logger.debug(f"提取文本: {text[:100]}...")
elif isinstance(item, str):
text_parts.append(item)
logger.debug(f"提取字符串: {item[:100]}...")
content = "".join(text_parts)
logger.debug(f"合并后内容长度: {len(content)}")
else:
content = str(msg.content)
logger.debug(f"转换为字符串: {content[:100]}...")
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0
content = msg.content
break
logger.info(f"最终提取的内容长度: {len(content)}")
elapsed_time = time.time() - start_time
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id,
actual_config_id)
await self.write(storage_type,end_user_id,content,user_rag_memory_id,actual_end_user_id,content,actual_config_id)
await self.term_memory_save(message_chat,end_user_id,content)
response = {
"content": content,
"model": self.model_name,
@@ -429,7 +286,7 @@ class LangChainAgent:
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": total_tokens
"total_tokens": 0
}
}
@@ -448,16 +305,15 @@ class LangChainAgent:
raise
async def chat_stream(
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id: Optional[str] = None,
config_id: Optional[str] = None,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None,
memory_flag: Optional[bool] = True,
files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件
self,
message: str,
history: Optional[List[Dict[str, str]]] = None,
context: Optional[str] = None,
end_user_id:Optional[str] = None,
config_id: Optional[str] = None,
storage_type:Optional[str] = None,
user_rag_memory_id:Optional[str] = None,
memory_flag: Optional[bool] = True
) -> AsyncGenerator[str, None]:
"""执行流式对话
@@ -491,13 +347,32 @@ class LangChainAgent:
except Exception as e:
logger.warning(f"Failed to get db session: {e}")
# 注意:不在这里写入用户消息,等 AI 回复后一起写入
history_term_memory_result = await self.term_memory_redis_read(end_user_id)
history_term_memory = history_term_memory_result[0]
if memory_flag:
if len(history_term_memory) >= 4 and storage_type != "rag":
history_term_memory = ';'.join(history_term_memory)
retrieved_content = history_term_memory_result[1]
db_for_memory = next(get_db())
try:
repo = LongTermMemoryRepository(db_for_memory)
repo.upsert(end_user_id, retrieved_content)
logger.info(
f'写入短长期:{storage_type, str(end_user_id), history_term_memory, str(user_rag_memory_id)}')
await self.write(storage_type, end_user_id, history_term_memory, user_rag_memory_id, end_user_id,
history_term_memory, actual_config_id)
except Exception as e:
logger.error(f"Failed to write to long term memory: {e}")
finally:
db_for_memory.close()
await self.write(storage_type, end_user_id, message, user_rag_memory_id, end_user_id, message, actual_config_id)
try:
# 准备消息列表(支持多模态)
messages = self._prepare_messages(message, history, context, files)
# 准备消息列表
messages = self._prepare_messages(message, history, context)
logger.debug(
f"准备流式调用has_tools={bool(self.tools)}, has_files={bool(files)}, message_count={len(messages)}"
f"准备流式调用has_tools={bool(self.tools)}, message_count={len(messages)}"
)
chunk_count = 0
@@ -505,106 +380,47 @@ class LangChainAgent:
# 统一使用 agent 的 astream_events 实现流式输出
logger.debug("使用 Agent astream_events 实现流式输出")
full_content = ''
full_content=''
try:
async for event in self.agent.astream_events(
{"messages": messages},
version="v2",
config={"recursion_limit": self.max_iterations}
{"messages": messages},
version="v2"
):
chunk_count += 1
kind = event.get("event")
# 处理所有可能的流式事件
if kind == "on_chat_model_stream":
# LLM 流式输出
chunk = event.get("data", {}).get("chunk")
if chunk and hasattr(chunk, "content"):
# 处理多模态响应content 可能是字符串或列表
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
full_content+=chunk.content
if chunk and hasattr(chunk, "content") and chunk.content:
yield chunk.content
yielded_content = True
elif kind == "on_llm_stream":
# 另一种 LLM 流式事件
chunk = event.get("data", {}).get("chunk")
if chunk:
if hasattr(chunk, "content"):
chunk_content = chunk.content
if isinstance(chunk_content, str) and chunk_content:
full_content += chunk_content
yield chunk_content
yielded_content = True
elif isinstance(chunk_content, list):
# 多模态响应:提取文本部分
for item in chunk_content:
if isinstance(item, dict):
# 通义千问格式: {"text": "..."}
if "text" in item:
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
# OpenAI 格式: {"type": "text", "text": "..."}
elif item.get("type") == "text":
text = item.get("text", "")
if text:
full_content += text
yield text
yielded_content = True
elif isinstance(item, str):
full_content += item
yield item
yielded_content = True
if hasattr(chunk, "content") and chunk.content:
full_content+=chunk.content
yield chunk.content
yielded_content = True
elif isinstance(chunk, str):
full_content += chunk
yield chunk
yielded_content = True
# 记录工具调用(可选)
elif kind == "on_tool_start":
logger.debug(f"工具调用开始: {event.get('name')}")
elif kind == "on_tool_end":
logger.debug(f"工具调用结束: {event.get('name')}")
logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件")
# 统计token消耗
output_messages = event.get("data", {}).get("output", {}).get("messages", [])
for msg in reversed(output_messages):
if isinstance(msg, AIMessage):
response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None
total_tokens = response_meta.get("token_usage", {}).get("total_tokens",
0) if response_meta else 0
yield total_tokens
break
if memory_flag:
await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id,
actual_config_id)
await self.write(storage_type, end_user_id,full_content, user_rag_memory_id, end_user_id,full_content, actual_config_id)
await self.term_memory_save(message_chat, end_user_id, full_content)
except Exception as e:
logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True)
raise
@@ -618,3 +434,5 @@ class LangChainAgent:
logger.info("=" * 80)
logger.info("chat_stream 方法执行结束")
logger.info("=" * 80)

View File

@@ -1,33 +1,14 @@
import json
import os
from pathlib import Path
from typing import Annotated, Optional
from typing import Any, Dict, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
load_dotenv()
class Settings:
# ========================================================================
# Deployment Mode Configuration
# ========================================================================
# community: 社区版(开源,功能受限)
# cloud: SaaS 云服务版(全功能,按量计费)
# enterprise: 企业私有化版License 控制)
DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community")
# License 配置(企业版)
LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json")
LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com")
# 计费服务配置SaaS 版)
BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "")
# 基础 URL用于 SSO 回调等)
BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000")
FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000")
ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true"
# API Keys Configuration
OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "")
@@ -91,30 +72,9 @@ class Settings:
# Single Sign-On configuration
ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true"
# SSO 免登配置
SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300"))
SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}")
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))
# Storage Configuration
STORAGE_TYPE: str = os.getenv("STORAGE_TYPE", "local")
# Aliyun OSS Configuration
OSS_ENDPOINT: str = os.getenv("OSS_ENDPOINT", "")
OSS_ACCESS_KEY_ID: str = os.getenv("OSS_ACCESS_KEY_ID", "")
OSS_ACCESS_KEY_SECRET: str = os.getenv("OSS_ACCESS_KEY_SECRET", "")
OSS_BUCKET_NAME: str = os.getenv("OSS_BUCKET_NAME", "")
# AWS S3 Configuration
S3_REGION: str = os.getenv("S3_REGION", "")
S3_ACCESS_KEY_ID: str = os.getenv("S3_ACCESS_KEY_ID", "")
S3_SECRET_ACCESS_KEY: str = os.getenv("S3_SECRET_ACCESS_KEY", "")
S3_BUCKET_NAME: str = os.getenv("S3_BUCKET_NAME", "")
S3_ENDPOINT_URL: str = os.getenv("S3_ENDPOINT_URL", "")
# VOLC ASR settings
VOLC_APP_KEY: str = os.getenv("VOLC_APP_KEY", "")
@@ -130,7 +90,6 @@ class Settings:
# Server Configuration
SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1")
FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api")
# ========================================================================
# Internal Configuration (not in .env, used by application code)
@@ -157,49 +116,6 @@ class Settings:
if origin.strip()
]
# Language Configuration
# Supported values: "zh" (Chinese), "en" (English)
# This controls the language used for memory summary titles and other generated content
DEFAULT_LANGUAGE: str = os.getenv("DEFAULT_LANGUAGE", "zh")
# ========================================================================
# Internationalization (i18n) Configuration
# ========================================================================
# Default language for API responses
I18N_DEFAULT_LANGUAGE: str = os.getenv("I18N_DEFAULT_LANGUAGE", "zh")
# Supported languages (comma-separated)
I18N_SUPPORTED_LANGUAGES: list[str] = [
lang.strip()
for lang in os.getenv("I18N_SUPPORTED_LANGUAGES", "zh,en").split(",")
if lang.strip()
]
# Core locales directory (community edition)
# Use absolute path to work from any working directory
I18N_CORE_LOCALES_DIR: str = os.getenv(
"I18N_CORE_LOCALES_DIR",
os.path.join(os.path.dirname(os.path.dirname(__file__)), "locales")
)
# Premium locales directory (enterprise edition, optional)
I18N_PREMIUM_LOCALES_DIR: Optional[str] = os.getenv("I18N_PREMIUM_LOCALES_DIR", None)
# Enable translation cache
I18N_ENABLE_TRANSLATION_CACHE: bool = os.getenv("I18N_ENABLE_TRANSLATION_CACHE", "true").lower() == "true"
# LRU cache size for hot translations
I18N_LRU_CACHE_SIZE: int = int(os.getenv("I18N_LRU_CACHE_SIZE", "1000"))
# Enable hot reload of translation files
I18N_ENABLE_HOT_RELOAD: bool = os.getenv("I18N_ENABLE_HOT_RELOAD", "false").lower() == "true"
# Fallback language when translation is missing
I18N_FALLBACK_LANGUAGE: str = os.getenv("I18N_FALLBACK_LANGUAGE", "zh")
# Log missing translations
I18N_LOG_MISSING_TRANSLATIONS: bool = os.getenv("I18N_LOG_MISSING_TRANSLATIONS", "true").lower() == "true"
# Logging settings
LOG_LEVEL: str = os.getenv("LOG_LEVEL", "INFO")
LOG_FORMAT: str = os.getenv("LOG_FORMAT", "%(asctime)s - %(name)s - %(levelname)s - %(message)s")
@@ -228,45 +144,18 @@ class Settings:
LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB
# Celery configuration (internal)
# NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持
# 详见 docs/celery-env-bug-report.md
# 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离
# 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰
REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3"))
REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4"))
# SMTP Email Configuration
SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com")
SMTP_PORT: int = int(os.getenv("SMTP_PORT", "587"))
SMTP_USER: str = os.getenv("SMTP_USER", "")
SMTP_PASSWORD: str = os.getenv("SMTP_PASSWORD", "")
CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1"))
CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2"))
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
DEFAULT_WORKSPACE_ID: Optional[str] = os.getenv("DEFAULT_WORKSPACE_ID", None)
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Celery Beat Schedule Configuration (定时任务执行频率)
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2"))
# implicit_emotions_update: 每天几分执行分钟0-59
IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0"))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
@@ -277,37 +166,11 @@ class Settings:
ENABLE_TOOL_MANAGEMENT: bool = os.getenv("ENABLE_TOOL_MANAGEMENT", "true").lower() == "true"
# official environment system version
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.1")
# model square loading
LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true"
SYSTEM_VERSION: str = os.getenv("SYSTEM_VERSION", "v0.2.0")
# workflow config
WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800))
WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600))
# ========================================================================
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
# Prompt 中最大类型数量
MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50"))
# 核心通用类型列表(逗号分隔)
CORE_GENERAL_TYPES: str = os.getenv(
"CORE_GENERAL_TYPES",
"Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building,"
"Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject"
)
# 实验模式开关(允许通过 API 动态切换本体配置)
ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true"
def get_memory_output_path(self, filename: str = "") -> str:
"""
Get the full path for memory module output files.

View File

@@ -46,7 +46,6 @@ class BizCode(IntEnum):
RESOURCE_ALREADY_EXISTS = 5002
VERSION_ALREADY_EXISTS = 5003
STATE_CONFLICT = 5004
RESOURCE_IN_USE = 5005
# 应用发布6xxx
PUBLISH_FAILED = 6001
@@ -126,7 +125,6 @@ HTTP_MAPPING = {
BizCode.RESOURCE_ALREADY_EXISTS: 409,
BizCode.VERSION_ALREADY_EXISTS: 409,
BizCode.STATE_CONFLICT: 409,
BizCode.RESOURCE_IN_USE: 409,
BizCode.PUBLISH_FAILED: 500,
BizCode.NO_DRAFT_TO_PUBLISH: 400,
BizCode.ROLLBACK_TARGET_NOT_FOUND: 400,

View File

@@ -1,82 +0,0 @@
# -*- coding: utf-8 -*-
"""语言处理工具模块
本模块提供集中化的语言校验和处理功能,确保整个应用中语言参数的一致性。
Functions:
validate_language: 校验语言参数,确保其为有效值
get_language_from_header: 从请求头获取并校验语言参数
"""
from typing import Optional
from app.core.logging_config import get_logger
logger = get_logger(__name__)
# 支持的语言列表
SUPPORTED_LANGUAGES = {"zh", "en"}
# 默认回退语言
DEFAULT_LANGUAGE = "zh"
def validate_language(language: Optional[str]) -> str:
"""
校验语言参数,确保其为有效值。
Args:
language: 待校验的语言代码,可以是 None、"zh""en" 或其他值
Returns:
有效的语言代码("zh""en"
Examples:
>>> validate_language("zh")
'zh'
>>> validate_language("en")
'en'
>>> validate_language("EN") # 大小写不敏感
'en'
>>> validate_language(None) # None 回退到默认值
'zh'
>>> validate_language("fr") # 不支持的语言回退到默认值
'zh'
"""
if language is None:
return DEFAULT_LANGUAGE
# 标准化:转小写并去除空白
lang = str(language).lower().strip()
if lang in SUPPORTED_LANGUAGES:
return lang
logger.warning(
f"无效的语言参数 '{language}',已回退到默认值 '{DEFAULT_LANGUAGE}'"
f"支持的语言: {SUPPORTED_LANGUAGES}"
)
return DEFAULT_LANGUAGE
def get_language_from_header(language_type: Optional[str]) -> str:
"""
从请求头获取并校验语言参数。
这是一个便捷函数,用于在 controller 层统一处理 X-Language-Type Header。
Args:
language_type: 从 X-Language-Type Header 获取的语言值
Returns:
有效的语言代码("zh""en"
Examples:
>>> get_language_from_header(None) # Header 未传递
'zh'
>>> get_language_from_header("en")
'en'
>>> get_language_from_header("invalid") # 无效值回退
'zh'
"""
return validate_language(language_type)

View File

@@ -38,56 +38,6 @@ class SensitiveDataLoggingFilter(logging.Filter):
return True
class Neo4jSuccessNotificationFilter(logging.Filter):
"""Neo4j 日志过滤器:过滤成功/信息性状态的通知,保留真正的警告和错误
Neo4j 驱动会以 WARNING 级别记录所有数据库通知,包括成功的操作。
这个过滤器会过滤掉以下 GQL 状态码的通知,只保留真正的警告和错误:
- 00000: 成功完成 (successful completion)
- 00N00: 无数据 (no data)
- 00NA0: 无数据,信息性通知 (no data, informational notification)
使用正则表达式进行更严格的匹配,避免误过滤无关的警告。
"""
import re
# 编译正则表达式以提高性能
# 匹配所有"成功/信息性"的 GQL 状态码:
# 00000 = 成功完成, 00N00 = 无数据, 00NA0 = 无数据信息性通知
GQL_STATUS_PATTERN = re.compile(r"gql_status=['\"](00000|00N00|00NA0)['\"]")
# 匹配 status_description 中的成功完成或信息性通知消息
SUCCESS_DESC_PATTERN = re.compile(r"status_description=['\"]note:\s*(successful\s+completion|no\s+data)['\"]", re.IGNORECASE)
def filter(self, record: logging.LogRecord) -> bool:
"""
过滤 Neo4j 成功通知
Args:
record: 日志记录
Returns:
True表示允许记录False表示拒绝过滤掉
"""
# 只处理 INFO 和 WARNING 级别的日志
# Neo4j 驱动对 severity='INFORMATION' 的通知使用 INFO 级别,
# 对 severity='WARNING' 的通知使用 WARNING 级别
if record.levelno not in (logging.INFO, logging.WARNING):
return True
# 检查是否是 Neo4j 的成功通知
message = str(record.msg)
# 使用正则表达式进行更严格的匹配
# 这样可以避免误过滤包含这些子字符串但不是 Neo4j 通知的日志
if self.GQL_STATUS_PATTERN.search(message) or self.SUCCESS_DESC_PATTERN.search(message):
return False # 过滤掉这条日志
# 保留其他所有日志(包括真正的警告和错误)
return True
class LoggingConfig:
"""全局日志配置类"""
@@ -115,22 +65,6 @@ class LoggingConfig:
# 清除现有处理器
root_logger.handlers.clear()
# Neo4j 通知过滤器 - 挂在 handler 上确保所有传播上来的日志都能被过滤
neo4j_filter = Neo4jSuccessNotificationFilter()
# 抑制 Neo4j 通知日志
# Neo4j 驱动内部会给 neo4j.notifications logger 配置自己的 handler
# 导致日志绕过根 logger 的 filter 直接输出。
# 多管齐下确保过滤生效:
# 1. 设置 neo4j.notifications 级别为 WARNING过滤 INFO 级别的 00NA0 通知)
# 2. 在所有 neo4j logger 上添加 filter过滤 WARNING 级别的成功通知)
# 3. 在根 handler 上也添加 filter兜底
neo4j_notifications_logger = logging.getLogger("neo4j.notifications")
neo4j_notifications_logger.setLevel(logging.WARNING)
for neo4j_logger_name in ["neo4j", "neo4j.io", "neo4j.pool", "neo4j.notifications"]:
neo4j_logger = logging.getLogger(neo4j_logger_name)
neo4j_logger.addFilter(neo4j_filter)
# 创建格式化器
formatter = logging.Formatter(
fmt=settings.LOG_FORMAT,
@@ -146,7 +80,6 @@ class LoggingConfig:
console_handler.setFormatter(formatter)
console_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
console_handler.addFilter(sensitive_filter)
console_handler.addFilter(neo4j_filter)
root_logger.addHandler(console_handler)
# 文件处理器(带轮转)
@@ -160,7 +93,6 @@ class LoggingConfig:
file_handler.setFormatter(formatter)
file_handler.setLevel(getattr(logging, settings.LOG_LEVEL.upper()))
file_handler.addFilter(sensitive_filter)
file_handler.addFilter(neo4j_filter)
root_logger.addHandler(file_handler)
cls._initialized = True

View File

@@ -0,0 +1,16 @@
"""
LangGraph Graph package for memory agent.
This package provides the LangGraph workflow orchestrator with modular
node implementations, routing logic, and state management.
Package structure:
- read_graph: Main graph factory for read operations
- write_graph: Main graph factory for write operations
- nodes: LangGraph node implementations
- routing: State routing logic
- state: State management utilities
"""
from app.core.memory.agent.langgraph_graph.read_graph import make_read_graph
__all__ = ['make_read_graph']

View File

@@ -4,7 +4,7 @@ LangGraph node implementations.
This module contains custom node implementations for the LangGraph workflow.
"""
# from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
# from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
#
# __all__ = ["ToolExecutionNode", "create_input_message"]
from app.core.memory.agent.langgraph_graph.nodes.tool_node import ToolExecutionNode
from app.core.memory.agent.langgraph_graph.nodes.input_node import create_input_message
__all__ = ["ToolExecutionNode", "create_input_message"]

View File

@@ -1,45 +0,0 @@
from app.core.memory.agent.utils.llm_tools import ReadState, WriteState
from app.schemas.memory_agent_schema import AgentMemoryDataset
def content_input_node(state: ReadState) -> ReadState:
"""
Start node - Extract content and maintain state information
Extracts the content from the first message in the state and returns it
as the data field while preserving all other state information.
Args:
state: ReadState containing messages and other state data
Returns:
ReadState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# Return content and maintain all state information
for pronoun in AgentMemoryDataset.PRONOUN:
content = content.replace(pronoun, AgentMemoryDataset.NAME)
return {"data": content}
def content_input_write(state: WriteState) -> WriteState:
"""
Start node - Extract content and maintain state information for write operations
Extracts the content from the first message in the state for write operations.
Args:
state: WriteState containing messages and other state data
Returns:
WriteState: Updated state with extracted content in data field
"""
content = state['messages'][0].content if state.get('messages') else ''
# Return content and maintain all state information
for pronoun in AgentMemoryDataset.PRONOUN:
content = content.replace(pronoun, AgentMemoryDataset.NAME)
return {"data": content}

View File

@@ -0,0 +1,150 @@
"""
Input node for LangGraph workflow entry point.
This module provides the create_input_message function which processes initial
user input with multimodal support and creates the first tool call message.
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any, Dict
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
logger = logging.getLogger(__name__)
async def create_input_message(
state: Dict[str, Any],
tool_name: str,
session_id: str,
search_switch: str,
apply_id: str,
group_id: str,
multimodal_processor: MultimodalProcessor,
memory_config: MemoryConfig,
) -> Dict[str, Any]:
"""
Create initial tool call message from user input.
This function:
1. Extracts the last message content from state
2. Processes multimodal inputs (images/audio) using the multimodal processor
3. Generates a unique message ID
4. Extracts namespace from session_id
5. Handles verified_data extraction for backward compatibility
6. Returns AIMessage with complete tool_calls structure
Args:
state: LangGraph state dictionary containing messages
tool_name: Name of the tool to invoke (typically "Split_The_Problem")
session_id: Session identifier (format: "call_id_{namespace}")
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
multimodal_processor: Processor for handling image/audio inputs
memory_config: MemoryConfig object containing all configuration
Returns:
State update with AIMessage containing tool_call
Examples:
>>> state = {"messages": [HumanMessage(content="What is AI?")]}
>>> result = await create_input_message(
... state, "Split_The_Problem", "call_id_user123", "0", "app1", "group1", processor, config
... )
>>> result["messages"][0].tool_calls[0]["name"]
'Split_The_Problem'
"""
messages = state.get("messages", [])
# Extract last message content
if messages:
last_message = messages[-1].content if hasattr(messages[-1], 'content') else str(messages[-1])
else:
logger.warning("[create_input_message] No messages in state, using empty string")
last_message = ""
logger.debug(f"[create_input_message] Original input: {last_message[:100]}...")
# Process multimodal input (images/audio)
try:
processed_content = await multimodal_processor.process_input(last_message)
if processed_content != last_message:
logger.info(
f"[create_input_message] Multimodal processing converted input "
f"from {len(last_message)} to {len(processed_content)} chars"
)
last_message = processed_content
except Exception as e:
logger.error(
f"[create_input_message] Multimodal processing failed: {e}",
exc_info=True
)
# Continue with original content
# Generate unique message ID
uuid_str = uuid.uuid4()
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Extract namespace from session_id
# Expected format: "call_id_{namespace}" or similar
try:
namespace = str(session_id).split('_id_')[1]
except (IndexError, AttributeError):
logger.warning(
f"[create_input_message] Could not extract namespace from session_id: {session_id}"
)
namespace = "unknown"
# Handle verified_data extraction (backward compatibility)
# This regex-based extraction is kept for compatibility with existing data formats
if 'verified_data' in str(last_message):
try:
messages_last = str(last_message).replace('\\n', '').replace('\\', '')
query_match = re.findall(r'"query": "(.*?)",', messages_last)
if query_match:
last_message = query_match[0]
logger.debug(
f"[create_input_message] Extracted query from verified_data: {last_message}"
)
except Exception as e:
logger.warning(
f"[create_input_message] Failed to extract query from verified_data: {e}"
)
# Construct tool call message
tool_call_id = f"{session_id}_{uuid_str}"
logger.info(
f"[create_input_message] Creating tool call for '{tool_name}' "
f"with ID: {tool_call_id}"
)
# Build tool arguments
tool_args = {
"sentence": last_message,
"sessionid": session_id,
"messages_id": str(uuid_str),
"search_switch": search_switch,
"apply_id": apply_id,
"group_id": group_id,
"memory_config": memory_config,
}
return {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": tool_name,
"args": tool_args,
"id": tool_call_id
}]
)
]
}

View File

@@ -1,282 +0,0 @@
import json
import os
import time
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.problem_models import ProblemExtensionResponse
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class ProblemNodeService(LLMServiceMixin):
"""
Problem processing node service class
Handles problem decomposition and extension operations using LLM services.
Inherits from LLMServiceMixin to provide structured LLM calling capabilities.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# Create global service instance
problem_service = ProblemNodeService()
async def Split_The_Problem(state: ReadState) -> ReadState:
"""
Problem decomposition node
Breaks down complex user queries into smaller, more manageable sub-problems.
Uses LLM to analyze the input and generate structured problem decomposition
with question types and reasoning.
Args:
state: ReadState containing user input and configuration
Returns:
ReadState: Updated state with problem decomposition results
"""
# 从状态中获取数据
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=content,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
structured = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
# 添加更详细的日志记录
logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}")
# Validate structured response
if not structured or not hasattr(structured, 'root'):
logger.warning("Split_The_Problem: 结构化响应为空或格式不正确")
split_result = json.dumps([], ensure_ascii=False)
elif not structured.root:
logger.warning("Split_The_Problem: 结构化响应的root为空")
split_result = json.dumps([], ensure_ascii=False)
else:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
split_result_dict = []
for index, item in enumerate(json.loads(split_result)):
split_data = {
"id": f"Q{index + 1}",
"question": item['extended_question'],
"type": item['type'],
"reason": item['reason']
}
split_result_dict.append(split_data)
logger.info(f"Split_The_Problem: 成功生成 {len(structured.root) if structured.root else 0} 个分解项")
result = {
"context": split_result,
"original": content,
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": split_result_dict,
"original_query": content
}
}
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"content_length": len(content),
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Split_The_Problem error details: {error_details}")
# Create default empty result
result = {
"context": json.dumps([], ensure_ascii=False),
"original": content,
"error": str(e),
"_intermediate": {
"type": "problem_split",
"title": "问题拆分",
"data": [],
"original_query": content,
"error": error_details
}
}
# Return updated state including spit_context field
return {"spit_data": result}
async def Problem_Extension(state: ReadState) -> ReadState:
"""
Problem extension node
Extends the decomposed problems from Split_The_Problem node by generating
additional related questions and organizing them by original question.
Uses LLM to create comprehensive question extensions for better memory retrieval.
Args:
state: ReadState containing decomposed problems and configuration
Returns:
ReadState: Updated state with extended problem results
"""
# Get original data and decomposition results
start = time.time()
content = state.get('data', '')
data = state.get('spit_data', '')['context']
end_user_id = state.get('end_user_id', '')
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
memory_config = state.get('memory_config', None)
databasets = {}
try:
data = json.loads(data)
for i in data:
databasets[i['extended_question']] = i['type']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.error(f"Problem_Extension: 数据解析失败: {e}")
# 使用空字典作为fallback
databasets = {}
data = []
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
# 生成 JSON schema 以指导 LLM 输出正确格式
json_schema = ProblemExtensionResponse.model_json_schema()
system_prompt = await problem_service.template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=databasets,
json_schema=json_schema
)
try:
# 使用优化的LLM服务
with get_db_context() as db_session:
response_content = await problem_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=ProblemExtensionResponse,
fallback_value=[]
)
logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}")
# Validate structured response
if not response_content or not hasattr(response_content, 'root'):
logger.warning("Problem_Extension: 结构化响应为空或格式不正确")
aggregated_dict = {}
elif not response_content.root:
logger.warning("Problem_Extension: 结构化响应的root为空")
aggregated_dict = {}
else:
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
try:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
logger.warning(f"Problem_Extension: 跳过无效项: key={key}, value={value}")
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as item_error:
logger.warning(f"Problem_Extension: 处理项目时出错: {item_error}")
continue
logger.info(f"Problem_Extension: 成功生成 {len(aggregated_dict)} 个扩展问题组")
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
# Provide more detailed error information
error_details = {
"error_type": type(e).__name__,
"error_message": str(e),
"questions_count": len(databasets),
"llm_model_id": str(memory_config.llm_model_id) if memory_config else None
}
logger.error(f"Problem_Extension error details: {error_details}")
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
print(time.time() - start)
result = {
"context": aggregated_dict,
"original": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"title": "问题扩展",
"data": aggregated_dict,
"original_query": content,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return {"problem_extension": result}

View File

@@ -1,509 +0,0 @@
# ===== 标准库 =====
import asyncio
import json
import os
# ===== 第三方库 =====
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.tool import (
create_hybrid_retrieval_tool_sync,
create_time_retrieval_tool,
extract_tool_message_content,
)
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
ReadState,
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas import model_schema
from app.services.memory_config_service import MemoryConfigService
from app.services.model_service import ModelConfigService
logger = get_agent_logger(__name__)
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results.
Args:
state: Current state containing configuration
question: Question to search for
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def llm_infomation(state: ReadState) -> ReadState:
"""
Get LLM configuration information from state
Retrieves model configuration details including model ID and tenant ID
from the memory configuration in the current state.
Args:
state: ReadState containing memory configuration
Returns:
ReadState: Model configuration as Pydantic model
"""
memory_config = state.get('memory_config', None)
model_id = memory_config.llm_model_id
tenant_id = memory_config.tenant_id
# Use existing memory_config instead of re-querying database
# or use thread-safe database access
with get_db_context() as db:
result_orm = ModelConfigService.get_model_by_id(db=db, model_id=model_id, tenant_id=tenant_id)
result_pydantic = model_schema.ModelConfig.model_validate(result_orm)
return result_pydantic
async def clean_databases(data) -> str:
"""
Simplified database search result cleaning function
Processes and cleans search results from various sources including
reranked results and time-based search results. Extracts text content
from structured data and returns as formatted string.
Args:
data: Search result data (can be string, dict, or other types)
Returns:
str: Cleaned content string
"""
try:
# Parse JSON string
if isinstance(data, str):
try:
data = json.loads(data)
except json.JSONDecodeError:
return data
if not isinstance(data, dict):
return str(data)
# Get result data
# with open("搜索结果.json","w",encoding='utf-8') as f:
# f.write(json.dumps(data, indent=4, ensure_ascii=False))
results = data.get('results', data)
if not isinstance(results, dict):
return str(results)
# Collect all content
content_list = []
# Process reranked results
reranked = results.get('reranked_results', {})
if reranked:
for category in ['summaries', 'communities', 'statements', 'chunks', 'entities']:
items = reranked.get(category, [])
if isinstance(items, list):
content_list.extend(items)
# Process time search results
time_search = results.get('time_search', {})
if time_search:
if isinstance(time_search, dict):
statements = time_search.get('statements', time_search.get('time_search', []))
if isinstance(statements, list):
content_list.extend(statements)
elif isinstance(time_search, list):
content_list.extend(time_search)
# Extract text content对 community 按 name 去重(多次 tool 调用会产生重复)
text_parts = []
seen_community_names = set()
for item in content_list:
if isinstance(item, dict):
# community 节点用 name 去重
if 'member_count' in item or 'core_entities' in item:
community_name = item.get('name') or item.get('id', '')
if community_name in seen_community_names:
continue
seen_community_names.add(community_name)
text = item.get('statement') or item.get('content') or item.get('summary', '')
if text:
text_parts.append(text)
elif isinstance(item, str):
text_parts.append(item)
return '\n'.join(text_parts).strip()
except Exception as e:
logger.error(f"clean_databases failed: {e}", exc_info=True)
return str(data)
async def retrieve_nodes(state: ReadState) -> ReadState:
"""
Retrieve information using simplified search approach
Processes extended problems from previous nodes and performs retrieval
using either RAG or hybrid search based on storage type. Handles concurrent
processing of multiple questions and deduplicates results.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
# Create async task to process individual questions
async def process_question_nodes(idx, question):
try:
# Prepare search parameters based on storage type
search_params = {
"end_user_id": end_user_id,
"question": question,
"return_raw_results": True
}
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question)
else:
clean_content, cleaned_query, raw_results = await SearchService().execute_hybrid_search(
**search_params, memory_config=memory_config
)
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# Process all questions concurrently
tasks = [process_question_nodes(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}
async def retrieve(state: ReadState) -> ReadState:
"""
Advanced retrieve function using LangChain agents and tools
Uses LangChain agents with specialized retrieval tools (time-based and hybrid)
to perform sophisticated information retrieval. Supports both RAG and traditional
memory storage approaches with concurrent processing and result deduplication.
Args:
state: ReadState containing problem extensions and configuration
Returns:
ReadState: Updated state with retrieval results and intermediate outputs
"""
# Get end_user_id from state
import time
start = time.time()
problem_extension = state.get('problem_extension', '')['context']
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
original = state.get('data', '')
problem_list = []
for key, values in problem_extension.items():
for data in values:
problem_list.append(data)
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
databases_anser = []
async def get_llm_info():
with get_db_context() as db: # 使用同步数据库上下文管理器
config_service = MemoryConfigService(db)
return await llm_infomation(state)
llm_config = await get_llm_info()
api_key_obj = llm_config.api_keys[0]
api_key = api_key_obj.api_key
api_base = api_key_obj.api_base
model_name = api_key_obj.model_name
llm = ChatOpenAI(
model=model_name,
api_key=api_key,
base_url=api_base,
temperature=0.2,
)
time_retrieval_tool = create_time_retrieval_tool(end_user_id)
search_params = {
"end_user_id": end_user_id,
"return_raw_results": True,
"include": ["summaries", "statements", "chunks", "entities", "communities"],
}
hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params)
agent = create_agent(
llm,
tools=[time_retrieval_tool, hybrid_retrieval],
system_prompt=f"我是检索专家可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}"
)
# Create async task to process individual questions
import asyncio
# Define semaphore at module level to limit maximum concurrency
SEMAPHORE = asyncio.Semaphore(5) # Limit to maximum 5 concurrent database operations
async def process_question(idx, question):
async with SEMAPHORE: # Limit concurrency
try:
if storage_type == "rag" and user_rag_memory_id:
retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state,
question)
else:
cleaned_query = question
# Use asyncio to run synchronous agent.invoke in thread pool
import asyncio
response = await asyncio.get_event_loop().run_in_executor(
None,
lambda: agent.invoke({"messages": question})
)
tool_results = extract_tool_message_content(response)
if tool_results == None:
raw_results = []
clean_content = ''
else:
raw_results = tool_results['content']
clean_content = await clean_databases(raw_results)
# 社区展开:从 tool 返回结果中提取命中的 community
# 沿 BELONGS_TO_COMMUNITY 关系拉取关联 Statement 追加到 clean_content
_expanded_stmts_to_write = []
try:
results_dict = raw_results.get('results', {}) if isinstance(raw_results, dict) else {}
reranked = results_dict.get('reranked_results', {})
community_hits = reranked.get('communities', [])
if not community_hits:
community_hits = results_dict.get('communities', [])
if community_hits:
from app.core.memory.agent.services.search_service import expand_communities_to_statements
_expanded_stmts_to_write, new_texts = await expand_communities_to_statements(
community_results=community_hits,
end_user_id=end_user_id,
existing_content=clean_content,
)
if new_texts:
clean_content = clean_content + '\n' + '\n'.join(new_texts)
except Exception as parse_err:
logger.warning(f"[Retrieve] 解析社区命中结果失败,跳过展开: {parse_err}")
try:
raw_results = raw_results['results']
# 写回展开结果,接口返回中可见(已在 helper 中清洗过字段)
if _expanded_stmts_to_write and isinstance(raw_results, dict):
raw_results.setdefault('reranked_results', {})['expanded_statements'] = _expanded_stmts_to_write
except Exception:
raw_results = []
return {
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(problem_list)
}
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Return empty result for this question
return {
"Query_small": question,
"Result_small": "",
"_intermediate": {
"type": "search_result",
"query": question,
"raw_results": [],
"index": idx + 1,
"total": len(problem_list)
}
}
# Process all questions concurrently
import asyncio
tasks = [process_question(idx, question) for idx, question in enumerate(problem_list)]
databases_anser = await asyncio.gather(*tasks)
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
if j != ['']:
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
# with open('retrieve_text.json', 'w') as f:
# json.dump(dup_databases, f, indent=4)
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
return {'retrieve': dup_databases}

View File

@@ -1,524 +0,0 @@
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.services.search_service import SearchService
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class SummaryNodeService(LLMServiceMixin):
"""
Summary node service class
Handles summary generation operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
generating summaries from retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# Create global service instance
summary_service = SummaryNodeService()
async def rag_config(state):
"""
Configure RAG (Retrieval-Augmented Generation) settings for summary operations
Creates configuration for knowledge base retrieval including similarity thresholds,
weights, and reranker settings specifically for summary generation.
Args:
state: Current state containing user_rag_memory_id
Returns:
dict: RAG configuration dictionary with knowledge base settings
"""
user_rag_memory_id = state.get('user_rag_memory_id', '')
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
return kb_config
async def rag_knowledge(state, question):
"""
Retrieve knowledge using RAG approach for summary generation
Performs knowledge retrieval from configured knowledge bases using the
provided question and returns formatted results for summary processing.
Args:
state: Current state containing configuration
question: Question to search for in knowledge base
Returns:
tuple: (retrieval_knowledge, clean_content, cleaned_query, raw_results)
- retrieval_knowledge: List of retrieved knowledge chunks
- clean_content: Formatted content string
- cleaned_query: Processed query string
- raw_results: Raw retrieval results
"""
kb_config = await rag_config(state)
end_user_id = state.get('end_user_id', '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = question
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except Exception:
retrieval_knowledge = []
clean_content = ''
raw_results = ''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
return retrieval_knowledge, clean_content, cleaned_query, raw_results
async def summary_history(state: ReadState) -> ReadState:
"""
Retrieve conversation history for summary context
Gets the conversation history for the current user to provide context
for summary generation operations.
Args:
state: ReadState containing end_user_id
Returns:
ReadState: Conversation history data
"""
end_user_id = state.get("end_user_id", '')
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
return history
async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,
search_mode) -> str:
"""
Enhanced summary_llm function with better error handling and data validation
Generates summaries using LLM with structured output. Includes fallback mechanisms
for handling LLM failures and provides robust error recovery.
Args:
state: ReadState containing current context
history: Conversation history for context
retrieve_info: Retrieved information to summarize
template_name: Jinja2 template name for prompt generation
operation_name: Type of operation (summary, input_summary, retrieve_summary)
response_model: Pydantic model for structured output
search_mode: Search mode flag ("0" for simple, "1" for complex)
Returns:
str: Generated summary text or fallback message
"""
data = state.get("data", '')
# Build system prompt
if str(search_mode) == "0":
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
data=retrieve_info,
query=data
)
else:
system_prompt = await summary_service.template_service.render_template(
template_name=template_name,
operation_name=operation_name,
query=data,
history=history,
retrieve_info=retrieve_info
)
try:
# Use optimized LLM service for structured output
with get_db_context() as db_session:
structured = await summary_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=response_model,
fallback_value=None
)
# Validate structured response
if structured is None:
logger.warning("LLM返回None使用默认回答")
return "信息不足,无法回答"
# Extract answer based on operation type
if operation_name == "summary":
aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答"
else:
# Handle RetrieveSummaryResponse
if hasattr(structured, 'data') and structured.data:
aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答"
else:
logger.warning("结构化响应缺少data字段")
aimessages = "信息不足,无法回答"
# Validate answer is not empty
if not aimessages or aimessages.strip() == "":
aimessages = "信息不足,无法回答"
return aimessages
except Exception as e:
logger.error(f"结构化输出失败: {e}", exc_info=True)
# Try unstructured output as fallback
try:
logger.info("尝试非结构化输出作为fallback")
response = await summary_service.call_llm_simple(
state=state,
db_session=db_session,
system_prompt=system_prompt,
fallback_message="信息不足,无法回答"
)
if response and response.strip():
# Simple response cleaning
cleaned_response = response.strip()
# Remove possible JSON markers
if cleaned_response.startswith('```'):
lines = cleaned_response.split('\n')
cleaned_response = '\n'.join(lines[1:-1])
return cleaned_response
else:
return "信息不足,无法回答"
except Exception as fallback_error:
logger.error(f"Fallback也失败: {fallback_error}")
return "信息不足,无法回答"
async def summary_redis_save(state: ReadState, aimessages) -> ReadState:
"""
Save summary results to Redis session storage
Stores the generated summary and user query in Redis for session management
and conversation history tracking.
Args:
state: ReadState containing user and query information
aimessages: Generated summary message to save
Returns:
ReadState: Updated state after saving to Redis
"""
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
await SessionService(store).save_session(
user_id=end_user_id,
query=data,
apply_id=end_user_id,
end_user_id=end_user_id,
ai_response=aimessages
)
await SessionService(store).cleanup_duplicates()
logger.info(f"sessionid: {aimessages} 写入成功")
async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState:
"""
Format summary results for different output types
Creates structured output formats for both input summary and retrieval summary
operations, including metadata and intermediate results for frontend display.
Args:
state: ReadState containing storage and user information
aimessages: Generated summary message
raw_results: Raw search/retrieval results
Returns:
tuple: (input_summary, retrieve_summary) formatted result dictionaries
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
input_summary = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": data,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
retrieve = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"title": "快速检索",
"summary": aimessages,
"query": data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return input_summary, retrieve
async def Input_Summary(state: ReadState) -> ReadState:
"""
Generate quick input summary from retrieved information
Performs fast retrieval and generates a quick summary response for user queries.
This function prioritizes speed by only searching summary nodes and provides
immediate feedback to users.
Args:
state: ReadState containing user query, storage configuration, and context
Returns:
ReadState: Dictionary containing summary results with status and metadata
"""
start = time.time()
storage_type = state.get("storage_type", '')
memory_config = state.get('memory_config', None)
user_rag_memory_id = state.get("user_rag_memory_id", '')
data = state.get("data", '')
end_user_id = state.get("end_user_id", '')
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
history = await summary_history(state)
search_params = {
"end_user_id": end_user_id,
"question": data,
"return_raw_results": True,
"include": ["summaries", "communities"] # MemorySummary 和 Community 同为高维度概括节点
}
try:
if storage_type != "rag":
retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(
**search_params,
memory_config=memory_config,
expand_communities=False, # 路径 "2" 只需要 community 的 summary 文本,不展开到 Statement
)
# 调试:打印 community 检索结果数量
if raw_results and isinstance(raw_results, dict):
reranked = raw_results.get('reranked_results', {})
community_hits = reranked.get('communities', [])
logger.debug(f"[Input_Summary] community 命中数: {len(community_hits)}, "
f"summary 命中数: {len(reranked.get('summaries', []))}")
else:
retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data)
except Exception as e:
logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True)
retrieve_info, question, raw_results = "", data, []
try:
# aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2',
# 'input_summary',RetrieveSummaryResponse)
# logger.info(f"快速答案总结==>>:{storage_type}--{user_rag_memory_id}--{aimessages}")
summary_result = await summary_prompt(state, retrieve_info, retrieve_info)
summary = summary_result[0]
except Exception as e:
logger.error(f"Input_Summary failed: {e}", exc_info=True)
summary = {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('检索', duration)
return {"summary": summary}
async def Retrieve_Summary(state: ReadState) -> ReadState:
"""
Generate comprehensive summary from retrieved expansion issues
Processes retrieved expansion issues and generates a detailed summary using LLM.
This function handles complex retrieval results and provides comprehensive answers
based on expanded query results.
Args:
state: ReadState containing retrieve data with expansion issues
Returns:
ReadState: Dictionary containing comprehensive summary results
"""
retrieve = state.get("retrieve", '')
history = await summary_history(state)
import json
with open("检索.json", "w", encoding='utf-8') as f:
f.write(json.dumps(retrieve, indent=4, ensure_ascii=False))
retrieve = retrieve.get("Expansion_issue", [])
start = time.time()
retrieve_info_str = []
for data in retrieve:
if data == '':
retrieve_info_str = ''
else:
for key, value in data.items():
if key == 'Answer_Small':
for i in value:
retrieve_info_str.append(i)
retrieve_info_str = list(set(retrieve_info_str))
retrieve_info_str = '\n'.join(retrieve_info_str)
aimessages = await summary_llm(state, history, retrieve_info_str,
'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1")
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary(state: ReadState) -> ReadState:
"""
Generate final comprehensive summary from verified data
Creates the final summary using verified expansion issues and conversation history.
This function processes verified data to generate the most comprehensive and
accurate response to user queries.
Args:
state: ReadState containing verified data and query information
Returns:
ReadState: Dictionary containing final summary results
"""
start = time.time()
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
history = await summary_history(state)
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'summary_prompt.jinja2', 'summary', SummaryResponse, 0)
if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "":
await summary_redis_save(state, aimessages)
if aimessages == '':
aimessages = '信息不足,无法回答'
try:
duration = time.time() - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# Fixed coroutine call - await first, then access return value
summary_result = await summary_prompt(state, aimessages, retrieve_info_str)
summary = summary_result[1]
return {"summary": summary}
async def Summary_fails(state: ReadState) -> ReadState:
"""
Generate fallback summary when normal summary process fails
Provides a fallback summary generation mechanism when the standard summary
process encounters errors or fails to produce satisfactory results. Uses
a specialized failure template to handle edge cases.
Args:
state: ReadState containing verified data and failure context
Returns:
ReadState: Dictionary containing fallback summary results
"""
storage_type = state.get("storage_type", '')
user_rag_memory_id = state.get("user_rag_memory_id", '')
history = await summary_history(state)
query = state.get("data", '')
verify = state.get("verify", '')
verify_expansion_issue = verify.get("verified_data", '')
retrieve_info_str = ''
for data in verify_expansion_issue:
for key, value in data.items():
if key == 'answer_small':
for i in value:
retrieve_info_str += i + '\n'
data = {
"query": query,
"history": history,
"retrieve_info": retrieve_info_str
}
aimessages = await summary_llm(state, history, data,
'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0)
result = {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
return {"summary": result}

View File

@@ -0,0 +1,234 @@
"""
Tool execution node for LangGraph workflow.
This module provides the ToolExecutionNode class which wraps tool execution
with parameter transformation logic using the ParameterBuilder service.
"""
import logging
import time
from typing import Any, Callable, Dict
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_content_payload,
extract_tool_call_id,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.prebuilt import ToolNode
logger = logging.getLogger(__name__)
class ToolExecutionNode:
"""
Custom LangGraph node that wraps tool execution with parameter transformation.
This node extracts content from previous tool results, transforms parameters
based on tool type using ParameterBuilder, and invokes the tool with the
correct argument structure.
Attributes:
tool_node: LangGraph ToolNode wrapping the actual tool
id: Node identifier for message IDs
tool_name: Name of the tool being executed
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
memory_config: MemoryConfig object containing all configuration
"""
def __init__(
self,
tool: Callable,
node_id: str,
namespace: str,
search_switch: str,
apply_id: str,
group_id: str,
parameter_builder: ParameterBuilder,
storage_type: str,
user_rag_memory_id: str,
memory_config: MemoryConfig,
):
"""
Initialize the tool execution node.
Args:
tool: The tool function to execute
node_id: Identifier for this node (used in message IDs)
namespace: Namespace for session management
search_switch: Search routing parameter
apply_id: Application identifier
group_id: Group identifier
parameter_builder: Service for building tool-specific arguments
storage_type: Storage type for the workspace
user_rag_memory_id: User RAG memory identifier
memory_config: MemoryConfig object containing all configuration
"""
self.tool_node = ToolNode([tool])
self.id = node_id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.parameter_builder = parameter_builder
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
self.memory_config = memory_config
logger.info(
f"[ToolExecutionNode] Initialized node '{self.id}' for tool '{self.tool_name}'"
)
async def __call__(self, state: Dict[str, Any]) -> Dict[str, Any]:
"""
Execute the tool with transformed parameters.
This method:
1. Extracts the last message from state
2. Extracts tool call ID using state extractors
3. Extracts content payload using state extractors
4. Builds tool arguments using parameter builder
5. Constructs AIMessage with tool_calls
6. Invokes the tool and returns the result
Args:
state: LangGraph state dictionary
Returns:
Updated state with tool result in messages
"""
messages = state.get("messages", [])
logger.debug( self.tool_name)
if not messages:
logger.warning(f"[ToolExecutionNode] {self.id} - No messages in state")
return {"messages": [AIMessage(content="Error: No messages in state")]}
last_message = messages[-1]
logger.debug(
f"[ToolExecutionNode] {self.id} - Processing message at {time.time()}"
)
try:
# Extract tool call ID using state extractors
tool_call_id = extract_tool_call_id(last_message)
logger.debug(f"[ToolExecutionNode] {self.id} - Extracted tool_call_id: {tool_call_id}")
except ValueError as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract tool call ID: {e}"
)
return {"messages": [AIMessage(content=f"Error: {str(e)}")]}
try:
# Extract content payload using state extractors
content = extract_content_payload(last_message)
logger.debug(
f"[ToolExecutionNode] {self.id} - Extracted content type: {type(content)}, content_keys: {list(content.keys()) if isinstance(content, dict) else 'N/A'}"
)
# Log raw message content for debugging
if hasattr(last_message, 'content'):
raw = last_message.content
logger.debug(f"[ToolExecutionNode] {self.id} - Raw message content (first 500 chars): {str(raw)[:500]}")
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to extract content: {e}",
exc_info=True
)
content = {}
try:
# Build tool arguments using parameter builder
tool_args = self.parameter_builder.build_tool_args(
tool_name=self.tool_name,
content=content,
tool_call_id=tool_call_id,
search_switch=self.search_switch,
apply_id=self.apply_id,
group_id=self.group_id,
memory_config=self.memory_config,
storage_type=self.storage_type,
user_rag_memory_id=self.user_rag_memory_id,
)
logger.debug(
f"[ToolExecutionNode] {self.id} - Built tool args with keys: {list(tool_args.keys())}"
)
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Failed to build tool args: {e}",
exc_info=True
)
return {"messages": [AIMessage(content=f"Error building arguments: {str(e)}")]}
# Construct tool input message
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": f"{self.id}_{tool_call_id}",
}]
)
]
}
try:
# Invoke the tool
result = await self.tool_node.ainvoke(tool_input)
logger.debug(
f"[ToolExecutionNode] {self.id} - Tool execution completed"
)
# Check for error in tool response
error_entry = None
if result and "messages" in result:
for msg in result["messages"]:
if hasattr(msg, 'content'):
try:
import json
content = msg.content
if isinstance(content, str):
parsed = json.loads(content)
if isinstance(parsed, dict) and "error" in parsed:
error_msg = parsed["error"]
logger.warning(
f"[ToolExecutionNode] {self.id} - Tool returned error: {error_msg}"
)
error_entry = {"tool": self.tool_name, "error": error_msg, "node_id": self.id}
except (json.JSONDecodeError, TypeError):
pass
# Return result with error tracking if error was found
if error_entry:
result["errors"] = [error_entry]
return result
except Exception as e:
logger.error(
f"[ToolExecutionNode] {self.id} - Tool execution failed: {e}",
exc_info=True
)
# Track error in state and return error message
from langchain_core.messages import ToolMessage
error_entry = {"tool": self.tool_name, "error": str(e), "node_id": self.id}
return {
"messages": [
ToolMessage(
content=f"Error executing tool: {str(e)}",
tool_call_id=f"{self.id}_{tool_call_id}"
)
],
"errors": [error_entry]
}

View File

@@ -1,184 +0,0 @@
import asyncio
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.models.verification_models import VerificationResult
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
from app.core.memory.agent.utils.llm_tools import (
PROJECT_ROOT_,
ReadState,
)
from app.core.memory.agent.utils.redis_tool import store
from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db_context
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__)
class VerificationNodeService(LLMServiceMixin):
"""
Verification node service class
Handles data verification operations using LLM services. Inherits from
LLMServiceMixin to provide structured LLM calling capabilities for
verifying and validating retrieved information.
Attributes:
template_service: Service for rendering Jinja2 templates
"""
def __init__(self):
super().__init__()
self.template_service = TemplateService(template_root)
# Create global service instance
verification_service = VerificationNodeService()
async def Verify_prompt(state: ReadState, messages_deal: VerificationResult):
"""
Process verification results and generate output format
Transforms VerificationResult objects into structured output format suitable
for frontend consumption. Handles conversion of VerificationItem objects to
dictionary format and adds metadata for tracking.
Args:
state: ReadState containing storage and user configuration
messages_deal: VerificationResult containing verification outcomes
Returns:
dict: Formatted verification result with status and metadata
"""
storage_type = state.get('storage_type', '')
user_rag_memory_id = state.get('user_rag_memory_id', '')
data = state.get('data', '')
# Convert VerificationItem objects to dictionary list
verified_data = []
if messages_deal.expansion_issue:
for item in messages_deal.expansion_issue:
if hasattr(item, 'model_dump'):
verified_data.append(item.model_dump())
elif isinstance(item, dict):
verified_data.append(item)
Verify_result = {
"status": messages_deal.split_result,
"verified_data": verified_data,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.split_result,
"reason": messages_deal.reason or "验证完成",
"query": messages_deal.query,
"verified_count": len(verified_data),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return Verify_result
async def Verify(state: ReadState):
logger.info("=== Verify 节点开始执行 ===")
try:
content = state.get('data', '')
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', None)
logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}")
history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id)
logger.info(f"Verify: 获取历史记录完成history length={len(history)}")
retrieve = state.get("retrieve", {})
logger.info(
f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}")
retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else []
logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}")
messages = {
"Query": content,
"Expansion_issue": retrieve_expansion
}
logger.info("Verify: 开始渲染模板")
# Generate JSON schema to guide LLM output format
json_schema = VerificationResult.model_json_schema()
system_prompt = await verification_service.template_service.render_template(
template_name='split_verify_prompt.jinja2',
operation_name='split_verify_prompt',
history=history,
sentence=messages,
json_schema=json_schema
)
logger.info(f"Verify: 模板渲染完成prompt length={len(system_prompt)}")
# 使用优化的LLM服务添加超时保护
logger.info("Verify: 开始调用 LLM")
try:
# Add asyncio.wait_for timeout wrapper to prevent infinite waiting
# Timeout set to 150 seconds (slightly longer than LLM config's 120 seconds)
with get_db_context() as db_session:
structured = await asyncio.wait_for(
verification_service.call_llm_structured(
state=state,
db_session=db_session,
system_prompt=system_prompt,
response_model=VerificationResult,
fallback_value={
"query": content,
"history": history if isinstance(history, list) else [],
"expansion_issue": [],
"split_result": "failed",
"reason": "验证失败或超时"
}
),
timeout=150.0 # 150 second timeout
)
logger.info(f"Verify: LLM 调用完成result={structured}")
except asyncio.TimeoutError:
logger.error("Verify: LLM 调用超时150秒使用 fallback 值")
structured = VerificationResult(
query=content,
history=history if isinstance(history, list) else [],
expansion_issue=[],
split_result="failed",
reason="LLM调用超时"
)
result = await Verify_prompt(state, structured)
logger.info("=== Verify 节点执行完成 ===")
return {"verify": result}
except Exception as e:
logger.error(f"Verify 节点执行失败: {e}", exc_info=True)
# 返回失败的验证结果
return {
"verify": {
"status": "failed",
"verified_data": [],
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', ''),
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": "failed",
"reason": f"验证过程出错: {str(e)}",
"query": state.get('data', ''),
"verified_count": 0,
"storage_type": state.get('storage_type', ''),
"user_rag_memory_id": state.get('user_rag_memory_id', '')
}
}
}

View File

@@ -1,67 +0,0 @@
from app.cache.memory.interest_memory import InterestMemoryCache
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.utils.write_tools import write
from app.core.logging_config import get_agent_logger
logger = get_agent_logger(__name__)
async def write_node(state: WriteState) -> WriteState:
"""
Write data to the database/file system.
Args:
state: WriteState containing messages, end_user_id, memory_config, and language
Returns:
dict: Contains 'write_result' with status and data fields
"""
messages = state.get('messages', [])
end_user_id = state.get('end_user_id', '')
memory_config = state.get('memory_config', '')
language = state.get('language', 'zh') # 默认中文
# Convert LangChain messages to structured format expected by write()
structured_messages = []
for msg in messages:
if hasattr(msg, 'type') and hasattr(msg, 'content'):
# Map LangChain message types to role names
role = 'user' if msg.type == 'human' else 'assistant' if msg.type == 'ai' else msg.type
structured_messages.append({
"role": role,
"content": msg.content # content is now guaranteed to be a string
})
try:
result = await write(
messages=structured_messages,
end_user_id=end_user_id,
memory_config=memory_config,
language=language,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
# 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成
for lang in ["zh", "en"]:
deleted = await InterestMemoryCache.delete_interest_distribution(
end_user_id=end_user_id,
language=lang,
)
if deleted:
logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}")
write_result = {
"status": "success",
"data": structured_messages,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
return {"write_result": write_result}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
write_result = {
"status": "error",
"message": str(e),
}
return {"write_result": write_result}

View File

@@ -1,86 +1,469 @@
#!/usr/bin/env python3
import json
import os
import re
import time
import warnings
from contextlib import asynccontextmanager
from typing import Literal
from langchain_core.messages import HumanMessage
from langgraph.constants import START, END
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.nodes import (
ToolExecutionNode,
create_input_message,
)
from app.core.memory.agent.mcp_server.services.parameter_builder import ParameterBuilder
from app.core.memory.agent.utils.llm_tools import COUNTState, ReadState
from app.core.memory.agent.utils.multimodal import MultimodalProcessor
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from langchain_core.messages import AIMessage
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
from app.db import get_db
from app.services.memory_config_service import MemoryConfigService
logger = get_agent_logger(__name__)
from app.core.memory.agent.utils.llm_tools import ReadState
from app.core.memory.agent.langgraph_graph.nodes.data_nodes import content_input_node
from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import (
Split_The_Problem,
Problem_Extension,
)
from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import (
retrieve,
)
from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import (
Input_Summary,
Retrieve_Summary,
Summary_fails,
Summary,
)
from app.core.memory.agent.langgraph_graph.nodes.verification_nodes import Verify
from app.core.memory.agent.langgraph_graph.routing.routers import (
Split_continue,
Retrieve_continue,
Verify_continue,
)
warnings.filterwarnings("ignore", category=RuntimeWarning)
load_dotenv()
redishost=os.getenv("REDISHOST")
redisport=os.getenv('REDISPORT')
redisdb=os.getenv('REDISDB')
redispassword=os.getenv('REDISPASSWORD')
counter = COUNTState(limit=3)
# Update loop count in workflow
async def update_loop_count(state):
"""Update loop counter"""
current_count = state.get("loop_count", 0)
return {"loop_count": current_count + 1}
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
messages = state["messages"]
# Add boundary check
if not messages:
return END
counter.add(1) # Increment by 1
loop_count = counter.get_total()
logger.debug(f"[should_continue] Current loop count: {loop_count}")
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"Status tools: {status_tools}")
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Maximum loop count is 3
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
counter.reset()
return "Summary" # Default based on business requirements
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Add default return value to avoid returning None
return 'Retrieve_Summary' # Default based on business logic
def Split_continue(state) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
# Direct dictionary access instead of regex parsing
search_switch = state.get("search_switch")
# Handle case where search_switch might be in messages
if search_switch is None and "messages" in state:
messages = state.get("messages", [])
if messages:
last_message = messages[-1]
# Try to extract from tool_calls args
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
for tool_call in last_message.tool_calls:
if isinstance(tool_call, dict) and "args" in tool_call:
search_switch = tool_call["args"].get("search_switch")
break
# Convert to string for comparison if needed
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # Default case
class ProblemExtensionNode:
def __init__(self, tool, id, namespace, search_switch, apply_id, group_id, storage_type="", user_rag_memory_id=""):
self.tool_node = ToolNode([tool])
self.id = id
self.tool_name = tool.name if hasattr(tool, 'name') else str(tool)
self.namespace = namespace
self.search_switch = search_switch
self.apply_id = apply_id
self.group_id = group_id
self.storage_type = storage_type
self.user_rag_memory_id = user_rag_memory_id
async def __call__(self, state):
messages = state["messages"]
last_message = messages[-1] if messages else ""
logger.debug(f"ProblemExtensionNode {self.id} - Current time: {time.time()} - Message: {last_message}")
if self.tool_name == 'Input_Summary':
tool_call = re.findall("'id': '(.*?)'", str(last_message))[0]
else:
tool_call = str(re.findall(r"tool_call_id=.*?'(.*?)'", str(last_message))[0]).replace('\\', '').split('_id')[1]
# Try to extract actual content payload from previous tool result
raw_msg = last_message.content if hasattr(last_message, 'content') else str(last_message)
extracted_payload = None
# Capture ToolMessage content field (supports single/double quotes), avoid greedy matching
m = re.search(r"content=(?:\"|\')(.*?)(?:\"|\'),\s*name=", raw_msg, flags=re.S)
if m:
extracted_payload = m.group(1)
else:
# Fallback: use raw string directly
extracted_payload = raw_msg
# Try to parse content as JSON first
try:
content = json.loads(extracted_payload)
except Exception:
# Try to extract JSON fragment from text and parse
parsed = None
candidates = re.findall(r"[\[{].*[\]}]", extracted_payload, flags=re.S)
for cand in candidates:
try:
parsed = json.loads(cand)
break
except Exception:
continue
# If still fails, use raw string as content
content = parsed if parsed is not None else extracted_payload
# Build correct parameters based on tool name
tool_args = {}
if self.tool_name == "Verify":
# Verify tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Retrieve":
# Retrieve tool requires context and usermessages parameters
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary":
# Summary tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == "Summary_fails":
# Summary_fails tool requires string type context parameter
if isinstance(content, dict):
# Convert dict to JSON string
tool_args["context"] = json.dumps(content, ensure_ascii=False)
else:
tool_args["context"] = str(content)
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
elif self.tool_name == 'Input_Summary':
tool_args["context"] = str(last_message)
tool_args["usermessages"] = str(tool_call)
tool_args["search_switch"] = str(self.search_switch)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_args["storage_type"] = getattr(self, 'storage_type', "")
tool_args["user_rag_memory_id"] = getattr(self, 'user_rag_memory_id', "")
elif self.tool_name == 'Retrieve_Summary':
# Retrieve_Summary expects dict directly, not JSON string
# content might be a JSON string, try to parse it
if isinstance(content, str):
try:
parsed_content = json.loads(content)
# Check if it has a "context" key
if isinstance(parsed_content, dict) and "context" in parsed_content:
tool_args["context"] = parsed_content["context"]
else:
tool_args["context"] = parsed_content
except json.JSONDecodeError:
# If parsing fails, wrap the string
tool_args["context"] = {"content": content}
elif isinstance(content, dict):
# Check if content has a "context" key that needs unwrapping
if "context" in content:
tool_args["context"] = content["context"]
else:
tool_args["context"] = content
else:
tool_args["context"] = {"content": str(content)}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
else:
# Other tools use context parameter
if isinstance(content, dict):
tool_args["context"] = content
else:
tool_args["context"] = {"content": content}
tool_args["usermessages"] = str(tool_call)
tool_args["apply_id"] = str(self.apply_id)
tool_args["group_id"] = str(self.group_id)
tool_input = {
"messages": [
AIMessage(
content="",
tool_calls=[{
"name": self.tool_name,
"args": tool_args,
"id": self.id + f"{tool_call}",
}]
)
]
}
result = await self.tool_node.ainvoke(tool_input)
result_text = str(result)
return {"messages": [AIMessage(content=result_text)]}
@asynccontextmanager
async def make_read_graph():
async def make_read_graph(namespace, tools, search_switch, apply_id, group_id, memory_config: MemoryConfig, storage_type=None, user_rag_memory_id=None):
"""
Create and return a LangGraph workflow for memory reading operations
Create a read graph workflow for memory operations.
Builds a state graph workflow that handles memory retrieval, problem analysis,
verification, and summarization. The workflow includes nodes for content input,
problem splitting, retrieval, verification, and various summary operations.
Yields:
StateGraph: Compiled LangGraph workflow for memory reading
Raises:
Exception: If workflow creation fails
Args:
namespace: Namespace identifier
tools: MCP tools loaded from session
search_switch: Search mode switch ("0", "1", or "2")
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type (optional)
user_rag_memory_id: User RAG memory ID (optional)
"""
try:
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem)
workflow.add_node("Problem_Extension", Problem_Extension)
workflow.add_node("Input_Summary", Input_Summary)
# workflow.add_node("Retrieve", retrieve_nodes)
workflow.add_node("Retrieve", retrieve)
workflow.add_node("Verify", Verify)
workflow.add_node("Retrieve_Summary", Retrieve_Summary)
workflow.add_node("Summary", Summary)
workflow.add_node("Summary_fails", Summary_fails)
memory = InMemorySaver()
tool = [i.name for i in tools]
logger.info(f"Initializing read graph with tools: {tool}")
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
# Extract tool functions
Split_The_Problem_ = next((t for t in tools if t.name == "Split_The_Problem"), None)
Problem_Extension_ = next((t for t in tools if t.name == "Problem_Extension"), None)
Retrieve_ = next((t for t in tools if t.name == "Retrieve"), None)
Verify_ = next((t for t in tools if t.name == "Verify"), None)
Summary_ = next((t for t in tools if t.name == "Summary"), None)
Summary_fails_ = next((t for t in tools if t.name == "Summary_fails"), None)
Retrieve_Summary_ = next((t for t in tools if t.name == "Retrieve_Summary"), None)
Input_Summary_ = next((t for t in tools if t.name == "Input_Summary"), None)
# Instantiate services
parameter_builder = ParameterBuilder()
multimodal_processor = MultimodalProcessor()
# Create nodes using new modular components
Split_The_Problem_node = ToolNode([Split_The_Problem_])
Problem_Extension_node = ToolExecutionNode(
tool=Problem_Extension_,
node_id="Problem_Extension_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Add edges to define workflow flow
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
Retrieve_node = ToolExecutionNode(
tool=Retrieve_,
node_id="Retrieve_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
'''-----'''
# workflow.add_edge("Retrieve", END)
Verify_node = ToolExecutionNode(
tool=Verify_,
node_id="Verify_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Summary_node = ToolExecutionNode(
tool=Summary_,
node_id="Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
# Compile workflow
graph = workflow.compile()
yield graph
Summary_fails_node = ToolExecutionNode(
tool=Summary_fails_,
node_id="Summary_fails_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
except Exception as e:
print(f"创建工作流失败: {e}")
raise
finally:
print("工作流创建完成")
Retrieve_Summary_node = ToolExecutionNode(
tool=Retrieve_Summary_,
node_id="Retrieve_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
Input_Summary_node = ToolExecutionNode(
tool=Input_Summary_,
node_id="Input_Summary_id",
namespace=namespace,
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
parameter_builder=parameter_builder,
storage_type=storage_type,
user_rag_memory_id=user_rag_memory_id,
memory_config=memory_config,
)
async def content_input_node(state):
state_search_switch = state.get("search_switch", search_switch)
tool_name = "Input_Summary" if state_search_switch == '2' else "Split_The_Problem"
session_prefix = "input_summary_call_id" if state_search_switch == '2' else "split_call_id"
return await create_input_message(
state=state,
tool_name=tool_name,
session_id=f"{session_prefix}_{namespace}",
search_switch=search_switch,
apply_id=apply_id,
group_id=group_id,
multimodal_processor=multimodal_processor,
memory_config=memory_config,
)
# Build workflow graph
workflow = StateGraph(ReadState)
workflow.add_node("content_input", content_input_node)
workflow.add_node("Split_The_Problem", Split_The_Problem_node)
workflow.add_node("Problem_Extension", Problem_Extension_node)
workflow.add_node("Retrieve", Retrieve_node)
workflow.add_node("Verify", Verify_node)
workflow.add_node("Summary", Summary_node)
workflow.add_node("Summary_fails", Summary_fails_node)
workflow.add_node("Retrieve_Summary", Retrieve_Summary_node)
workflow.add_node("Input_Summary", Input_Summary_node)
# Add edges using imported routers
workflow.add_edge(START, "content_input")
workflow.add_conditional_edges("content_input", Split_continue)
workflow.add_edge("Input_Summary", END)
workflow.add_edge("Split_The_Problem", "Problem_Extension")
workflow.add_edge("Problem_Extension", "Retrieve")
workflow.add_conditional_edges("Retrieve", Retrieve_continue)
workflow.add_edge("Retrieve_Summary", END)
workflow.add_conditional_edges("Verify", Verify_continue)
workflow.add_edge("Summary_fails", END)
workflow.add_edge("Summary", END)
graph = workflow.compile(checkpointer=memory)
yield graph

View File

@@ -0,0 +1,13 @@
"""LangGraph routing logic."""
from app.core.memory.agent.langgraph_graph.routing.routers import (
Verify_continue,
Retrieve_continue,
Split_continue,
)
__all__ = [
"Verify_continue",
"Retrieve_continue",
"Split_continue",
]

View File

@@ -1,64 +1,123 @@
"""
Routing functions for LangGraph conditional edges.
This module provides routing functions that determine the next node to execute
based on state values. All functions return Literal types for type safety.
"""
import logging
import re
from typing import Literal
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.state.extractors import extract_search_switch
from app.core.memory.agent.utils.llm_tools import ReadState, COUNTState
logger = get_agent_logger(__name__)
logger = logging.getLogger(__name__)
# Global counter for Verify routing
counter = COUNTState(limit=3)
def Split_continue(state: ReadState) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
logger.debug(f"Split_continue state: {state}")
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '2':
return 'Input_Summary'
return 'Split_The_Problem' # 默认情况
def Retrieve_continue(state) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing based on search_switch value.
Args:
state: State dictionary containing search_switch
Returns:
Next node to execute
"""
search_switch = state.get('search_switch', '')
if search_switch is not None:
search_switch = str(search_switch)
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
return 'Retrieve_Summary' # Default based on business logic
def Verify_continue(state: ReadState) -> Literal["Summary", "Summary_fails", "content_input"]:
status = state.get('verify', '')['status']
# loop_count = counter.get_total()
if "success" in status:
# counter.reset()
"""
Determine routing after Verify node based on verification result.
This function checks the verification result in the last message and routes to:
- Summary: if verification succeeded
- content_input: if verification failed and retry limit not reached
- Summary_fails: if verification failed and retry limit reached
Args:
state: LangGraph state containing messages
Returns:
Next node name as Literal type
"""
messages = state.get("messages", [])
# Boundary check
if not messages:
logger.warning("[Verify_continue] No messages in state, defaulting to Summary")
counter.reset()
return "Summary"
elif "failed" in status:
# if loop_count < 2: # Maximum loop count is 3
# return "content_input"
# else:
# counter.reset()
return "Summary_fails"
# Increment counter
counter.add(1)
loop_count = counter.get_total()
logger.debug(f"[Verify_continue] Current loop count: {loop_count}")
# Extract verification result from last message
last_message = messages[-1]
last_message_str = str(last_message).replace('\\', '')
status_tools = re.findall(r'"split_result": "(.*?)"', last_message_str)
logger.debug(f"[Verify_continue] Status tools: {status_tools}")
# Route based on verification result
if "success" in status_tools:
counter.reset()
return "Summary"
elif "failed" in status_tools:
if loop_count < 2: # Max retry count is 2
return "content_input"
else:
counter.reset()
return "Summary_fails"
else:
# Add default return value to avoid returning None
# counter.reset()
return "Summary" # Default based on business requirements
# Default to Summary if status is unclear
counter.reset()
return "Summary"
def Retrieve_continue(state: dict) -> Literal["Verify", "Retrieve_Summary"]:
"""
Determine routing after Retrieve node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '0': Route to Verify (verification needed)
- search_switch == '1': Route to Retrieve_Summary (direct summary)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
search_switch = extract_search_switch(state)
logger.debug(f"[Retrieve_continue] search_switch: {search_switch}")
if search_switch == '0':
return 'Verify'
elif search_switch == '1':
return 'Retrieve_Summary'
# Default to Retrieve_Summary
logger.debug("[Retrieve_continue] No valid search_switch, defaulting to Retrieve_Summary")
return 'Retrieve_Summary'
def Split_continue(state: dict) -> Literal["Split_The_Problem", "Input_Summary"]:
"""
Determine routing after content_input node based on search_switch value.
This function routes based on the search_switch parameter:
- search_switch == '2': Route to Input_Summary (direct input summary)
- Otherwise: Route to Split_The_Problem (problem decomposition)
Args:
state: LangGraph state dictionary
Returns:
Next node name as Literal type
"""
logger.debug(f"[Split_continue] state keys: {state.keys()}")
search_switch = extract_search_switch(state)
logger.debug(f"[Split_continue] search_switch: {search_switch}")
if search_switch == '2':
return 'Input_Summary'
# Default to Split_The_Problem
return 'Split_The_Problem'

View File

@@ -1,304 +0,0 @@
import json
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.langgraph_graph.tools.write_tool import format_parsing, messages_parse
from app.core.memory.agent.models.write_aggregate_model import WriteAggregateModel
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import count_store
from app.core.memory.agent.utils.redis_tool import write_store
from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.memory_short_repository import LongTermMemoryRepository
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_konwledges_server import write_rag
from app.services.task_service import get_task_memory_write_result
from app.tasks import write_message_task
from app.utils.config_utils import resolve_config_id
logger = get_agent_logger(__name__)
template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id):
"""
Write messages to RAG storage system
Combines user and AI messages into a single string format and stores them
in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval.
Args:
end_user_id: User identifier for the conversation
user_message: User's input message content
ai_message: AI's response message content
user_rag_memory_id: RAG memory identifier for storage location
"""
# RAG mode: combine messages into string format (maintain original logic)
combined_message = f"user: {user_message}\nassistant: {ai_message}"
await write_rag(end_user_id, combined_message, user_rag_memory_id)
logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}')
async def write(
storage_type,
end_user_id,
user_message,
ai_message,
user_rag_memory_id,
actual_end_user_id,
actual_config_id,
long_term_messages=None
):
"""
Write memory with structured message support
Handles memory writing operations for different storage types (Neo4j/RAG).
Supports both individual message pairs and batch long-term message processing.
Args:
storage_type: Storage type identifier ("neo4j" or "rag")
end_user_id: Terminal user identifier
user_message: User message content
ai_message: AI response content
user_rag_memory_id: RAG memory identifier
actual_end_user_id: Actual user identifier for storage
actual_config_id: Configuration identifier
long_term_messages: Optional list of structured messages for batch processing
Logic explanation:
- RAG mode: Combines user_message and ai_message into string format, maintains original logic
- Neo4j mode: Uses structured message lists
1. If both user_message and ai_message are not empty: Creates paired messages [user, assistant]
2. If only user_message exists: Creates single user message [user] (for historical memory scenarios)
3. Each message is converted to independent Chunk, preserving speaker field
"""
if long_term_messages is None:
long_term_messages = []
with get_db_context() as db:
actual_config_id = resolve_config_id(actual_config_id, db)
# Neo4j mode: Use structured message lists
structured_messages = []
# Always add user message (if not empty)
if isinstance(user_message, str) and user_message.strip() != "":
structured_messages.append({"role": "user", "content": user_message})
# Only add assistant message when AI reply is not empty
if isinstance(ai_message, str) and ai_message.strip() != "":
structured_messages.append({"role": "assistant", "content": ai_message})
# If long_term_messages provided, use it to replace structured_messages
if long_term_messages and isinstance(long_term_messages, list):
structured_messages = long_term_messages
elif long_term_messages and isinstance(long_term_messages, str):
# If it's a JSON string, parse it first
try:
structured_messages = json.loads(long_term_messages)
except json.JSONDecodeError:
logger.error(f"Failed to parse long_term_messages as JSON: {long_term_messages}")
# If no messages, return directly
if not structured_messages:
logger.warning(f"No messages to write for user {actual_end_user_id}")
return
logger.info(
f"[WRITE] Submitting Celery task - user={actual_end_user_id}, messages={len(structured_messages)}, config={actual_config_id}")
write_id = write_message_task.delay(
actual_end_user_id, # end_user_id: User ID
structured_messages, # message: JSON string format message list
str(actual_config_id), # config_id: Configuration ID string
storage_type, # storage_type: "neo4j"
user_rag_memory_id or "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode)
)
logger.info(f"[WRITE] Celery task submitted - task_id={write_id}")
write_status = get_task_memory_write_result(str(write_id))
logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}')
async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope):
"""
Save long-term memory data to database
Handles the storage of long-term memory data based on different strategies
(chunk-based or aggregate-based) and manages the transition from short-term
to long-term memory storage.
Args:
long_term_messages: Long-term message data to be saved
actual_config_id: Configuration identifier for memory settings
end_user_id: User identifier for memory association
type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE)
scope: Scope/window size for memory processing
"""
with get_db_context() as db_session:
repo = LongTermMemoryRepository(db_session)
from app.core.memory.agent.utils.redis_tool import write_store
result = write_store.get_session_by_userid(end_user_id)
if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE:
data = await format_parsing(result, "dict")
chunk_data = data[:scope]
if len(chunk_data) == scope:
repo.upsert(end_user_id, chunk_data)
logger.info(f'---------写入短长期-----------')
else:
long_time_data = write_store.find_user_recent_sessions(end_user_id, 5)
long_messages = await messages_parse(long_time_data)
repo.upsert(end_user_id, long_messages)
logger.info(f'写入短长期:')
"""Window-based dialogue processing"""
async def window_dialogue(end_user_id, langchain_messages, memory_config, scope):
"""
Process dialogue based on window size and write to Neo4j
Manages conversation data based on a sliding window approach. When the window
reaches the specified scope size, it triggers long-term memory storage to Neo4j.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
langchain_messages: Original message data list
scope: Window size determining when to trigger long-term storage
"""
scope = scope
is_end_user_id = count_store.get_sessions_count(end_user_id)
if is_end_user_id is not False:
is_end_user_id = count_store.get_sessions_count(end_user_id)[0]
redis_messages = count_store.get_sessions_count(end_user_id)[1]
if is_end_user_id and int(is_end_user_id) != int(scope):
is_end_user_id += 1
langchain_messages += redis_messages
count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages)
elif int(is_end_user_id) == int(scope):
logger.info('写入长期记忆NEO4J')
formatted_messages = (redis_messages)
# Get config_id (if memory_config is an object, extract config_id; otherwise use directly)
if hasattr(memory_config, 'config_id'):
config_id = memory_config.config_id
else:
config_id = memory_config
await write(
AgentMemory_Long_Term.STORAGE_NEO4J,
end_user_id,
"",
"",
None,
end_user_id,
config_id,
formatted_messages
)
count_store.update_sessions_count(end_user_id, 1, langchain_messages)
else:
count_store.save_sessions_count(end_user_id, 1, langchain_messages)
"""Time-based memory processing"""
async def memory_long_term_storage(end_user_id, memory_config, time):
"""
Process memory storage based on time intervals and write to Neo4j
Retrieves Redis data based on time intervals and writes it to Neo4j for
long-term storage. This function handles time-based memory consolidation.
Args:
end_user_id: Terminal user identifier
memory_config: Memory configuration object containing settings
time: Time interval for data retrieval
"""
long_time_data = write_store.find_user_recent_sessions(end_user_id, time)
format_messages = long_time_data
messages = []
memory_config = memory_config.config_id
for i in format_messages:
message = json.loads(i['Query'])
messages += message
if format_messages:
await write(AgentMemory_Long_Term.STORAGE_NEO4J, end_user_id, "", "", None, end_user_id,
memory_config, messages)
async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config) -> dict:
"""
Aggregation judgment function: determine if input sentence and historical messages describe the same event
Uses LLM-based analysis to determine whether new messages should be aggregated with existing
historical data or stored as separate events. This helps optimize memory storage and retrieval.
Args:
end_user_id: Terminal user identifier
ori_messages: Original message list, format like [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
memory_config: Memory configuration object containing LLM settings
Returns:
dict: Aggregation judgment result containing is_same_event flag and processed output
"""
history = None
try:
# 1. Get historical session data (using new method)
result = write_store.get_all_sessions_by_end_user_id(end_user_id)
history = await format_parsing(result)
if not result:
history = []
else:
history = await format_parsing(result)
json_schema = WriteAggregateModel.model_json_schema()
template_service = TemplateService(template_root)
system_prompt = await template_service.render_template(
template_name='write_aggregate_judgment.jinja2',
operation_name='aggregate_judgment',
history=history,
sentence=ori_messages,
json_schema=json_schema
)
with get_db_context() as db_session:
factory = MemoryClientFactory(db_session)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
messages = [
{
"role": "user",
"content": system_prompt
}
]
structured = await llm_client.response_structured(
messages=messages,
response_model=WriteAggregateModel
)
output_value = structured.output
if isinstance(output_value, list):
output_value = [
{"role": msg.role, "content": msg.content}
for msg in output_value
]
result_dict = {
"is_same_event": structured.is_same_event,
"output": output_value
}
if not structured.is_same_event:
logger.info(result_dict)
await write("neo4j", end_user_id, "", "", None, end_user_id,
memory_config.config_id, output_value)
return result_dict
except Exception as e:
print(f"[aggregate_judgment] 发生错误: {e}")
import traceback
traceback.print_exc()
return {
"is_same_event": False,
"output": ori_messages,
"messages": ori_messages,
"history": history if 'history' in locals() else [],
"error": str(e)
}

View File

@@ -0,0 +1,13 @@
"""LangGraph state management utilities."""
from app.core.memory.agent.langgraph_graph.state.extractors import (
extract_search_switch,
extract_tool_call_id,
extract_content_payload,
)
__all__ = [
"extract_search_switch",
"extract_tool_call_id",
"extract_content_payload",
]

View File

@@ -0,0 +1,179 @@
"""
State extraction utilities for type-safe access to LangGraph state values.
This module provides utility functions for extracting values from LangGraph state
dictionaries with proper error handling and sensible defaults.
"""
import json
import logging
from typing import Any, Optional
logger = logging.getLogger(__name__)
def extract_search_switch(state: dict) -> Optional[str]:
"""
Extract search_switch from state or messages.
"""
search_switch = state.get("search_switch")
if search_switch is not None:
return str(search_switch)
# Try to extract from messages
messages = state.get("messages", [])
if not messages:
return None
# 从最新的消息开始查找
for message in reversed(messages):
# 尝试从 tool_calls 中提取
if hasattr(message, "tool_calls") and message.tool_calls:
for tool_call in message.tool_calls:
if isinstance(tool_call, dict):
# 从 tool_call 的 args 中提取
if "args" in tool_call and isinstance(tool_call["args"], dict):
search_switch = tool_call["args"].get("search_switch")
if search_switch is not None:
return str(search_switch)
# 直接从 tool_call 中提取
search_switch = tool_call.get("search_switch")
if search_switch is not None:
return str(search_switch)
# 尝试从 content 中提取(如果是 JSON 格式)
if hasattr(message, "content"):
try:
import json
if isinstance(message.content, str):
content_data = json.loads(message.content)
if isinstance(content_data, dict):
search_switch = content_data.get("search_switch")
if search_switch is not None:
return str(search_switch)
except (json.JSONDecodeError, ValueError):
pass
return None
def extract_tool_call_id(message: Any) -> str:
"""
Extract tool call ID from message using structured attributes.
This function extracts the tool call ID from a message object, handling both
direct attribute access and tool_calls list structures.
Args:
message: Message object (typically ToolMessage or AIMessage)
Returns:
Tool call ID as string
Raises:
ValueError: If tool call ID cannot be extracted
Examples:
>>> message = ToolMessage(content="...", tool_call_id="call_123")
>>> extract_tool_call_id(message)
'call_123'
"""
# Try direct attribute access for ToolMessage
if hasattr(message, "tool_call_id"):
tool_call_id = message.tool_call_id
if tool_call_id:
return str(tool_call_id)
# Try extracting from tool_calls list for AIMessage
if hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "id" in tool_call:
return str(tool_call["id"])
# Try extracting from id attribute
if hasattr(message, "id"):
message_id = message.id
if message_id:
return str(message_id)
# If all else fails, raise an error
raise ValueError(f"Could not extract tool call ID from message: {type(message)}")
def extract_content_payload(message: Any) -> Any:
"""
Extract content payload from ToolMessage, parsing JSON if needed.
This function extracts the content from a message and attempts to parse it as JSON
if it appears to be a JSON string. It handles various message formats and provides
sensible fallbacks.
Args:
message: Message object (typically ToolMessage)
Returns:
Parsed content (dict, list, or str)
Examples:
>>> message = ToolMessage(content='{"key": "value"}')
>>> extract_content_payload(message)
{'key': 'value'}
>>> message = ToolMessage(content='plain text')
>>> extract_content_payload(message)
'plain text'
"""
# Extract raw content
# For ToolMessages (responses from tools), extract from content
if hasattr(message, "content"):
raw_content = message.content
logger.info(f"extract_content_payload: raw_content type={type(raw_content)}, value={str(raw_content)[:500]}")
# Handle MCP content format: [{'type': 'text', 'text': '...'}]
if isinstance(raw_content, list):
for block in raw_content:
if isinstance(block, dict) and block.get('type') == 'text':
raw_content = block.get('text', '')
logger.info(f"extract_content_payload: extracted text from MCP format: {str(raw_content)[:300]}")
break
# If content is empty and this is an AIMessage with tool_calls,
# extract from args (this handles the initial tool call from content_input)
if not raw_content and hasattr(message, "tool_calls") and message.tool_calls:
tool_call = message.tool_calls[0]
if isinstance(tool_call, dict) and "args" in tool_call:
return tool_call["args"]
else:
raw_content = str(message)
# If content is already a dict or list, return it directly
if isinstance(raw_content, (dict, list)):
logger.info(f"extract_content_payload: returning raw dict/list with keys={list(raw_content.keys()) if isinstance(raw_content, dict) else 'list'}")
return raw_content
# Try to parse as JSON
if isinstance(raw_content, str):
# First, try direct JSON parsing
try:
parsed = json.loads(raw_content)
logger.info(f"extract_content_payload: parsed JSON, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
pass
# If that fails, try to extract JSON from the string
# This handles cases where the content is embedded in a larger string
import re
json_candidates = re.findall(r'[\[{].*[\]}]', raw_content, flags=re.DOTALL)
for candidate in json_candidates:
try:
parsed = json.loads(candidate)
logger.info(f"extract_content_payload: parsed JSON from candidate, keys={list(parsed.keys()) if isinstance(parsed, dict) else 'list'}")
return parsed
except (json.JSONDecodeError, ValueError):
continue
# If all parsing attempts fail, return the raw content
logger.info(f"extract_content_payload: returning raw content (parsing failed)")
return raw_content

View File

@@ -1,405 +0,0 @@
import asyncio
import json
from datetime import datetime, timedelta
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.core.memory.src.search import (
search_by_temporal,
search_by_keyword_temporal,
)
def extract_tool_message_content(response):
"""
Extract ToolMessage content and tool names from agent response
Parses agent response messages to extract tool execution results and metadata.
Handles JSON parsing and provides structured access to tool output data.
Args:
response: Agent response dictionary containing messages
Returns:
dict: Dictionary containing tool_name and parsed content, or None if no tool message found
- tool_name: Name of the executed tool
- content: Parsed tool execution result (JSON or raw text)
"""
messages = response.get('messages', [])
for message in messages:
if hasattr(message, 'tool_call_id') and hasattr(message, 'content'):
# This is a ToolMessage
tool_content = message.content
tool_name = None
# Try to get tool name
if hasattr(message, 'name'):
tool_name = message.name
elif hasattr(message, 'tool_name'):
tool_name = message.tool_name
try:
# Parse JSON content
parsed_content = json.loads(tool_content)
return {
'tool_name': tool_name,
'content': parsed_content
}
except json.JSONDecodeError:
# If not JSON format, return content directly
return {
'tool_name': tool_name,
'content': tool_content
}
return None
class TimeRetrievalInput(BaseModel):
"""
Input schema for time retrieval tool
Defines the expected input parameters for time-based retrieval operations.
Used for validation and documentation of tool parameters.
Attributes:
context: User input query content for search
end_user_id: Group ID for filtering search results, defaults to test user
"""
context: str = Field(description="用户输入的查询内容")
end_user_id: str = Field(default="88a459f5_text09", description="组ID用于过滤搜索结果")
def create_time_retrieval_tool(end_user_id: str):
"""
Create a TimeRetrieval tool with specific end_user_id (synchronous version) for searching statements by time range
Creates a specialized time-based retrieval tool that searches for statements within
specified time ranges. Includes field cleaning functionality to remove unnecessary
metadata from search results.
Args:
end_user_id: User identifier for scoping search results
Returns:
function: Configured TimeRetrievalWithGroupId tool function
"""
def clean_temporal_result_fields(data):
"""
Clean unnecessary fields from temporal search results and modify structure
Removes metadata fields that are not needed for end-user consumption and
restructures the response format for better usability.
Args:
data: Data to be cleaned (dict, list, or other types)
Returns:
Cleaned data with unnecessary fields removed
"""
# List of fields to filter out
fields_to_remove = {
'id', 'apply_id', 'user_id', 'chunk_id', 'created_at',
'valid_at', 'invalid_at', 'statement_ids'
}
if isinstance(data, dict):
cleaned = {}
for key, value in data.items():
if key == 'statements' and isinstance(value, dict) and 'statements' in value:
# Change statements: {"statements": [...]} to time_search: {"statements": [...]}
cleaned_value = clean_temporal_result_fields(value)
# Further change internal statements to time_search
if 'statements' in cleaned_value:
cleaned['results'] = {
'time_search': cleaned_value['statements']
}
else:
cleaned['results'] = cleaned_value
elif key not in fields_to_remove:
cleaned[key] = clean_temporal_result_fields(value)
return cleaned
elif isinstance(data, list):
return [clean_temporal_result_fields(item) for item in data]
else:
return data
@tool
def TimeRetrievalWithGroupId(context: str, start_date: str = None, end_date: str = None,
end_user_id_param: str = None, clean_output: bool = True) -> str:
"""
Optimized time retrieval tool, combines time range search only (synchronous version), automatically filters unnecessary metadata fields
Performs time-based search operations with automatic metadata filtering. Supports
flexible date range specification and provides clean, user-friendly output.
Explicit parameters:
- context: Query context content
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- end_user_id_param: Group ID (optional, overrides default group ID)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results with temporal data
"""
async def _async_search():
# Use passed parameters or default values
actual_end_user_id = end_user_id_param or end_user_id
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=7)).strftime("%Y-%m-%d")
# Basic time search
results = await search_by_temporal(
end_user_id=actual_end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=10
)
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
@tool
def KeywordTimeRetrieval(context: str, days_back: int = 7, start_date: str = None, end_date: str = None,
clean_output: bool = True) -> str:
"""
Optimized keyword time retrieval tool, combines keyword and time range search (synchronous version), automatically filters unnecessary metadata fields
Performs combined keyword and temporal search operations with automatic metadata
filtering. Provides more targeted search results by combining content relevance
with time-based filtering.
Explicit parameters:
- context: Query content for keyword matching
- days_back: Number of days to search backwards, default 7 days
- start_date: Start time (optional, format: YYYY-MM-DD)
- end_date: End time (optional, format: YYYY-MM-DD)
- clean_output: Whether to clean metadata fields from output
- end_date needs to be obtained based on user description, output format uses strftime("%Y-%m-%d")
Returns:
str: JSON formatted search results combining keyword and temporal data
"""
async def _async_search():
actual_end_date = end_date or datetime.now().strftime("%Y-%m-%d")
actual_start_date = start_date or (datetime.now() - timedelta(days=days_back)).strftime("%Y-%m-%d")
# Keyword time search
results = await search_by_keyword_temporal(
query_text=context,
end_user_id=end_user_id,
start_date=actual_start_date,
end_date=actual_end_date,
limit=15
)
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_temporal_result_fields(results)
else:
cleaned_results = results
return json.dumps(cleaned_results, ensure_ascii=False, indent=2)
return asyncio.run(_async_search())
return TimeRetrievalWithGroupId
def create_hybrid_retrieval_tool_async(memory_config, **search_params):
"""
Create hybrid retrieval tool using run_hybrid_search for hybrid retrieval, optimize output format and filter unnecessary fields
Creates an advanced hybrid search tool that combines multiple search strategies
(keyword, vector, hybrid) with automatic result cleaning and formatting.
Args:
memory_config: Memory configuration object containing LLM and search settings
**search_params: Search parameters including end_user_id, limit, include, etc.
Returns:
function: Configured HybridSearch tool function with async capabilities
"""
def clean_result_fields(data):
"""
Recursively clean unnecessary fields from results
Removes metadata fields that are not needed for end-user consumption,
improving readability and reducing response size.
Args:
data: Data to be cleaned (can be dict, list, or other types)
Returns:
Cleaned data with unnecessary fields removed
"""
# List of fields to filter out
# TODO: fact_summary functionality temporarily disabled, will be enabled after future development
fields_to_remove = {
'invalid_at', 'valid_at', 'chunk_id_from_rel', 'entity_ids',
'expired_at', 'created_at', 'chunk_id', 'apply_id',
'user_id', 'statement_ids', 'updated_at', "chunk_ids", "fact_summary"
}
# 注意:'id' 字段保留community 展开时需要用 community id 查询成员 statements
if isinstance(data, dict):
# Clean dictionary
cleaned = {}
for key, value in data.items():
if key not in fields_to_remove:
cleaned[key] = clean_result_fields(value) # Recursively clean nested data
return cleaned
elif isinstance(data, list):
# Clean each element in list
return [clean_result_fields(item) for item in data]
else:
# Return other types directly
return data
@tool
async def HybridSearch(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
rerank_alpha: float = 0.6,
use_forgetting_rerank: bool = False,
use_llm_rerank: bool = False,
clean_output: bool = True # New: whether to clean output fields
) -> str:
"""
Optimized hybrid retrieval tool, supports keyword, vector and hybrid search, automatically filters unnecessary metadata fields
Provides comprehensive search capabilities combining multiple search strategies
with intelligent result ranking and automatic metadata filtering for clean output.
Args:
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
rerank_alpha: Reranking weight parameter for result scoring
use_forgetting_rerank: Whether to use forgetting-based reranking
use_llm_rerank: Whether to use LLM-based reranking
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted comprehensive search results
"""
try:
# Import run_hybrid_search function
from app.core.memory.src.search import run_hybrid_search
# Merge parameters, prioritize passed parameters
final_params = {
"query_text": context,
"search_type": search_type,
"end_user_id": end_user_id or search_params.get("end_user_id"),
"limit": limit or search_params.get("limit", 10),
"include": search_params.get("include", ["summaries", "statements", "chunks", "entities", "communities"]),
"output_path": None, # Don't save to file
"memory_config": memory_config,
"rerank_alpha": rerank_alpha,
"use_forgetting_rerank": use_forgetting_rerank,
"use_llm_rerank": use_llm_rerank
}
# Execute hybrid retrieval
raw_results = await run_hybrid_search(**final_params)
# Clean unnecessary fields from results
if clean_output:
cleaned_results = clean_result_fields(raw_results)
else:
cleaned_results = raw_results
# Format return results
formatted_results = {
"search_query": context,
"search_type": search_type,
"results": cleaned_results
}
return json.dumps(formatted_results, ensure_ascii=False, indent=2, default=str)
except Exception as e:
error_result = {
"error": f"混合检索失败: {str(e)}",
"search_query": context,
"search_type": search_type,
"timestamp": datetime.now().isoformat()
}
return json.dumps(error_result, ensure_ascii=False, indent=2)
return HybridSearch
def create_hybrid_retrieval_tool_sync(memory_config, **search_params):
"""
Create synchronous version of hybrid retrieval tool, optimize output format and filter unnecessary fields
Creates a synchronous wrapper around the async hybrid search functionality,
making it compatible with synchronous tool execution environments.
Args:
memory_config: Memory configuration object containing search settings
**search_params: Search parameters for configuration
Returns:
function: Configured HybridSearchSync tool function
"""
@tool
def HybridSearchSync(
context: str,
search_type: str = "hybrid",
limit: int = 10,
end_user_id: str = None,
clean_output: bool = True
) -> str:
"""
Optimized hybrid retrieval tool (synchronous version), automatically filters unnecessary metadata fields
Provides the same hybrid search capabilities as the async version but in a
synchronous execution context. Automatically handles async-to-sync conversion.
Args:
context: Query content for search
search_type: Search type ('keyword', 'embedding', 'hybrid')
limit: Result quantity limit
end_user_id: Group ID for filtering search results
clean_output: Whether to clean metadata fields from output
Returns:
str: JSON formatted search results
"""
async def _async_search():
# Create async tool and execute
async_tool = create_hybrid_retrieval_tool_async(memory_config, **search_params)
return await async_tool.ainvoke({
"context": context,
"search_type": search_type,
"limit": limit,
"end_user_id": end_user_id,
"clean_output": clean_output
})
return asyncio.run(_async_search())
return HybridSearchSync

View File

@@ -1,106 +0,0 @@
import json
from langchain_core.messages import HumanMessage, AIMessage
async def format_parsing(messages: list, type: str = 'string'):
"""
Format and parse message lists into different output types
Processes message lists from storage and converts them into either string format
or dictionary format based on the specified type parameter. Handles JSON parsing
and role-based message organization.
Args:
messages: List of message objects from storage containing message data
type: Return type specification ('string' for text format, 'dict' for key-value pairs)
Returns:
list: Formatted message list in the specified format
- 'string': List of formatted text messages with role prefixes
- 'dict': List of dictionaries mapping user messages to AI responses
"""
result = []
user = []
ai = []
for message in messages:
hstory_messages = message['messages']
for history_messag in hstory_messages.strip().splitlines():
history_messag = json.loads(history_messag)
for content in history_messag:
role = content['role']
content = content['content']
if type == "string":
if role == 'human' or role == "user":
content = '用户:' + content
else:
content = 'AI:' + content
result.append(content)
if type == "dict":
if role == 'human' or role == "user":
user.append(content)
else:
ai.append(content)
if type == "dict":
for key, values in zip(user, ai):
result.append({key: values})
return result
async def messages_parse(messages: list | dict):
"""
Parse messages from storage format into user-AI conversation pairs
Extracts and organizes conversation data from stored message format,
separating user and AI messages and pairing them for database storage.
Args:
messages: List or dictionary containing stored message data with Query fields
Returns:
list: List of dictionaries containing user-AI message pairs for database storage
"""
user = []
ai = []
database = []
for message in messages:
Query = message['Query']
Query = json.loads(Query)
for data in Query:
role = data['role']
if role == "human":
user.append(data['content'])
if role == "ai":
ai.append(data['content'])
for key, values in zip(user, ai):
database.append({key, values})
return database
async def agent_chat_messages(user_content, ai_content):
"""
Create structured chat message format for agent conversations
Formats user and AI content into a standardized message structure suitable
for agent processing and storage. Creates role-based message objects.
Args:
user_content: User's message content string
ai_content: AI's response content string
Returns:
list: List of structured message dictionaries with role and content fields
"""
messages = [
{
"role": "user",
"content": f"{user_content}"
},
{
"role": "assistant",
"content": f"{ai_content}"
}
]
return messages

View File

@@ -3,17 +3,17 @@ import json
import sys
import warnings
from contextlib import asynccontextmanager
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from app.db import get_db, get_db_context
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.utils.llm_tools import WriteState
from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node
from app.schemas.memory_agent_schema import AgentMemory_Long_Term
from app.services.memory_config_service import MemoryConfigService
from app.schemas.memory_config_schema import MemoryConfig
from langchain_core.messages import AIMessage
from langgraph.constants import END, START
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode
warnings.filterwarnings("ignore", category=RuntimeWarning)
logger = get_agent_logger(__name__)
if sys.platform.startswith("win"):
@@ -21,112 +21,60 @@ if sys.platform.startswith("win"):
@asynccontextmanager
async def make_write_graph():
async def make_write_graph(user_id, tools, apply_id, group_id, memory_config: MemoryConfig):
"""
Create a write graph workflow for memory operations.
Args:
user_id: User identifier
tools: MCP tools loaded from session
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
"""
logger.info("Loading MCP tools: %s", [t.name for t in tools])
logger.info(f"Using memory_config: {memory_config.config_name} (id={memory_config.config_id})")
data_write_tool = next((t for t in tools if t.name == "Data_write"), None)
if not data_write_tool:
logger.error("Data_write tool not found", exc_info=True)
raise ValueError("Data_write tool not found")
write_node = ToolNode([data_write_tool])
async def call_model(state):
messages = state["messages"]
last_message = messages[-1]
content = last_message[1] if isinstance(last_message, tuple) else last_message.content
# Call Data_write directly with memory_config
write_params = {
"content": content,
"apply_id": apply_id,
"group_id": group_id,
"user_id": user_id,
"memory_config": memory_config,
}
logger.debug(f"Passing memory_config to Data_write: {memory_config.config_id}")
write_result = await data_write_tool.ainvoke(write_params)
if isinstance(write_result, dict):
result_content = write_result.get("data", str(write_result))
else:
result_content = str(write_result)
logger.info("Write content: %s", result_content)
return {"messages": [AIMessage(content=result_content)]}
workflow = StateGraph(WriteState)
workflow.add_node("content_input", call_model)
workflow.add_node("save_neo4j", write_node)
workflow.add_edge(START, "save_neo4j")
workflow.add_edge(START, "content_input")
workflow.add_edge("content_input", "save_neo4j")
workflow.add_edge("save_neo4j", END)
graph = workflow.compile()
yield graph
async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '',
end_user_id: str = '', scope: int = 6):
"""
Handle long-term memory storage with different strategies
Supports multiple storage strategies including chunk-based, time-based,
and aggregate judgment approaches for long-term memory persistence.
Args:
long_term_type: Storage strategy type ('chunk', 'time', 'aggregate')
langchain_messages: List of messages to store
memory_config: Memory configuration identifier
end_user_id: User group identifier
scope: Scope parameter for chunk-based storage (default: 6)
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \
aggregate_judgment
from app.core.memory.agent.utils.redis_tool import write_store
write_store.save_session_write(end_user_id, langchain_messages)
# 获取数据库会话
with get_db_context() as db_session:
config_service = MemoryConfigService(db_session)
memory_config = config_service.load_memory_config(
config_id=memory_config, # 改为整数
service_name="MemoryAgentService"
)
if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK:
'''Strategy 1: Dialogue window with 6 rounds of conversation'''
await window_dialogue(end_user_id, langchain_messages, memory_config, scope)
if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME:
"""Time-based strategy"""
await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE)
if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE:
"""Strategy 3: Aggregate judgment"""
await aggregate_judgment(end_user_id, langchain_messages, memory_config)
async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id):
"""
Write long-term memory with different storage types
Handles both RAG-based storage and traditional memory storage approaches.
For traditional storage, uses chunk-based strategy with paired user-AI messages.
Args:
storage_type: Type of storage (RAG or traditional)
end_user_id: User group identifier
message_chat: User message content
aimessages: AI response messages
user_rag_memory_id: RAG memory identifier
actual_config_id: Actual configuration ID
"""
from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent
from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save
from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages
if storage_type == AgentMemory_Long_Term.STORAGE_RAG:
await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id)
else:
# AI reply writing (user messages and AI replies paired, written as complete dialogue at once)
CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK
SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE
long_term_messages = await agent_chat_messages(message_chat, aimessages)
await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages,
memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE)
await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE)
# async def main():
# """主函数 - 运行工作流"""
# langchain_messages = [
# {
# "role": "user",
# "content": "今天周五去爬山"
# },
# {
# "role": "assistant",
# "content": "好耶"
# }
#
# ]
# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID
# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4"
# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2)
#
#
#
# if __name__ == "__main__":
# import asyncio
# asyncio.run(main())

View File

@@ -0,0 +1,28 @@
"""
MCP Server package for memory agent.
This package provides the FastMCP server implementation with context-based
dependency injection for tool functions.
Package structure:
- server: FastMCP server initialization and context setup
- tools: MCP tool implementations
- models: Pydantic response models
- services: Business logic services
"""
# from app.core.memory.agent.mcp_server.server import (
# mcp,
# initialize_context,
# main,
# get_context_resource
# )
# # Import tools to register them (but don't export them)
# from app.core.memory.agent.mcp_server import tools
# __all__ = [
# 'mcp',
# 'initialize_context',
# 'main',
# 'get_context_resource',
# ]

View File

@@ -0,0 +1,11 @@
"""
MCP Server Instance
This module contains the FastMCP server instance that is shared across all modules.
It's in a separate file to avoid circular import issues.
"""
from mcp.server.fastmcp import FastMCP
# Initialize FastMCP server instance
# This instance is shared across all tool modules
mcp = FastMCP('data_flow')

View File

@@ -0,0 +1,14 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str
expansion_issue: List[Dict[str, Any]]
split_result: str
reason: Optional[str] = None
history: List[Dict[str, Any]] = Field(default_factory=list)

View File

@@ -0,0 +1,159 @@
"""
MCP Server initialization with FastMCP context setup.
This module initializes the FastMCP server and registers shared resources
in the context for dependency injection into tool functions.
"""
import os
import sys
from app.core.config import settings
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.services.search_service import SearchService
from app.core.memory.agent.mcp_server.services.session_service import SessionService
from app.core.memory.agent.mcp_server.services.template_service import TemplateService
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.redis_tool import store
logger = get_agent_logger(__name__)
def get_context_resource(ctx, resource_name: str):
"""
Helper function to retrieve a resource from the FastMCP context.
Args:
ctx: FastMCP Context object (passed to tool functions)
resource_name: Name of the resource to retrieve
Returns:
The requested resource
Raises:
AttributeError: If the resource doesn't exist
Example:
@mcp.tool()
async def my_tool(ctx: Context):
template_service = get_context_resource(ctx, 'template_service')
llm_client = get_context_resource(ctx, 'llm_client')
"""
if not hasattr(ctx, 'fastmcp') or ctx.fastmcp is None:
raise RuntimeError("Context does not have fastmcp attribute")
if not hasattr(ctx.fastmcp, resource_name):
raise AttributeError(
f"Resource '{resource_name}' not found in context. "
f"Available resources: {[k for k in dir(ctx.fastmcp) if not k.startswith('_')]}"
)
return getattr(ctx.fastmcp, resource_name)
def initialize_context():
"""
Initialize and register shared resources in FastMCP context.
This function sets up all shared resources that will be available
to tool functions via dependency injection through the context parameter.
Resources are stored as attributes on the FastMCP instance and can be
accessed via ctx.fastmcp in tool functions.
Resources registered:
- session_store: RedisSessionStore for session management
- llm_client: LLM client for structured API calls
- app_settings: Application settings (renamed to avoid conflict with FastMCP settings)
- template_service: Service for template rendering
- search_service: Service for hybrid search
- session_service: Service for session operations
"""
try:
# Register Redis session store
logger.info("Registering session_store in context")
mcp.session_store = store
# Note: LLM client is NOT loaded at server startup
# It should be loaded dynamically when needed, with config_id passed explicitly
# to make_write_graph or make_read_graph functions
logger.info("LLM client will be loaded dynamically with config_id when needed")
mcp.llm_client = None # Placeholder - actual client loaded per-request with config_id
# Register application settings (renamed to avoid conflict with FastMCP's settings)
logger.info("Registering app_settings in context")
mcp.app_settings = settings
# Register template service
template_root = PROJECT_ROOT_ + '/agent/utils/prompt'
# logger.info(f"Registering template_service in context with root: {template_root}")
template_service = TemplateService(template_root)
mcp.template_service = template_service
# Register search service
# logger.info("Registering search_service in context")
search_service = SearchService()
mcp.search_service = search_service
# Register session service
# logger.info("Registering session_service in context")
session_service = SessionService(store)
mcp.session_service = session_service
# logger.info("All context resources registered successfully")
except Exception as e:
logger.error(f"Failed to initialize context: {e}", exc_info=True)
raise
def main():
"""
Main entry point for the MCP server.
Initializes context and starts the server with SSE transport.
"""
try:
logger.info("Starting MCP server initialization")
# Initialize context resources
initialize_context()
# Import and register tools (imports trigger tool registration)
from app.core.memory.agent.mcp_server.tools import ( # noqa: F401
data_tools,
problem_tools,
retrieval_tools,
summary_tools,
verification_tools,
)
# Tools are registered via imports above
# Get MCP port from environment (default: 8081)
mcp_port = int(os.getenv("MCP_PORT", "8081"))
logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Configure DNS rebinding protection for Docker container compatibility
from mcp.server.fastmcp.server import TransportSecuritySettings
# Disable DNS rebinding protection to allow Docker container hostnames
# This allows containers to connect using service names like 'mcp-server'
mcp.settings.transport_security = TransportSecuritySettings(
enable_dns_rebinding_protection=False,
)
logger.info("DNS rebinding protection: disabled for Docker container compatibility")
# logger.info(f"Starting MCP server on {settings.SERVER_IP}:{mcp_port} with SSE transport")
# Run the server with SSE transport for HTTP connections
import uvicorn
app = mcp.sse_app()
uvicorn.run(app, host=settings.SERVER_IP, port=mcp_port, log_level="info")
except Exception as e:
logger.error(f"Failed to start MCP server: {e}", exc_info=True)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -4,19 +4,22 @@ Parameter Builder for constructing tool call arguments.
This service provides tool-specific parameter transformation logic
to build correct arguments for each tool type.
"""
from typing import Any, Dict, Optional
from app.core.logging_config import get_agent_logger
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class ParameterBuilder:
"""Service for building tool call arguments based on tool type."""
def __init__(self):
"""Initialize the parameter builder."""
logger.info("ParameterBuilder initialized")
def build_tool_args(
self,
tool_name: str,
@@ -24,9 +27,10 @@ class ParameterBuilder:
tool_call_id: str,
search_switch: str,
apply_id: str,
end_user_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: Optional[str] = None,
user_rag_memory_id: Optional[str] = None
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""
Build tool arguments based on tool type.
@@ -44,7 +48,8 @@ class ParameterBuilder:
tool_call_id: Extracted tool call identifier
search_switch: Search routing parameter
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory ID for knowledge base retrieval (optional)
@@ -55,18 +60,19 @@ class ParameterBuilder:
base_args = {
"usermessages": tool_call_id,
"apply_id": apply_id,
"end_user_id": end_user_id
"group_id": group_id,
"memory_config": memory_config,
}
# Always add storage_type and user_rag_memory_id (with defaults if None)
base_args["storage_type"] = storage_type if storage_type is not None else ""
base_args["user_rag_memory_id"] = user_rag_memory_id if user_rag_memory_id is not None else ""
# Tool-specific argument construction
if tool_name in ["Verify","Summary", "Summary_fails",'Retrieve_Summary']:
# Verify expects dict context
if tool_name in ["Verify", "Summary", "Summary_fails", "Retrieve_Summary", "Problem_Extension"]:
# These tools expect dict context
return {
"context": content if isinstance(content, dict) else {},
"context": content if isinstance(content, dict) else {"content": content},
**base_args
}

View File

@@ -0,0 +1,216 @@
"""
Search Service for executing hybrid search and processing results.
This service provides clean search result processing with content extraction
and deduplication.
"""
from typing import TYPE_CHECKING, List, Optional, Tuple
from app.core.logging_config import get_agent_logger
from app.core.memory.src.search import run_hybrid_search
from app.core.memory.utils.data.text_utils import escape_lucene_query
if TYPE_CHECKING:
from app.schemas.memory_config_schema import MemoryConfig
logger = get_agent_logger(__name__)
class SearchService:
"""Service for executing hybrid search and processing results."""
def __init__(self, memory_config: "MemoryConfig" = None):
"""
Initialize the search service.
Args:
memory_config: Optional MemoryConfig for embedding model configuration.
If not provided, must be passed to execute_hybrid_search.
"""
self.memory_config = memory_config
logger.info("SearchService initialized")
def extract_content_from_result(self, result: dict) -> str:
"""
Extract only meaningful content from search results, dropping all metadata.
Extraction rules by node type:
- Statements: extract 'statement' field
- Entities: extract 'name' and 'fact_summary' fields
- Summaries: extract 'content' field
- Chunks: extract 'content' field
Args:
result: Search result dictionary
Returns:
Clean content string without metadata
"""
if not isinstance(result, dict):
return str(result)
content_parts = []
# Statements: extract statement field
if 'statement' in result and result['statement']:
content_parts.append(result['statement'])
# Summaries/Chunks: extract content field
if 'content' in result and result['content']:
content_parts.append(result['content'])
# Entities: extract name and fact_summary (commented out in original)
# if 'name' in result and result['name']:
# content_parts.append(result['name'])
# if result.get('fact_summary'):
# content_parts.append(result['fact_summary'])
# Return concatenated content or empty string
return '\n'.join(content_parts) if content_parts else ""
def clean_query(self, query: str) -> str:
"""
Clean and escape query text for Lucene.
- Removes wrapping quotes
- Removes newlines and carriage returns
- Applies Lucene escaping
Args:
query: Raw query string
Returns:
Cleaned and escaped query string
"""
q = str(query).strip()
# Remove wrapping quotes
if (q.startswith("'") and q.endswith("'")) or (
q.startswith('"') and q.endswith('"')
):
q = q[1:-1]
# Remove newlines and carriage returns
q = q.replace('\r', ' ').replace('\n', ' ').strip()
# Apply Lucene escaping
q = escape_lucene_query(q)
return q
async def execute_hybrid_search(
self,
group_id: str,
question: str,
limit: int = 15,
search_type: str = "hybrid",
include: Optional[List[str]] = None,
rerank_alpha: float = 0.6,
activation_boost_factor: float = 0.8,
output_path: str = "search_results.json",
return_raw_results: bool = False,
memory_config: "MemoryConfig" = None,
) -> Tuple[str, str, Optional[dict]]:
"""
Execute hybrid search with two-stage ranking.
Stage 1: Filter by content relevance (BM25 + Embedding)
Stage 2: Rerank by activation values (ACTR)
Args:
group_id: Group identifier for filtering
question: Search query text
limit: Max results per category (default: 15)
search_type: "hybrid", "keyword", or "embedding" (default: "hybrid")
include: Result types (default: ["statements", "chunks", "entities", "summaries"])
rerank_alpha: BM25 weight (default: 0.6)
activation_boost_factor: Activation impact on memory strength (default: 0.8)
output_path: JSON output path (default: "search_results.json")
return_raw_results: Return full metadata (default: False)
memory_config: MemoryConfig for embedding model
Returns:
Tuple[str, str, Optional[dict]]: (clean_content, cleaned_query, raw_results)
"""
if include is None:
include = ["statements", "chunks", "entities", "summaries"]
# Use provided memory_config or fall back to instance config
config = memory_config or self.memory_config
if not config:
raise ValueError("memory_config is required for search - either pass it to __init__ or execute_hybrid_search")
# Clean query
cleaned_query = self.clean_query(question)
try:
# Execute search using memory_config
answer = await run_hybrid_search(
query_text=cleaned_query,
search_type=search_type,
group_id=group_id,
limit=limit,
include=include,
output_path=output_path,
memory_config=config,
rerank_alpha=rerank_alpha,
activation_boost_factor=activation_boost_factor,
)
# Extract results based on search type and include parameter
# Prioritize summaries as they contain synthesized contextual information
answer_list = []
# For hybrid search, use reranked_results
if search_type == "hybrid":
reranked_results = answer.get('reranked_results', {})
# Priority order: summaries first (most contextual), then statements, chunks, entities
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in reranked_results:
category_results = reranked_results[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
else:
# For keyword or embedding search, results are directly in answer dict
# Apply same priority order
priority_order = ['summaries', 'statements', 'chunks', 'entities']
for category in priority_order:
if category in include and category in answer:
category_results = answer[category]
if isinstance(category_results, list):
answer_list.extend(category_results)
# Extract clean content from all results
content_list = [
self.extract_content_from_result(ans)
for ans in answer_list
]
# Filter out empty strings and join with newlines
clean_content = '\n'.join([c for c in content_list if c])
# Log first 200 chars
logger.info(f"检索接口搜索结果==>>:{clean_content[:200]}...")
# Return raw results if requested
if return_raw_results:
return clean_content, cleaned_query, answer
else:
return clean_content, cleaned_query, None
except Exception as e:
logger.error(
f"Search failed for query '{question}' in group '{group_id}': {e}",
exc_info=True
)
# Return empty results on failure
if return_raw_results:
return "", cleaned_query, {}
else:
return "", cleaned_query, None

View File

@@ -59,7 +59,7 @@ class SessionService:
self,
user_id: str,
apply_id: str,
end_user_id: str
group_id: str
) -> List[dict]:
"""
Retrieve conversation history from Redis.
@@ -67,20 +67,20 @@ class SessionService:
Args:
user_id: User identifier
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
Returns:
List of conversation history items with Query and Answer keys
Returns empty list if no history found or on error
"""
try:
history = self.store.find_user_apply_group(user_id, apply_id, end_user_id)
history = self.store.find_user_apply_group(user_id, apply_id, group_id)
# Validate history structure
if not isinstance(history, list):
logger.warning(
f"Invalid history format for user {user_id}, "
f"apply {apply_id}, group {end_user_id}: expected list, got {type(history)}"
f"apply {apply_id}, group {group_id}: expected list, got {type(history)}"
)
return []
@@ -89,7 +89,7 @@ class SessionService:
except Exception as e:
logger.error(
f"Failed to retrieve history for user {user_id}, "
f"apply {apply_id}, group {end_user_id}: {e}",
f"apply {apply_id}, group {group_id}: {e}",
exc_info=True
)
# Return empty list on error to allow execution to continue
@@ -100,7 +100,7 @@ class SessionService:
user_id: str,
query: str,
apply_id: str,
end_user_id: str,
group_id: str,
ai_response: str
) -> Optional[str]:
"""
@@ -110,7 +110,7 @@ class SessionService:
user_id: User identifier
query: User query/message
apply_id: Application identifier
end_user_id: Group identifier
group_id: Group identifier
ai_response: AI response/answer
Returns:
@@ -131,7 +131,7 @@ class SessionService:
userid=user_id,
messages=query,
apply_id=apply_id,
end_user_id=end_user_id,
group_id=group_id,
aimessages=ai_response
)
@@ -152,7 +152,7 @@ class SessionService:
Duplicates are identified by matching:
- sessionid
- user_id (id field)
- end_user_id
- group_id
- messages
- aimessages

View File

@@ -3,22 +3,12 @@ Template Service for loading and rendering Jinja2 templates.
This service provides centralized template management with caching and error handling.
"""
import os
from functools import lru_cache
from typing import Optional
from jinja2 import Environment, FileSystemLoader, Template, TemplateNotFound
from jinja2 import (
Environment,
FileSystemLoader,
Template,
TemplateNotFound,
)
from app.core.logging_config import (
get_agent_logger,
log_prompt_rendering,
)
from app.core.logging_config import get_agent_logger, log_prompt_rendering
logger = get_agent_logger(__name__)

View File

@@ -0,0 +1,27 @@
"""
MCP Tools module.
This module contains all MCP tool implementations organized by functionality.
Tools are organized into the following modules:
- problem_tools: Question segmentation and extension
- retrieval_tools: Database and context retrieval
- verification_tools: Data verification
- summary_tools: Summarization and summary retrieval
- data_tools: Data type differentiation and writing
"""
# Import all tool modules to register them with the MCP server
from . import problem_tools
from . import retrieval_tools
from . import verification_tools
from . import summary_tools
from . import data_tools
__all__ = [
'problem_tools',
'retrieval_tools',
'verification_tools',
'summary_tools',
'data_tools',
]

View File

@@ -0,0 +1,155 @@
"""
Data Tools for data type differentiation and writing.
This module contains MCP tools for distinguishing data types and writing data.
"""
import os
from app.core.logging_config import get_agent_logger
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.retrieval_models import (
DistinguishTypeResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.write_tools import write
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Data_type_differentiation(
ctx: Context,
context: str,
memory_config: MemoryConfig,
) -> dict:
"""
Distinguish the type of data (read or write).
Args:
ctx: FastMCP context for dependency injection
context: Text to analyze for type differentiation
memory_config: MemoryConfig object containing LLM configuration
Returns:
dict: Contains 'context' with the original text and 'type' field
"""
try:
# Extract services from context
template_service = get_context_resource(ctx, 'template_service')
# Get LLM client from memory_config using factory pattern
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Render template
try:
system_prompt = await template_service.render_template(
template_name='distinguish_types_prompt.jinja2',
operation_name='status_typle',
user_query=context
)
except Exception as e:
logger.error(
f"Template rendering failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"type": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=DistinguishTypeResponse
)
result = structured.model_dump()
# Add context to result
result["context"] = context
return result
except Exception as e:
logger.error(
f"LLM call failed for Data_type_differentiation: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": f"LLM call failed: {str(e)}"
}
except Exception as e:
logger.error(
f"Data_type_differentiation failed: {e}",
exc_info=True
)
return {
"context": context,
"type": "error",
"message": str(e)
}
@mcp.tool()
async def Data_write(
ctx: Context,
content: str,
user_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Write data to the database/file system.
Args:
ctx: FastMCP context for dependency injection
content: Data content to write
user_id: User identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'status', 'saved_to', and 'data' fields
"""
try:
# Ensure output directory exists
os.makedirs("data_output", exist_ok=True)
file_path = os.path.join("data_output", "user_data.csv")
# Write data - clients are constructed inside write() from memory_config
await write(
content=content,
user_id=user_id,
apply_id=apply_id,
group_id=group_id,
memory_config=memory_config,
)
logger.info(f"Write completed successfully! Config: {memory_config.config_name}")
return {
"status": "success",
"saved_to": file_path,
"data": content,
"config_id": memory_config.config_id,
"config_name": memory_config.config_name,
}
except Exception as e:
logger.error(f"Data_write failed: {e}", exc_info=True)
return {
"status": "error",
"message": str(e),
}

View File

@@ -0,0 +1,304 @@
"""
Problem Tools for question segmentation and extension.
This module contains MCP tools for breaking down and extending user questions.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.problem_models import (
ProblemBreakdownResponse,
ProblemExtensionResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import Problem_Extension_messages_deal
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Split_The_Problem(
ctx: Context,
sentence: str,
sessionid: str,
messages_id: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
) -> dict:
"""
Segment the dialogue or sentence into sub-problems.
Args:
ctx: FastMCP context for dependency injection
sentence: Original sentence to split
sessionid: Session identifier
messages_id: Message identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
Returns:
dict: Contains 'context' (JSON string of split results) and 'original' sentence
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Extract user ID from session
user_id = session_service.resolve_user_id(sessionid)
# Get conversation history
history = await session_service.get_history(user_id, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Render template
try:
system_prompt = await template_service.render_template(
template_name='problem_breakdown_prompt.jinja2',
operation_name='split_the_problem',
history=history,
sentence=sentence
)
except Exception as e:
logger.error(
f"Template rendering failed for Split_The_Problem: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemBreakdownResponse
)
# Handle RootModel response with .root attribute access
if structured is None:
# LLM returned None, use empty list as fallback
split_result = json.dumps([], ensure_ascii=False)
elif hasattr(structured, 'root') and structured.root is not None:
split_result = json.dumps(
[item.model_dump() for item in structured.root],
ensure_ascii=False
)
elif isinstance(structured, list):
# Fallback: treat structured itself as the list
split_result = json.dumps(
[item.model_dump() for item in structured],
ensure_ascii=False
)
else:
# Last resort: use empty list
split_result = json.dumps([], ensure_ascii=False)
except Exception as e:
logger.error(
f"LLM call failed for Split_The_Problem: {e}",
exc_info=True
)
split_result = json.dumps([], ensure_ascii=False)
logger.info("Problem splitting")
logger.info(f"Problem split result: {split_result}")
# Emit intermediate output for frontend
result = {
"context": split_result,
"original": sentence,
"_intermediate": {
"type": "problem_split",
"data": json.loads(split_result) if split_result else [],
"original_query": sentence
}
}
return result
except Exception as e:
logger.error(
f"Split_The_Problem failed: {e}",
exc_info=True
)
return {
"context": json.dumps([], ensure_ascii=False),
"original": sentence,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem splitting', duration)
@mcp.tool()
async def Problem_Extension(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Extend the problem with additional sub-questions.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing split problem results
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'context' (aggregated questions) and 'original' question
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID from usermessages
from app.core.memory.agent.utils.messages_tool import Resolve_username
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
history = []
# Process context to extract questions
extent_quest, original = await Problem_Extension_messages_deal(context)
# Format questions for template rendering
questions_formatted = []
for msg in extent_quest:
if msg.get("role") == "user":
questions_formatted.append(msg.get("content", ""))
# Render template
try:
system_prompt = await template_service.render_template(
template_name='Problem_Extension_prompt.jinja2',
operation_name='problem_extension',
history=history,
questions=questions_formatted
)
except Exception as e:
logger.error(
f"Template rendering failed for Problem_Extension: {e}",
exc_info=True
)
return {
"context": {},
"original": original,
"error": f"Prompt rendering failed: {str(e)}"
}
# Call LLM with structured response
try:
response_content = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=ProblemExtensionResponse
)
# Aggregate results by original question
aggregated_dict = {}
for item in response_content.root:
key = getattr(item, "original_question", None) or (
item.get("original_question") if isinstance(item, dict) else None
)
value = getattr(item, "extended_question", None) or (
item.get("extended_question") if isinstance(item, dict) else None
)
if not key or not value:
continue
aggregated_dict.setdefault(key, []).append(value)
except Exception as e:
logger.error(
f"LLM call failed for Problem_Extension: {e}",
exc_info=True
)
aggregated_dict = {}
logger.info("Problem extension")
logger.info(f"Problem extension result: {aggregated_dict}")
# Emit intermediate output for frontend
result = {
"context": aggregated_dict,
"original": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "problem_extension",
"data": aggregated_dict,
"original_query": original,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
return result
except Exception as e:
logger.error(
f"Problem_Extension failed: {e}",
exc_info=True
)
return {
"context": {},
"original": context.get("original", ""),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Problem extension', duration)

View File

@@ -0,0 +1,294 @@
"""
Retrieval Tools for database and context retrieval.
This module contains MCP tools for retrieving data using hybrid search.
"""
import os
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import (
deduplicate_entries,
merge_to_key_value_pairs,
)
from app.core.memory.agent.utils.messages_tool import Retriev_messages_deal
from app.core.rag.nlp.search import knowledge_retrieval
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Retrieve(
ctx: Context,
context,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Retrieve data from the database using hybrid search.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary or string containing query information
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'context' with Query and Expansion_issue results
"""
kb_config = {
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id": os.getenv('reranker_id'),
"reranker_top_k": 10
}
start = time.time()
logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
logger.info(f"Retrieve: context type={type(context)}, context={str(context)[:500]}")
try:
# Extract services from context
search_service = get_context_resource(ctx, 'search_service')
databases_anser = []
# Handle both dict and string context
if isinstance(context, dict):
# Process dict context with extended questions
all_items = []
logger.info(f"Retrieve: context keys={list(context.keys())}")
content, original = await Retriev_messages_deal(context)
logger.info(f"Retrieve: after Retriev_messages_deal - content_type={type(content)}, content={str(content)[:300]}")
logger.info(f"Retrieve: original='{original[:100] if original else 'EMPTY'}'")
if not original:
logger.warning(f"Retrieve: original query is empty! context={context}")
# Extract all query items from content
# content is like {original_question: [extended_questions...], ...}
for key, values in content.items():
if isinstance(values, list):
all_items.extend(values)
elif isinstance(values, str):
all_items.append(values)
elif values is not None:
# Fallback: convert non-empty non-list values to string
all_items.append(str(values))
# Execute search for each question
for idx, question in enumerate(all_items):
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": question,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(question, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query=question
raw_results=clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results=''
cleaned_query = question
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
databases_anser.append({
"Query_small": cleaned_query,
"Result_small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": idx + 1,
"total": len(all_items)
}
})
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for question '{question}': {e}",
exc_info=True
)
# Continue with empty result for this question
databases_anser.append({
"Query_small": question,
"Result_small": ""
})
# Build initial database data structure
databases_data = {
"Query": original,
"Expansion_issue": databases_anser
}
# Collect intermediate outputs before deduplication
intermediate_outputs = []
for item in databases_anser:
if '_intermediate' in item:
intermediate_outputs.append(item['_intermediate'])
# Deduplicate and merge results
deduplicated_data = deduplicate_entries(databases_data['Expansion_issue'])
deduplicated_data_merged = merge_to_key_value_pairs(
deduplicated_data,
'Query_small',
'Result_small'
)
# Restructure for Verify/Retrieve_Summary compatibility
keys, val = [], []
for item in deduplicated_data_merged:
for items_key, items_value in item.items():
keys.append(items_key)
val.append(items_value)
send_verify = []
for i, j in zip(keys, val, strict=False):
send_verify.append({
"Query_small": i,
"Answer_Small": j
})
dup_databases = {
"Query": original,
"Expansion_issue": send_verify,
"_intermediate_outputs": intermediate_outputs # Preserve intermediate outputs
}
logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results")
else:
# Handle string context (simple query)
query = str(context).strip()
try:
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
if storage_type == "rag" and user_rag_memory_id:
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
clean_content = '\n\n'.join(retrieval_knowledge)
cleaned_query = query
raw_results = clean_content
logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}")
except:
clean_content = ''
raw_results = ''
cleaned_query = query
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
clean_content, cleaned_query, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
# Keep structure for Verify/Retrieve_Summary compatibility
dup_databases = {
"Query": cleaned_query,
"Expansion_issue": [{
"Query_small": cleaned_query,
"Answer_Small": clean_content,
"_intermediate": {
"type": "search_result",
"query": cleaned_query,
"raw_results": raw_results,
"index": 1,
"total": 1
}
}]
}
except Exception as e:
logger.error(
f"Retrieve: hybrid_search failed for query '{query}': {e}",
exc_info=True
)
# Return empty results on failure
dup_databases = {
"Query": query,
"Expansion_issue": []
}
logger.info(
f"Retrieval: {storage_type}--{user_rag_memory_id}--Query={dup_databases.get('Query', '')}, "
f"Expansion_issue count={len(dup_databases.get('Expansion_issue', []))}"
)
# Build result with intermediate outputs
result = {
"context": dup_databases,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
# Add intermediate outputs list if they exist
intermediate_outputs = dup_databases.get('_intermediate_outputs', [])
if intermediate_outputs:
result['_intermediates'] = intermediate_outputs
logger.info(f"Adding {len(intermediate_outputs)} intermediate outputs to result")
else:
logger.warning("No intermediate outputs found in dup_databases")
return result
except Exception as e:
logger.error(
f"Retrieve failed: {e}",
exc_info=True
)
return {
"context": {
"Query": "",
"Expansion_issue": []
},
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)

View File

@@ -0,0 +1,640 @@
"""
Summary Tools for data summarization.
This module contains MCP tools for summarizing retrieved data and generating responses.
LLM clients are constructed from MemoryConfig when needed.
"""
import json
import os
import re
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.models.summary_models import (
RetrieveSummaryResponse,
SummaryResponse,
)
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Summary_messages_deal,
)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.core.rag.nlp.search import knowledge_retrieval
from app.db import get_db_context
from app.schemas.memory_config_schema import MemoryConfig
from dotenv import load_dotenv
from mcp.server.fastmcp import Context
load_dotenv()
logger = get_agent_logger(__name__)
@mcp.tool()
async def Summary(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize the verified data.
Args:
ctx: FastMCP context for dependency injection
context: JSON string containing verified data
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Process context to extract answer and query
answer_small, query = await Summary_messages_deal(context)
start_time= time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
data = {
"query": query,
"history": history,
"retrieve_info": answer_small
}
except Exception as e:
logger.error(
f"Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='summary_prompt.jinja2',
operation_name='summary',
data=data,
query=query
)
except Exception as e:
logger.error(
f"Template rendering failed for Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=SummaryResponse
)
aimessages = structured.query_answer or ""
except Exception as e:
logger.error(
f"LLM call failed for Summary: {e}",
exc_info=True
)
aimessages = ""
try:
# Save session
if aimessages != "":
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"sessionid: {sessionid} 写入失败,错误信息:{str(e)}",
exc_info=True
)
return {
"status": "error",
"message": str(e)
}
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after verification: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Summary', duration)
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
@mcp.tool()
async def Retrieve_Summary(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Summarize data directly from retrieval results.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing Query and Expansion_issue from Retrieve
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'summary_result'
"""
start = time.time()
try:
# Extract services from context
template_service = get_context_resource(ctx, "template_service")
session_service = get_context_resource(ctx, "session_service")
# Get LLM client from memory_config
with get_db_context() as db:
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client_from_config(memory_config)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Handle both 'content' and 'context' keys (LangGraph uses 'content')
logger.debug(f"Retrieve_Summary: raw context type={type(context)}, keys={list(context.keys()) if isinstance(context, dict) else 'N/A'}")
if isinstance(context, dict):
if "content" in context:
inner = context["content"]
# If it's a JSON string, parse it
if isinstance(inner, str):
try:
parsed = json.loads(inner)
logger.info("Retrieve_Summary: successfully parsed JSON")
except json.JSONDecodeError:
# Try unescaping first
try:
unescaped = inner.encode('utf-8').decode('unicode_escape')
parsed = json.loads(unescaped)
logger.info("Retrieve_Summary: parsed after unescaping")
except (json.JSONDecodeError, UnicodeDecodeError) as e:
logger.error(
f"Retrieve_Summary: parsing failed even after unescape: {e}"
)
context_dict = {"Query": "", "Expansion_issue": []}
parsed = None
if parsed:
# Check if parsed has 'context' wrapper
if isinstance(parsed, dict) and "context" in parsed:
context_dict = parsed["context"]
else:
context_dict = parsed
elif isinstance(inner, dict):
context_dict = inner
else:
context_dict = {"Query": "", "Expansion_issue": []}
elif "context" in context:
context_dict = context["context"] if isinstance(context["context"], dict) else context
else:
context_dict = context
else:
context_dict = {"Query": "", "Expansion_issue": []}
query = context_dict.get("Query", "")
expansion_issue = context_dict.get("Expansion_issue", [])
logger.debug(f"Retrieve_Summary: query='{query}', expansion_issue count={len(expansion_issue)}")
logger.debug(f"Retrieve_Summary: expansion_issue={expansion_issue[:2] if expansion_issue else 'empty'}")
# Extract retrieve_info from expansion_issue
retrieve_info = []
for item in expansion_issue:
# Check for both Answer_Small and Answer_Small (typo) for backward compatibility
answer = None
if isinstance(item, dict):
if "Answer_Small" in item:
answer = item["Answer_Small"]
if answer is not None:
# Handle both string and list formats
if isinstance(answer, list):
# Join list of characters/strings into a single string
retrieve_info.append(''.join(str(x) for x in answer))
elif isinstance(answer, str):
retrieve_info.append(answer)
else:
retrieve_info.append(str(answer))
# Join all retrieve_info into a single string
retrieve_info_str = '\n\n'.join(retrieve_info) if retrieve_info else ""
start_time=time.time()
history = await session_service.get_history(sessionid, apply_id, group_id)
# Override with empty list for now (as in original)
end_time=time.time()
logger.info(f"Retrieve_Summary-REDIS搜索{end_time - start_time}")
except Exception as e:
logger.error(
f"Retrieve_Summary: initialization failed: {e}",
exc_info=True
)
return {
"status": "error",
"summary_result": "信息不足,无法回答"
}
try:
# Render template
system_prompt = await template_service.render_template(
template_name='Retrieve_Summary_prompt.jinja2',
operation_name='retrieve_summary',
query=query,
history=history,
retrieve_info=retrieve_info_str
)
except Exception as e:
logger.error(
f"Template rendering failed for Retrieve_Summary: {e}",
exc_info=True
)
return {
"status": "error",
"message": f"Prompt rendering failed: {str(e)}"
}
try:
# Call LLM with structured response
structured = await llm_client.response_structured(
messages=[{"role": "system", "content": system_prompt}],
response_model=RetrieveSummaryResponse
)
# Handle case where structured response might be None or incomplete
if structured and hasattr(structured, 'data') and structured.data:
aimessages = structured.data.query_answer or ""
else:
logger.warning("Structured response is None or incomplete, using default message")
aimessages = "信息不足,无法回答"
# Check for insufficient information response
if '信息不足,无法回答' not in str(aimessages) or str(aimessages)!="":
# Save session
await session_service.save_session(
user_id=sessionid,
query=query,
apply_id=apply_id,
group_id=group_id,
ai_response=aimessages
)
logger.info(f"sessionid: {aimessages} 写入成功")
except Exception as e:
logger.error(
f"Retrieve_Summary: LLM call failed: {e}",
exc_info=True
)
aimessages = ""
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
# Use fallback if empty
if aimessages == '':
aimessages = '信息不足,无法回答'
logger.info(f"Summary after retrieval: {aimessages}")
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval summary', duration)
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "retrieval_summary",
"summary": aimessages,
"query": query,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
@mcp.tool()
async def Input_Summary(
ctx: Context,
context: str,
usermessages: str,
search_switch: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = "",
) -> dict:
"""
Generate a quick summary for direct input without verification.
Args:
ctx: FastMCP context for dependency injection
context: String containing the input sentence
usermessages: User messages identifier
search_switch: Search switch value for routing ('2' for summaries only)
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (e.g., 'rag', 'vector')
user_rag_memory_id: User RAG memory identifier
Returns:
dict: Contains 'query_answer' with the summary result
"""
start = time.time()
logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}")
try:
# Extract services from context
session_service = get_context_resource(ctx, "session_service")
search_service = get_context_resource(ctx, "search_service")
# Resolve session ID
sessionid = Resolve_username(usermessages) or ""
sessionid = sessionid.replace('call_id_', '')
start_time=time.time()
history = await session_service.get_history(
str(sessionid),
str(apply_id),
str(group_id)
)
end_time=time.time()
logger.info(f"Input_Summary-REDIS搜索{end_time - start_time}")
# Override with empty list for now (as in original)
# Log the raw context for debugging
logger.info(f"Input_Summary: Received context type={type(context)}, value={context[:200] if isinstance(context, str) else context}")
# Extract sentence from context
# Context can be a string or might contain the sentence in various formats
try:
# Try to parse as JSON first
if isinstance(context, str) and (context.startswith('{') or context.startswith('[')):
try:
import json
context_dict = json.loads(context)
if isinstance(context_dict, dict):
query = context_dict.get('sentence', context_dict.get('content', context))
else:
query = context
except json.JSONDecodeError:
# Not valid JSON, try regex
match = re.search(r"'sentence':\s*['\"]?(.*?)['\"]?\s*,", context)
query = match.group(1) if match else context
else:
query = context
except Exception as e:
logger.warning(f"Failed to extract query from context: {e}")
query = context
# Clean query
query = str(query).strip().strip("\"'")
logger.debug(f"Input_Summary: Extracted query='{query}' from context type={type(context)}")
# Execute search based on search_switch and storage_type
try:
logger.info(f"search_switch: {search_switch}, storage_type: {storage_type}")
# Prepare search parameters based on storage type
search_params = {
"group_id": group_id,
"question": query,
"return_raw_results": True
}
# Add storage-specific parameters
# Retrieval
if search_switch == '2':
search_params["include"] = ["summaries"]
if storage_type == "rag" and user_rag_memory_id:
raw_results = []
retrieve_info = ""
kb_config={
"knowledge_bases": [
{
"kb_id": user_rag_memory_id,
"similarity_threshold": 0.7,
"vector_similarity_weight": 0.5,
"top_k": 10,
"retrieve_type": "participle"
}
],
"merge_strategy": "weight",
"reranker_id":os.getenv('reranker_id'),
"reranker_top_k": 10
}
retrieve_chunks_result = knowledge_retrieval(query, kb_config,[str(group_id)])
try:
retrieval_knowledge = [i.page_content for i in retrieve_chunks_result]
retrieve_info = '\n\n'.join(retrieval_knowledge)
raw_results=[retrieve_info]
logger.info(f"Input_Summary: Using RAG storage with memory_id={user_rag_memory_id}")
except:
retrieve_info=''
raw_results=['']
logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
logger.info("Input_Summary: Using summary for retrieval")
else:
retrieve_info, question, raw_results = await search_service.execute_hybrid_search(
**search_params, memory_config=memory_config
)
except Exception as e:
logger.error(
f"Input_Summary: hybrid_search failed, using empty results: {e}",
exc_info=True
)
retrieve_info, question, raw_results = "", query, []
# Return retrieved information directly without LLM processing
# Use the raw retrieved info as the answer
aimessages = retrieve_info if retrieve_info else "信息不足,无法回答"
logger.info(f"Quick answer (no LLM): {storage_type}--{user_rag_memory_id}--{aimessages[:500]}...")
# Emit intermediate output for frontend
return {
"status": "success",
"summary_result": aimessages,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "input_summary",
"title": "快速答案",
"summary": aimessages,
"query": query,
"raw_results": raw_results,
"search_mode": "quick_search",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Input_Summary failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "信息不足,无法回答",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Retrieval', duration)
@mcp.tool()
async def Summary_fails(
ctx: Context,
context: str,
usermessages: str,
apply_id: str,
group_id: str,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Handle workflow failure when summary cannot be generated.
Args:
ctx: FastMCP context for dependency injection
context: Failure context string
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'query_answer' with failure message
"""
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Parse session ID from usermessages
usermessages_parts = usermessages.split('_')[1:]
sessionid = '_'.join(usermessages_parts[:-1])
# Cleanup duplicate sessions
await session_service.cleanup_duplicates()
logger.info("没有相关数据")
logger.debug(f"Summary_fails called with apply_id: {apply_id}, group_id: {group_id}")
return {
"status": "success",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
except Exception as e:
logger.error(
f"Summary_fails failed: {e}",
exc_info=True
)
return {
"status": "fail",
"summary_result": "没有相关数据",
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"error": str(e)
}

View File

@@ -0,0 +1,174 @@
"""
Verification Tools for data verification.
This module contains MCP tools for verifying retrieved data.
"""
import time
from app.core.logging_config import get_agent_logger, log_time
from app.core.memory.agent.mcp_server.mcp_instance import mcp
from app.core.memory.agent.mcp_server.server import get_context_resource
from app.core.memory.agent.utils.llm_tools import PROJECT_ROOT_
from app.core.memory.agent.utils.messages_tool import (
Resolve_username,
Retrieve_verify_tool_messages_deal,
Verify_messages_deal,
)
from app.core.memory.agent.utils.verify_tool import VerifyTool
from app.schemas.memory_config_schema import MemoryConfig
from jinja2 import Template
from mcp.server.fastmcp import Context
logger = get_agent_logger(__name__)
@mcp.tool()
async def Verify(
ctx: Context,
context: dict,
usermessages: str,
apply_id: str,
group_id: str,
memory_config: MemoryConfig,
storage_type: str = "",
user_rag_memory_id: str = ""
) -> dict:
"""
Verify the retrieved data.
Args:
ctx: FastMCP context for dependency injection
context: Dictionary containing query and expansion issues
usermessages: User messages identifier
apply_id: Application identifier
group_id: Group identifier
memory_config: MemoryConfig object containing all configuration
storage_type: Storage type for the workspace (optional)
user_rag_memory_id: User RAG memory identifier (optional)
Returns:
dict: Contains 'status' and 'verified_data' with verification results
"""
start = time.time()
try:
# Extract services from context
session_service = get_context_resource(ctx, 'session_service')
# Load verification prompt template
file_path = PROJECT_ROOT_ + '/agent/utils/prompt/split_verify_prompt.jinja2'
# Read template file directly (VerifyTool expects raw template content)
from app.core.memory.agent.utils.messages_tool import read_template_file
system_prompt = await read_template_file(file_path)
# Resolve session ID
sessionid = Resolve_username(usermessages)
# Get conversation history
history = await session_service.get_history(sessionid, apply_id, group_id)
template = Template(system_prompt)
system_prompt = template.render(history=history, sentence=context)
# Process context to extract query and results
Query_small, Result_small, query = await Verify_messages_deal(context)
# Build query list for verification
query_list = []
for query_small, anser in zip(Query_small, Result_small, strict=False):
query_list.append({
'Query_small': query_small,
'Answer_Small': anser
})
messages = {
"Query": query,
"Expansion_issue": query_list
}
# Call verification workflow with LLM model ID from memory_config
verify_tool = VerifyTool(
system_prompt=system_prompt,
verify_data=messages,
llm_model_id=str(memory_config.llm_model_id)
)
verify_result = await verify_tool.verify()
# Parse LLM verification result with error handling
try:
messages_deal = await Retrieve_verify_tool_messages_deal(
verify_result,
history,
query
)
except Exception as e:
logger.error(
f"Retrieve_verify_tool_messages_deal parsing failed: {e}",
exc_info=True
)
# Fallback to avoid 500 errors
messages_deal = {
"data": {
"query": query,
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": history,
}
logger.info(f"Verification result: {messages_deal}")
# Emit intermediate output for frontend
return {
"status": "success",
"verified_data": messages_deal,
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"_intermediate": {
"type": "verification",
"title": "Data Verification",
"result": messages_deal.get("split_result", "unknown"),
"reason": messages_deal.get("reason", ""),
"query": query,
"verified_count": len(query_list),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id
}
}
except Exception as e:
logger.error(
f"Verify failed: {e}",
exc_info=True
)
return {
"status": "error",
"message": str(e),
"storage_type": storage_type,
"user_rag_memory_id": user_rag_memory_id,
"verified_data": {
"data": {
"query": "",
"expansion_issue": []
},
"split_result": "failed",
"reason": str(e),
"history": [],
}
}
finally:
# Log execution time
end = time.time()
try:
duration = end - start
except Exception:
duration = 0.0
log_time('Verification', duration)

View File

@@ -1,32 +0,0 @@
"""Pydantic models for verification operations."""
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
class VerificationItem(BaseModel):
"""Individual verification item for a query-answer pair."""
query_small: str = Field(..., description="子问题")
answer_small: str = Field(..., description="子问题的回答")
status: str = Field(..., description="验证状态True 或 False")
query_answer: str = Field(..., description="问题的答案(与 answer_small 相同)")
class VerificationResult(BaseModel):
"""Result model for verification operation."""
query: str = Field(..., description="原始查询问题")
history: List[Dict[str, Any]] = Field(default_factory=list, description="历史对话记录")
expansion_issue: List[VerificationItem] = Field(
default_factory=list,
description="验证后的数据列表,包含所有通过验证的问答对"
)
split_result: str = Field(
...,
description="验证结果状态successexpansion_issue 非空)或 failedexpansion_issue 为空)"
)
reason: Optional[str] = Field(
None,
description="验证结果的说明和分析"
)

View File

@@ -1,28 +0,0 @@
"""Pydantic models for write aggregate judgment operations."""
from typing import List, Union
from pydantic import BaseModel, Field
class MessageItem(BaseModel):
"""Individual message item in conversation."""
role: str = Field(..., description="角色user 或 assistant")
content: str = Field(..., description="消息内容")
class WriteAggregateResponse(BaseModel):
"""Response model for aggregate judgment containing judgment result and output."""
is_same_event: bool = Field(
...,
description="是否是同一事件。True表示是同一事件False表示不同事件"
)
output: Union[List[MessageItem], bool] = Field(
...,
description="如果is_same_event为True返回False如果is_same_event为False返回消息列表"
)
# 为了保持向后兼容,保留旧的类名作为别名
WriteAggregateModel = WriteAggregateResponse

View File

@@ -0,0 +1,114 @@
import os
import sys
import traceback
import requests
# from qcloud_cos import CosConfig, CosS3Client
# from qcloud_cos.cos_exception import CosClientError, CosServiceError
# from config.paths import BASE_DIR
BASE_DIR = os.path.dirname(os.path.realpath(sys.argv[0]))
class OSSUploader:
"""对象存储文件上传工具类"""
def __init__(self, env):
api = {
"test": "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon",
"prod": "https://lingqi.redbearai.com/api/user/file/common/upload/v2/anon"
}
self.api = api.get(env, "https://testlingqi.redbearai.com/api/user/file/common/upload/v2/anon")
self.privacy = "false"
self.headers = {
"User-Agent": 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
'AppleWebKit/537.36 (KHTML, like Gecko)'
' Chrome/133.0.6833.84 Safari/537.36'
}
@staticmethod
def _generate_object_key(file_path, prefix='xhs_'):
"""
生成对象存储的Key
:param file_path: 本地文件路径
:param prefix: 存储前缀,用于分类存储
:return: 生成的对象Key
"""
# 文件md5值.后缀名
filename = os.path.basename(file_path)
filename = f"{filename}"
# 组合成完整的对象Key
return f"{prefix}{filename}"
def upload_image(self, file_name, prefix='jd_'):
"""
上传文件到COS并返回可访问的URL
:param file_url: 文件路径
:param file_name: 文件名称
:param media_type: 文件类型
:param prefix: 存储前缀,用于分类存储
:return: 文件访问URL
"""
# 检查文件是否存在
file_path = os.path.join(BASE_DIR, file_name)
# response = requests.get(url, headers=self.headers, stream=True)
# if response.status_code == 200:
# with open(file_path, "wb") as f:
# for chunk in response.iter_content(1024): # 分块写入,避免内存占用过大
# f.write(chunk)
# else:
# raise Exception(f"文件下载失败,{file_name}")
# 生成对象Key
object_key = self._generate_object_key(file_path, prefix +file_name.split('.')[-1])
try:
upload_response = requests.post(
self.api,
data={
"privacy": self.privacy,
"fileName": object_key,
}
)
if upload_response.status_code != 200:
raise Exception('上传接口请求失败')
resp = upload_response.json()
name = resp["data"]["name"]
file_url = resp["data"]["path"]
policy = resp["data"]["policy"]
with open(file_path, 'rb') as f:
oss_push_resp = requests.post(
policy["host"],
files={
"key": policy["dir"],
"OSSAccessKeyId": policy["accessid"],
"name": name,
"policy": policy["policy"],
"success_action_status": 200,
"signature": policy["signature"],
"file": f,
}
)
if oss_push_resp.status_code == 200:
return file_url
raise Exception("OSS上传失败")
except Exception:
raise Exception(f"上传失败: \n{traceback.format_exc()}")
finally:
print('success')
# os.remove(file_path)
if __name__ == '__main__':
cos_uploader = OSSUploader("prod")
url =cos_uploader.upload_image('./example01.jpg')
print(url)

Some files were not shown because too many files have changed in this diff Show More