diff --git a/.gitignore b/.gitignore index 2fb41537..66d1beb2 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,7 @@ search_results.json api/migrations/versions tmp files +powers/ # Exclude dep files huggingface.co/ diff --git a/README.md b/README.md index 2f53a996..95d8d737 100644 --- a/README.md +++ b/README.md @@ -226,8 +226,8 @@ REDIS_PORT=6379 REDIS_DB=1 # Celery (Using Redis as broker) -BROKER_URL=redis://127.0.0.1:6379/0 -RESULT_BACKEND=redis://127.0.0.1:6379/0 +REDIS_DB_CELERY_BROKER=1 +REDIS_DB_CELERY_BACKEND=2 # JWT Secret Key (Formation method: openssl rand -hex 32) SECRET_KEY=your-secret-key-here diff --git a/README_CN.md b/README_CN.md index aed69b03..1472acac 100644 --- a/README_CN.md +++ b/README_CN.md @@ -201,8 +201,8 @@ REDIS_PORT=6379 REDIS_DB=1 # Celery (使用Redis作为broker) -BROKER_URL=redis://127.0.0.1:6379/0 -RESULT_BACKEND=redis://127.0.0.1:6379/0 +REDIS_DB_CELERY_BROKER=1 +REDIS_DB_CELERY_BACKEND=2 # JWT密钥 (生成方式: openssl rand -hex 32) SECRET_KEY=your-secret-key-here diff --git a/api/app/aioRedis.py b/api/app/aioRedis.py index c729a3dc..f758dd15 100644 --- a/api/app/aioRedis.py +++ b/api/app/aioRedis.py @@ -10,7 +10,6 @@ from app.core.config import settings # 设置日志记录器 logger = logging.getLogger(__name__) - # 创建连接池 pool = ConnectionPool.from_url( f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}", @@ -21,6 +20,7 @@ pool = ConnectionPool.from_url( ) aio_redis = redis.StrictRedis(connection_pool=pool) + async def get_redis_connection(): """获取Redis连接""" try: @@ -29,7 +29,8 @@ async def get_redis_connection(): logger.error(f"Redis连接失败: {str(e)}") return None -async def aio_redis_set(key: str, val: str|dict, expire: int = None): + +async def aio_redis_set(key: str, val: str | dict, expire: int = None): """设置Redis键值 Args: @@ -40,7 +41,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): try: if isinstance(val, dict): val = json.dumps(val, ensure_ascii=False) - + if expire is not None: # 设置带过期时间的键值 await aio_redis.set(key, val, ex=expire) @@ -50,6 +51,7 @@ async def aio_redis_set(key: str, val: str|dict, expire: int = None): except Exception as e: logger.error(f"Redis set错误: {str(e)}") + async def aio_redis_get(key: str): """获取Redis键值""" try: @@ -58,6 +60,7 @@ async def aio_redis_get(key: str): logger.error(f"Redis get错误: {str(e)}") return None + async def aio_redis_delete(key: str): """删除Redis键""" try: @@ -66,6 +69,7 @@ async def aio_redis_delete(key: str): logger.error(f"Redis delete错误: {str(e)}") return None + async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: """发布消息到Redis频道""" try: @@ -78,9 +82,10 @@ async def aio_redis_publish(channel: str, message: Dict[str, Any]) -> bool: logger.error(f"Redis发布错误: {str(e)}") return False + class RedisSubscriber: """Redis订阅器""" - + def __init__(self, channel: str): self.channel = channel self.conn = None @@ -88,25 +93,25 @@ class RedisSubscriber: self.is_closed = False self._queue = asyncio.Queue() self._task = None - + async def start(self): """开始订阅""" if self.is_closed or self._task: return - + self._task = asyncio.create_task(self._receive_messages()) logger.info(f"开始订阅: {self.channel}") - + async def _receive_messages(self): """接收消息""" try: self.conn = await get_redis_connection() if not self.conn: return - + self.pubsub = self.conn.pubsub() await self.pubsub.subscribe(self.channel) - + while not self.is_closed: try: message = await self.pubsub.get_message(ignore_subscribe_messages=True, timeout=0.01) @@ -127,7 +132,7 @@ class RedisSubscriber: finally: await self._queue.put(None) await self._cleanup() - + async def _cleanup(self): """清理资源""" if self.pubsub: @@ -141,7 +146,7 @@ class RedisSubscriber: await self.conn.close() except Exception: pass - + async def get_message(self) -> Optional[Dict[str, Any]]: """获取消息""" if self.is_closed: @@ -153,7 +158,7 @@ class RedisSubscriber: except Exception as e: logger.error(f"获取消息错误: {str(e)}") return None - + async def close(self): """关闭订阅器""" if self.is_closed: @@ -163,32 +168,33 @@ class RedisSubscriber: self._task.cancel() await self._cleanup() + class RedisPubSubManager: """Redis发布订阅管理器""" - + def __init__(self): self.subscribers = {} - + async def publish(self, channel: str, message: Dict[str, Any]) -> bool: return await aio_redis_publish(channel, message) - + def get_subscriber(self, channel: str) -> RedisSubscriber: if channel in self.subscribers: subscriber = self.subscribers[channel] if not subscriber.is_closed: return subscriber - + subscriber = RedisSubscriber(channel) self.subscribers[channel] = subscriber return subscriber - + def cancel_subscription(self, channel: str) -> bool: if channel in self.subscribers: asyncio.create_task(self.subscribers[channel].close()) del self.subscribers[channel] return True return False - + def cancel_all_subscriptions(self) -> int: count = len(self.subscribers) for subscriber in self.subscribers.values(): @@ -196,6 +202,6 @@ class RedisPubSubManager: self.subscribers.clear() return count + # 全局实例 pubsub_manager = RedisPubSubManager() - diff --git a/api/app/cache/__init__.py b/api/app/cache/__init__.py index 748ce8ae..ca7aa91a 100644 --- a/api/app/cache/__init__.py +++ b/api/app/cache/__init__.py @@ -2,7 +2,9 @@ Cache 缓存模块 提供各种缓存功能的统一入口 -注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ +from .memory import InterestMemoryCache -__all__ = [] +__all__ = [ + "InterestMemoryCache", +] diff --git a/api/app/cache/memory/__init__.py b/api/app/cache/memory/__init__.py index 35f45aad..551062ac 100644 --- a/api/app/cache/memory/__init__.py +++ b/api/app/cache/memory/__init__.py @@ -2,7 +2,11 @@ Memory 缓存模块 提供记忆系统相关的缓存功能 -注意:隐性记忆和情绪建议已迁移到数据库存储,不再使用Redis缓存 """ +from .interest_memory import InterestMemoryCache +from .activity_stats_cache import ActivityStatsCache -__all__ = [] +__all__ = [ + "InterestMemoryCache", + "ActivityStatsCache", +] diff --git a/api/app/cache/memory/activity_stats_cache.py b/api/app/cache/memory/activity_stats_cache.py new file mode 100644 index 00000000..6b162cdd --- /dev/null +++ b/api/app/cache/memory/activity_stats_cache.py @@ -0,0 +1,124 @@ +""" +Recent Activity Stats Cache + +记忆提取活动统计缓存模块 +用于缓存每次记忆提取流程的统计数据,按 workspace_id 存储,24小时后释放 +查询命令:cache:memory:activity_stats:by_workspace:7de31a97-40a6-4fc0-b8d3-15c89f523843 +""" +import json +import logging +from typing import Optional, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + +# 缓存过期时间:24小时 +ACTIVITY_STATS_CACHE_EXPIRE = 86400 + + +class ActivityStatsCache: + """记忆提取活动统计缓存类""" + + PREFIX = "cache:memory:activity_stats" + + @classmethod + def _get_key(cls, workspace_id: str) -> str: + """生成 Redis key + + Args: + workspace_id: 工作空间ID + + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_workspace:{workspace_id}" + + @classmethod + async def set_activity_stats( + cls, + workspace_id: str, + stats: Dict[str, Any], + expire: int = ACTIVITY_STATS_CACHE_EXPIRE, + ) -> bool: + """设置记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + stats: 统计数据,格式: + { + "chunk_count": int, + "statements_count": int, + "triplet_entities_count": int, + "triplet_relations_count": int, + "temporal_count": int, + } + expire: 过期时间(秒),默认24小时 + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(workspace_id) + payload = { + "stats": stats, + "generated_at": datetime.now().isoformat(), + "workspace_id": workspace_id, + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置活动统计缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置活动统计缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_activity_stats( + cls, + workspace_id: str, + ) -> Optional[Dict[str, Any]]: + """获取记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + + Returns: + 统计数据字典,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(workspace_id) + value = await aio_redis.get(key) + if value: + payload = json.loads(value) + logger.info(f"命中活动统计缓存: {key}") + return payload + logger.info(f"活动统计缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取活动统计缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_activity_stats( + cls, + workspace_id: str, + ) -> bool: + """删除记忆提取活动统计缓存 + + Args: + workspace_id: 工作空间ID + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(workspace_id) + result = await aio_redis.delete(key) + logger.info(f"删除活动统计缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除活动统计缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/cache/memory/interest_memory.py b/api/app/cache/memory/interest_memory.py new file mode 100644 index 00000000..108e2a37 --- /dev/null +++ b/api/app/cache/memory/interest_memory.py @@ -0,0 +1,122 @@ +""" +Interest Distribution Cache + +兴趣分布缓存模块 +用于缓存用户的兴趣分布标签数据,避免重复调用模型生成 +""" +import json +import logging +from typing import Optional, List, Dict, Any +from datetime import datetime + +from app.aioRedis import aio_redis + +logger = logging.getLogger(__name__) + +# 缓存过期时间:24小时 +INTEREST_CACHE_EXPIRE = 86400 + + +class InterestMemoryCache: + """兴趣分布缓存类""" + + PREFIX = "cache:memory:interest_distribution" + + @classmethod + def _get_key(cls, end_user_id: str, language: str) -> str: + """生成 Redis key + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 完整的 Redis key + """ + return f"{cls.PREFIX}:by_user:{end_user_id}:{language}" + + @classmethod + async def set_interest_distribution( + cls, + end_user_id: str, + language: str, + data: List[Dict[str, Any]], + expire: int = INTEREST_CACHE_EXPIRE, + ) -> bool: + """设置用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...] + expire: 过期时间(秒),默认24小时 + + Returns: + 是否设置成功 + """ + try: + key = cls._get_key(end_user_id, language) + payload = { + "data": data, + "generated_at": datetime.now().isoformat(), + "cached": True, + } + value = json.dumps(payload, ensure_ascii=False) + await aio_redis.set(key, value, ex=expire) + logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒") + return True + except Exception as e: + logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True) + return False + + @classmethod + async def get_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> Optional[List[Dict[str, Any]]]: + """获取用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 兴趣分布列表,缓存不存在或已过期返回 None + """ + try: + key = cls._get_key(end_user_id, language) + value = await aio_redis.get(key) + if value: + payload = json.loads(value) + logger.info(f"命中兴趣分布缓存: {key}") + return payload.get("data") + logger.info(f"兴趣分布缓存不存在或已过期: {key}") + return None + except Exception as e: + logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True) + return None + + @classmethod + async def delete_interest_distribution( + cls, + end_user_id: str, + language: str, + ) -> bool: + """删除用户兴趣分布缓存 + + Args: + end_user_id: 用户ID + language: 语言类型 + + Returns: + 是否删除成功 + """ + try: + key = cls._get_key(end_user_id, language) + result = await aio_redis.delete(key) + logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}") + return result > 0 + except Exception as e: + logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True) + return False diff --git a/api/app/celery_app.py b/api/app/celery_app.py index ba294651..0319e079 100644 --- a/api/app/celery_app.py +++ b/api/app/celery_app.py @@ -7,20 +7,48 @@ from celery import Celery from celery.schedules import crontab from app.core.config import settings +from app.core.logging_config import get_logger + +logger = get_logger(__name__) # macOS fork() safety - must be set before any Celery initialization if platform.system() == 'Darwin': os.environ.setdefault('OBJC_DISABLE_INITIALIZE_FORK_SAFETY', 'YES') # 创建 Celery 应用实例 -# broker: 任务队列(使用 Redis DB 0) -# backend: 结果存储(使用 Redis DB 10) +# broker: 任务队列(使用 Redis DB,由 CELERY_BROKER_DB 指定) +# backend: 结果存储(使用 Redis DB,由 CELERY_BACKEND_DB 指定) +# NOTE: 不要在 .env 中设置 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, +# 这些名称会被 Celery CLI 的 Click 框架劫持,详见 docs/celery-env-bug-report.md + +# Build canonical broker/backend URLs and force them into os.environ so that +# Celery's Settings.broker_url property (which checks CELERY_BROKER_URL first) +# cannot be overridden by stray env vars. +# See: https://github.com/celery/celery/issues/4284 +_broker_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BROKER}" +_backend_url = f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB_CELERY_BACKEND}" +os.environ["CELERY_BROKER_URL"] = _broker_url +os.environ["CELERY_RESULT_BACKEND"] = _backend_url +# Neutralize legacy Celery env vars that can be hijacked by Celery's CLI/Click +# integration and accidentally override our canonical URLs. +os.environ.pop("BROKER_URL", None) +os.environ.pop("RESULT_BACKEND", None) +os.environ.pop("CELERY_BROKER", None) +os.environ.pop("CELERY_BACKEND", None) + celery_app = Celery( "redbear_tasks", - broker=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BROKER}", - backend=f"redis://:{quote(settings.REDIS_PASSWORD)}@{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.CELERY_BACKEND}", + broker=_broker_url, + backend=_backend_url, ) +logger.info( + "Celery app initialized", + extra={ + "broker": _broker_url.replace(quote(settings.REDIS_PASSWORD), "***"), + "backend": _backend_url.replace(quote(settings.REDIS_PASSWORD), "***"), + }, +) # Default queue for unrouted tasks celery_app.conf.task_default_queue = 'memory_tasks' @@ -44,8 +72,8 @@ celery_app.conf.update( task_ignore_result=False, # 超时设置 - task_time_limit=1800, # 30分钟硬超时 - task_soft_time_limit=1500, # 25分钟软超时 + task_time_limit=3600, # 60分钟硬超时 + task_soft_time_limit=3000, # 50分钟软超时 # Worker 设置 (per-worker settings are in docker-compose command line) worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution @@ -92,7 +120,7 @@ celery_app.conf.update( celery_app.autodiscover_tasks(['app']) # Celery Beat schedule for periodic tasks -memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS) +memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE) memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS) workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS) forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS) diff --git a/api/app/config/__init__.py b/api/app/config/__init__.py new file mode 100644 index 00000000..df675a16 --- /dev/null +++ b/api/app/config/__init__.py @@ -0,0 +1 @@ +"""Configuration module for application settings.""" diff --git a/api/app/config/default_ontology_config.py b/api/app/config/default_ontology_config.py new file mode 100644 index 00000000..157aa73e --- /dev/null +++ b/api/app/config/default_ontology_config.py @@ -0,0 +1,239 @@ +"""默认本体场景配置 + +本模块定义系统预设的本体场景和实体类型配置。 +这些配置用于在工作空间创建时自动初始化默认场景。 +支持中英文双语配置,根据用户语言偏好创建对应语言的场景。 +""" + +# 在线教育场景配置 +ONLINE_EDUCATION_SCENE = { + "name_chinese": "在线教育", + "name_english": "Online Education", + "description_chinese": "适用于在线教育平台的本体建模,包含学生、教师、课程等核心实体类型", + "description_english": "Ontology modeling for online education platforms, including core entity types such as students, teachers, and courses", + "types": [ + { + "name_chinese": "学生", + "name_english": "Student", + "description_chinese": "在教育系统中接受教育的个体,包含姓名、学号、年级、班级等属性", + "description_english": "Individuals receiving education in the education system, including attributes such as name, student ID, grade, and class" + }, + { + "name_chinese": "教师", + "name_english": "Teacher", + "description_chinese": "在教育系统中提供教学服务的个体,包含姓名、工号、任教学科、职称等属性", + "description_english": "Individuals providing teaching services in the education system, including attributes such as name, employee ID, teaching subject, and title" + }, + { + "name_chinese": "课程", + "name_english": "Course", + "description_chinese": "教育系统中的教学内容单元,包含课程名称、课程代码、学分、学时等属性", + "description_english": "Teaching content units in the education system, including attributes such as course name, course code, credits, and class hours" + }, + { + "name_chinese": "作业", + "name_english": "Assignment", + "description_chinese": "课程中布置的学习任务,包含作业标题、截止日期、所属课程、提交状态等属性", + "description_english": "Learning tasks assigned in courses, including attributes such as assignment title, deadline, course, and submission status" + }, + { + "name_chinese": "成绩", + "name_english": "Grade", + "description_chinese": "学生学习成果的评价结果,包含分数、评级、考试类型、所属课程等属性", + "description_english": "Evaluation results of student learning outcomes, including attributes such as score, rating, exam type, and course" + }, + { + "name_chinese": "考试", + "name_english": "Exam", + "description_chinese": "评估学生学习成果的测试活动,包含考试名称、时间、地点、科目等属性", + "description_english": "Test activities to assess student learning outcomes, including attributes such as exam name, time, location, and subject" + }, + { + "name_chinese": "教室", + "name_english": "Classroom", + "description_chinese": "进行教学活动的物理或虚拟空间,包含教室编号、容量、设备等属性", + "description_english": "Physical or virtual spaces for teaching activities, including attributes such as classroom number, capacity, and equipment" + }, + { + "name_chinese": "学科", + "name_english": "Subject", + "description_chinese": "知识的分类领域,包含学科名称、代码、所属院系等属性", + "description_english": "Classification domains of knowledge, including attributes such as subject name, code, and department" + }, + { + "name_chinese": "教材", + "name_english": "Textbook", + "description_chinese": "教学使用的书籍或资料,包含书名、作者、出版社、ISBN等属性", + "description_english": "Books or materials used for teaching, including attributes such as title, author, publisher, and ISBN" + }, + { + "name_chinese": "班级", + "name_english": "Class", + "description_chinese": "学生的组织单位,包含班级名称、年级、人数、班主任等属性", + "description_english": "Organizational units of students, including attributes such as class name, grade, number of students, and class teacher" + }, + { + "name_chinese": "学期", + "name_english": "Semester", + "description_chinese": "教学时间的划分单位,包含学期名称、开始时间、结束时间等属性", + "description_english": "Time division units for teaching, including attributes such as semester name, start time, and end time" + }, + { + "name_chinese": "课时", + "name_english": "Class Hour", + "description_chinese": "课程的时间单位,包含上课时间、地点、教师、课程等属性", + "description_english": "Time units of courses, including attributes such as class time, location, teacher, and course" + }, + { + "name_chinese": "教学计划", + "name_english": "Teaching Plan", + "description_chinese": "课程的教学安排,包含教学目标、内容安排、进度计划等属性", + "description_english": "Teaching arrangements for courses, including attributes such as teaching objectives, content arrangement, and progress plan" + } + ] +} + +# 情感陪伴场景配置 +EMOTIONAL_COMPANION_SCENE = { + "name_chinese": "情感陪伴", + "name_english": "Emotional Companion", + "description_chinese": "适用于情感陪伴应用的本体建模,包含用户、情绪、活动等核心实体类型", + "description_english": "Ontology modeling for emotional companion applications, including core entity types such as users, emotions, and activities", + "types": [ + { + "name_chinese": "用户", + "name_english": "User", + "description_chinese": "使用情感陪伴服务的个体,包含姓名、昵称、性格特征、偏好等属性", + "description_english": "Individuals using emotional companion services, including attributes such as name, nickname, personality traits, and preferences" + }, + { + "name_chinese": "情绪", + "name_english": "Emotion", + "description_chinese": "用户的情感状态,包含情绪类型、强度、触发原因、持续时间等属性", + "description_english": "Emotional states of users, including attributes such as emotion type, intensity, trigger cause, and duration" + }, + { + "name_chinese": "活动", + "name_english": "Activity", + "description_chinese": "用户参与的各类活动,包含活动名称、类型、参与者、时间地点等属性", + "description_english": "Various activities users participate in, including attributes such as activity name, type, participants, time, and location" + }, + { + "name_chinese": "对话", + "name_english": "Conversation", + "description_chinese": "用户之间的交流记录,包含对话主题、参与者、时间、关键内容等属性", + "description_english": "Communication records between users, including attributes such as conversation topic, participants, time, and key content" + }, + { + "name_chinese": "兴趣爱好", + "name_english": "Hobby", + "description_chinese": "用户的兴趣和爱好,包含爱好名称、类别、熟练程度、相关活动等属性", + "description_english": "User interests and hobbies, including attributes such as hobby name, category, proficiency level, and related activities" + }, + { + "name_chinese": "日常事件", + "name_english": "Daily Event", + "description_chinese": "用户日常生活中的事件,包含事件描述、时间、地点、相关人物等属性", + "description_english": "Events in users' daily lives, including attributes such as event description, time, location, and related people" + }, + { + "name_chinese": "关系", + "name_english": "Relationship", + "description_chinese": "用户之间的社会关系,包含关系类型、亲密度、建立时间等属性", + "description_english": "Social relationships between users, including attributes such as relationship type, intimacy, and establishment time" + }, + { + "name_chinese": "回忆", + "name_english": "Memory", + "description_chinese": "用户的重要记忆片段,包含回忆内容、时间、地点、相关人物等属性", + "description_english": "Important memory fragments of users, including attributes such as memory content, time, location, and related people" + }, + { + "name_chinese": "地点", + "name_english": "Location", + "description_chinese": "用户活动的地理位置,包含地点名称、地址、类型、相关事件等属性", + "description_english": "Geographic locations of user activities, including attributes such as location name, address, type, and related events" + }, + { + "name_chinese": "时间节点", + "name_english": "Time Point", + "description_chinese": "重要的时间标记,包含日期、事件、意义等属性", + "description_english": "Important time markers, including attributes such as date, event, and significance" + }, + { + "name_chinese": "目标", + "name_english": "Goal", + "description_chinese": "用户设定的目标,包含目标描述、截止时间、完成状态、相关活动等属性", + "description_english": "Goals set by users, including attributes such as goal description, deadline, completion status, and related activities" + }, + { + "name_chinese": "成就", + "name_english": "Achievement", + "description_chinese": "用户获得的成就,包含成就名称、获得时间、描述、相关目标等属性", + "description_english": "Achievements obtained by users, including attributes such as achievement name, acquisition time, description, and related goals" + } + ] +} + +# 导出默认场景列表 +DEFAULT_SCENES = [ONLINE_EDUCATION_SCENE, EMOTIONAL_COMPANION_SCENE] + + +def get_scene_name(scene_config: dict, language: str = "zh") -> str: + """获取场景名称(根据语言) + + Args: + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的场景名称 + """ + if language == "en": + return scene_config.get("name_english", scene_config.get("name_chinese")) + return scene_config.get("name_chinese") + + +def get_scene_description(scene_config: dict, language: str = "zh") -> str: + """获取场景描述(根据语言) + + Args: + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的场景描述 + """ + if language == "en": + return scene_config.get("description_english", scene_config.get("description_chinese")) + return scene_config.get("description_chinese") + + +def get_type_name(type_config: dict, language: str = "zh") -> str: + """获取类型名称(根据语言) + + Args: + type_config: 类型配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的类型名称 + """ + if language == "en": + return type_config.get("name_english", type_config.get("name_chinese")) + return type_config.get("name_chinese") + + +def get_type_description(type_config: dict, language: str = "zh") -> str: + """获取类型描述(根据语言) + + Args: + type_config: 类型配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + 对应语言的类型描述 + """ + if language == "en": + return type_config.get("description_english", type_config.get("description_chinese")) + return type_config.get("description_chinese") diff --git a/api/app/config/default_ontology_initializer.py b/api/app/config/default_ontology_initializer.py new file mode 100644 index 00000000..3d06a352 --- /dev/null +++ b/api/app/config/default_ontology_initializer.py @@ -0,0 +1,249 @@ +# -*- coding: utf-8 -*- +"""默认本体场景初始化器 + +本模块提供默认本体场景和类型的自动初始化功能。 +在工作空间创建时,自动添加预设的本体场景和实体类型。 + +Classes: + DefaultOntologyInitializer: 默认本体场景初始化器 +""" + +import logging +from typing import List, Optional, Tuple +from uuid import UUID + +from sqlalchemy.orm import Session + +from app.config.default_ontology_config import ( + DEFAULT_SCENES, + get_scene_name, + get_scene_description, + get_type_name, + get_type_description, +) +from app.core.logging_config import get_business_logger +from app.repositories.ontology_scene_repository import OntologySceneRepository +from app.repositories.ontology_class_repository import OntologyClassRepository + + +class DefaultOntologyInitializer: + """默认本体场景初始化器 + + 负责在工作空间创建时自动初始化默认的本体场景和类型。 + 遵循最小侵入原则,确保初始化失败不阻止工作空间创建。 + + Attributes: + db: 数据库会话 + scene_repo: 场景Repository + class_repo: 类型Repository + logger: 业务日志记录器 + """ + + def __init__(self, db: Session): + """初始化 + + Args: + db: 数据库会话 + """ + self.db = db + self.scene_repo = OntologySceneRepository(db) + self.class_repo = OntologyClassRepository(db) + self.logger = get_business_logger() + + def initialize_default_scenes( + self, + workspace_id: UUID, + language: str = "zh" + ) -> Tuple[bool, str]: + """为工作空间初始化默认场景 + + 创建两个默认场景(在线教育、情感陪伴)及其对应的实体类型。 + 如果创建失败,记录错误日志但不抛出异常。 + + Args: + workspace_id: 工作空间ID + language: 语言类型 ("zh" 或 "en"),默认为 "zh" + + Returns: + Tuple[bool, str]: (是否成功, 错误信息) + """ + try: + self.logger.info( + f"开始初始化默认本体场景 - workspace_id={workspace_id}, language={language}" + ) + + scenes_created = 0 + total_types_created = 0 + + # 遍历默认场景配置 + for scene_config in DEFAULT_SCENES: + scene_name = get_scene_name(scene_config, language) + + # 创建场景及其类型 + scene_id = self._create_scene_with_types(workspace_id, scene_config, language) + + if scene_id: + scenes_created += 1 + # 统计类型数量 + types_count = len(scene_config.get("types", [])) + total_types_created += types_count + + self.logger.info( + f"场景创建成功 - scene_name={scene_name}, " + f"scene_id={scene_id}, types_count={types_count}, language={language}" + ) + else: + self.logger.warning( + f"场景创建失败 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, language={language}" + ) + + # 记录总体结果 + self.logger.info( + f"默认场景初始化完成 - workspace_id={workspace_id}, " + f"language={language}, scenes_created={scenes_created}, " + f"total_types_created={total_types_created}" + ) + + # 如果至少创建了一个场景,视为成功 + if scenes_created > 0: + return True, "" + else: + error_msg = "所有默认场景创建失败" + self.logger.error( + f"默认场景初始化失败 - workspace_id={workspace_id}, " + f"language={language}, error={error_msg}" + ) + return False, error_msg + + except Exception as e: + error_msg = f"默认场景初始化异常: {str(e)}" + self.logger.error( + f"默认场景初始化异常 - workspace_id={workspace_id}, " + f"language={language}, error={str(e)}", + exc_info=True + ) + return False, error_msg + + def _create_scene_with_types( + self, + workspace_id: UUID, + scene_config: dict, + language: str = "zh" + ) -> Optional[UUID]: + """创建场景及其类型 + + Args: + workspace_id: 工作空间ID + scene_config: 场景配置字典 + language: 语言类型 ("zh" 或 "en") + + Returns: + Optional[UUID]: 创建的场景ID,失败返回None + """ + try: + scene_name = get_scene_name(scene_config, language) + scene_description = get_scene_description(scene_config, language) + + # 检查是否已存在同名场景(支持向后兼容) + existing_scene = self.scene_repo.get_by_name(scene_name, workspace_id) + if existing_scene: + self.logger.info( + f"场景已存在,跳过创建 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, scene_id={existing_scene.scene_id}, " + f"language={language}" + ) + return None + + # 创建场景记录,设置 is_system_default=true + scene_data = { + "scene_name": scene_name, + "scene_description": scene_description + } + + scene = self.scene_repo.create(scene_data, workspace_id) + + # 设置系统默认标识 + scene.is_system_default = True + self.db.flush() + + self.logger.info( + f"场景创建成功 - scene_name={scene_name}, " + f"scene_id={scene.scene_id}, is_system_default=True, language={language}" + ) + + # 批量创建类型 + types_config = scene_config.get("types", []) + types_created = self._batch_create_types(scene.scene_id, types_config, language) + + self.logger.info( + f"场景类型创建完成 - scene_id={scene.scene_id}, " + f"types_created={types_created}/{len(types_config)}, language={language}" + ) + + return scene.scene_id + + except Exception as e: + scene_name = get_scene_name(scene_config, language) + self.logger.error( + f"场景创建失败 - scene_name={scene_name}, " + f"workspace_id={workspace_id}, language={language}, error={str(e)}", + exc_info=True + ) + return None + + def _batch_create_types( + self, + scene_id: UUID, + types_config: List[dict], + language: str = "zh" + ) -> int: + """批量创建实体类型 + + Args: + scene_id: 场景ID + types_config: 类型配置列表 + language: 语言类型 ("zh" 或 "en") + + Returns: + int: 成功创建的类型数量 + """ + created_count = 0 + + for type_config in types_config: + try: + type_name = get_type_name(type_config, language) + type_description = get_type_description(type_config, language) + + # 创建类型数据 + class_data = { + "class_name": type_name, + "class_description": type_description + } + + # 创建类型 + ontology_class = self.class_repo.create(class_data, scene_id) + + # 设置系统默认标识 + ontology_class.is_system_default = True + self.db.flush() + + created_count += 1 + + self.logger.debug( + f"类型创建成功 - class_name={type_name}, " + f"class_id={ontology_class.class_id}, " + f"scene_id={scene_id}, is_system_default=True, language={language}" + ) + + except Exception as e: + type_name = get_type_name(type_config, language) + self.logger.warning( + f"单个类型创建失败,继续创建其他类型 - " + f"class_name={type_name}, scene_id={scene_id}, " + f"language={language}, error={str(e)}" + ) + # 继续创建其他类型 + continue + + return created_count diff --git a/api/app/controllers/app_controller.py b/api/app/controllers/app_controller.py index f1508114..cdf94345 100644 --- a/api/app/controllers/app_controller.py +++ b/api/app/controllers/app_controller.py @@ -1,7 +1,8 @@ import uuid from typing import Optional, Annotated -from fastapi import APIRouter, Depends, Path +import yaml +from fastapi import APIRouter, Depends, Path, Form, UploadFile, File from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session @@ -17,12 +18,13 @@ from app.repositories.end_user_repository import EndUserRepository from app.schemas import app_schema from app.schemas.response_schema import PageData, PageMeta from app.schemas.workflow_schema import WorkflowConfig as WorkflowConfigSchema -from app.schemas.workflow_schema import WorkflowConfigUpdate +from app.schemas.workflow_schema import WorkflowConfigUpdate, WorkflowImportSave from app.services import app_service, workspace_service from app.services.agent_config_helper import enrich_agent_config from app.services.app_service import AppService -from app.services.workflow_service import WorkflowService, get_workflow_service from app.services.app_statistics_service import AppStatisticsService +from app.services.workflow_import_service import WorkflowImportService +from app.services.workflow_service import WorkflowService, get_workflow_service router = APIRouter(prefix="/apps", tags=["Apps"]) logger = get_business_logger() @@ -65,7 +67,7 @@ def list_apps( # 当 ids 存在且不为 None 时,根据 ids 获取应用 if ids is not None: - app_ids = [id.strip() for id in ids.split(',') if id.strip()] + app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()] items_orm = app_service.get_apps_by_ids(db, app_ids, workspace_id) items = [service._convert_to_schema(app, workspace_id) for app in items_orm] return success(data=items) @@ -394,10 +396,10 @@ async def draft_run( from app.models import AgentConfig, ModelConfig from sqlalchemy import select from app.core.exceptions import BusinessException - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService service = AppService(db) - draft_service = DraftRunService(db) + draft_service = AgentRunService(db) # 1. 验证应用 app = service._get_app_or_404(app_id) @@ -482,8 +484,8 @@ async def draft_run( } ) - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_cfg, model_config=model_config, @@ -787,8 +789,8 @@ async def draft_run_compare( # 流式返回 if payload.stream: async def event_generator(): - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) async for event in draft_service.run_compare_stream( agent_config=agent_cfg, models=model_configs, @@ -818,8 +820,8 @@ async def draft_run_compare( ) # 非流式返回 - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run_compare( agent_config=agent_cfg, models=model_configs, @@ -833,7 +835,8 @@ async def draft_run_compare( web_search=True, memory=True, parallel=payload.parallel, - timeout=payload.timeout or 60 + timeout=payload.timeout or 60, + files=payload.files ) logger.info( @@ -879,6 +882,60 @@ async def update_workflow_config( return success(data=WorkflowConfigSchema.model_validate(cfg)) +@router.get("/{app_id}/workflow/export") +@cur_workspace_access_guard() +async def export_workflow_config( + app_id: uuid.UUID, + db: Annotated[Session, Depends(get_db)], + current_user: Annotated[User, Depends(get_current_user)] +): + """导出工作流配置为YAML文件""" + workflow_service = WorkflowService(db) + + return success(data={ + "content": workflow_service.export_workflow_dsl(app_id=app_id), + }) + + +@router.post("/workflow/import") +@cur_workspace_access_guard() +async def import_workflow_config( + file: UploadFile = File(...), + platform: str = Form(...), + app_id: str = Form(None), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) + +): + """从YAML内容导入工作流配置""" + if not file.filename.lower().endswith((".yaml", ".yml")): + return fail(msg="Only yaml file is allowed", code=BizCode.BAD_REQUEST) + + raw_text = (await file.read()).decode("utf-8") + import_service = WorkflowImportService(db) + config = yaml.safe_load(raw_text) + result = await import_service.upload_config(platform, config) + return success(data=result) + + +@router.post("/workflow/import/save") +@cur_workspace_access_guard() +async def save_workflow_import( + data: WorkflowImportSave, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + import_service = WorkflowImportService(db) + app = await import_service.save_workflow( + user_id=current_user.id, + workspace_id=current_user.current_workspace_id, + temp_id=data.temp_id, + name=data.name, + description=data.description, + ) + return success(data=app_schema.App.model_validate(app)) + + @router.get("/{app_id}/statistics", summary="应用统计数据") @cur_workspace_access_guard() def get_app_statistics( @@ -889,12 +946,14 @@ def get_app_statistics( current_user=Depends(get_current_user), ): """获取应用统计数据 - + Args: app_id: 应用ID start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) - + db: 数据库连接 + current_user: 当前用户 + Returns: - daily_conversations: 每日会话数统计 - total_conversations: 总会话数 @@ -931,6 +990,8 @@ def get_workspace_api_statistics( Args: start_date: 开始时间戳(毫秒) end_date: 结束时间戳(毫秒) + db: 数据库连接 + current_user: 当前用户 Returns: 每日统计数据列表,每项包含: diff --git a/api/app/controllers/chunk_controller.py b/api/app/controllers/chunk_controller.py index 620d8a1a..988aa706 100644 --- a/api/app/controllers/chunk_controller.py +++ b/api/app/controllers/chunk_controller.py @@ -441,14 +441,14 @@ async def retrieve_chunks( # 1 participle search, 2 semantic search, 3 hybrid search match retrieve_data.retrieve_type: case chunk_schema.RetrieveType.PARTICIPLE: - rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case chunk_schema.RetrieveType.SEMANTIC: - rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) + rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) return success(data=rs, msg="retrieval successful") case _: - rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight) - rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold) + rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter) + rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter) # Efficient deduplication seen_ids = set() unique_rs = [] diff --git a/api/app/controllers/mcp_market_config_controller.py b/api/app/controllers/mcp_market_config_controller.py index 98012568..7f73663e 100644 --- a/api/app/controllers/mcp_market_config_controller.py +++ b/api/app/controllers/mcp_market_config_controller.py @@ -90,7 +90,7 @@ async def get_mcp_servers( cookies=cookies) raise_for_http_status(r) except requests.exceptions.RequestException as e: - api_logger.error(f"mFailed to get MCP servers: {str(e)}") + api_logger.error(f"Failed to get MCP servers: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to get MCP servers: {str(e)}" @@ -118,6 +118,65 @@ async def get_mcp_servers( return success(data=result, msg="Query of mcp servers list successful") +@router.get("/operational_mcp_servers", response_model=ApiResponse) +async def get_operational_mcp_servers( + mcp_market_config_id: uuid.UUID, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) +): + """ + Query the operational mcp servers list in pages + - Support keyword search for name,author,owner + - Return paging metadata + operational mcp server list + """ + api_logger.info( + f"Query operational mcp server list: tenant_id={current_user.tenant_id}, username: {current_user.username}") + + # 1. Query mcp market config information from the database + api_logger.debug(f"Query mcp market config: {mcp_market_config_id}") + db_mcp_market_config = mcp_market_config_service.get_mcp_market_config_by_id(db, + mcp_market_config_id=mcp_market_config_id, + current_user=current_user) + if not db_mcp_market_config: + api_logger.warning( + f"The mcp market config does not exist or access is denied: mcp_market_config_id={mcp_market_config_id}") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="The mcp market config does not exist or access is denied" + ) + + # 2. Execute paged query + api = MCPApi() + token = db_mcp_market_config.token + api.login(token) + + url = f'{api.mcp_base_url}/operational' + headers = api.builder_headers(api.headers) + + try: + cookies = api.get_cookies(access_token=token, cookies_required=True) + r = api.session.get(url, headers=headers, cookies=cookies) + raise_for_http_status(r) + except requests.exceptions.RequestException as e: + api_logger.error(f"Failed to get operational MCP servers: {str(e)}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to get operational MCP servers: {str(e)}" + ) + + data = api._handle_response(r) + total = data.get('total_count', 0) + mcp_server_list = data.get('mcp_server_list', []) + # items = [{ + # 'name': item.get('name', ''), + # 'id': item.get('id', ''), + # 'description': item.get('description', '') + # } for item in mcp_server_list] + + # 3. Return structured response + return success(data=mcp_server_list, msg="Query of operational mcp servers list successful") + + @router.get("/mcp_server", response_model=ApiResponse) async def get_mcp_server( mcp_market_config_id: uuid.UUID, diff --git a/api/app/controllers/memory_agent_controller.py b/api/app/controllers/memory_agent_controller.py index b88e65ff..e3d2bf92 100644 --- a/api/app/controllers/memory_agent_controller.py +++ b/api/app/controllers/memory_agent_controller.py @@ -1,27 +1,29 @@ from typing import List, Optional +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, File, Form, Query, UploadFile, Header +from sqlalchemy.orm import Session +from starlette.responses import StreamingResponse + +from app.cache.memory.interest_memory import InterestMemoryCache from app.celery_app import celery_app from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger +from app.core.memory.agent.utils.redis_tool import store +from app.core.memory.agent.utils.session_tools import SessionService from app.core.rag.llm.cv_model import QWenCV from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import cur_workspace_access_guard, get_current_user from app.models import ModelApiKey from app.models.user_model import User -from app.core.memory.agent.utils.session_tools import SessionService -from app.core.memory.agent.utils.redis_tool import store -from app.repositories import knowledge_repository, WorkspaceRepository +from app.repositories import knowledge_repository from app.schemas.memory_agent_schema import UserInput, Write_UserInput from app.schemas.response_schema import ApiResponse from app.services import task_service, workspace_service from app.services.memory_agent_service import MemoryAgentService from app.services.model_service import ModelConfigService -from dotenv import load_dotenv -from fastapi import APIRouter, Depends, File, Form, Query, UploadFile,Header -from sqlalchemy.orm import Session -from starlette.responses import StreamingResponse load_dotenv() api_logger = get_api_logger() @@ -36,7 +38,7 @@ router = APIRouter( @router.get("/health/status", response_model=ApiResponse) async def get_health_status( - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user) ): """ Get latest health status written by Celery periodic task @@ -54,8 +56,9 @@ async def get_health_status( @router.get("/download_log") async def download_log( - log_type: str = Query("file", regex="^(file|transmission)$", description="日志类型: file=完整文件, transmission=实时流式传输"), - current_user: User = Depends(get_current_user) + log_type: str = Query("file", regex="^(file|transmission)$", + description="日志类型: file=完整文件, transmission=实时流式传输"), + current_user: User = Depends(get_current_user) ): """ Download or stream agent service log file @@ -74,16 +77,16 @@ async def download_log( - transmission mode: StreamingResponse with SSE """ api_logger.info(f"Log download requested with log_type={log_type}") - + # Validate log_type parameter (FastAPI Query regex already validates, but explicit check for clarity) if log_type not in ["file", "transmission"]: api_logger.warning(f"Invalid log_type parameter: {log_type}") return fail( - BizCode.BAD_REQUEST, - "无效的log_type参数", + BizCode.BAD_REQUEST, + "无效的log_type参数", "log_type必须是'file'或'transmission'" ) - + # Route to appropriate mode if log_type == "file": # File mode: Return complete log file content @@ -118,10 +121,10 @@ async def download_log( @router.post("/writer_service", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Write service endpoint - processes write operations synchronously @@ -135,11 +138,11 @@ async def write_server( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id api_logger.info(f"Write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") - + # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( db=db, @@ -148,7 +151,7 @@ async def write_server( ) if storage_type is None: storage_type = 'neo4j' user_rag_memory_id = '' - + # 如果 storage_type 是 rag,必须确保有有效的 user_rag_memory_id if storage_type == 'rag': if workspace_id: @@ -160,13 +163,15 @@ async def write_server( if knowledge: user_rag_memory_id = str(knowledge.id) else: - api_logger.warning(f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") + api_logger.warning( + f"未找到名为 'USER_RAG_MERORY' 的知识库,workspace_id: {workspace_id},将使用 neo4j 存储") storage_type = 'neo4j' else: api_logger.warning("workspace_id 为空,无法使用 rag 存储,将使用 neo4j 存储") storage_type = 'neo4j' - - api_logger.info(f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") + + api_logger.info( + f"Write service requested for group {user_input.end_user_id}, storage_type: {storage_type}, user_rag_memory_id: {user_rag_memory_id}") try: messages_list = memory_agent_service.get_messages_list(user_input) result = await memory_agent_service.write_memory( @@ -174,7 +179,7 @@ async def write_server( messages_list, config_id, db, - storage_type, + storage_type, user_rag_memory_id, language ) @@ -194,10 +199,10 @@ async def write_server( @router.post("/writer_service_async", response_model=ApiResponse) @cur_workspace_access_guard() async def write_server_async( - user_input: Write_UserInput, - language_type: str = Header(default=None, alias="X-Language-Type"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + language_type: str = Header(default=None, alias="X-Language-Type"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Async write service endpoint - enqueues write processing to Celery @@ -212,10 +217,11 @@ async def write_server_async( """ # 使用集中化的语言校验 language = get_language_from_header(language_type) - + config_id = user_input.config_id workspace_id = current_user.current_workspace_id - api_logger.info(f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") + api_logger.info( + f"Async write service: workspace_id={workspace_id}, config_id={config_id}, language_type={language}") # 获取 storage_type,如果为 None 则使用默认值 storage_type = workspace_service.get_workspace_storage_type( @@ -243,7 +249,7 @@ async def write_server_async( args=[user_input.end_user_id, messages_list, config_id, storage_type, user_rag_memory_id, language] ) api_logger.info(f"Write task queued: {task.id}") - + return success(data={"task_id": task.id}, msg="写入任务已提交") except Exception as e: api_logger.error(f"Async write operation failed: {str(e)}") @@ -253,9 +259,9 @@ async def write_server_async( @router.post("/read_service", response_model=ApiResponse) @cur_workspace_access_guard() async def read_server( - user_input: UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Read service endpoint - processes read operations synchronously @@ -290,8 +296,9 @@ async def read_server( ) if knowledge: user_rag_memory_id = str(knowledge.id) - - api_logger.info(f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") + + api_logger.info( + f"Read service: group={user_input.end_user_id}, storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}, workspace_id={workspace_id}") try: result = await memory_agent_service.read_memory( user_input.end_user_id, @@ -305,7 +312,8 @@ async def read_server( ) if str(user_input.search_switch) == "2": retrieve_info = result['answer'] - history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, user_input.end_user_id) + history = await SessionService(store).get_history(user_input.end_user_id, user_input.end_user_id, + user_input.end_user_id) query = user_input.message # 调用 memory_agent_service 的方法生成最终答案 @@ -318,7 +326,7 @@ async def read_server( db=db ) if "信息不足,无法回答" in result['answer']: - result['answer']=retrieve_info + result['answer'] = retrieve_info return success(data=result, msg="回复对话消息成功") except BaseException as e: # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup @@ -334,9 +342,10 @@ async def read_server( @router.post("/file", response_model=ApiResponse) async def file_update( files: List[UploadFile] = File(..., description="要上传的文件"), - model_id:str = Form(..., description="模型ID"), + model_id: str = Form(..., description="模型ID"), metadata: Optional[str] = Form(None, description="文件元数据 (JSON格式)"), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 文件上传接口 - 支持图片识别 @@ -349,9 +358,6 @@ async def file_update( Returns: 文件处理结果 """ - - db_gen = get_db() # get_db 通常是一个生成器 - db = next(db_gen) api_logger.info(f"File upload requested, file count: {len(files)}") config = ModelConfigService.get_model_by_id(db=db, model_id=model_id) apiConfig: ModelApiKey = config.api_keys[0] @@ -360,7 +366,7 @@ async def file_update( for file in files: api_logger.debug(f"Processing file: {file.filename}, content_type: {file.content_type}") content = await file.read() - + if file.content_type and file.content_type.startswith("image/"): vision_model = QWenCV( key=apiConfig.api_key, @@ -374,12 +380,12 @@ async def file_update( else: api_logger.warning(f"Unsupported file type: {file.content_type}") file_content.append(f"[不支持的文件类型: {file.content_type}]") - + result_text = ';'.join(file_content) api_logger.info(f"File processing completed, result length: {len(result_text)}") - + return success(data=result_text, msg="转换文本成功") - + except Exception as e: api_logger.error(f"File processing failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "转换文本失败", str(e)) @@ -429,8 +435,8 @@ async def read_server_async( @router.get("/read_result/", response_model=ApiResponse) async def get_read_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async read task @@ -451,7 +457,7 @@ async def get_read_task_result( try: result = task_service.get_task_memory_read_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -469,7 +475,7 @@ async def get_read_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="查询任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -478,7 +484,7 @@ async def get_read_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "查询任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -498,7 +504,7 @@ async def get_read_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Read task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -506,8 +512,8 @@ async def get_read_task_result( @router.get("/write_result/", response_model=ApiResponse) async def get_write_task_result( - task_id: str, - current_user: User = Depends(get_current_user) + task_id: str, + current_user: User = Depends(get_current_user) ): """ Get the status and result of an async write task @@ -528,7 +534,7 @@ async def get_write_task_result( try: result = task_service.get_task_memory_write_result(task_id) status = result.get("status") - + if status == "SUCCESS": # 任务成功完成 task_result = result.get("result", {}) @@ -546,7 +552,7 @@ async def get_write_task_result( else: # 旧格式:直接返回结果 return success(data=task_result, msg="写入任务已完成") - + elif status == "FAILURE": # 任务失败 error_info = result.get("result", "Unknown error") @@ -555,7 +561,7 @@ async def get_write_task_result( else: error_msg = str(error_info) return fail(BizCode.INTERNAL_ERROR, "写入任务失败", error_msg) - + elif status in ["PENDING", "STARTED"]: # 任务进行中 return success( @@ -575,7 +581,7 @@ async def get_write_task_result( }, msg=f"任务状态: {status}" ) - + except Exception as e: api_logger.error(f"Write task status check failed: {str(e)}", exc_info=True) return fail(BizCode.INTERNAL_ERROR, "任务状态查询失败", str(e)) @@ -583,9 +589,9 @@ async def get_write_task_result( @router.post("/status_type", response_model=ApiResponse) async def status_type( - user_input: Write_UserInput, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + user_input: Write_UserInput, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ Determine the type of user message (read or write) @@ -628,9 +634,10 @@ async def status_type( @router.get("/stats/types", response_model=ApiResponse) async def get_knowledge_type_stats_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - only_active: bool = Query(True, description="仅统计有效记录(status=1)"), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + only_active: bool = Query(True, description="仅统计有效记录(status=1)"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ 统计当前空间下各知识库类型的数量,包含 General | Web | Third-party | Folder。 @@ -639,14 +646,9 @@ async def get_knowledge_type_stats_api( - 知识库类型根据当前用户的 current_workspace_id 过滤 - 如果用户没有当前工作空间,对应的统计返回 0 """ - api_logger.info(f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") + api_logger.info( + f"Knowledge type stats requested for workspace_id: {current_user.current_workspace_id}, end_user_id: {end_user_id}") try: - from app.db import get_db - - # 获取数据库会话 - db_gen = get_db() - db = next(db_gen) - # 调用service层函数 result = await memory_agent_service.get_knowledge_type_stats( end_user_id=end_user_id, @@ -654,48 +656,70 @@ async def get_knowledge_type_stats_api( current_workspace_id=current_user.current_workspace_id, db=db ) - + return success(data=result, msg="获取知识库类型统计成功") except Exception as e: api_logger.error(f"Knowledge type stats failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e)) -@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse) -async def get_hot_memory_tags_by_user_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - limit: int = Query(20, description="返回标签数量限制"), - current_user: User = Depends(get_current_user), - db: Session=Depends(get_db), +@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse) +async def get_interest_distribution_by_user_api( + end_user_id: str = Query(..., description="用户ID(必填)"), + limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"), + language_type: str = Header(default=None, alias="X-Language-Type"), + current_user: User = Depends(get_current_user), + db: Session = Depends(get_db), ): """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签 - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] """ - api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}") + language = get_language_from_header(language_type) + api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}") try: - result = await memory_agent_service.get_hot_memory_tags_by_user( + # 优先读取缓存 + cached = await InterestMemoryCache.get_interest_distribution( end_user_id=end_user_id, - limit=limit + language=language, ) - return success(data=result, msg="获取热门记忆标签成功") + if cached is not None: + api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}") + return success(data=cached, msg="获取兴趣分布标签成功") + + # 缓存未命中,调用模型生成 + result = await memory_agent_service.get_interest_distribution_by_user( + end_user_id=end_user_id, + limit=limit, + language=language + ) + + # 写入缓存,24小时过期 + await InterestMemoryCache.set_interest_distribution( + end_user_id=end_user_id, + language=language, + data=result, + ) + + return success(data=result, msg="获取兴趣分布标签成功") except Exception as e: - api_logger.error(f"Hot memory tags by user failed: {str(e)}") - return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e)) + api_logger.error(f"Interest distribution by user failed: {str(e)}") + return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e)) @router.get("/analytics/user_profile", response_model=ApiResponse) async def get_user_profile_api( - end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: Optional[str] = Query(None, description="用户ID(可选)"), + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取用户详情,包含: @@ -733,17 +757,17 @@ async def get_user_profile_api( # ): # """ # Get parsed API documentation (Public endpoint - no authentication required) - + # Args: # file_path: Optional path to API docs file. If None, uses default path. - + # Returns: # Parsed API documentation including title, meta info, and sections # """ # api_logger.info(f"API docs requested, file_path: {file_path or 'default'}") # try: # result = await memory_agent_service.get_api_docs(file_path) - + # if result.get("success"): # return success(msg=result["msg"], data=result["data"]) # else: @@ -759,9 +783,9 @@ async def get_user_profile_api( @router.get("/end_user/{end_user_id}/connected_config", response_model=ApiResponse) async def get_end_user_connected_config( - end_user_id: str, - db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + end_user_id: str, + db: Session = Depends(get_db), + current_user: User = Depends(get_current_user) ): """ 获取终端用户关联的记忆配置 @@ -780,9 +804,9 @@ async def get_end_user_connected_config( from app.services.memory_agent_service import ( get_end_user_connected_config as get_config, ) - + api_logger.info(f"Getting connected config for end_user: {end_user_id}") - + try: result = get_config(end_user_id, db) return success(data=result, msg="获取终端用户关联配置成功") @@ -791,4 +815,4 @@ async def get_end_user_connected_config( return fail(BizCode.NOT_FOUND, str(e)) except Exception as e: api_logger.error(f"Failed to get end user connected config: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) \ No newline at end of file + return fail(BizCode.INTERNAL_ERROR, "获取终端用户关联配置失败", str(e)) diff --git a/api/app/controllers/memory_dashboard_controller.py b/api/app/controllers/memory_dashboard_controller.py index 475d184e..1b5b45fb 100644 --- a/api/app/controllers/memory_dashboard_controller.py +++ b/api/app/controllers/memory_dashboard_controller.py @@ -606,8 +606,8 @@ async def dashboard_data( # 获取RAG相关数据 try: - # total_memory: 使用 total_chunk(总chunk数) - total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user) + # total_memory: 只统计用户知识库(permission_id='Memory')的chunk数 + total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user) rag_data["total_memory"] = total_chunk # total_app: 统计当前空间下的所有app数量 diff --git a/api/app/controllers/memory_short_term_controller.py b/api/app/controllers/memory_short_term_controller.py index 1cca266e..0acac6ce 100644 --- a/api/app/controllers/memory_short_term_controller.py +++ b/api/app/controllers/memory_short_term_controller.py @@ -1,16 +1,18 @@ -from fastapi import APIRouter, Depends, HTTPException, status,Header +from typing import Optional + +from dotenv import load_dotenv +from fastapi import APIRouter, Depends, Header, HTTPException, status +from sqlalchemy.orm import Session + from app.core.language_utils import get_language_from_header from app.core.logging_config import get_api_logger from app.core.response_utils import success from app.db import get_db from app.dependencies import get_current_user from app.models.user_model import User - +from app.services.memory_short_service import LongService, ShortService from app.services.memory_storage_service import search_entity -from app.services.memory_short_service import ShortService,LongService -from dotenv import load_dotenv -from sqlalchemy.orm import Session -from typing import Optional + load_dotenv() api_logger = get_api_logger() @@ -29,11 +31,11 @@ async def short_term_configs( language = get_language_from_header(language_type) # 获取短期记忆数据 - short_term=ShortService(end_user_id) + short_term=ShortService(end_user_id, db) short_result=short_term.get_short_databasets() short_count=short_term.get_short_count() - long_term=LongService(end_user_id) + long_term=LongService(end_user_id, db) long_result=long_term.get_long_databasets() entity_result = await search_entity(end_user_id) diff --git a/api/app/controllers/memory_storage_controller.py b/api/app/controllers/memory_storage_controller.py index 826724c9..d91dfc36 100644 --- a/api/app/controllers/memory_storage_controller.py +++ b/api/app/controllers/memory_storage_controller.py @@ -2,7 +2,7 @@ from typing import Optional from uuid import UUID from fastapi import APIRouter, Depends, Query -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from app.core.error_codes import BizCode @@ -85,6 +85,7 @@ def create_config( payload: ConfigParamsCreate, current_user: User = Depends(get_current_user), db: Session = Depends(get_db), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type"), ) -> dict: workspace_id = current_user.current_workspace_id # 检查用户是否已选择工作空间 @@ -99,7 +100,29 @@ def create_config( svc = DataConfigService(db) result = svc.create(payload) return success(data=result, msg="创建成功") + except ValueError as e: + err_str = str(e) + if err_str.startswith("DUPLICATE_CONFIG_NAME:"): + config_name = err_str.split(":", 1)[1] + api_logger.warning(f"重复的配置名称 '{config_name}' 在工作空间 {workspace_id}") + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{config_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{config_name}」的记忆配置,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Create config failed: {err_str}") + return fail(BizCode.INTERNAL_ERROR, "创建配置失败", err_str) except Exception as e: + from sqlalchemy.exc import IntegrityError + if isinstance(e, IntegrityError) and "uq_workspace_config_name" in str(getattr(e, 'orig', '')): + api_logger.warning(f"重复的配置名称 '{payload.config_name}' 在工作空间 {workspace_id}") + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Config name already exists", f"A config named \"{payload.config_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "配置名称已存在", f"当前工作空间下已存在名为「{payload.config_name}」的记忆配置,请使用其他名称") + return JSONResponse(status_code=400, content=msg) api_logger.error(f"Create config failed: {str(e)}") return fail(BizCode.INTERNAL_ERROR, "创建配置失败", str(e)) @@ -521,10 +544,11 @@ async def clear_hot_memory_tags_cache( @router.get("/analytics/recent_activity_stats", response_model=ApiResponse) async def get_recent_activity_stats_api( current_user: User = Depends(get_current_user), - ) -> dict: - api_logger.info("Recent activity stats requested") +) -> dict: + workspace_id = str(current_user.current_workspace_id) if current_user.current_workspace_id else None + api_logger.info(f"Recent activity stats requested: workspace_id={workspace_id}") try: - result = await analytics_recent_activity_stats() + result = await analytics_recent_activity_stats(workspace_id=workspace_id) return success(data=result, msg="查询成功") except Exception as e: api_logger.error(f"Recent activity stats failed: {str(e)}") diff --git a/api/app/controllers/model_controller.py b/api/app/controllers/model_controller.py index bb1ba526..6204a745 100644 --- a/api/app/controllers/model_controller.py +++ b/api/app/controllers/model_controller.py @@ -371,6 +371,11 @@ def update_model( if model_data.type is not None or model_data.provider is not None: raise BusinessException("不允许更改模型类型和供应商", BizCode.INVALID_PARAMETER) + + if model_data.is_active: + active_keys = ModelApiKeyService.get_api_keys_by_model(db=db, model_config_id=model_id, is_active=model_data.is_active) + if not active_keys: + raise BusinessException("请先为该模型配置可用的 API Key", BizCode.INVALID_PARAMETER) try: api_logger.debug(f"开始更新模型配置: model_id={model_id}") @@ -469,7 +474,9 @@ async def create_model_api_key_by_provider( config=api_key_data.config, is_active=api_key_data.is_active, priority=api_key_data.priority, - model_config_ids=model_config_ids + model_config_ids=model_config_ids, + capability=api_key_data.capability, + is_omni=api_key_data.is_omni ) created_keys, failed_models = await ModelApiKeyService.create_api_key_by_provider(db=db, data=create_data) diff --git a/api/app/controllers/ontology_controller.py b/api/app/controllers/ontology_controller.py index 49a2fb3a..3d2a1bdb 100644 --- a/api/app/controllers/ontology_controller.py +++ b/api/app/controllers/ontology_controller.py @@ -25,13 +25,13 @@ from typing import Dict, Optional, List from urllib.parse import quote from fastapi import APIRouter, Depends, HTTPException, File, UploadFile, Form, Header -from fastapi.responses import StreamingResponse +from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from app.core.config import settings from app.core.error_codes import BizCode from app.core.language_utils import get_language_from_header -from app.core.logging_config import get_api_logger +from app.core.logging_config import get_api_logger, get_business_logger from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -61,6 +61,7 @@ from app.repositories.ontology_scene_repository import OntologySceneRepository api_logger = get_api_logger() +business_logger = get_business_logger() logger = logging.getLogger(__name__) router = APIRouter( @@ -123,15 +124,23 @@ def _get_ontology_service( ) # 通过 Repository 获取可用的 API Key(负载均衡逻辑由 Repository 处理) - from app.repositories.model_repository import ModelApiKeyRepository - api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) - if not api_keys: + # from app.repositories.model_repository import ModelApiKeyRepository + from app.services.model_service import ModelApiKeyService + api_key_config = ModelApiKeyService.get_available_api_key(db, model_config.id) + if not api_key_config: logger.error(f"Model {llm_id} has no active API key") raise HTTPException( status_code=400, detail="指定的LLM模型没有可用的API密钥" ) - api_key_config = api_keys[0] + # api_keys = ModelApiKeyRepository.get_by_model_config(db, model_config.id) + # if not api_keys: + # logger.error(f"Model {llm_id} has no active API key") + # raise HTTPException( + # status_code=400, + # detail="指定的LLM模型没有可用的API密钥" + # ) + # api_key_config = api_keys[0] is_composite = getattr(model_config, 'is_composite', False) logger.info( @@ -153,6 +162,7 @@ def _get_ontology_service( provider=actual_provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, max_retries=3, timeout=60.0 ) @@ -279,7 +289,8 @@ async def extract_ontology( async def create_scene( request: SceneCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type") ): """创建本体场景 @@ -350,8 +361,18 @@ async def create_scene( return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) except RuntimeError as e: - api_logger.error(f"Runtime error in scene creation: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "场景创建失败", str(e)) + err_str = str(e) + if "UniqueViolation" in err_str or "uq_workspace_scene_name" in err_str: + api_logger.warning(f"Duplicate scene name '{request.scene_name}' in workspace {current_user.current_workspace_id}") + from app.core.language_utils import get_language_from_header + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Scene name already exists", f"A scene named \"{request.scene_name}\" already exists in the current workspace. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "场景名称已存在", f"当前工作空间下已存在名为「{request.scene_name}」的场景,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Runtime error in scene creation: {err_str}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "场景创建失败", err_str) except Exception as e: api_logger.error(f"Unexpected error in scene creation: {str(e)}", exc_info=True) @@ -399,6 +420,20 @@ async def update_scene( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认场景 + scene_repo = OntologySceneRepository(db) + scene = scene_repo.get_by_id(scene_uuid) + if scene and scene.is_system_default: + business_logger.warning( + f"尝试修改系统默认场景: user_id={current_user.id}, " + f"scene_id={scene_id}, scene_name={scene.scene_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认场景不可修改", + "该场景为系统预设场景,不允许修改" + ) + # 创建OntologyService实例 from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig @@ -491,6 +526,19 @@ async def delete_scene( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认场景 + scene_repo = OntologySceneRepository(db) + scene = scene_repo.get_by_id(scene_uuid) + if scene and scene.is_system_default: + business_logger.warning( + f"尝试删除系统默认场景: user_id={current_user.id}, " + f"scene_id={scene_id}, scene_name={scene.scene_name}" + ) + raise HTTPException( + status_code=400, + detail="SYSTEM_DEFAULT_SCENE_CANNOT_DELETE" + ) + # 创建OntologyService实例 from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig @@ -514,6 +562,9 @@ async def delete_scene( return success(data={"deleted": success_flag}, msg="场景删除成功") + except HTTPException: + raise + except ValueError as e: api_logger.warning(f"Validation error in scene deletion: {str(e)}") return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) @@ -621,7 +672,8 @@ async def get_scenes( async def create_class( request: ClassCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = Header(None, alias="X-Language-Type") ): """创建本体类型 @@ -636,7 +688,7 @@ async def create_class( ApiResponse: 包含创建的类型信息 """ from app.controllers.ontology_secondary_routes import create_class_handler - return await create_class_handler(request, db, current_user) + return await create_class_handler(request, db, current_user, x_language_type) @router.put("/class/{class_id}", response_model=ApiResponse) diff --git a/api/app/controllers/ontology_secondary_routes.py b/api/app/controllers/ontology_secondary_routes.py index 99017eea..8720065b 100644 --- a/api/app/controllers/ontology_secondary_routes.py +++ b/api/app/controllers/ontology_secondary_routes.py @@ -7,11 +7,11 @@ from uuid import UUID from typing import Optional -from fastapi import Depends +from fastapi import Depends, Header from sqlalchemy.orm import Session from app.core.error_codes import BizCode -from app.core.logging_config import get_api_logger +from app.core.logging_config import get_api_logger, get_business_logger from app.core.response_utils import fail, success from app.db import get_db from app.dependencies import get_current_user @@ -30,9 +30,11 @@ from app.schemas.response_schema import ApiResponse from app.services.ontology_service import OntologyService from app.core.memory.llm_tools.openai_client import OpenAIClient from app.core.models.base import RedBearModelConfig +from app.repositories.ontology_class_repository import OntologyClassRepository api_logger = get_api_logger() +business_logger = get_business_logger() def _get_dummy_ontology_service(db: Session) -> OntologyService: @@ -56,7 +58,7 @@ async def scenes_handler( workspace_id: Optional[str] = None, scene_name: Optional[str] = None, page: Optional[int] = None, - page_size: Optional[int] = None, + pagesize: Optional[int] = None, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) ): @@ -69,14 +71,14 @@ async def scenes_handler( workspace_id: 工作空间ID(可选,默认当前用户工作空间) scene_name: 场景名称关键词(可选,支持模糊匹配) page: 页码(可选,从1开始,仅在全量查询时有效) - page_size: 每页数量(可选,仅在全量查询时有效) + pagesize: 每页数量(可选,仅在全量查询时有效) db: 数据库会话 current_user: 当前用户 """ operation = "search" if scene_name else "list" api_logger.info( f"Scene {operation} requested by user {current_user.id}, " - f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, page_size={page_size}" + f"workspace_id={workspace_id}, keyword={scene_name}, page={page}, pagesize={pagesize}" ) try: @@ -103,13 +105,13 @@ async def scenes_handler( api_logger.warning(f"Invalid page number: {page}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") - if page_size is not None and page_size < 1: - api_logger.warning(f"Invalid page_size: {page_size}") + if pagesize is not None and pagesize < 1: + api_logger.warning(f"Invalid pagesize: {pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") - # 如果只提供了page或page_size中的一个,返回错误 - if (page is not None and page_size is None) or (page is None and page_size is not None): - api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + # 如果只提供了page或pagesize中的一个,返回错误 + if (page is not None and pagesize is None) or (page is None and pagesize is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") # 模糊搜索场景(支持分页) @@ -117,17 +119,15 @@ async def scenes_handler( total = len(scenes) # 如果提供了分页参数,进行分页处理 - if page is not None and page_size is not None: - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size + if page is not None and pagesize is not None: + start_idx = (page - 1) * pagesize + end_idx = start_idx + pagesize scenes = scenes[start_idx:end_idx] # 构建响应 items = [] for scene in scenes: - # 获取前3个class_name作为entity_type entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None - # 动态计算 type_num type_num = len(scene.classes) if scene.classes else 0 items.append(SceneResponse( @@ -139,17 +139,16 @@ async def scenes_handler( workspace_id=scene.workspace_id, created_at=scene.created_at, updated_at=scene.updated_at, - classes_count=type_num + classes_count=type_num, + is_system_default=scene.is_system_default )) # 构建响应(包含分页信息) - if page is not None and page_size is not None: - # 计算是否有下一页 - hasnext = (page * page_size) < total - + if page is not None and pagesize is not None: + hasnext = (page * pagesize) < total pagination_info = PaginationInfo( page=page, - pagesize=page_size, + pagesize=pagesize, total=total, hasnext=hasnext ) @@ -163,28 +162,25 @@ async def scenes_handler( ) else: # 获取所有场景(支持分页) - # 验证分页参数 if page is not None and page < 1: api_logger.warning(f"Invalid page number: {page}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "页码必须大于0") - if page_size is not None and page_size < 1: - api_logger.warning(f"Invalid page_size: {page_size}") + if pagesize is not None and pagesize < 1: + api_logger.warning(f"Invalid pagesize: {pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "每页数量必须大于0") - # 如果只提供了page或page_size中的一个,返回错误 - if (page is not None and page_size is None) or (page is None and page_size is not None): - api_logger.warning(f"Incomplete pagination params: page={page}, page_size={page_size}") + # 如果只提供了page或pagesize中的一个,返回错误 + if (page is not None and pagesize is None) or (page is None and pagesize is not None): + api_logger.warning(f"Incomplete pagination params: page={page}, pagesize={pagesize}") return fail(BizCode.BAD_REQUEST, "请求参数无效", "分页参数page和pagesize必须同时提供") - scenes, total = service.list_scenes(ws_uuid, page, page_size) + scenes, total = service.list_scenes(ws_uuid, page, pagesize) # 构建响应 items = [] for scene in scenes: - # 获取前3个class_name作为entity_type entity_type = [cls.class_name for cls in scene.classes[:3]] if scene.classes else None - # 动态计算 type_num type_num = len(scene.classes) if scene.classes else 0 items.append(SceneResponse( @@ -196,17 +192,16 @@ async def scenes_handler( workspace_id=scene.workspace_id, created_at=scene.created_at, updated_at=scene.updated_at, - classes_count=type_num + classes_count=type_num, + is_system_default=scene.is_system_default )) # 构建响应(包含分页信息) - if page is not None and page_size is not None: - # 计算是否有下一页 - hasnext = (page * page_size) < total - + if page is not None and pagesize is not None: + hasnext = (page * pagesize) < total pagination_info = PaginationInfo( page=page, - pagesize=page_size, + pagesize=pagesize, total=total, hasnext=hasnext ) @@ -236,7 +231,8 @@ async def scenes_handler( async def create_class_handler( request: ClassCreateRequest, db: Session = Depends(get_db), - current_user: User = Depends(get_current_user) + current_user: User = Depends(get_current_user), + x_language_type: Optional[str] = None ): """创建本体类型(统一使用列表形式,支持单个或批量)""" @@ -269,8 +265,11 @@ async def create_class_handler( ] if count == 1: - # 单个创建 + # 单个创建 - 先检查重名 class_data = classes_data[0] + existing = OntologyClassRepository(db).get_by_name(class_data["class_name"], request.scene_id) + if existing: + raise ValueError(f"DUPLICATE_CLASS_NAME:{class_data['class_name']}") ontology_class = service.create_class( scene_id=request.scene_id, class_name=class_data["class_name"], @@ -328,12 +327,36 @@ async def create_class_handler( return success(data=response.model_dump(mode='json'), msg="批量创建完成") except ValueError as e: - api_logger.warning(f"Validation error in class creation: {str(e)}") - return fail(BizCode.BAD_REQUEST, "请求参数无效", str(e)) - + err_str = str(e) + if err_str.startswith("DUPLICATE_CLASS_NAME:"): + class_name = err_str.split(":", 1)[1] + api_logger.warning(f"Duplicate class name '{class_name}' in scene {request.scene_id}") + from app.core.language_utils import get_language_from_header + from fastapi.responses import JSONResponse + lang = get_language_from_header(x_language_type) + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.warning(f"Validation error in class creation: {err_str}") + return fail(BizCode.BAD_REQUEST, "请求参数无效", err_str) + except RuntimeError as e: - api_logger.error(f"Runtime error in class creation: {str(e)}", exc_info=True) - return fail(BizCode.INTERNAL_ERROR, "类型创建失败", str(e)) + err_str = str(e) + if "UniqueViolation" in err_str or "uq_scene_class_name" in err_str: + api_logger.warning(f"Duplicate class name in scene {request.scene_id}") + from app.core.language_utils import get_language_from_header + from fastapi.responses import JSONResponse + lang = get_language_from_header(x_language_type) + class_name = request.classes[0].class_name if request.classes else "" + if lang == "en": + msg = fail(BizCode.BAD_REQUEST, "Class name already exists", f"A class named \"{class_name}\" already exists in this scene. Please use a different name.") + else: + msg = fail(BizCode.BAD_REQUEST, "类型名称已存在", f"当前场景下已存在名为「{class_name}」的类型,请使用其他名称") + return JSONResponse(status_code=400, content=msg) + api_logger.error(f"Runtime error in class creation: {err_str}", exc_info=True) + return fail(BizCode.INTERNAL_ERROR, "类型创建失败", err_str) except Exception as e: api_logger.error(f"Unexpected error in class creation: {str(e)}", exc_info=True) @@ -366,6 +389,20 @@ async def update_class_handler( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认类型 + class_repo = OntologyClassRepository(db) + ontology_class = class_repo.get_by_id(class_uuid) + if ontology_class and ontology_class.is_system_default: + business_logger.warning( + f"尝试修改系统默认类型: user_id={current_user.id}, " + f"class_id={class_id}, class_name={ontology_class.class_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认类型不可修改", + "该类型为系统预设类型,不允许修改" + ) + # 创建Service service = _get_dummy_ontology_service(db) @@ -429,6 +466,20 @@ async def delete_class_handler( api_logger.warning(f"User {current_user.id} has no current workspace") return fail(BizCode.BAD_REQUEST, "请求参数无效", "当前用户没有工作空间") + # 检查是否为系统默认类型 + class_repo = OntologyClassRepository(db) + ontology_class = class_repo.get_by_id(class_uuid) + if ontology_class and ontology_class.is_system_default: + business_logger.warning( + f"尝试删除系统默认类型: user_id={current_user.id}, " + f"class_id={class_id}, class_name={ontology_class.class_name}" + ) + return fail( + BizCode.BAD_REQUEST, + "系统默认类型不可删除", + "该类型为系统预设类型,不允许删除" + ) + # 创建Service service = _get_dummy_ontology_service(db) @@ -585,6 +636,7 @@ async def classes_handler( scene_id=scene_uuid, scene_name=scene.scene_name, scene_description=scene.scene_description, + is_system_default=scene.is_system_default, items=items ) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index 9f5f8075..3c634ae0 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -2,25 +2,32 @@ import hashlib import json import uuid from typing import Annotated + from fastapi import APIRouter, Depends, Query, Request from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.core.response_utils import success +from app.core.response_utils import success, fail from app.db import get_db, get_db_read from app.dependencies import get_share_user_id, ShareTokenData +from app.models.app_model import App +from app.models.app_model import AppType from app.repositories import knowledge_repository +from app.repositories.end_user_repository import EndUserRepository from app.repositories.workflow_repository import WorkflowConfigRepository from app.schemas import release_share_schema, conversation_schema from app.schemas.response_schema import PageData, PageMeta from app.services import workspace_service +from app.services.app_chat_service import AppChatService, get_app_chat_service from app.services.auth_service import create_access_token from app.services.conversation_service import ConversationService from app.services.release_share_service import ReleaseShareService from app.services.shared_chat_service import SharedChatService -from app.services.app_chat_service import AppChatService, get_app_chat_service -from app.utils.app_config_utils import dict_to_multi_agent_config, workflow_config_4_app_release, \ +from app.services.workflow_service import WorkflowService +from app.utils.app_config_utils import workflow_config_4_app_release, \ agent_config_4_app_release, multi_agent_config_4_app_release router = APIRouter(prefix="/public/share", tags=["Public Share"]) @@ -206,15 +213,13 @@ def list_conversations( logger.debug(f"share_data:{share_data.user_id}") other_id = share_data.user_id service = SharedChatService(db) - share, release = service._get_release_by_share_token(share_data.share_token, password) - from app.repositories.end_user_repository import EndUserRepository + share, release = service.get_release_by_share_token(share_data.share_token, password) end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, other_id=other_id ) logger.debug(new_end_user.id) - service = SharedChatService(db) conversations, total = service.list_conversations( share_token=share_data.share_token, user_id=str(new_end_user.id), @@ -293,19 +298,15 @@ async def chat( # 提前验证和准备(在流式响应开始前完成) # 这样可以确保错误能正确返回,而不是在流式响应中间出错 - from app.models.app_model import AppType + try: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode - from app.services.app_service import AppService # 验证分享链接和密码 - share, release = service._get_release_by_share_token(share_token, password) + share, release = service.get_release_by_share_token(share_token, password) # # Create end_user_id by concatenating app_id with user_id # end_user_id = f"{share.app_id}_{user_id}" # Store end_user_id in database with original user_id - from app.repositories.end_user_repository import EndUserRepository end_user_repo = EndUserRepository(db) new_end_user = end_user_repo.get_or_create_end_user( app_id=share.app_id, @@ -318,7 +319,6 @@ async def chat( """获取存储类型和工作空间的ID""" # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用) - from app.models.app_model import App app = db.query(App).filter( App.id == appid, App.is_active.is_(True) @@ -359,12 +359,12 @@ async def chat( app_type = release.app.type if release.app else None # 根据应用类型验证配置 - if app_type == "agent": + if app_type == AppType.AGENT: # Agent 类型:验证模型配置 model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("Agent 应用未配置模型", BizCode.AGENT_CONFIG_MISSING) - elif app_type == "multi_agent": + elif app_type == AppType.MULTI_AGENT: # Multi-Agent 类型:验证多 Agent 配置 config = release.config or {} if not config.get("sub_agents"): @@ -638,6 +638,34 @@ async def chat( # return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) else: - from app.core.exceptions import BusinessException - from app.core.error_codes import BizCode raise BusinessException(f"不支持的应用类型: {app_type}", BizCode.APP_TYPE_NOT_SUPPORTED) + + +@router.get("/config", summary="获取应用启动配置") +async def config_query( + password: str = Query(None, description="访问密码"), + share_data: ShareTokenData = Depends(get_share_user_id), + db: Session = Depends(get_db), +): + share_service = SharedChatService(db) + share_token = share_data.share_token + share, release = share_service.get_release_by_share_token(share_token, password) + if release.app.type == AppType.WORKFLOW: + workflow_service = WorkflowService(db) + content = { + "app_type": release.app.type, + "variables": workflow_service.get_start_node_variables(release.config) + } + elif release.app.type == AppType.AGENT: + content = { + "app_type": release.app.type, + "variables": release.config.get("variables") + } + elif release.app.type == AppType.MULTI_AGENT: + content = { + "app_type": release.app.type, + "variables": [] + } + else: + return fail(msg="Unsupported app type", code=BizCode.APP_TYPE_NOT_SUPPORTED) + return success(data=content) diff --git a/api/app/controllers/service/app_api_controller.py b/api/app/controllers/service/app_api_controller.py index bb71d831..64143f57 100644 --- a/api/app/controllers/service/app_api_controller.py +++ b/api/app/controllers/service/app_api_controller.py @@ -89,7 +89,6 @@ async def chat( body = await request.json() payload = AppChatRequest(**body) - other_id = payload.user_id app = app_service.get_app(api_key_auth.resource_id, api_key_auth.workspace_id) other_id = payload.user_id workspace_id = app.workspace_id @@ -135,7 +134,8 @@ async def chat( app_id=app.id, workspace_id=workspace_id, user_id=end_user_id, - is_draft=False + is_draft=False, + conversation_id=payload.conversation_id ) if app_type == AppType.AGENT: @@ -249,6 +249,7 @@ async def chat( app_id=app.id, workspace_id=workspace_id, release_id=app.current_release.id, + public=True ): event_type = event.get("event", "message") event_data = event.get("data", {}) diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index accd749e..34489e8a 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -39,7 +39,7 @@ async def write_memory_api_service( Stores memory content for the specified end user using the Memory API Service. """ - logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}") + logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}") memory_api_service = MemoryAPIService(db) diff --git a/api/app/controllers/workspace_controller.py b/api/app/controllers/workspace_controller.py index d2afb10f..9bcd8571 100644 --- a/api/app/controllers/workspace_controller.py +++ b/api/app/controllers/workspace_controller.py @@ -1,7 +1,7 @@ import uuid from typing import List, Optional -from fastapi import APIRouter, Depends, HTTPException, Query, status +from fastapi import APIRouter, Depends, Header, HTTPException, Query, status from sqlalchemy.orm import Session from app.core.logging_config import get_api_logger @@ -95,16 +95,29 @@ def get_workspaces( @router.post("", response_model=ApiResponse) def create_workspace( workspace: WorkspaceCreate, + language_type: str = Header(default="zh", alias="X-Language-Type"), db: Session = Depends(get_db), current_user: User = Depends(get_current_superuser), ): """创建新的工作空间""" - api_logger.info(f"用户 {current_user.username} 请求创建工作空间: {workspace.name}") + from app.core.language_utils import get_language_from_header + + # 验证并获取语言参数 + language = get_language_from_header(language_type) + + api_logger.info( + f"用户 {current_user.username} 请求创建工作空间: {workspace.name}, " + f"language={language}" + ) result = workspace_service.create_workspace( - db=db, workspace=workspace, user=current_user) + db=db, workspace=workspace, user=current_user, language=language + ) - api_logger.info(f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, 创建者: {current_user.username}") + api_logger.info( + f"工作空间创建成功 - 名称: {workspace.name}, ID: {result.id}, " + f"创建者: {current_user.username}, language={language}" + ) result_schema = WorkspaceResponse.model_validate(result) return success(data=result_schema, msg="工作空间创建成功") diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index fae20ea2..88b6371c 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,35 +11,37 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.core.logging_config import get_business_logger from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType +from app.models.models_model import ModelType, ModelProvider from app.services.memory_agent_service import ( get_end_user_connected_config, ) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool + logger = get_business_logger() class LangChainAgent: def __init__( - self, - model_name: str, - api_key: str, - provider: str = "openai", - api_base: Optional[str] = None, - temperature: float = 0.7, - max_tokens: int = 2000, - system_prompt: Optional[str] = None, - tools: Optional[Sequence[BaseTool]] = None, - streaming: bool = False, - max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) - max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 + self, + model_name: str, + api_key: str, + provider: str = "openai", + api_base: Optional[str] = None, + is_omni: bool = False, + temperature: float = 0.7, + max_tokens: int = 2000, + system_prompt: Optional[str] = None, + tools: Optional[Sequence[BaseTool]] = None, + streaming: bool = False, + max_iterations: Optional[int] = None, # 最大迭代次数(None 表示自动计算) + max_tool_consecutive_calls: int = 3 # 单个工具最大连续调用次数 ): """初始化 LangChain Agent @@ -60,12 +62,13 @@ class LangChainAgent: self.provider = provider self.tools = tools or [] self.streaming = streaming + self.is_omni = is_omni self.max_tool_consecutive_calls = max_tool_consecutive_calls - + # 工具调用计数器:记录每个工具的连续调用次数 self.tool_call_counter: Dict[str, int] = {} self.last_tool_called: Optional[str] = None - + # 根据工具数量动态调整最大迭代次数 # 基础值 + 每个工具额外的调用机会 if max_iterations is None: @@ -73,9 +76,9 @@ class LangChainAgent: self.max_iterations = 5 + len(self.tools) * 2 else: self.max_iterations = max_iterations - + self.system_prompt = system_prompt or "你是一个专业的AI助手" - + logger.debug( f"Agent 迭代次数配置: max_iterations={self.max_iterations}, " f"tool_count={len(self.tools)}, " @@ -89,6 +92,7 @@ class LangChainAgent: provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni, extra_params={ "temperature": temperature, "max_tokens": max_tokens, @@ -143,21 +147,22 @@ class LangChainAgent: """ from langchain_core.tools import StructuredTool from functools import wraps - + wrapped_tools = [] - + for original_tool in tools: tool_name = original_tool.name original_func = original_tool.func if hasattr(original_tool, 'func') else None - + if not original_func: # 如果无法获取原始函数,直接使用原工具 wrapped_tools.append(original_tool) continue - + # 创建包装函数 def make_wrapped_func(tool_name, original_func): """创建包装函数的工厂函数,避免闭包问题""" + @wraps(original_func) def wrapped_func(*args, **kwargs): """包装后的工具函数,跟踪连续调用次数""" @@ -168,13 +173,13 @@ class LangChainAgent: # 切换到新工具,重置计数器 self.tool_call_counter[tool_name] = 1 self.last_tool_called = tool_name - + current_count = self.tool_call_counter[tool_name] - + logger.debug( f"工具调用: {tool_name}, 连续调用次数: {current_count}/{self.max_tool_consecutive_calls}" ) - + # 检查是否超过最大连续调用次数 if current_count > self.max_tool_consecutive_calls: logger.warning( @@ -185,12 +190,12 @@ class LangChainAgent: f"工具 '{tool_name}' 已连续调用 {self.max_tool_consecutive_calls} 次," f"未找到有效结果。请尝试其他方法或直接回答用户的问题。" ) - + # 调用原始工具函数 return original_func(*args, **kwargs) - + return wrapped_func - + # 使用 StructuredTool 创建新工具 wrapped_tool = StructuredTool( name=original_tool.name, @@ -198,17 +203,17 @@ class LangChainAgent: func=make_wrapped_func(tool_name, original_func), args_schema=original_tool.args_schema if hasattr(original_tool, 'args_schema') else None ) - + wrapped_tools.append(wrapped_tool) - + return wrapped_tools def _prepare_messages( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - files: Optional[List[Dict[str, Any]]] = None + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + files: Optional[List[Dict[str, Any]]] = None ) -> List[BaseMessage]: """准备消息列表 @@ -248,7 +253,7 @@ class LangChainAgent: messages.append(HumanMessage(content=user_content)) return messages - + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -261,23 +266,26 @@ class LangChainAgent: List[Dict]: 消息内容列表 """ # 根据 provider 使用不同的文本格式 - if self.provider.lower() in ["bedrock", "anthropic"]: - # Anthropic/Bedrock: {"type": "text", "text": "..."} - content_parts = [{"type": "text", "text": text}] - else: - # 通义千问等: {"text": "..."} - content_parts = [{"text": text}] - + # if (self.provider.lower() in [ModelProvider.BEDROCK, ModelProvider.OPENAI, ModelProvider.XINFERENCE, + # ModelProvider.GPUSTACK] or ( + # self.provider.lower() == ModelProvider.DASHSCOPE and self.is_omni)): + # # Anthropic/Bedrock/Xinference/Gpustack/Openai: {"type": "text", "text": "..."} + # content_parts = [{"type": "text", "text": text}] + # else: + # # 通义千问等: {"text": "..."} + # content_parts = [{"type": "text", "text": text}] + content_parts = [{"type": "text", "text": text}] + # 添加文件内容 # MultimodalService 已经根据 provider 返回了正确格式,直接使用 content_parts.extend(files) - + logger.debug( f"构建多模态消息: provider={self.provider}, " f"parts={len(content_parts)}, " f"files={len(files)}" ) - + return content_parts async def chat( @@ -302,7 +310,7 @@ class LangChainAgent: Returns: Dict: 包含 content 和元数据的字典 """ - message_chat= message + message_chat = message start_time = time.time() actual_config_id = config_id # If config_id is None, try to get from end_user's connected config @@ -322,8 +330,8 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type,str(end_user_id), message, str(user_rag_memory_id)}') + logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') + print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -367,14 +375,14 @@ class LangChainAgent: # 获取最后的 AI 消息 output_messages = result.get("messages", []) content = "" - + logger.debug(f"输出消息数量: {len(output_messages)}") total_tokens = 0 for msg in reversed(output_messages): if isinstance(msg, AIMessage): logger.debug(f"找到 AI 消息,content 类型: {type(msg.content)}") logger.debug(f"AI 消息内容: {msg.content}") - + # 处理多模态响应:content 可能是字符串或列表 if isinstance(msg.content, str): content = msg.content @@ -407,12 +415,13 @@ class LangChainAgent: response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 break - + logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, + actual_config_id) response = { "content": content, "model": self.model_name, @@ -439,16 +448,16 @@ class LangChainAgent: raise async def chat_stream( - self, - message: str, - history: Optional[List[Dict[str, str]]] = None, - context: Optional[str] = None, - end_user_id:Optional[str] = None, - config_id: Optional[str] = None, - storage_type:Optional[str] = None, - user_rag_memory_id:Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + self, + message: str, + history: Optional[List[Dict[str, str]]] = None, + context: Optional[str] = None, + end_user_id: Optional[str] = None, + config_id: Optional[str] = None, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, + memory_flag: Optional[bool] = True, + files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -482,7 +491,6 @@ class LangChainAgent: except Exception as e: logger.warning(f"Failed to get db session: {e}") - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) @@ -500,13 +508,13 @@ class LangChainAgent: full_content = '' try: async for event in self.agent.astream_events( - {"messages": messages}, - version="v2", - config={"recursion_limit": self.max_iterations} + {"messages": messages}, + version="v2", + config={"recursion_limit": self.max_iterations} ): chunk_count += 1 kind = event.get("event") - + # 处理所有可能的流式事件 if kind == "on_chat_model_stream": # LLM 流式输出 @@ -540,7 +548,7 @@ class LangChainAgent: full_content += item yield item yielded_content = True - + elif kind == "on_llm_stream": # 另一种 LLM 流式事件 chunk = event.get("data", {}).get("chunk") @@ -577,13 +585,13 @@ class LangChainAgent: full_content += chunk yield chunk yielded_content = True - + # 记录工具调用(可选) elif kind == "on_tool_start": logger.debug(f"工具调用开始: {event.get('name')}") elif kind == "on_tool_end": logger.debug(f"工具调用结束: {event.get('name')}") - + logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 output_messages = event.get("data", {}).get("output", {}).get("messages", []) @@ -595,7 +603,8 @@ class LangChainAgent: yield total_tokens break if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, actual_config_id) + await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, + actual_config_id) except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise @@ -609,5 +618,3 @@ class LangChainAgent: logger.info("=" * 80) logger.info("chat_stream 方法执行结束") logger.info("=" * 80) - - diff --git a/api/app/core/config.py b/api/app/core/config.py index 7392d29a..bbe327b6 100644 --- a/api/app/core/config.py +++ b/api/app/core/config.py @@ -1,9 +1,10 @@ import json import os from pathlib import Path -from typing import Any, Dict, Optional +from typing import Annotated, Any, Dict, Optional from dotenv import load_dotenv +from pydantic import Field, TypeAdapter load_dotenv() @@ -16,18 +17,18 @@ class Settings: # cloud: SaaS 云服务版(全功能,按量计费) # enterprise: 企业私有化版(License 控制) DEPLOYMENT_MODE: str = os.getenv("DEPLOYMENT_MODE", "community") - + # License 配置(企业版) LICENSE_FILE: str = os.getenv("LICENSE_FILE", "/etc/app/license.json") LICENSE_SERVER_URL: str = os.getenv("LICENSE_SERVER_URL", "https://license.yourcompany.com") - + # 计费服务配置(SaaS 版) BILLING_SERVICE_URL: str = os.getenv("BILLING_SERVICE_URL", "") - + # 基础 URL(用于 SSO 回调等) BASE_URL: str = os.getenv("BASE_URL", "http://localhost:8000") FRONTEND_URL: str = os.getenv("FRONTEND_URL", "http://localhost:3000") - + ENABLE_SINGLE_WORKSPACE: bool = os.getenv("ENABLE_SINGLE_WORKSPACE", "true").lower() == "true" # API Keys Configuration OPENAI_API_KEY: str = os.getenv("OPENAI_API_KEY", "") @@ -57,7 +58,6 @@ class Settings: REDIS_PORT: int = int(os.getenv("REDIS_PORT", "6379")) REDIS_DB: int = int(os.getenv("REDIS_DB", "1")) REDIS_PASSWORD: str = os.getenv("REDIS_PASSWORD", "") - # ElasticSearch configuration ELASTICSEARCH_HOST: str = os.getenv("ELASTICSEARCH_HOST", "https://127.0.0.1") @@ -91,7 +91,7 @@ class Settings: # Single Sign-On configuration ENABLE_SINGLE_SESSION: bool = os.getenv("ENABLE_SINGLE_SESSION", "false").lower() == "true" - + # SSO 免登配置 SSO_TOKEN_EXPIRE_SECONDS: int = int(os.getenv("SSO_TOKEN_EXPIRE_SECONDS", "300")) SSO_TRUSTED_SOURCES_CONFIG: str = os.getenv("SSO_TRUSTED_SOURCES_CONFIG", "{}") @@ -130,7 +130,7 @@ class Settings: # Server Configuration SERVER_IP: str = os.getenv("SERVER_IP", "127.0.0.1") - FILE_LOCAL_SERVER_URL : str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") + FILE_LOCAL_SERVER_URL: str = os.getenv("FILE_LOCAL_SERVER_URL", "http://localhost:8000/api") # ======================================================================== # Internal Configuration (not in .env, used by application code) @@ -190,8 +190,12 @@ class Settings: LOG_FILE_MAX_SIZE_MB: int = int(os.getenv("LOG_FILE_MAX_SIZE_MB", "10")) # 10MB # Celery configuration (internal) - CELERY_BROKER: int = int(os.getenv("CELERY_BROKER", "1")) - CELERY_BACKEND: int = int(os.getenv("CELERY_BACKEND", "2")) + # NOTE: 变量名不以 CELERY_ 开头,避免被 Celery CLI 的前缀匹配机制劫持 + # 详见 docs/celery-env-bug-report.md + # 默认使用 Redis DB 3 (broker) 和 DB 4 (backend),与业务缓存 (DB 1/2) 隔离 + # 多人共用同一 Redis 时,每位开发者应在 .env 中配置不同的 DB 编号避免任务互相干扰 + REDIS_DB_CELERY_BROKER: int = int(os.getenv("REDIS_DB_CELERY_BROKER", "3")) + REDIS_DB_CELERY_BACKEND: int = int(os.getenv("REDIS_DB_CELERY_BACKEND", "4")) # SMTP Email Configuration SMTP_SERVER: str = os.getenv("SMTP_SERVER", "smtp.gmail.com") @@ -201,21 +205,30 @@ class Settings: REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300")) HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600")) - MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24")) REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30)) # Memory Cache Regeneration Configuration MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24")) - # Periodic Task Schedule Configuration - # workspace_reflection: 每隔多少秒执行一次 - WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")) - # forgetting_cycle: 每隔多少小时执行一次 - FORGETTING_CYCLE_INTERVAL_HOURS: int = int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")) - # implicit_emotions_update: 每天几点执行(小时,0-23) + # Celery Beat Schedule Configuration (定时任务执行频率) + MEMORY_INCREMENT_HOUR: int = TypeAdapter( + Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")] + ).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2"))) + MEMORY_INCREMENT_MINUTE: int = TypeAdapter( + Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")] + ).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0"))) + WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter( + Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")] + ).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30"))) + FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter( + Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")] + ).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24"))) + IMPLICIT_EMOTIONS_UPDATE_HOUR: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_HOUR", "2")) # implicit_emotions_update: 每天几分执行(分钟,0-59) - IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) # Memory Module Configuration (internal) + IMPLICIT_EMOTIONS_UPDATE_MINUTE: int = int(os.getenv("IMPLICIT_EMOTIONS_UPDATE_MINUTE", "0")) + # Memory Module Configuration (internal) + MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output") MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory") @@ -232,27 +245,28 @@ class Settings: LOAD_MODEL: bool = os.getenv("LOAD_MODEL", "false").lower() == "true" # workflow config + WORKFLOW_IMPORT_CACHE_TIMEOUT: int = int(os.getenv("WORKFLOW_IMPORT_CACHE_TIMEOUT", 1800)) WORKFLOW_NODE_TIMEOUT: int = int(os.getenv("WORKFLOW_NODE_TIMEOUT", 600)) # ======================================================================== # General Ontology Type Configuration # ======================================================================== # 通用本体文件路径列表(逗号分隔) - GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "General_purpose_entity.ttl") - + GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl") + # 是否启用通用本体类型功能 ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true" - + # Prompt 中最大类型数量 MAX_ONTOLOGY_TYPES_IN_PROMPT: int = int(os.getenv("MAX_ONTOLOGY_TYPES_IN_PROMPT", "50")) - + # 核心通用类型列表(逗号分隔) CORE_GENERAL_TYPES: str = os.getenv( "CORE_GENERAL_TYPES", "Person,Organization,Company,GovernmentAgency,Place,Location,City,Country,Building," "Event,SportsEvent,SocialEvent,Work,Book,Film,Software,Concept,TopicalConcept,AcademicSubject" ) - + # 实验模式开关(允许通过 API 动态切换本体配置) ONTOLOGY_EXPERIMENT_MODE: bool = os.getenv("ONTOLOGY_EXPERIMENT_MODE", "true").lower() == "true" diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py index ac1fb9a6..784e5802 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/problem_nodes.py @@ -1,10 +1,10 @@ -import os import json +import os import time -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.problem_models import ProblemExtensionResponse +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -12,10 +12,9 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) @@ -53,13 +52,14 @@ async def Split_The_Problem(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - structured = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + structured = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) # 添加更详细的日志记录 logger.info(f"Split_The_Problem: 开始处理问题分解,内容长度: {len(content)}") @@ -111,7 +111,7 @@ async def Split_The_Problem(state: ReadState) -> ReadState: "error_type": type(e).__name__, "error_message": str(e), "content_length": len(content), - "llm_model_id": memory_config.llm_model_id if memory_config else None + "llm_model_id": str(memory_config.llm_model_id) if memory_config else None } logger.error(f"Split_The_Problem error details: {error_details}") @@ -171,13 +171,14 @@ async def Problem_Extension(state: ReadState) -> ReadState: try: # 使用优化的LLM服务 - response_content = await problem_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=ProblemExtensionResponse, - fallback_value=[] - ) + with get_db_context() as db_session: + response_content = await problem_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=ProblemExtensionResponse, + fallback_value=[] + ) logger.info(f"Problem_Extension: 开始处理问题扩展,问题数量: {len(databasets)}") @@ -220,7 +221,7 @@ async def Problem_Extension(state: ReadState) -> ReadState: "error_type": type(e).__name__, "error_message": str(e), "questions_count": len(databasets), - "llm_model_id": memory_config.llm_model_id if memory_config else None + "llm_model_id": str(memory_config.llm_model_id) if memory_config else None } logger.error(f"Problem_Extension error details: {error_details}") diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py index 1880357c..06539ad1 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/retrieve_nodes.py @@ -6,31 +6,26 @@ import os # ===== 第三方库 ===== from langchain.agents import create_agent from langchain_openai import ChatOpenAI + from app.core.logging_config import get_agent_logger -from app.db import get_db, get_db_context - -from app.schemas import model_schema -from app.services.memory_config_service import MemoryConfigService -from app.services.model_service import ModelConfigService - -from app.core.memory.agent.services.search_service import SearchService -from app.core.memory.agent.utils.llm_tools import ( - COUNTState, - ReadState, - deduplicate_entries, - merge_to_key_value_pairs, -) from app.core.memory.agent.langgraph_graph.tools.tool import ( create_hybrid_retrieval_tool_sync, create_time_retrieval_tool, extract_tool_message_content, ) - +from app.core.memory.agent.services.search_service import SearchService +from app.core.memory.agent.utils.llm_tools import ( + ReadState, + deduplicate_entries, + merge_to_key_value_pairs, +) from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context +from app.schemas import model_schema +from app.services.memory_config_service import MemoryConfigService +from app.services.model_service import ModelConfigService logger = get_agent_logger(__name__) -db = next(get_db()) - async def rag_config(state): @@ -50,10 +45,12 @@ async def rag_config(state): "reranker_top_k": 10 } return kb_config -async def rag_knowledge(state,question): + + +async def rag_knowledge(state, question): kb_config = await rag_config(state) end_user_id = state.get('end_user_id', '') - user_rag_memory_id=state.get("user_rag_memory_id",'') + user_rag_memory_id = state.get("user_rag_memory_id", '') retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) try: retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] @@ -61,13 +58,13 @@ async def rag_knowledge(state,question): cleaned_query = question raw_results = clean_content logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") - except Exception : - retrieval_knowledge=[] + except Exception: + retrieval_knowledge = [] clean_content = '' raw_results = '' cleaned_query = question logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") - return retrieval_knowledge,clean_content,cleaned_query,raw_results + return retrieval_knowledge, clean_content, cleaned_query, raw_results async def llm_infomation(state: ReadState) -> ReadState: @@ -113,7 +110,7 @@ async def clean_databases(data) -> str: # 收集所有内容 content_list = [] - + # 处理重排序结果 reranked = results.get('reranked_results', {}) if reranked: @@ -141,7 +138,6 @@ async def clean_databases(data) -> str: elif isinstance(item, str): text_parts.append(item) - return '\n'.join(text_parts).strip() except Exception as e: @@ -150,23 +146,23 @@ async def clean_databases(data) -> str: async def retrieve_nodes(state: ReadState) -> ReadState: - ''' 模型信息 ''' - problem_extension=state.get('problem_extension', '')['context'] - storage_type=state.get('storage_type', '') - user_rag_memory_id=state.get('user_rag_memory_id', '') - end_user_id=state.get('end_user_id', '') + problem_extension = state.get('problem_extension', '')['context'] + storage_type = state.get('storage_type', '') + user_rag_memory_id = state.get('user_rag_memory_id', '') + end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - original=state.get('data', '') - problem_list=[] - for key,values in problem_extension.items(): + original = state.get('data', '') + problem_list = [] + for key, values in problem_extension.items(): for data in values: problem_list.append(data) logger.info(f"Retrieve: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") + # 创建异步任务处理单个问题 async def process_question_nodes(idx, question): try: @@ -244,7 +240,7 @@ async def retrieve_nodes(state: ReadState) -> ReadState: send_verify = [] for i, j in zip(keys, val, strict=False): - if j!=['']: + if j != ['']: send_verify.append({ "Query_small": i, "Answer_Small": j @@ -257,15 +253,13 @@ async def retrieve_nodes(state: ReadState) -> ReadState: } logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") - return {'retrieve':dup_databases} - - + return {'retrieve': dup_databases} async def retrieve(state: ReadState) -> ReadState: # 从state中获取end_user_id import time - start=time.time() + start = time.time() problem_extension = state.get('problem_extension', '')['context'] storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') @@ -283,6 +277,7 @@ async def retrieve(state: ReadState) -> ReadState: with get_db_context() as db: # 使用同步数据库上下文管理器 config_service = MemoryConfigService(db) return await llm_infomation(state) + llm_config = await get_llm_info() api_key_obj = llm_config.api_keys[0] api_key = api_key_obj.api_key @@ -296,11 +291,11 @@ async def retrieve(state: ReadState) -> ReadState: ) time_retrieval_tool = create_time_retrieval_tool(end_user_id) - search_params = { "end_user_id": end_user_id, "return_raw_results": True } - hybrid_retrieval=create_hybrid_retrieval_tool_sync(memory_config, **search_params) + search_params = {"end_user_id": end_user_id, "return_raw_results": True} + hybrid_retrieval = create_hybrid_retrieval_tool_sync(memory_config, **search_params) agent = create_agent( llm, - tools=[time_retrieval_tool,hybrid_retrieval], + tools=[time_retrieval_tool, hybrid_retrieval], system_prompt=f"我是检索专家,可以根据适合的工具进行检索。当前使用的end_user_id是: {end_user_id}" ) @@ -314,7 +309,8 @@ async def retrieve(state: ReadState) -> ReadState: async with SEMAPHORE: # 限制并发 try: if storage_type == "rag" and user_rag_memory_id: - retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, question) + retrieval_knowledge, clean_content, cleaned_query, raw_results = await rag_knowledge(state, + question) else: cleaned_query = question # 使用 asyncio 在线程池中运行同步的 agent.invoke @@ -413,5 +409,3 @@ async def retrieve(state: ReadState) -> ReadState: # json.dump(dup_databases, f, indent=4) logger.info(f"Collected {len(intermediate_outputs)} intermediate outputs from search results") return {'retrieve': dup_databases} - - diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py index 0144c0e9..87606bf8 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/summary_nodes.py @@ -1,5 +1,3 @@ - - import os import time @@ -17,33 +15,77 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.db import get_db +from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') logger = get_agent_logger(__name__) -db_session = next(get_db()) + class SummaryNodeService(LLMServiceMixin): """总结节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 summary_service = SummaryNodeService() + +async def rag_config(state): + user_rag_memory_id = state.get('user_rag_memory_id', '') + kb_config = { + "knowledge_bases": [ + { + "kb_id": user_rag_memory_id, + "similarity_threshold": 0.7, + "vector_similarity_weight": 0.5, + "top_k": 10, + "retrieve_type": "participle" + } + ], + "merge_strategy": "weight", + "reranker_id": os.getenv('reranker_id'), + "reranker_top_k": 10 + } + return kb_config + + +async def rag_knowledge(state, question): + kb_config = await rag_config(state) + end_user_id = state.get('end_user_id', '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + retrieve_chunks_result = knowledge_retrieval(question, kb_config, [str(end_user_id)]) + try: + retrieval_knowledge = [i.page_content for i in retrieve_chunks_result] + clean_content = '\n\n'.join(retrieval_knowledge) + cleaned_query = question + raw_results = clean_content + logger.info(f" Using RAG storage with memory_id={user_rag_memory_id}") + except Exception: + retrieval_knowledge = [] + clean_content = '' + raw_results = '' + cleaned_query = question + logger.info(f"No content retrieved from knowledge base: {user_rag_memory_id}") + return retrieval_knowledge, clean_content, cleaned_query, raw_results + + async def summary_history(state: ReadState) -> ReadState: end_user_id = state.get("end_user_id", '') history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) return history -async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model,search_mode) -> str: + +async def summary_llm(state: ReadState, history, retrieve_info, template_name, operation_name, response_model, + search_mode) -> str: """ 增强的summary_llm函数,包含更好的错误处理和数据验证 """ data = state.get("data", '') - + # 构建系统提示词 if str(search_mode) == "0": system_prompt = await summary_service.template_service.render_template( @@ -62,18 +104,19 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o ) try: # 使用优化的LLM服务进行结构化输出 - structured = await summary_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=response_model, - fallback_value=None - ) + with get_db_context() as db_session: + structured = await summary_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=response_model, + fallback_value=None + ) # 验证结构化响应 if structured is None: - logger.warning(f"LLM返回None,使用默认回答") + logger.warning("LLM返回None,使用默认回答") return "信息不足,无法回答" - + # 根据操作类型提取答案 if operation_name == "summary": aimessages = getattr(structured, 'query_answer', None) or "信息不足,无法回答" @@ -82,18 +125,18 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if hasattr(structured, 'data') and structured.data: aimessages = getattr(structured.data, 'query_answer', None) or "信息不足,无法回答" else: - logger.warning(f"结构化响应缺少data字段") + logger.warning("结构化响应缺少data字段") aimessages = "信息不足,无法回答" - + # 验证答案不为空 if not aimessages or aimessages.strip() == "": aimessages = "信息不足,无法回答" - + return aimessages - + except Exception as e: logger.error(f"结构化输出失败: {e}", exc_info=True) - + # 尝试非结构化输出作为fallback try: logger.info("尝试非结构化输出作为fallback") @@ -103,7 +146,7 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o system_prompt=system_prompt, fallback_message="信息不足,无法回答" ) - + if response and response.strip(): # 简单清理响应 cleaned_response = response.strip() @@ -111,16 +154,17 @@ async def summary_llm(state: ReadState, history, retrieve_info, template_name, o if cleaned_response.startswith('```'): lines = cleaned_response.split('\n') cleaned_response = '\n'.join(lines[1:-1]) - + return cleaned_response else: return "信息不足,无法回答" - + except Exception as fallback_error: logger.error(f"Fallback也失败: {fallback_error}") return "信息不足,无法回答" -async def summary_redis_save(state: ReadState,aimessages) -> ReadState: + +async def summary_redis_save(state: ReadState, aimessages) -> ReadState: data = state.get("data", '') end_user_id = state.get("end_user_id", '') await SessionService(store).save_session( @@ -132,10 +176,12 @@ async def summary_redis_save(state: ReadState,aimessages) -> ReadState: ) await SessionService(store).cleanup_duplicates() logger.info(f"sessionid: {aimessages} 写入成功") -async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: - storage_type=state.get("storage_type",'') - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') + + +async def summary_prompt(state: ReadState, aimessages, raw_results) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') input_summary = { "status": "success", "summary_result": aimessages, @@ -152,14 +198,14 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: "user_rag_memory_id": user_rag_memory_id } } - retrieve={ + retrieve = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, "_intermediate": { "type": "retrieval_summary", - "title":"快速检索", + "title": "快速检索", "summary": aimessages, "query": data, "storage_type": storage_type, @@ -167,17 +213,18 @@ async def summary_prompt(state: ReadState,aimessages,raw_results) -> ReadState: } } - return input_summary,retrieve + return input_summary, retrieve + async def Input_Summary(state: ReadState) -> ReadState: - start=time.time() - storage_type=state.get("storage_type",'') + start = time.time() + storage_type = state.get("storage_type", '') memory_config = state.get('memory_config', None) - user_rag_memory_id=state.get("user_rag_memory_id",'') - data=state.get("data", '') - end_user_id=state.get("end_user_id", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') + data = state.get("data", '') + end_user_id = state.get("end_user_id", '') logger.info(f"Input_Summary: storage_type={storage_type}, user_rag_memory_id={user_rag_memory_id}") - history = await summary_history( state) + history = await summary_history(state) search_params = { "end_user_id": end_user_id, "question": data, @@ -186,12 +233,14 @@ async def Input_Summary(state: ReadState) -> ReadState: } try: - retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, memory_config=memory_config) + if storage_type != "rag": + retrieve_info, question, raw_results = await SearchService().execute_hybrid_search(**search_params, + memory_config=memory_config) + else: + retrieval_knowledge, retrieve_info, question, raw_results = await rag_knowledge(state, data) except Exception as e: - logger.error( f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True ) + logger.error(f"Input_Summary: hybrid_search failed, using empty results: {e}", exc_info=True) retrieve_info, question, raw_results = "", data, [] - - try: # aimessages=await summary_llm(state,history,retrieve_info,'Retrieve_Summary_prompt.jinja2', # 'input_summary',RetrieveSummaryResponse) @@ -199,8 +248,8 @@ async def Input_Summary(state: ReadState) -> ReadState: summary_result = await summary_prompt(state, retrieve_info, retrieve_info) summary = summary_result[0] except Exception as e: - logger.error( f"Input_Summary failed: {e}", exc_info=True ) - summary= { + logger.error(f"Input_Summary failed: {e}", exc_info=True) + summary = { "status": "fail", "summary_result": "信息不足,无法回答", "storage_type": storage_type, @@ -213,30 +262,31 @@ async def Input_Summary(state: ReadState) -> ReadState: except Exception: duration = 0.0 log_time('检索', duration) - return {"summary":summary} + return {"summary": summary} -async def Retrieve_Summary(state: ReadState)-> ReadState: - retrieve=state.get("retrieve", '') - history = await summary_history( state) + +async def Retrieve_Summary(state: ReadState) -> ReadState: + retrieve = state.get("retrieve", '') + history = await summary_history(state) import json - with open("检索.json","w",encoding='utf-8') as f: + with open("检索.json", "w", encoding='utf-8') as f: f.write(json.dumps(retrieve, indent=4, ensure_ascii=False)) - retrieve=retrieve.get("Expansion_issue", []) - start=time.time() - retrieve_info_str=[] + retrieve = retrieve.get("Expansion_issue", []) + start = time.time() + retrieve_info_str = [] for data in retrieve: - if data=='': - retrieve_info_str='' + if data == '': + retrieve_info_str = '' else: for key, value in data.items(): - if key=='Answer_Small': + if key == 'Answer_Small': for i in value: retrieve_info_str.append(i) - retrieve_info_str=list(set(retrieve_info_str)) - retrieve_info_str='\n'.join(retrieve_info_str) + retrieve_info_str = list(set(retrieve_info_str)) + retrieve_info_str = '\n'.join(retrieve_info_str) - aimessages=await summary_llm(state,history,retrieve_info_str, - 'direct_summary_prompt.jinja2','retrieve_summary',RetrieveSummaryResponse,"1") + aimessages = await summary_llm(state, history, retrieve_info_str, + 'direct_summary_prompt.jinja2', 'retrieve_summary', RetrieveSummaryResponse, "1") if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) if aimessages == '': @@ -248,33 +298,33 @@ async def Retrieve_Summary(state: ReadState)-> ReadState: except Exception: duration = 0.0 log_time('Retrieval summary', duration) - + # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} + return {"summary": summary} -async def Summary(state: ReadState)-> ReadState: - start=time.time() +async def Summary(state: ReadState) -> ReadState: + start = time.time() query = state.get("data", '') - verify=state.get("verify", '') - verify_expansion_issue=verify.get("verified_data", '') - retrieve_info_str='' + verify = state.get("verify", '') + verify_expansion_issue = verify.get("verified_data", '') + retrieve_info_str = '' for data in verify_expansion_issue: for key, value in data.items(): - if key=='answer_small': + if key == 'answer_small': for i in value: - retrieve_info_str+=i+'\n' - history=await summary_history(state) + retrieve_info_str += i + '\n' + history = await summary_history(state) data = { "query": query, "history": history, "retrieve_info": retrieve_info_str } - aimessages=await summary_llm(state,history,data, - 'summary_prompt.jinja2','summary',SummaryResponse,0) + aimessages = await summary_llm(state, history, data, + 'summary_prompt.jinja2', 'summary', SummaryResponse, 0) if '信息不足,无法回答' not in str(aimessages) or str(aimessages) != "": await summary_redis_save(state, aimessages) @@ -289,11 +339,12 @@ async def Summary(state: ReadState)-> ReadState: # 修复协程调用 - 先await,然后访问返回值 summary_result = await summary_prompt(state, aimessages, retrieve_info_str) summary = summary_result[1] - return {"summary":summary} + return {"summary": summary} -async def Summary_fails(state: ReadState)-> ReadState: - storage_type=state.get("storage_type", '') - user_rag_memory_id=state.get("user_rag_memory_id", '') + +async def Summary_fails(state: ReadState) -> ReadState: + storage_type = state.get("storage_type", '') + user_rag_memory_id = state.get("user_rag_memory_id", '') history = await summary_history(state) query = state.get("data", '') verify = state.get("verify", '') @@ -309,12 +360,12 @@ async def Summary_fails(state: ReadState)-> ReadState: "history": history, "retrieve_info": retrieve_info_str } - aimessages = await summary_llm(state, history, data, - 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) - result= { + aimessages = await summary_llm(state, history, data, + 'fail_summary_prompt.jinja2', 'summary', SummaryResponse, 0) + result = { "status": "success", "summary_result": aimessages, "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id } - return {"summary":result} \ No newline at end of file + return {"summary": result} diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py index b809faf2..3f7b491e 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/verification_nodes.py @@ -1,8 +1,9 @@ +import asyncio import os -from app.core.logging_config import get_agent_logger -from app.db import get_db +from app.core.logging_config import get_agent_logger from app.core.memory.agent.models.verification_models import VerificationResult +from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.utils.llm_tools import ( PROJECT_ROOT_, ReadState, @@ -10,28 +11,30 @@ from app.core.memory.agent.utils.llm_tools import ( from app.core.memory.agent.utils.redis_tool import store from app.core.memory.agent.utils.session_tools import SessionService from app.core.memory.agent.utils.template_tools import TemplateService -from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin +from app.db import get_db_context template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -db_session = next(get_db()) logger = get_agent_logger(__name__) + class VerificationNodeService(LLMServiceMixin): """验证节点服务类""" - + def __init__(self): super().__init__() self.template_service = TemplateService(template_root) + # 创建全局服务实例 verification_service = VerificationNodeService() + async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): """处理验证结果并生成输出格式""" storage_type = state.get('storage_type', '') user_rag_memory_id = state.get('user_rag_memory_id', '') data = state.get('data', '') - + # 将 VerificationItem 对象转换为字典列表 verified_data = [] if messages_deal.expansion_issue: @@ -40,7 +43,7 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): verified_data.append(item.model_dump()) elif isinstance(item, dict): verified_data.append(item) - + Verify_result = { "status": messages_deal.split_result, "verified_data": verified_data, @@ -58,34 +61,37 @@ async def Verify_prompt(state: ReadState, messages_deal: VerificationResult): } } return Verify_result + + async def Verify(state: ReadState): logger.info("=== Verify 节点开始执行 ===") try: content = state.get('data', '') end_user_id = state.get('end_user_id', '') memory_config = state.get('memory_config', None) - + logger.info(f"Verify: content={content[:50] if content else 'empty'}..., end_user_id={end_user_id}") history = await SessionService(store).get_history(end_user_id, end_user_id, end_user_id) logger.info(f"Verify: 获取历史记录完成,history length={len(history)}") retrieve = state.get("retrieve", {}) - logger.info(f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") - + logger.info( + f"Verify: retrieve data type={type(retrieve)}, keys={retrieve.keys() if isinstance(retrieve, dict) else 'N/A'}") + retrieve_expansion = retrieve.get("Expansion_issue", []) if isinstance(retrieve, dict) else [] logger.info(f"Verify: Expansion_issue length={len(retrieve_expansion)}") - + messages = { "Query": content, "Expansion_issue": retrieve_expansion } logger.info("Verify: 开始渲染模板") - + # 生成 JSON schema 以指导 LLM 输出正确格式 json_schema = VerificationResult.model_json_schema() - + system_prompt = await verification_service.template_service.render_template( template_name='split_verify_prompt.jinja2', operation_name='split_verify_prompt', @@ -94,29 +100,30 @@ async def Verify(state: ReadState): json_schema=json_schema ) logger.info(f"Verify: 模板渲染完成,prompt length={len(system_prompt)}") - + # 使用优化的LLM服务,添加超时保护 logger.info("Verify: 开始调用 LLM") try: # 添加 asyncio.wait_for 超时包裹,防止无限等待 # 超时时间设置为 150 秒(比 LLM 配置的 120 秒稍长) - import asyncio - structured = await asyncio.wait_for( - verification_service.call_llm_structured( - state=state, - db_session=db_session, - system_prompt=system_prompt, - response_model=VerificationResult, - fallback_value={ - "query": content, - "history": history if isinstance(history, list) else [], - "expansion_issue": [], - "split_result": "failed", - "reason": "验证失败或超时" - } - ), - timeout=150.0 # 150秒超时 - ) + + with get_db_context() as db_session: + structured = await asyncio.wait_for( + verification_service.call_llm_structured( + state=state, + db_session=db_session, + system_prompt=system_prompt, + response_model=VerificationResult, + fallback_value={ + "query": content, + "history": history if isinstance(history, list) else [], + "expansion_issue": [], + "split_result": "failed", + "reason": "验证失败或超时" + } + ), + timeout=150.0 # 150秒超时 + ) logger.info(f"Verify: LLM 调用完成,result={structured}") except asyncio.TimeoutError: logger.error("Verify: LLM 调用超时(150秒),使用 fallback 值") @@ -127,11 +134,11 @@ async def Verify(state: ReadState): split_result="failed", reason="LLM调用超时" ) - + result = await Verify_prompt(state, structured) logger.info("=== Verify 节点执行完成 ===") return {"verify": result} - + except Exception as e: logger.error(f"Verify 节点执行失败: {e}", exc_info=True) # 返回失败的验证结果 @@ -152,4 +159,4 @@ async def Verify(state: ReadState): "user_rag_memory_id": state.get('user_rag_memory_id', '') } } - } \ No newline at end of file + } diff --git a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py index ad0473fc..10fe96ba 100644 --- a/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py +++ b/api/app/core/memory/agent/langgraph_graph/nodes/write_nodes.py @@ -1,3 +1,4 @@ +from app.cache.memory.interest_memory import InterestMemoryCache from app.core.memory.agent.utils.llm_tools import WriteState from app.core.memory.agent.utils.write_tools import write from app.core.logging_config import get_agent_logger @@ -40,6 +41,15 @@ async def write_node(state: WriteState) -> WriteState: ) logger.info(f"Write completed successfully! Config: {memory_config.config_name}") + # 写入 neo4j 成功后,删除该用户的兴趣分布缓存,确保下次请求重新生成 + for lang in ["zh", "en"]: + deleted = await InterestMemoryCache.delete_interest_distribution( + end_user_id=end_user_id, + language=lang, + ) + if deleted: + logger.info(f"Invalidated interest distribution cache: end_user_id={end_user_id}, language={lang}") + write_result = { "status": "success", "data": structured_messages, diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index 3476d0ec..cba1b230 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -5,7 +5,6 @@ from langchain_core.messages import HumanMessage from langgraph.constants import START, END from langgraph.graph import StateGraph - from app.db import get_db from app.services.memory_config_service import MemoryConfigService @@ -32,7 +31,6 @@ from app.core.memory.agent.langgraph_graph.routing.routers import ( ) - @asynccontextmanager async def make_read_graph(): """创建并返回 LangGraph 工作流""" @@ -49,7 +47,7 @@ async def make_read_graph(): workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) workflow.add_node("Summary_fails", Summary_fails) - + # 添加边 workflow.add_edge(START, "content_input") workflow.add_conditional_edges("content_input", Split_continue) @@ -62,20 +60,20 @@ async def make_read_graph(): workflow.add_edge("Summary_fails", END) workflow.add_edge("Summary", END) - '''-----''' # workflow.add_edge("Retrieve", END) - + # 编译工作流 graph = workflow.compile() yield graph - + except Exception as e: print(f"创建工作流失败: {e}") raise finally: print("工作流创建完成") + async def main(): """主函数 - 运行工作流""" message = "昨天有什么好看的电影" @@ -92,17 +90,19 @@ async def main(): service_name="MemoryAgentService" ) import time - start=time.time() + start = time.time() try: async with make_read_graph() as graph: config = {"configurable": {"thread_id": end_user_id}} # 初始状态 - 包含所有必要字段 - initial_state = {"messages": [HumanMessage(content=message)] ,"search_switch":search_switch,"end_user_id":end_user_id - ,"storage_type":storage_type,"user_rag_memory_id":user_rag_memory_id,"memory_config":memory_config} + initial_state = {"messages": [HumanMessage(content=message)], "search_switch": search_switch, + "end_user_id": end_user_id + , "storage_type": storage_type, "user_rag_memory_id": user_rag_memory_id, + "memory_config": memory_config} # 获取节点更新信息 _intermediate_outputs = [] summary = '' - + async for update_event in graph.astream( initial_state, stream_mode="updates", @@ -110,7 +110,7 @@ async def main(): ): for node_name, node_data in update_event.items(): print(f"处理节点: {node_name}") - + # 处理不同Summary节点的返回结构 if 'Summary' in node_name: if 'InputSummary' in node_data and 'summary_result' in node_data['InputSummary']: @@ -125,23 +125,22 @@ async def main(): spit_data = node_data.get('spit_data', {}).get('_intermediate', None) if spit_data and spit_data != [] and spit_data != {}: _intermediate_outputs.append(spit_data) - + # Problem_Extension 节点 problem_extension = node_data.get('problem_extension', {}).get('_intermediate', None) if problem_extension and problem_extension != [] and problem_extension != {}: _intermediate_outputs.append(problem_extension) - + # Retrieve 节点 retrieve_node = node_data.get('retrieve', {}).get('_intermediate_outputs', None) if retrieve_node and retrieve_node != [] and retrieve_node != {}: _intermediate_outputs.extend(retrieve_node) - + # Verify 节点 verify_n = node_data.get('verify', {}).get('_intermediate', None) if verify_n and verify_n != [] and verify_n != {}: _intermediate_outputs.append(verify_n) - # Summary 节点 summary_n = node_data.get('summary', {}).get('_intermediate', None) if summary_n and summary_n != [] and summary_n != {}: @@ -161,17 +160,20 @@ async def main(): # print(f"=== 最终摘要 ===") print(summary) - + except Exception as e: import traceback traceback.print_exc() + finally: + db_session.close() - end=time.time() - print(100*'y') - print(f"总耗时: {end-start}s") - print(100*'y') + end = time.time() + print(100 * 'y') + print(f"总耗时: {end - start}s") + print(100 * 'y') if __name__ == "__main__": import asyncio + asyncio.run(main()) diff --git a/api/app/core/memory/agent/utils/get_dialogs.py b/api/app/core/memory/agent/utils/get_dialogs.py index bfb0f675..ea44d0a5 100644 --- a/api/app/core/memory/agent/utils/get_dialogs.py +++ b/api/app/core/memory/agent/utils/get_dialogs.py @@ -21,7 +21,7 @@ async def get_chunked_dialogs( end_user_id: Group identifier messages: Structured message list [{"role": "user", "content": "..."}, ...] ref_id: Reference identifier - config_id: Configuration ID for processing + config_id: Configuration ID for processing (used to load pruning config) Returns: List of DialogData objects with generated chunks @@ -57,6 +57,63 @@ async def get_chunked_dialogs( end_user_id=end_user_id, config_id=config_id ) + + # 语义剪枝步骤(在分块之前) + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import SemanticPruner + from app.core.memory.models.config_models import PruningConfig + from app.db import get_db_context + from app.services.memory_config_service import MemoryConfigService + from app.core.memory.utils.llm.llm_utils import MemoryClientFactory + + # 加载剪枝配置 + pruning_config = None + if config_id: + try: + with get_db_context() as db: + # 使用 MemoryConfigService 加载完整的 MemoryConfig 对象 + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + service_name="semantic_pruning" + ) + + if memory_config: + pruning_config = PruningConfig( + pruning_switch=memory_config.pruning_enabled, + pruning_scene=memory_config.pruning_scene or "education", + pruning_threshold=memory_config.pruning_threshold, + scene_id=str(memory_config.scene_id) if memory_config.scene_id else None, + ontology_classes=memory_config.ontology_classes, + ) + logger.info(f"[剪枝] 加载配置: switch={pruning_config.pruning_switch}, scene={pruning_config.pruning_scene}, threshold={pruning_config.pruning_threshold}") + + # 获取LLM客户端用于剪枝 + if pruning_config.pruning_switch: + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client_from_config(memory_config) + + # 执行剪枝 - 使用 prune_dataset 支持消息级剪枝 + pruner = SemanticPruner(config=pruning_config, llm_client=llm_client) + original_msg_count = len(dialog_data.context.msgs) + + # 使用 prune_dataset 而不是 prune_dialog + # prune_dataset 会进行消息级剪枝,即使对话整体相关也会删除不重要消息 + pruned_dialogs = await pruner.prune_dataset([dialog_data]) + + if pruned_dialogs: + dialog_data = pruned_dialogs[0] + remaining_msg_count = len(dialog_data.context.msgs) + deleted_count = original_msg_count - remaining_msg_count + logger.info(f"[剪枝] 完成: 原始{original_msg_count}条 -> 保留{remaining_msg_count}条 (删除{deleted_count}条)") + else: + logger.warning("[剪枝] prune_dataset 返回空列表") + else: + logger.info("[剪枝] 配置中剪枝开关关闭,跳过剪枝") + except Exception as e: + logger.warning(f"[剪枝] 加载配置失败,跳过剪枝: {e}", exc_info=True) + except Exception as e: + logger.warning(f"[剪枝] 执行失败,跳过剪枝: {e}", exc_info=True) chunker = DialogueChunker(chunker_strategy) extracted_chunks = await chunker.process_dialogue(dialog_data) diff --git a/api/app/core/memory/agent/utils/llm_client_pool.py b/api/app/core/memory/agent/utils/llm_client_pool.py deleted file mode 100644 index fddd54f6..00000000 --- a/api/app/core/memory/agent/utils/llm_client_pool.py +++ /dev/null @@ -1,56 +0,0 @@ - -import asyncio -from typing import Dict, Optional -from app.core.memory.utils.llm.llm_utils import get_llm_client_fast -from app.db import get_db -from app.core.logging_config import get_agent_logger - -logger = get_agent_logger(__name__) - -class LLMClientPool: - """LLM客户端连接池""" - - def __init__(self, max_size: int = 5): - self.max_size = max_size - self.pools: Dict[str, asyncio.Queue] = {} - self.active_clients: Dict[str, int] = {} - - async def get_client(self, llm_model_id: str): - """获取LLM客户端""" - if llm_model_id not in self.pools: - self.pools[llm_model_id] = asyncio.Queue(maxsize=self.max_size) - self.active_clients[llm_model_id] = 0 - - pool = self.pools[llm_model_id] - - try: - # 尝试从池中获取客户端 - client = pool.get_nowait() - logger.debug(f"从池中获取LLM客户端: {llm_model_id}") - return client - except asyncio.QueueEmpty: - # 池为空,创建新客户端 - if self.active_clients[llm_model_id] < self.max_size: - db_session = next(get_db()) - client = get_llm_client_fast(llm_model_id, db_session) - self.active_clients[llm_model_id] += 1 - logger.debug(f"创建新LLM客户端: {llm_model_id}") - return client - else: - # 等待可用客户端 - logger.debug(f"等待LLM客户端可用: {llm_model_id}") - return await pool.get() - - async def return_client(self, llm_model_id: str, client): - """归还LLM客户端到池中""" - if llm_model_id in self.pools: - try: - self.pools[llm_model_id].put_nowait(client) - logger.debug(f"归还LLM客户端到池: {llm_model_id}") - except asyncio.QueueFull: - # 池已满,丢弃客户端 - self.active_clients[llm_model_id] -= 1 - logger.debug(f"池已满,丢弃LLM客户端: {llm_model_id}") - -# 全局客户端池 -llm_client_pool = LLMClientPool() diff --git a/api/app/core/memory/agent/utils/write_tools.py b/api/app/core/memory/agent/utils/write_tools.py index 93c6ef6f..22030278 100644 --- a/api/app/core/memory/agent/utils/write_tools.py +++ b/api/app/core/memory/agent/utils/write_tools.py @@ -225,5 +225,24 @@ async def write( with open(log_file, "a", encoding="utf-8") as f: f.write(f"=== Pipeline Run Completed: {timestamp} ===\n\n") + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + + stats_to_cache = { + "chunk_count": len(all_chunk_nodes) if all_chunk_nodes else 0, + "statements_count": len(all_statement_nodes) if all_statement_nodes else 0, + "triplet_entities_count": len(all_entity_nodes) if all_entity_nodes else 0, + "triplet_relations_count": len(all_entity_entity_edges) if all_entity_entity_edges else 0, + "temporal_count": 0, + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[WRITE] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + logger.info("=== Pipeline Complete ===") logger.info(f"Total execution time: {total_time:.2f} seconds") \ No newline at end of file diff --git a/api/app/core/memory/analytics/hot_memory_tags.py b/api/app/core/memory/analytics/hot_memory_tags.py index f99b811e..6afcec6d 100644 --- a/api/app/core/memory/analytics/hot_memory_tags.py +++ b/api/app/core/memory/analytics/hot_memory_tags.py @@ -1,9 +1,12 @@ import asyncio import json +import logging import os from typing import List, Tuple from app.core.config import settings + +logger = logging.getLogger(__name__) from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -16,6 +19,10 @@ class FilteredTags(BaseModel): """用于接收LLM筛选后的核心标签列表的模型。""" meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。") +class InterestTags(BaseModel): + """用于接收LLM筛选后的兴趣活动标签列表的模型。""" + interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。") + async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: """ 使用LLM筛选标签列表,仅保留具有代表性的核心名词。 @@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]: return structured_response.meaningful_tags except Exception as e: - print(f"LLM筛选过程中发生错误: {e}") + logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True) # 在LLM失败时返回原始标签,确保流程继续 return tags +async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]: + """ + 使用LLM从标签列表中筛选出代表用户兴趣活动的标签。 + + 与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣, + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 + + Args: + tags: 原始标签列表 + end_user_id: 用户ID,用于获取LLM配置 + + Returns: + 筛选后的兴趣活动标签列表 + """ + try: + with get_db_context() as db: + from app.services.memory_agent_service import ( + get_end_user_connected_config, + ) + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + + if not config_id and not workspace_id: + raise ValueError( + f"No memory_config_id found for end_user_id: {end_user_id}." + ) + + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + + if not memory_config.llm_model_id: + raise ValueError( + f"No llm_model_id found in memory config {config_id}." + ) + + factory = MemoryClientFactory(db) + llm_client = factory.get_llm_client(memory_config.llm_model_id) + + tag_list_str = ", ".join(tags) + from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt + rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language) + messages = [ + { + "role": "user", + "content": rendered_prompt + } + ] + + structured_response = await llm_client.response_structured( + messages=messages, + response_model=InterestTags + ) + + return structured_response.interest_tags + + except Exception as e: + logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True) + return tags + + async def get_raw_tags_from_db( connector: Neo4jConnector, end_user_id: str, @@ -139,14 +210,14 @@ async def get_raw_tags_from_db( return [(record["name"], record["frequency"]) for record in results] -async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = False) -> List[Tuple[str, int]]: +async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]: """ 获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。 - 查询更多的标签(limit=40)给LLM提供更丰富的上下文进行筛选。 + 查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。 Args: end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id - limit: 返回的标签数量限制 + limit: 最终返回的标签数量限制(默认10) by_user: 是否按user_id查询(默认False,按end_user_id查询) Raises: @@ -161,8 +232,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = # 使用项目的Neo4jConnector connector = Neo4jConnector() try: - # 1. 从数据库获取原始排名靠前的标签 - raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user) + # 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文) + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) if not raw_tags_with_freq: return [] @@ -177,7 +249,61 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 40, by_user: bool = if tag in meaningful_tag_names: final_tags.append((tag, freq)) - return final_tags + # 4. 限制返回的标签数量 + return final_tags[:limit] finally: # 确保关闭连接 await connector.close() + +async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]: + """ + 获取用户的兴趣分布标签。 + + 与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt, + 过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。 + + Args: + end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id + limit: 最终返回的标签数量限制(默认10) + by_user: 是否按user_id查询(默认False,按end_user_id查询) + + Raises: + ValueError: 如果end_user_id未提供或为空 + """ + if not end_user_id or not end_user_id.strip(): + raise ValueError( + "end_user_id is required. Please provide a valid end_user_id or user_id." + ) + + connector = Neo4jConnector() + try: + # 查询更多原始标签,给LLM提供充足上下文 + query_limit = 40 + raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user) + if not raw_tags_with_freq: + return [] + + raw_tag_names = [tag for tag, freq in raw_tags_with_freq] + raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq} + + # 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签) + interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language) + + # 构建最终标签列表: + # - 原始标签中存在的,保留原始频率 + # - LLM推断出的新标签(不在原始列表中),赋予默认频率1 + final_tags = [] + seen = set() + for tag in interest_tag_names: + if tag in seen: + continue + seen.add(tag) + freq = raw_freq_map.get(tag, 1) + final_tags.append((tag, freq)) + + # 按频率降序排列 + final_tags.sort(key=lambda x: x[1], reverse=True) + + return final_tags[:limit] + finally: + await connector.close() diff --git a/api/app/core/memory/models/config_models.py b/api/app/core/memory/models/config_models.py index ca1780aa..c2d62ac1 100644 --- a/api/app/core/memory/models/config_models.py +++ b/api/app/core/memory/models/config_models.py @@ -10,7 +10,7 @@ Classes: TemporalSearchParams: Parameters for temporal search queries """ -from typing import Optional +from typing import Optional, List from pydantic import BaseModel, Field @@ -55,17 +55,26 @@ class PruningConfig(BaseModel): Attributes: pruning_switch: Enable or disable semantic pruning - pruning_scene: Scene type for pruning ('education', 'online_service', 'outbound') + pruning_scene: Scene name for pruning, either a built-in key + ('education', 'online_service', 'outbound') or a custom scene_name + from ontology_scene table pruning_threshold: Pruning ratio (0-0.9, max 0.9 to avoid complete removal) + scene_id: Optional ontology scene UUID, used to load custom ontology classes + ontology_classes: List of class_name strings from ontology_class table, + injected into the prompt when pruning_scene is not a built-in scene """ pruning_switch: bool = Field(False, description="Enable semantic pruning when True.") pruning_scene: str = Field( "education", - description="Scene for pruning: one of 'education', 'online_service', 'outbound'.", + description="Scene for pruning: built-in key or custom scene_name from ontology_scene.", ) pruning_threshold: float = Field( 0.5, ge=0.0, le=0.9, description="Pruning ratio within 0-0.9 (max 0.9 to avoid termination).") + scene_id: Optional[str] = Field(None, description="Ontology scene UUID (optional).") + ontology_classes: Optional[List[str]] = Field( + None, description="Class names from ontology_class table for custom scenes." + ) class TemporalSearchParams(BaseModel): diff --git a/api/General_purpose_entity.ttl b/api/app/core/memory/ontology_services/General_purpose_entity.ttl similarity index 100% rename from api/General_purpose_entity.ttl rename to api/app/core/memory/ontology_services/General_purpose_entity.ttl diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py index d19e511b..904b238f 100644 --- a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/data_pruning.py @@ -5,20 +5,27 @@ - 对话级一次性抽取判定相关性 - 仅对"不相关对话"的消息按比例删除 - 重要信息(时间、编号、金额、联系方式、地址等)优先保留 +- 改进版:增强重要性判断、智能填充消息识别、问答对保护、并发优化 """ +import asyncio import os import hashlib import json import re +from collections import OrderedDict from datetime import datetime -from typing import List, Optional +from typing import List, Optional, Dict, Tuple, Set from pydantic import BaseModel, Field from app.core.memory.models.message_models import DialogData, ConversationMessage, ConversationContext from app.core.memory.models.config_models import PruningConfig from app.core.memory.utils.config.config_utils import get_pruning_config from app.core.memory.utils.prompt.prompt_utils import prompt_env, log_prompt_rendering, log_template_rendering +from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import ( + SceneConfigRegistry, + ScenePatterns +) class DialogExtractionResponse(BaseModel): @@ -36,6 +43,23 @@ class DialogExtractionResponse(BaseModel): keywords: List[str] = Field(default_factory=list) +class MessageImportanceResponse(BaseModel): + """消息重要性批量判断的结构化返回(用于LLM语义判断)。 + + - importance_scores: 消息索引到重要性分数的映射 (0-10分) + - reasons: 可选的判断理由 + """ + importance_scores: Dict[int, int] = Field(default_factory=dict, description="消息索引到重要性分数(0-10)的映射") + reasons: Optional[Dict[int, str]] = Field(default_factory=dict, description="可选的判断理由") + + +class QAPair(BaseModel): + """问答对模型,用于识别和保护对话中的问答结构。""" + question_idx: int = Field(..., description="问题消息的索引") + answer_idx: int = Field(..., description="答案消息的索引") + confidence: float = Field(default=1.0, description="问答对的置信度(0-1)") + + class SemanticPruner: """语义剪枝:在预处理与分块之间过滤与场景不相关内容。 @@ -43,109 +67,385 @@ class SemanticPruner: 重要信息(时间、编号、金额、联系方式、地址等)优先保留。 """ - def __init__(self, config: Optional[PruningConfig] = None, llm_client=None): - cfg_dict = get_pruning_config() if config is None else config.model_dump() - self.config = PruningConfig.model_validate(cfg_dict) + def __init__(self, config: Optional[PruningConfig] = None, llm_client=None, language: str = "zh", max_concurrent: int = 5): + # 如果没有提供config,使用默认配置 + if config is None: + # 使用默认的剪枝配置 + config = PruningConfig( + pruning_switch=False, # 默认关闭剪枝,保持向后兼容 + pruning_scene="education", + pruning_threshold=0.5 + ) + + self.config = config self.llm_client = llm_client + self.language = language # 保存语言配置 + self.max_concurrent = max_concurrent # 新增:最大并发数 + + # 详细日志配置:限制逐条消息日志的数量 + self._detailed_prune_logging = True # 是否启用详细日志 + self._max_debug_msgs_per_dialog = 20 # 每个对话最多记录前N条消息的详细日志 + + # 加载场景特定配置(内置场景走专门规则,自定义场景 fallback 到通用规则) + self.scene_config: ScenePatterns = SceneConfigRegistry.get_config( + self.config.pruning_scene, + fallback_to_generic=True + ) + + # 判断是否为内置专门场景 + self._is_builtin_scene = SceneConfigRegistry.is_scene_supported(self.config.pruning_scene) + + # 自定义场景的本体类型列表(用于注入提示词) + self._ontology_classes = getattr(self.config, "ontology_classes", None) or [] + + if self._is_builtin_scene: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 使用内置专门配置") + else: + self._log(f"[剪枝-初始化] 场景={self.config.pruning_scene} 为自定义场景,使用通用规则 + 本体类型提示词注入") + if self._ontology_classes: + self._log(f"[剪枝-初始化] 注入本体类型: {self._ontology_classes}") + else: + self._log(f"[剪枝-初始化] 未找到本体类型,将使用通用提示词") + # Load Jinja2 template self.template = prompt_env.get_template("extracat_Pruning.jinja2") - # 对话抽取缓存:避免同一对话重复调用 LLM / 重复渲染 - self._dialog_extract_cache: dict[str, DialogExtractionResponse] = {} + + # 对话抽取缓存:使用 OrderedDict 实现 LRU 缓存 + self._dialog_extract_cache: OrderedDict[str, DialogExtractionResponse] = OrderedDict() + self._cache_max_size = 1000 # 缓存大小限制 + # 运行日志:收集关键终端输出,便于写入 JSON self.run_logs: List[str] = [] - # 采用顺序处理,移除并发配置以简化与稳定执行 def _is_important_message(self, message: ConversationMessage) -> bool: """基于启发式规则识别重要信息消息,优先保留。 - - 含日期/时间(如YYYY-MM-DD、HH:MM、2024年11月10日、上午/下午)。 - - 含编号/ID/订单号/申请号/账号/电话/金额等关键字段。 - - 关键词:"时间"、"日期"、"编号"、"订单"、"流水"、"金额"、"¥"、"元"、"电话"、"手机号"、"邮箱"、"地址"。 + 改进版:使用场景特定的模式进行识别 + - 根据 pruning_scene 动态加载对应的识别规则 + - 支持教育、在线服务、外呼三个场景的特定模式 """ - import re text = message.msg.strip() if not text: return False - patterns = [ - r"\b\d{4}-\d{1,2}-\d{1,2}\b", - r"\b\d{1,2}:\d{2}\b", - r"\d{4}年\d{1,2}月\d{1,2}日", - r"上午|下午|AM|PM", - r"订单号|工单|申请号|编号|ID|账号|账户", - r"电话|手机号|微信|QQ|邮箱", - r"地址|地点", - r"金额|费用|价格|¥|¥|\d+元", - r"时间|日期|有效期|截止", - ] - for p in patterns: - if re.search(p, text, flags=re.IGNORECASE): + + # 使用场景特定的模式 + all_patterns = ( + self.scene_config.high_priority_patterns + + self.scene_config.medium_priority_patterns + + self.scene_config.low_priority_patterns + ) + + for pattern, _ in all_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): return True + + # 检查是否为问句(以问号结尾或包含疑问词) + if text.endswith("?") or text.endswith("?"): + return True + + # 检查是否包含问句关键词 + if any(keyword in text for keyword in self.scene_config.question_keywords): + return True + + # 检查是否包含决策性关键词 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + return True + return False + def _importance_score(self, message: ConversationMessage) -> int: """为重要消息打分,用于在保留比例内优先保留更关键的内容。 - 简单启发:匹配到的类别越多、越关键分值越高。 + 改进版:使用场景特定的权重体系(0-10分) + - 根据场景动态调整不同信息类型的权重 + - 高优先级模式:4-6分 + - 中优先级模式:2-3分 + - 低优先级模式:1分 """ - import re text = message.msg.strip() score = 0 - weights = [ - (r"\b\d{4}-\d{1,2}-\d{1,2}\b", 3), - (r"\b\d{1,2}:\d{2}\b", 2), - (r"\d{4}年\d{1,2}月\d{1,2}日", 3), - (r"订单号|工单|申请号|编号|ID|账号|账户", 4), - (r"电话|手机号|微信|QQ|邮箱", 3), - (r"地址|地点", 2), - (r"金额|费用|价格|¥|¥|\d+元", 4), - (r"时间|日期|有效期|截止", 2), - ] - for p, w in weights: - if re.search(p, text, flags=re.IGNORECASE): - score += w - return score + + # 使用场景特定的权重 + for pattern, weight in self.scene_config.high_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight + + for pattern, weight in self.scene_config.medium_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight + + for pattern, weight in self.scene_config.low_priority_patterns: + if re.search(pattern, text, flags=re.IGNORECASE): + score += weight + + # 问句加分 + if text.endswith("?") or text.endswith("?"): + score += 2 + + # 包含问句关键词加分 + if any(keyword in text for keyword in self.scene_config.question_keywords): + score += 1 + + # 包含决策性关键词加分 + if any(keyword in text for keyword in self.scene_config.decision_keywords): + score += 2 + + # 长度加分(较长的消息通常包含更多信息) + if len(text) > 50: + score += 1 + if len(text) > 100: + score += 1 + + return min(score, 10) # 最高10分 def _is_filler_message(self, message: ConversationMessage) -> bool: - """检测典型寒暄/口头禅/确认类短消息,用于跳过LLM分类以加速。 + """检测典型寒暄/口头禅/确认类短消息。 + 改进版:更严格的填充消息判断,避免误删场景相关内容 满足以下之一视为填充消息: - - 纯标点或长度很短(<= 4 个汉字或 <= 8 个字符)且不包含数字或关键实体; - - 常见词:你好/您好/在吗/嗯/嗯嗯/哦/好的/好/行/可以/不可以/谢谢/拜拜/再见/哈哈/呵呵/哈哈哈/。。。/??。 + - 纯标点或空白 + - 在场景特定填充词库中(精确匹配) + - 纯表情符号 + - 常见寒暄(精确匹配短语) + + 注意:不再使用长度判断,避免误删短但重要的消息 """ - import re t = message.msg.strip() if not t: return True - # 常见填充语 - fillers = [ - "你好", "您好", "在吗", "嗯", "嗯嗯", "哦", "好的", "好", "行", "可以", "不可以", "谢谢", - "拜拜", "再见", "哈哈", "呵呵", "哈哈哈", "。。。", "??", "??" - ] - if t in fillers: + + # 检查是否在场景特定填充词库中(精确匹配) + if t in self.scene_config.filler_phrases: return True - # 长度与字符类型判断 - if len(t) <= 8: - # 非数字、无关键实体的短文本 - if not re.search(r"[0-9]", t) and not self._is_important_message(message): - # 主要是标点或简单确认词 - if re.fullmatch(r"[。!?,.!?…·\s]+", t) or t in fillers: - return True + + # 常见寒暄和问候(精确匹配,避免误删) + common_greetings = { + "在吗", "在不在", "在呢", "在的", + "你好", "您好", "hello", "hi", + "拜拜", "再见", "拜", "88", "bye", + "好的", "好", "行", "可以", "嗯", "哦", "啊", + "是的", "对", "对的", "没错", "是啊", + "哈哈", "呵呵", "嘿嘿", "嗯嗯" + } + if t in common_greetings: + return True + + # 检查是否为纯表情符号(方括号包裹) + if re.fullmatch(r"(\[[^\]]+\])+", t): + return True + + # 检查是否为纯emoji(Unicode表情) + emoji_pattern = re.compile( + "[" + "\U0001F600-\U0001F64F" # 表情符号 + "\U0001F300-\U0001F5FF" # 符号和象形文字 + "\U0001F680-\U0001F6FF" # 交通和地图符号 + "\U0001F1E0-\U0001F1FF" # 旗帜 + "\U00002702-\U000027B0" + "\U000024C2-\U0001F251" + "]+", flags=re.UNICODE + ) + if emoji_pattern.fullmatch(t): + return True + + # 纯标点符号 + if re.fullmatch(r"[。!?,.!?…·\s]+", t): + return True + return False + + async def _batch_evaluate_importance_with_llm( + self, + messages: List[ConversationMessage], + context: str = "" + ) -> Dict[int, int]: + """使用LLM批量评估消息的重要性(语义层面)。 + + Args: + messages: 消息列表 + context: 对话上下文(可选) + + Returns: + 消息索引到重要性分数(0-10)的映射 + """ + if not self.llm_client or not messages: + return {} + + # 构建批量评估的提示词 + msg_list = [] + for idx, msg in enumerate(messages): + msg_list.append(f"{idx}. {msg.msg}") + + msg_text = "\n".join(msg_list) + + prompt = f"""请评估以下消息的重要性,给每条消息打分(0-10分): +- 0-2分:无意义的寒暄、口头禅、纯表情 +- 3-5分:一般性对话,有一定信息量但不关键 +- 6-8分:包含重要信息(时间、地点、人物、事件等) +- 9-10分:关键决策、承诺、重要数据 + +对话上下文: +{context if context else "无"} + +待评估的消息: +{msg_text} + +请以JSON格式返回,格式为: +{{ + "importance_scores": {{ + "0": 分数, + "1": 分数, + ... + }} +}} +""" + + try: + messages_for_llm = [ + {"role": "system", "content": "你是一个专业的对话分析助手,擅长评估消息的重要性。"}, + {"role": "user", "content": prompt} + ] + + response = await self.llm_client.response_structured( + messages_for_llm, + MessageImportanceResponse + ) + + # 转换字符串键为整数键 + return {int(k): v for k, v in response.importance_scores.items()} + except Exception as e: + self._log(f"[剪枝-LLM] 批量重要性评估失败: {str(e)[:100]}") + return {} + + def _identify_qa_pairs(self, messages: List[ConversationMessage]) -> List[QAPair]: + """识别对话中的问答对,用于保护问答结构的完整性。 + + 改进版:使用场景特定的问句关键词,并排除寒暄类问句 + + Args: + messages: 消息列表 + + Returns: + 问答对列表 + """ + qa_pairs = [] + + # 寒暄类问句,不应该被保护(这些不是真正的问答) + greeting_questions = { + "在吗", "在不在", "你好吗", "怎么样", "好吗", + "有空吗", "忙吗", "睡了吗", "起床了吗" + } + + for i in range(len(messages) - 1): + current_msg = messages[i].msg.strip() + next_msg = messages[i + 1].msg.strip() + + # 排除寒暄类问句 + if current_msg in greeting_questions: + continue + + # 使用场景特定的问句关键词,但要求更严格 + is_question = False + + # 1. 以问号结尾 + if current_msg.endswith("?") or current_msg.endswith("?"): + is_question = True + # 2. 包含实质性问句关键词(排除"吗"这种太宽泛的) + elif any(word in current_msg for word in ["什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时"]): + is_question = True + + if is_question and next_msg: + # 检查下一条消息是否像答案(不是另一个问句,也不是寒暄) + is_answer = not (next_msg.endswith("?") or next_msg.endswith("?")) + + # 排除寒暄类回复 + greeting_answers = {"你好", "您好", "在呢", "在的", "嗯", "哦", "好的"} + if next_msg in greeting_answers: + is_answer = False + + if is_answer: + qa_pairs.append(QAPair( + question_idx=i, + answer_idx=i + 1, + confidence=0.8 # 基于规则的置信度 + )) + + return qa_pairs + + def _get_protected_indices( + self, + messages: List[ConversationMessage], + qa_pairs: List[QAPair], + window_size: int = 2 + ) -> Set[int]: + """获取需要保护的消息索引集合(问答对+上下文窗口)。 + + Args: + messages: 消息列表 + qa_pairs: 问答对列表 + window_size: 上下文窗口大小(前后各保留几条消息) + + Returns: + 需要保护的消息索引集合 + """ + protected = set() + + for qa_pair in qa_pairs: + # 保护问答对本身 + protected.add(qa_pair.question_idx) + protected.add(qa_pair.answer_idx) + + # 保护上下文窗口 + for offset in range(-window_size, window_size + 1): + q_idx = qa_pair.question_idx + offset + a_idx = qa_pair.answer_idx + offset + + if 0 <= q_idx < len(messages): + protected.add(q_idx) + if 0 <= a_idx < len(messages): + protected.add(a_idx) + + return protected async def _extract_dialog_important(self, dialog_text: str) -> DialogExtractionResponse: """对话级一次性抽取:从整段对话中提取重要信息并判定相关性。 - - 仅使用 LLM 结构化输出; + 改进版: + - LRU缓存管理 + - 重试机制 + - 降级策略 """ # 缓存命中则直接返回(场景+内容作为键) cache_key = f"{self.config.pruning_scene}:" + hashlib.sha1(dialog_text.encode("utf-8")).hexdigest() + + # LRU缓存:如果命中,移到末尾(最近使用) if cache_key in self._dialog_extract_cache: + self._dialog_extract_cache.move_to_end(cache_key) return self._dialog_extract_cache[cache_key] - rendered = self.template.render(pruning_scene=self.config.pruning_scene, dialog_text=dialog_text) - log_template_rendering("extracat_Pruning.jinja2", {"pruning_scene": self.config.pruning_scene}) + # LRU缓存大小限制:超过限制时删除最旧的条目 + if len(self._dialog_extract_cache) >= self._cache_max_size: + # 删除最旧的条目(OrderedDict的第一个) + oldest_key = next(iter(self._dialog_extract_cache)) + del self._dialog_extract_cache[oldest_key] + self._log(f"[剪枝-缓存] LRU缓存已满,删除最旧条目") + + rendered = self.template.render( + pruning_scene=self.config.pruning_scene, + is_builtin_scene=self._is_builtin_scene, + ontology_classes=self._ontology_classes, + dialog_text=dialog_text, + language=self.language + ) + log_template_rendering("extracat_Pruning.jinja2", { + "pruning_scene": self.config.pruning_scene, + "is_builtin_scene": self._is_builtin_scene, + "ontology_classes_count": len(self._ontology_classes), + "language": self.language + }) log_prompt_rendering("pruning-extract", rendered) - # 强制使用 LLM;移除正则回退 + # 强制使用 LLM if not self.llm_client: raise RuntimeError("llm_client 未配置;请配置 LLM 以进行结构化抽取。") @@ -153,12 +453,32 @@ class SemanticPruner: {"role": "system", "content": "你是一个严谨的场景抽取助手,只输出严格 JSON。"}, {"role": "user", "content": rendered}, ] - try: - ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) - self._dialog_extract_cache[cache_key] = ex - return ex - except Exception as e: - raise RuntimeError("LLM 结构化抽取失败;请检查 LLM 配置或重试。") from e + + # 重试机制 + max_retries = 3 + for attempt in range(max_retries): + try: + ex = await self.llm_client.response_structured(messages, DialogExtractionResponse) + self._dialog_extract_cache[cache_key] = ex + return ex + except Exception as e: + if attempt < max_retries - 1: + self._log(f"[剪枝-LLM] 第 {attempt + 1} 次尝试失败,重试中... 错误: {str(e)[:100]}") + await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避 + continue + else: + # 降级策略:标记为相关,避免误删 + self._log(f"[剪枝-LLM] LLM 调用失败 {max_retries} 次,使用降级策略(标记为相关)") + fallback_response = DialogExtractionResponse( + is_related=True, + times=[], + ids=[], + amounts=[], + contacts=[], + addresses=[], + keywords=[] + ) + return fallback_response def _msg_matches_tokens(self, message: ConversationMessage, tokens: List[str]) -> bool: """判断消息是否包含任意抽取到的重要片段。""" @@ -248,12 +568,14 @@ class SemanticPruner: async def prune_dataset(self, dialogs: List[DialogData]) -> List[DialogData]: """数据集层面:全局消息级剪枝,保留所有对话。 - - 仅在"不相关对话"的范围内执行消息剪枝;相关对话不动。 - - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留。 - - 删除总量 = 阈值 * 全部不相关可删消息数,按可删容量比例分配;顺序删除。 - - 保证每段对话至少保留1条消息,不会删除整段对话。 + 改进版: + - 消息级独立判断,每条消息根据场景规则独立评估 + - 问答对保护已注释(暂不启用,留作观察) + - 优化删除策略:填充消息 → 不重要消息 → 低分重要消息 + - 只删除"不重要的不相关消息",重要信息(时间、编号等)强制保留 + - 保证每段对话至少保留1条消息,不会删除整段对话 """ - # 如果剪枝功能关闭,直接返回原始数据集。 + # 如果剪枝功能关闭,直接返回原始数据集 if not self.config.pruning_switch: return dialogs @@ -264,179 +586,140 @@ class SemanticPruner: proportion = 0.9 if proportion < 0.0: proportion = 0.0 - evaluated_dialogs = [] # list of dicts: {dialog, is_related} self._log( - f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch}" + f"[剪枝-数据集] 对话总数={len(dialogs)} 场景={self.config.pruning_scene} 删除比例={proportion} 开关={self.config.pruning_switch} 模式=消息级独立判断" ) - # 对话级相关性分类(一次性对整段对话文本进行判断,顺序执行并复用缓存) - evaluated_dialogs = [] - for idx, dd in enumerate(dialogs): - try: - ex = await self._extract_dialog_important(dd.content) - evaluated_dialogs.append({ - "dialog": dd, - "is_related": bool(ex.is_related), - "index": idx, - "extraction": ex - }) - except Exception: - evaluated_dialogs.append({ - "dialog": dd, - "is_related": True, - "index": idx, - "extraction": None - }) - - # 统计相关 / 不相关对话 - not_related_dialogs = [d for d in evaluated_dialogs if not d["is_related"]] - related_dialogs = [d for d in evaluated_dialogs if d["is_related"]] - self._log( - f"[剪枝-数据集] 相关对话数={len(related_dialogs)} 不相关对话数={len(not_related_dialogs)}" - ) - - # 简洁打印第几段对话相关/不相关(索引基于1) - def _fmt_indices(items, cap: int = 10): - inds = [i["index"] + 1 for i in items] - if len(inds) <= cap: - return inds - # 超过上限时只打印前cap个,并标注总数 - return inds[:cap] + ["...", f"共{len(inds)}个"] - - rel_inds = _fmt_indices(related_dialogs) - nrel_inds = _fmt_indices(not_related_dialogs) - self._log(f"[剪枝-数据集] 相关对话:第{rel_inds}段;不相关对话:第{nrel_inds}段") - + result: List[DialogData] = [] - if not_related_dialogs: - # 为每个不相关对话进行一次性抽取,识别重要/不重要(避免逐条 LLM) - per_dialog_info = {} - total_unrelated = 0 - total_capacity = 0 - for d in not_related_dialogs: - dd = d["dialog"] - extraction = d.get("extraction") - if extraction is None: - extraction = await self._extract_dialog_important(dd.content) - # 合并所有重要标记 - tokens = extraction.times + extraction.ids + extraction.amounts + extraction.contacts + extraction.addresses + extraction.keywords - msgs = dd.context.msgs - # 分类消息 - imp_unrel_msgs = [m for m in msgs if self._msg_matches_tokens(m, tokens) or self._is_important_message(m)] - unimp_unrel_msgs = [m for m in msgs if m not in imp_unrel_msgs] - # 重要消息按重要性排序 - imp_sorted_ids = [id(m) for m in sorted(imp_unrel_msgs, key=lambda m: self._importance_score(m))] - info = { - "dialog": dd, - "total_msgs": len(msgs), - "unrelated_count": len(msgs), - "imp_ids_sorted": imp_sorted_ids, - "unimp_ids": [id(m) for m in unimp_unrel_msgs], - } - per_dialog_info[d["index"]] = info - total_unrelated += info["unrelated_count"] - # 全局删除配额:比例作用于全部不相关消息(重要+不重要) - global_delete = int(total_unrelated * proportion) - if proportion > 0 and total_unrelated > 0 and global_delete == 0: - global_delete = 1 - # 每段的最大可删容量:不重要全部 + 重要最多删除 floor(len(重要)*比例),且至少保留1条消息 - capacities = [] - for d in not_related_dialogs: - idx = d["index"] - info = per_dialog_info[idx] - # 统计重要数量 - imp_count = len(info["imp_ids_sorted"]) - unimp_count = len(info["unimp_ids"]) - imp_cap = int(imp_count * proportion) - cap = min(unimp_count + imp_cap, max(0, info["total_msgs"] - 1)) - capacities.append(cap) - total_capacity = sum(capacities) - if global_delete > total_capacity: - print(f"[剪枝-数据集] 不相关消息总数={total_unrelated},目标删除={global_delete},最大可删={total_capacity}(重要消息按比例保留)。将按最大可删执行。") - global_delete = total_capacity - - # 配额分配:按不相关消息占比分配到各对话,但不超过各自容量 - alloc = [] - for i, d in enumerate(not_related_dialogs): - idx = d["index"] - info = per_dialog_info[idx] - share = int(global_delete * (info["unrelated_count"] / total_unrelated)) if total_unrelated > 0 else 0 - alloc.append(min(share, capacities[i])) - allocated = sum(alloc) - rem = global_delete - allocated - turn = 0 - while rem > 0 and turn < 100000: - progressed = False - for i in range(len(not_related_dialogs)): - if rem <= 0: - break - if alloc[i] < capacities[i]: - alloc[i] += 1 - rem -= 1 - progressed = True - if not progressed: - break - turn += 1 - - # 应用删除:相关对话不动;不相关按分配先删不重要,再删重要(低分优先) - total_deleted_confirm = 0 - for d in evaluated_dialogs: - dd = d["dialog"] - msgs = dd.context.msgs - original = len(msgs) - if d["is_related"]: - result.append(dd) - continue - idx_in_unrel = next((k for k, x in enumerate(not_related_dialogs) if x["index"] == d["index"]), None) - if idx_in_unrel is None: - result.append(dd) - continue - quota = alloc[idx_in_unrel] - info = per_dialog_info[d["index"]] - # 计算本对话重要最多可删数量 - imp_count = len(info["imp_ids_sorted"]) - imp_del_cap = int(imp_count * proportion) - # 先构造顺序删除的"不重要ID集合"(按出现顺序前 quota 条) - unimp_delete_ids = set(info["unimp_ids"][:min(quota, len(info["unimp_ids"]))]) - del_unimp = min(quota, len(unimp_delete_ids)) - rem_quota = quota - del_unimp - # 再从重要里选低分优先的删除ID(不超过 imp_del_cap) - imp_delete_ids = set(info["imp_ids_sorted"][:min(rem_quota, imp_del_cap)]) - deleted_here = 0 - actual_unimp_deleted = 0 - actual_imp_deleted = 0 - kept = [] - for m in msgs: - mid = id(m) - if mid in unimp_delete_ids and actual_unimp_deleted < del_unimp: - actual_unimp_deleted += 1 - deleted_here += 1 - continue - if mid in imp_delete_ids and actual_imp_deleted < len(imp_delete_ids): - actual_imp_deleted += 1 - deleted_here += 1 - continue - kept.append(m) - if not kept and msgs: - kept = [msgs[0]] - dd.context.msgs = kept - total_deleted_confirm += deleted_here - self._log( - f"[剪枝-对话] 对话 {d['index']+1} 总消息={original} 分配删除={quota} 实删={deleted_here} 保留={len(kept)}" - ) - result.append(dd) - self._log(f"[剪枝-数据集] 全局消息级顺序剪枝完成,总删除 {total_deleted_confirm} 条(不相关消息,重要按比例保留)。") - else: - # 全部相关:不执行剪枝 - result = [d["dialog"] for d in evaluated_dialogs] + total_original_msgs = 0 + total_deleted_msgs = 0 + + for d_idx, dd in enumerate(dialogs): + msgs = dd.context.msgs + original_count = len(msgs) + total_original_msgs += original_count + + # ========== 问答对保护(已注释,暂不启用,留作观察) ========== + # qa_pairs = self._identify_qa_pairs(msgs) + # protected_indices = self._get_protected_indices(msgs, qa_pairs, window_size=0) + # ======================================================== + + # 消息级分类:每条消息独立判断 + important_msgs = [] # 重要消息(保留) + unimportant_msgs = [] # 不重要消息(可删除) + filler_msgs = [] # 填充消息(优先删除) + + # 判断是否需要详细日志(仅对前N条消息记录) + should_log_details = self._detailed_prune_logging and original_count <= self._max_debug_msgs_per_dialog + if self._detailed_prune_logging and original_count > self._max_debug_msgs_per_dialog: + self._log(f" 对话[{d_idx}]消息数={original_count},仅采样前{self._max_debug_msgs_per_dialog}条进行详细日志") + + for idx, m in enumerate(msgs): + msg_text = m.msg.strip() + + # ========== 问答对保护判断(已注释) ========== + # if idx in protected_indices: + # important_msgs.append((idx, m)) + # self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(问答对保护)") + # ========================================== + + # 填充消息(寒暄、表情等) + if self._is_filler_message(m): + filler_msgs.append((idx, m)) + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 填充") + # 重要信息(学号、成绩、时间、金额等) + elif self._is_important_message(m): + important_msgs.append((idx, m)) + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 重要(场景规则)") + # 其他消息 + else: + unimportant_msgs.append((idx, m)) + if should_log_details or idx < self._max_debug_msgs_per_dialog: + self._log(f" [{idx}] '{msg_text[:30]}...' → 不重要") + + # 计算删除配额 + delete_target = int(original_count * proportion) + if proportion > 0 and original_count > 0 and delete_target == 0: + delete_target = 1 + + # 确保至少保留1条消息 + max_deletable = max(0, original_count - 1) + delete_target = min(delete_target, max_deletable) + + # 删除策略:优先删除填充消息,再删除不重要消息 + to_delete_indices = set() + deleted_details = [] # 记录删除的消息详情 + + # 第一步:删除填充消息 + filler_to_delete = min(len(filler_msgs), delete_target) + for i in range(filler_to_delete): + idx, msg = filler_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 填充: '{msg.msg[:50]}'") + + # 第二步:如果还需要删除,删除不重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0: + unimp_to_delete = min(len(unimportant_msgs), remaining_quota) + for i in range(unimp_to_delete): + idx, msg = unimportant_msgs[i] + to_delete_indices.add(idx) + deleted_details.append(f"[{idx}] 不重要: '{msg.msg[:50]}'") + + # 第三步:如果还需要删除,按重要性分数删除重要消息 + remaining_quota = delete_target - len(to_delete_indices) + if remaining_quota > 0 and important_msgs: + # 按重要性分数排序(分数低的优先删除) + imp_sorted = sorted(important_msgs, key=lambda x: self._importance_score(x[1])) + imp_to_delete = min(len(imp_sorted), remaining_quota) + for i in range(imp_to_delete): + idx, msg = imp_sorted[i] + to_delete_indices.add(idx) + score = self._importance_score(msg) + deleted_details.append(f"[{idx}] 重要(分数{score}): '{msg.msg[:50]}'") + + # 执行删除 + kept_msgs = [] + for idx, m in enumerate(msgs): + if idx not in to_delete_indices: + kept_msgs.append(m) + + # 确保至少保留1条 + if not kept_msgs and msgs: + kept_msgs = [msgs[0]] + + dd.context.msgs = kept_msgs + deleted_count = original_count - len(kept_msgs) + total_deleted_msgs += deleted_count + + # 输出删除详情 + if deleted_details: + self._log(f"[剪枝-删除详情] 对话 {d_idx+1} 删除了以下消息:") + for detail in deleted_details: + self._log(f" {detail}") + + # ========== 问答对统计(已注释) ========== + # qa_info = f",问答对={len(qa_pairs)}" if qa_pairs else "" + # ======================================== + + self._log( + f"[剪枝-对话] 对话 {d_idx+1} 总消息={original_count} " + f"(重要={len(important_msgs)} 不重要={len(unimportant_msgs)} 填充={len(filler_msgs)}) " + f"删除={deleted_count} 保留={len(kept_msgs)}" + ) + + result.append(dd) + self._log(f"[剪枝-数据集] 剩余对话数={len(result)}") - # 将本次剪枝阶段的终端输出保存为 JSON 文件(仅在剪枝器内部完成) + # 保存日志 try: from app.core.config import settings settings.ensure_memory_output_dir() log_output_path = settings.get_memory_output_path("pruned_terminal.json") - # 去除日志前缀标签(如 [剪枝-数据集]、[剪枝-对话])后再解析为结构化字段保存 sanitized_logs = [self._sanitize_log_line(l) for l in self.run_logs] payload = self._parse_logs_to_structured(sanitized_logs) with open(log_output_path, "w", encoding="utf-8") as f: @@ -448,6 +731,7 @@ class SemanticPruner: if not result: print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断") return dialogs + return result def _log(self, msg: str) -> None: diff --git a/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py new file mode 100644 index 00000000..ed9592af --- /dev/null +++ b/api/app/core/memory/storage_services/extraction_engine/data_preprocessing/scene_config.py @@ -0,0 +1,326 @@ +""" +场景特定配置 - 为不同场景提供定制化的剪枝规则 + +功能: +- 场景特定的重要信息识别模式 +- 场景特定的重要性评分权重 +- 场景特定的填充词库 +- 场景特定的问答对识别规则 +""" + +from typing import Dict, List, Set, Tuple +from dataclasses import dataclass, field + + +@dataclass +class ScenePatterns: + """场景特定的识别模式""" + + # 重要信息的正则模式(优先级从高到低) + high_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) # (pattern, weight) + medium_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + low_priority_patterns: List[Tuple[str, int]] = field(default_factory=list) + + # 填充词库(无意义对话) + filler_phrases: Set[str] = field(default_factory=set) + + # 问句关键词(用于识别问答对) + question_keywords: Set[str] = field(default_factory=set) + + # 决策性/承诺性关键词 + decision_keywords: Set[str] = field(default_factory=set) + + +class SceneConfigRegistry: + """场景配置注册表 - 管理所有场景的特定配置""" + + # 基础通用模式(所有场景共享) + BASE_HIGH_PRIORITY = [ + (r"订单号|工单|申请号|编号|ID|账号|账户", 5), + (r"金额|费用|价格|¥|¥|\d+元", 5), + (r"\d{11}", 4), # 手机号 + (r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", 4), # 邮箱 + ] + + BASE_MEDIUM_PRIORITY = [ + (r"\d{4}-\d{1,2}-\d{1,2}", 3), # 日期 + (r"\d{4}年\d{1,2}月\d{1,2}日", 3), + (r"电话|手机号|微信|QQ|联系方式", 3), + (r"地址|地点|位置", 2), + (r"时间|日期|有效期|截止", 2), + (r"今天|明天|后天|昨天|前天", 3), # 相对时间(提高权重) + (r"下周|下月|下年|上周|上月|上年|本周|本月|本年", 3), + (r"今年|去年|明年", 3), + ] + + BASE_LOW_PRIORITY = [ + (r"\d{1,2}:\d{2}", 2), # 时间点 HH:MM + (r"\d{1,2}点\d{0,2}分?", 2), # 时间点 X点Y分 或 X点 + (r"上午|下午|中午|晚上|早上|傍晚|凌晨", 2), # 时段(提高权重并扩充) + (r"AM|PM|am|pm", 1), + ] + + BASE_FILLERS = { + # 基础寒暄 + "你好", "您好", "在吗", "在的", "在呢", "嗯", "嗯嗯", "哦", "哦哦", + "好的", "好", "行", "可以", "不可以", "谢谢", "多谢", "感谢", + "拜拜", "再见", "88", "拜", "回见", + # 口头禅 + "哈哈", "呵呵", "哈哈哈", "嘿嘿", "嘻嘻", "hiahia", + "额", "呃", "啊", "诶", "唉", "哎", "嗯哼", + # 确认词 + "是的", "对", "对的", "没错", "嗯嗯", "好嘞", "收到", "明白", "了解", "知道了", + # 标点和符号 + "。。。", "...", "???", "???", "!!!", "!!!", + # 表情符号 + "[微笑]", "[呲牙]", "[发呆]", "[得意]", "[流泪]", "[害羞]", "[闭嘴]", + "[睡]", "[大哭]", "[尴尬]", "[发怒]", "[调皮]", "[龇牙]", "[惊讶]", + "[难过]", "[酷]", "[冷汗]", "[抓狂]", "[吐]", "[偷笑]", "[可爱]", + "[白眼]", "[傲慢]", "[饥饿]", "[困]", "[惊恐]", "[流汗]", "[憨笑]", + # 网络用语 + "hhh", "hhhh", "2333", "666", "gg", "ok", "OK", "okok", + "emmm", "emm", "em", "mmp", "wtf", "omg", + } + + BASE_QUESTION_KEYWORDS = { + "什么", "为什么", "怎么", "如何", "哪里", "哪个", "谁", "多少", "几点", "何时", "吗" + } + + BASE_DECISION_KEYWORDS = { + "必须", "一定", "务必", "需要", "要求", "规定", "应该", + "承诺", "保证", "确保", "负责", "同意", "答应" + } + + @classmethod + def get_education_config(cls) -> ScenePatterns: + """教育场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 成绩相关(最高优先级) + (r"成绩|分数|得分|满分|及格|不及格", 6), + (r"GPA|绩点|学分|平均分", 6), + (r"\d+分|\d+\.?\d*分", 5), # 具体分数 + (r"排名|名次|第.{1,3}名", 5), # 支持"第三名"、"第1名"等 + + # 学籍信息 + (r"学号|学生证|教师工号|工号", 5), + (r"班级|年级|专业|院系", 4), + + # 课程相关 + (r"课程|科目|学科|必修|选修", 4), + (r"教材|课本|教科书|参考书", 4), + (r"章节|第.{1,3}章|第.{1,3}节", 3), # 支持"第三章"、"第1章"等 + + # 学科内容(新增) + (r"微积分|导数|积分|函数|极限|微分", 4), + (r"代数|几何|三角|概率|统计", 4), + (r"物理|化学|生物|历史|地理", 4), + (r"英语|语文|数学|政治|哲学", 4), + (r"定义|定理|公式|概念|原理|法则", 3), + (r"例题|解题|证明|推导|计算", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 教学活动 + (r"作业|练习|习题|题目", 3), + (r"考试|测验|测试|考核|期中|期末", 3), + (r"上课|下课|课堂|讲课", 2), + (r"提问|回答|发言|讨论", 2), + (r"问一下|请教|咨询|询问", 2), # 新增:问询相关 + (r"理解|明白|懂|掌握|学会", 2), # 新增:学习状态 + + # 时间安排 + (r"课表|课程表|时间表", 3), + (r"第.{1,3}节课|第.{1,3}周", 2), # 支持"第三节课"、"第1周"等 + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"老师|教师|同学|学生", 1), + (r"教室|实验室|图书馆", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 教育场景特有填充词(移除了"明白了"、"懂了"、"不懂"等,这些在教育场景中有意义) + "老师好", "同学们好", "上课", "下课", "起立", "坐下", + "举手", "请坐", "很好", "不错", "继续", + "下一个", "下一题", "下一位", "还有吗", "还有问题吗", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "为啥", "咋", "咋办", "怎样", "如何做", + "能不能", "可不可以", "行不行", "对不对", "是不是", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "必考", "重点", "考点", "难点", "关键", + "记住", "背诵", "掌握", "理解", "复习", + } + ) + + @classmethod + def get_online_service_config(cls) -> ScenePatterns: + """在线服务场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 工单相关(最高优先级) + (r"工单号|工单编号|ticket|TK\d+", 6), + (r"工单状态|处理中|已解决|已关闭|待处理", 5), + (r"优先级|紧急|高优先级|P0|P1|P2", 5), + + # 产品信息 + (r"产品型号|型号|SKU|产品编号", 5), + (r"序列号|SN|设备号", 5), + (r"版本号|软件版本|固件版本", 4), + + # 问题描述 + (r"故障|错误|异常|bug|问题", 4), + (r"错误代码|故障代码|error code", 5), + (r"无法|不能|失败|报错", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 服务相关 + (r"退款|退货|换货|补发", 4), + (r"发票|收据|凭证", 3), + (r"物流|快递|运单号", 3), + (r"保修|质保|售后", 3), + + # 时效相关 + (r"SLA|响应时间|处理时长", 4), + (r"超时|延迟|等待", 2), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"客服|工程师|技术支持", 1), + (r"用户|客户|会员", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 在线服务特有填充词 + "您好", "请问", "请稍等", "稍等", "马上", "立即", + "正在查询", "正在处理", "正在为您", "帮您查一下", + "还有其他问题吗", "还需要什么帮助", "很高兴为您服务", + "感谢您的耐心等待", "抱歉让您久等了", + "已记录", "已反馈", "已转接", "已升级", + "祝您生活愉快", "再见", "欢迎下次咨询", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "能否", "可否", "是否", "有没有", "能不能", + "怎么办", "如何处理", "怎么解决", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "立即处理", "马上解决", "尽快", "优先", + "升级", "转接", "派单", "跟进", + "补偿", "赔偿", "退款", "换货", + } + ) + + @classmethod + def get_outbound_config(cls) -> ScenePatterns: + """外呼场景配置""" + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY + [ + # 意向相关(最高优先级) + (r"意向|意愿|兴趣|感兴趣", 6), + (r"A类|B类|C类|D类|高意向|低意向", 6), + (r"成交|签约|下单|购买|确认", 6), + + # 联系信息(外呼场景中更重要) + (r"预约|约定|安排|确定时间", 5), + (r"下次联系|回访|跟进", 5), + (r"方便|有空|可以|时间", 4), + + # 通话状态 + (r"接通|未接通|占线|关机|停机", 4), + (r"通话时长|通话时间", 3), + ], + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY + [ + # 客户信息 + (r"姓名|称呼|先生|女士", 3), + (r"公司|单位|职位|职务", 3), + (r"需求|要求|期望", 3), + + # 跟进状态 + (r"跟进状态|进展|进度", 3), + (r"已联系|待联系|联系中", 2), + (r"拒绝|不感兴趣|考虑|再说", 3), + ], + low_priority_patterns=cls.BASE_LOW_PRIORITY + [ + (r"销售|客户经理|业务员", 1), + (r"产品|服务|方案", 1), + ], + filler_phrases=cls.BASE_FILLERS | { + # 外呼场景特有填充词 + "您好", "喂", "hello", "打扰了", "不好意思", + "方便接电话吗", "现在方便吗", "占用您一点时间", + "我是", "我们是", "我们公司", "我们这边", + "了解一下", "介绍一下", "简单说一下", + "考虑考虑", "想一想", "再说", "再看看", + "不需要", "不感兴趣", "没兴趣", "不用了", + "好的", "行", "可以", "没问题", "那就这样", + "再联系", "回头聊", "有需要再说", + }, + question_keywords=cls.BASE_QUESTION_KEYWORDS | { + "有没有", "需不需要", "要不要", "考虑不考虑", + "了解吗", "知道吗", "听说过吗", + "方便吗", "有空吗", "在吗", + }, + decision_keywords=cls.BASE_DECISION_KEYWORDS | { + "确定", "决定", "选择", "购买", "下单", + "预约", "安排", "约定", "确认", + "跟进", "回访", "联系", "沟通", + } + ) + + @classmethod + def get_config(cls, scene: str, fallback_to_generic: bool = True) -> ScenePatterns: + """根据场景名称获取配置 + + Args: + scene: 场景名称 ('education', 'online_service', 'outbound' 或其他) + fallback_to_generic: 如果场景不存在,是否降级到通用配置 + + Returns: + 对应场景的配置,如果场景不存在: + - fallback_to_generic=True: 返回通用配置(仅基础规则) + - fallback_to_generic=False: 抛出异常 + """ + scene_map = { + 'education': cls.get_education_config, + 'online_service': cls.get_online_service_config, + 'outbound': cls.get_outbound_config, + } + + if scene in scene_map: + return scene_map[scene]() + + if fallback_to_generic: + # 返回通用配置(仅包含基础规则,不包含场景特定规则) + return cls.get_generic_config() + else: + raise ValueError(f"不支持的场景: {scene},支持的场景: {list(scene_map.keys())}") + + @classmethod + def get_generic_config(cls) -> ScenePatterns: + """通用场景配置 - 仅包含基础规则,适用于未定义的场景 + + 这是一个保守的配置,只使用最通用的规则,避免误删重要信息 + """ + return ScenePatterns( + high_priority_patterns=cls.BASE_HIGH_PRIORITY, + medium_priority_patterns=cls.BASE_MEDIUM_PRIORITY, + low_priority_patterns=cls.BASE_LOW_PRIORITY, + filler_phrases=cls.BASE_FILLERS, + question_keywords=cls.BASE_QUESTION_KEYWORDS, + decision_keywords=cls.BASE_DECISION_KEYWORDS + ) + + @classmethod + def get_all_scenes(cls) -> List[str]: + """获取所有预定义场景的列表""" + return ['education', 'online_service', 'outbound'] + + @classmethod + def is_scene_supported(cls, scene: str) -> bool: + """检查场景是否有专门的配置支持 + + Args: + scene: 场景名称 + + Returns: + True: 有专门配置 + False: 将使用通用配置 + """ + return scene in cls.get_all_scenes() diff --git a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py index a47497da..1242e4e6 100644 --- a/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py +++ b/api/app/core/memory/storage_services/extraction_engine/extraction_orchestrator.py @@ -1932,17 +1932,17 @@ def preprocess_data( Returns: 经过清洗转换后的 DialogData 列表 """ - print("\n=== 数据预处理 ===") + logger.debug("=== 数据预处理 ===") from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_preprocessor import ( DataPreprocessor, ) preprocessor = DataPreprocessor() try: cleaned_data = preprocessor.preprocess(input_path=input_path, output_path=output_path, skip_cleaning=skip_cleaning, indices=indices) - print(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") + logger.debug(f"数据预处理完成!共处理了 {len(cleaned_data)} 条对话数据") return cleaned_data except Exception as e: - print(f"数据预处理过程中出现错误: {e}") + logger.error(f"数据预处理过程中出现错误: {e}") raise @@ -1961,7 +1961,7 @@ async def get_chunked_dialogs_from_preprocessed( Returns: 带 chunks 的 DialogData 列表 """ - print(f"\n=== 批量对话分块处理 (使用 {chunker_strategy}) ===") + logger.debug(f"=== 批量对话分块处理 (使用 {chunker_strategy}) ===") if not data: raise ValueError("预处理数据为空,无法进行分块") @@ -1988,6 +1988,7 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: Optional[str] = None, llm_client: Optional[Any] = None, skip_cleaning: bool = True, + pruning_config: Optional[Dict] = None, ) -> List[DialogData]: """包含数据预处理步骤的完整分块流程 @@ -2000,11 +2001,12 @@ async def get_chunked_dialogs_with_preprocessing( input_data_path: 输入数据路径 llm_client: LLM 客户端 skip_cleaning: 是否跳过数据清洗步骤(默认False) + pruning_config: 剪枝配置字典,包含 pruning_switch, pruning_scene, pruning_threshold Returns: 带 chunks 的 DialogData 列表 """ - print("\n=== 完整数据处理流程(包含预处理)===") + logger.debug("=== 完整数据处理流程(包含预处理)===") if input_data_path is None: input_data_path = os.path.join( @@ -2030,7 +2032,19 @@ async def get_chunked_dialogs_with_preprocessing( from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( SemanticPruner, ) - pruner = SemanticPruner(llm_client=llm_client) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + if pruning_config: + # 使用传入的配置 + config = PruningConfig(**pruning_config) + logger.debug(f"[剪枝] 使用传入配置: switch={config.pruning_switch}, scene={config.pruning_scene}, threshold={config.pruning_threshold}") + else: + # 使用默认配置(关闭剪枝) + config = None + logger.debug("[剪枝] 未提供配置,使用默认配置(剪枝关闭)") + + pruner = SemanticPruner(config=config, llm_client=llm_client) # 记录单对话场景下剪枝前的消息数量 single_dialog_original_msgs = None @@ -2043,12 +2057,12 @@ async def get_chunked_dialogs_with_preprocessing( if len(preprocessed_data) == 1 and single_dialog_original_msgs is not None: remaining_msgs = len(preprocessed_data[0].context.msgs) if preprocessed_data[0].context else 0 deleted_msgs = max(0, single_dialog_original_msgs - remaining_msgs) - print( + logger.debug( f"语义剪枝完成!剩余 1 条对话!原始消息数:{single_dialog_original_msgs}," f"保留消息数:{remaining_msgs},删除 {deleted_msgs} 条。" ) else: - print(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") + logger.debug(f"语义剪枝完成!剩余 {len(preprocessed_data)} 条对话") # 保存剪枝后的数据 try: @@ -2059,9 +2073,9 @@ async def get_chunked_dialogs_with_preprocessing( dp = DataPreprocessor(output_file_path=pruned_output_path) dp.save_data(preprocessed_data, output_path=pruned_output_path) except Exception as se: - print(f"保存剪枝结果失败:{se}") + logger.error(f"保存剪枝结果失败:{se}") except Exception as e: - print(f"语义剪枝过程中出现错误,跳过剪枝: {e}") + logger.error(f"语义剪枝过程中出现错误,跳过剪枝: {e}") # 步骤3: 对话分块 return await get_chunked_dialogs_from_preprocessed( diff --git a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py index 40e98507..bbbf1c51 100644 --- a/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py +++ b/api/app/core/memory/storage_services/extraction_engine/knowledge_extraction/chunk_extraction.py @@ -1,5 +1,7 @@ import os -from typing import Optional +from typing import Optional, List, Any +from enum import Enum +from pathlib import Path from app.core.logging_config import get_memory_logger from app.core.memory.models.message_models import DialogData, Chunk @@ -10,6 +12,20 @@ from app.core.memory.utils.config.config_utils import get_chunker_config logger = get_memory_logger(__name__) +class ChunkerStrategy(Enum): + """Supported chunking strategies.""" + RECURSIVE = "RecursiveChunker" + SEMANTIC = "SemanticChunker" + LATE = "LateChunker" + NEURAL = "NeuralChunker" + LLM = "LLMChunker" + + @classmethod + def get_valid_strategies(cls) -> List[str]: + """Get list of valid strategy names.""" + return [strategy.value for strategy in cls] + + class DialogueChunker: """A class that processes dialogues and fills them with chunks based on a specified strategy. @@ -17,23 +33,51 @@ class DialogueChunker: of different chunking strategies to dialogue data. """ - def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client=None): + def __init__(self, chunker_strategy: str = "RecursiveChunker", llm_client: Optional[Any] = None): """Initialize the DialogueChunker with a specific chunking strategy. Args: chunker_strategy: The chunking strategy to use (default: RecursiveChunker) - Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker + Options: SemanticChunker, RecursiveChunker, LateChunker, NeuralChunker, LLMChunker + llm_client: LLM client instance (required for LLMChunker strategy) + + Raises: + ValueError: If chunker_strategy is invalid or required parameters are missing """ - self.chunker_strategy = chunker_strategy - chunker_config_dict = get_chunker_config(chunker_strategy) - self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + # Validate strategy + valid_strategies = ChunkerStrategy.get_valid_strategies() + if chunker_strategy not in valid_strategies: + raise ValueError( + f"Invalid chunker_strategy: '{chunker_strategy}'. " + f"Must be one of {valid_strategies}" + ) - if self.chunker_config.chunker_strategy == "LLMChunker": - self.chunker_client = ChunkerClient(self.chunker_config, llm_client) - else: - self.chunker_client = ChunkerClient(self.chunker_config) + self.chunker_strategy = chunker_strategy + logger.info(f"Initializing DialogueChunker with strategy: {chunker_strategy}") + + try: + # Load and validate configuration + chunker_config_dict = get_chunker_config(chunker_strategy) + if not chunker_config_dict: + raise ValueError(f"Failed to load configuration for strategy: {chunker_strategy}") + + self.chunker_config = ChunkerConfig.model_validate(chunker_config_dict) + + # Initialize chunker client + if self.chunker_config.chunker_strategy == "LLMChunker": + if not llm_client: + raise ValueError("llm_client is required for LLMChunker strategy") + self.chunker_client = ChunkerClient(self.chunker_config, llm_client) + else: + self.chunker_client = ChunkerClient(self.chunker_config) + + logger.info(f"DialogueChunker initialized successfully with strategy: {chunker_strategy}") + + except Exception as e: + logger.error(f"Failed to initialize DialogueChunker: {e}", exc_info=True) + raise - async def process_dialogue(self, dialogue: DialogData) -> list[Chunk]: + async def process_dialogue(self, dialogue: DialogData) -> List[Chunk]: """Process a dialogue by generating chunks and adding them to the DialogData object. Args: @@ -43,54 +87,125 @@ class DialogueChunker: A list of Chunk objects Raises: - ValueError: If chunking fails or returns empty chunks + ValueError: If dialogue is invalid or chunking fails + Exception: If chunking process encounters an error """ - result_dialogue = await self.chunker_client.generate_chunks(dialogue) - chunks = result_dialogue.chunks - - if not chunks or len(chunks) == 0: + # Validate input + if not dialogue: + raise ValueError("dialogue cannot be None") + + if not dialogue.context or not dialogue.context.msgs: raise ValueError( - f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " - f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}, " - f"Strategy: {self.chunker_config.chunker_strategy}" + f"Dialogue {dialogue.ref_id} has no messages to chunk. " + f"Context: {dialogue.context is not None}, " + f"Messages: {len(dialogue.context.msgs) if dialogue.context else 0}" ) + + logger.info( + f"Processing dialogue {dialogue.ref_id} with {len(dialogue.context.msgs)} messages " + f"using strategy: {self.chunker_strategy}" + ) + + try: + # Generate chunks + result_dialogue = await self.chunker_client.generate_chunks(dialogue) + chunks = result_dialogue.chunks - return chunks + # Validate results + if not chunks or len(chunks) == 0: + raise ValueError( + f"Chunking failed: No chunks generated for dialogue {dialogue.ref_id}. " + f"Messages: {len(dialogue.context.msgs)}, " + f"Content length: {len(dialogue.content) if dialogue.content else 0}, " + f"Strategy: {self.chunker_config.chunker_strategy}" + ) - def save_chunking_results(self, dialogue: DialogData, output_path: Optional[str] = None) -> str: + logger.info( + f"Successfully generated {len(chunks)} chunks for dialogue {dialogue.ref_id}. " + f"Total characters processed: {len(dialogue.content) if dialogue.content else 0}" + ) + + return chunks + + except ValueError: + # Re-raise validation errors + raise + except Exception as e: + logger.error( + f"Error processing dialogue {dialogue.ref_id} with strategy {self.chunker_strategy}: {e}", + exc_info=True + ) + raise + + def save_chunking_results( + self, + chunks: List[Chunk], + dialogue: DialogData, + output_path: Optional[str] = None, + preview_length: int = 100 + ) -> str: """Save the chunking results to a file and return the output path. Args: - dialogue: The processed DialogData object with chunks - output_path: Optional path to save the output + chunks: List of Chunk objects to save + dialogue: The DialogData object that was processed + output_path: Optional path to save the output (defaults to current directory) + preview_length: Maximum length of content preview (default: 100) Returns: The path where the output was saved + + Raises: + ValueError: If chunks or dialogue is invalid + IOError: If file writing fails """ - if not output_path: - output_path = os.path.join( - os.path.dirname(__file__), "..", "..", - f"chunker_output_{self.chunker_strategy.lower()}.txt" - ) - - output_lines = [ - f"=== Chunking Results ({self.chunker_strategy}) ===", - f"Dialogue ID: {dialogue.ref_id}", - f"Original conversation has {len(dialogue.context.msgs)} messages", - f"Total characters: {len(dialogue.content)}", - f"Generated {len(dialogue.chunks)} chunks:" - ] + # Validate input + if not chunks: + raise ValueError("chunks list cannot be empty") + if not dialogue: + raise ValueError("dialogue cannot be None") - for i, chunk in enumerate(dialogue.chunks): - output_lines.append(f" Chunk {i+1}: {len(chunk.content)} characters") - output_lines.append(f" Content preview: {chunk.content}...") - if chunk.metadata: - output_lines.append(f" Metadata: {chunk.metadata}") + # Generate default output path if not provided + if not output_path: + output_dir = Path(__file__).parent.parent.parent + output_path = str(output_dir / f"chunker_output_{self.chunker_strategy.lower()}.txt") + + logger.info(f"Saving chunking results to: {output_path}") + + try: + # Prepare output content + output_lines = [ + f"=== Chunking Results ({self.chunker_strategy}) ===", + f"Dialogue ID: {dialogue.ref_id}", + f"Original conversation has {len(dialogue.context.msgs) if dialogue.context else 0} messages", + f"Total characters: {len(dialogue.content) if dialogue.content else 0}", + f"Generated {len(chunks)} chunks:", + "" + ] + + for i, chunk in enumerate(chunks, 1): + content_preview = chunk.content[:preview_length] if chunk.content else "" + if len(chunk.content) > preview_length: + content_preview += "..." + + output_lines.append(f" Chunk {i}: {len(chunk.content)} characters") + output_lines.append(f" Content preview: {content_preview}") + if chunk.metadata: + output_lines.append(f" Metadata: {chunk.metadata}") + output_lines.append("") - with open(output_path, "w", encoding="utf-8") as f: - f.write("\n".join(output_lines)) + # Write to file + with open(output_path, "w", encoding="utf-8") as f: + f.write("\n".join(output_lines)) - logger.info(f"Chunking results saved to: {output_path}") - return output_path + logger.info(f"Successfully saved chunking results to: {output_path}") + return output_path + + except IOError as e: + logger.error(f"Failed to write chunking results to {output_path}: {e}", exc_info=True) + raise + except Exception as e: + logger.error(f"Unexpected error saving chunking results: {e}", exc_info=True) + raise diff --git a/api/app/core/memory/utils/ontology/ontology_parser.py b/api/app/core/memory/utils/ontology/ontology_parser.py index a8bd054c..d75a8905 100644 --- a/api/app/core/memory/utils/ontology/ontology_parser.py +++ b/api/app/core/memory/utils/ontology/ontology_parser.py @@ -327,7 +327,7 @@ class MultiOntologyParser: Example: >>> parser = MultiOntologyParser([ - ... "General_purpose_entity.ttl", + ... "app/core/memory/ontology_services/General_purpose_entity.ttl", ... "domain_specific.owl" ... ]) >>> registry = parser.parse_all() diff --git a/api/app/core/memory/utils/prompt/prompt_utils.py b/api/app/core/memory/utils/prompt/prompt_utils.py index 50d31f2a..0cea98f2 100644 --- a/api/app/core/memory/utils/prompt/prompt_utils.py +++ b/api/app/core/memory/utils/prompt/prompt_utils.py @@ -400,7 +400,8 @@ async def render_user_summary_prompt( user_id: str, entities: str, statements: str, - language: str = "zh" + language: str = "zh", + user_display_name: str = None ) -> str: """ Renders the user summary prompt using the user_summary.jinja2 template. @@ -410,16 +411,22 @@ async def render_user_summary_prompt( entities: Core entities with frequency information statements: Representative statement samples language: The language to use for summary generation ("zh" for Chinese, "en" for English) + user_display_name: Display name for the user (e.g., other_name or "该用户"/"the user") Returns: Rendered prompt content as string """ + # 如果没有提供 user_display_name,使用默认值 + if user_display_name is None: + user_display_name = "该用户" if language == "zh" else "the user" + template = prompt_env.get_template("user_summary.jinja2") rendered_prompt = template.render( user_id=user_id, entities=entities, statements=statements, - language=language + language=language, + user_display_name=user_display_name ) # 记录渲染结果到提示日志 @@ -429,7 +436,8 @@ async def render_user_summary_prompt( 'user_id': user_id, 'entities_len': len(entities), 'statements_len': len(statements), - 'language': language + 'language': language, + 'user_display_name': user_display_name }) return rendered_prompt @@ -540,3 +548,20 @@ async def render_ontology_extraction_prompt( }) return rendered_prompt + + +def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str: + """ + Renders the interest filter prompt using the interest_filter.jinja2 template. + + Args: + tag_list: Comma-separated string of raw tags to filter + language: Output language ("zh" for Chinese, "en" for English) + + Returns: + Rendered prompt content as string + """ + template = prompt_env.get_template("interest_filter.jinja2") + rendered_prompt = template.render(tag_list=tag_list, language=language) + log_prompt_rendering('interest filter', rendered_prompt) + return rendered_prompt diff --git a/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 b/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 index 8253924b..6b620df9 100644 --- a/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/extracat_Pruning.jinja2 @@ -1,6 +1,6 @@ {# 对话级抽取与相关性判定模板(用于剪枝加速) - 输入:pruning_scene, dialog_text + 输入:pruning_scene, is_builtin_scene, ontology_classes, dialog_text, language 输出:严格 JSON(不要包含任何多余文本),字段: - is_related: bool,是否与所选场景相关 - times: [string],从对话中抽取的时间相关文本(日期、时间、时间段、有效期等) @@ -16,7 +16,8 @@ - 仅输出上述键;避免多余解释或字段。 #} -{% set scene_instructions = { +{# ── 内置场景的固定说明 ── #} +{% set builtin_scene_instructions = { 'education': { 'zh': '教育场景:教学、课程、考试、作业、老师/学生互动、学习资源、学校管理等。', 'en': 'Education Scenario: Teaching, courses, exams, homework, teacher/student interaction, learning resources, school management, etc.' @@ -31,16 +32,40 @@ } } %} -{% set scene_key = pruning_scene %} -{% if scene_key not in scene_instructions %} -{% set scene_key = 'education' %} +{# ── 确定最终使用的场景说明 ── #} +{% if is_builtin_scene %} + {# 内置专门场景:使用固定说明 #} + {% set scene_key = pruning_scene %} + {% if scene_key not in builtin_scene_instructions %}{% set scene_key = 'education' %}{% endif %} + {% set instruction = builtin_scene_instructions[scene_key][language] if language in ['zh', 'en'] else builtin_scene_instructions[scene_key]['zh'] %} + {% set custom_types_str = '' %} +{% else %} + {# 自定义场景:使用场景名称 + 本体类型列表构建说明 #} + {% if ontology_classes and ontology_classes | length > 0 %} + {% if language == 'en' %} + {% set custom_types_str = ontology_classes | join(', ') %} + {% set instruction = 'Custom scene "' ~ pruning_scene ~ '": The dialogue is related to this scene if it involves any of the following entity types: ' ~ custom_types_str ~ '.' %} + {% else %} + {% set custom_types_str = ontology_classes | join('、') %} + {% set instruction = '自定义场景「' ~ pruning_scene ~ '」:对话涉及以下任意实体类型时视为相关:' ~ custom_types_str ~ '。' %} + {% endif %} + {% else %} + {# 无本体类型时退化为通用说明 #} + {% if language == 'en' %} + {% set instruction = 'Custom scene "' ~ pruning_scene ~ '": Determine whether the dialogue content is relevant to this scene based on overall context.' %} + {% else %} + {% set instruction = '自定义场景「' ~ pruning_scene ~ '」:根据对话整体内容判断是否与该场景相关。' %} + {% endif %} + {% set custom_types_str = '' %} + {% endif %} {% endif %} -{% set instruction = scene_instructions[scene_key][language] if language in ['zh', 'en'] else scene_instructions[scene_key]['zh'] %} - {% if language == "zh" %} 请在下方对话全文基础上,按该场景进行一次性抽取并判定相关性: 场景说明:{{ instruction }} +{% if not is_builtin_scene and custom_types_str %} +重要提示:只要对话中出现与上述实体类型({{ custom_types_str }})相关的内容,即判定为相关(is_related=true)。 +{% endif %} 对话全文: """ @@ -60,6 +85,9 @@ {% else %} Based on the full dialogue below, perform one-time extraction and relevance determination according to this scenario: Scenario Description: {{ instruction }} +{% if not is_builtin_scene and custom_types_str %} +Important: If the dialogue contains content related to any of the entity types above ({{ custom_types_str }}), mark it as relevant (is_related=true). +{% endif %} Full Dialogue: """ diff --git a/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 new file mode 100644 index 00000000..7957bf1c --- /dev/null +++ b/api/app/core/memory/utils/prompt/prompts/interest_filter.jinja2 @@ -0,0 +1,67 @@ +{% if language == "zh" %} +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. + +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" + +Examples of inference: +- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩' +- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影' +- '晨间冥想坚持天数', '身心协同峰值' → '冥想' +- '川味可视化', '川菜' → '烹饪' +- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化' +- '吉他', '指弹', '琴谱' → '吉他' +- '跑步', '5公里', '跑鞋' → '跑步' +- '瑜伽垫', '瑜伽课' → '瑜伽' + +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步') +- If multiple tags all point to '攀岩', output '攀岩' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., '助手', '用户', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分') +- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影') + +**Output format**: Return a list of concise interest activity names in Chinese. + +**Example**: +Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间'] +Output: ['攀岩', '摄影', '冥想', '烹饪', '编程'] + +Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }} +{% else %} +You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent. + +**Step 1 - Infer the underlying interest from each tag**: +Look at each tag and ask: "What hobby or interest does this tag suggest the user has?" + +Examples of inference: +- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing' +- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography' +- 'morning meditation streak', 'mind-body peak' → 'meditation' +- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking' +- 'open source project', 'data visualization tool', 'Python' → 'programming' +- 'guitar', 'fingerpicking', 'sheet music' → 'guitar' +- 'running', '5km', 'running shoes' → 'running' + +**Step 2 - Consolidate and deduplicate**: +- Merge tags that point to the same interest into one representative label +- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation') +- If multiple tags all point to 'rock climbing', output 'rock climbing' only once + +**Step 3 - Filter out non-interest tags**: +Remove tags that do NOT suggest any hobby or interest: +- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI') +- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating') + +**Output format**: Return a list of concise interest activity names in English. + +**Example**: +Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time'] +Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming'] + +Now process the following tag list and return the inferred interest activities in English: {{ tag_list }} +{% endif %} diff --git a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 index 35619112..30b48719 100644 --- a/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 +++ b/api/app/core/memory/utils/prompt/prompts/user_summary.jinja2 @@ -14,8 +14,8 @@ Your task is to generate a comprehensive user profile based on the provided enti {% endif %} ===Inputs=== -{% if user_id %} -- User ID: {{ user_id }} +{% if user_display_name %} +- User Display Name: {{ user_display_name }} {% endif %} {% if entities %} - Core Entities & Frequency: {{ entities }} @@ -33,6 +33,20 @@ Your task is to generate a comprehensive user profile based on the provided enti 3. Avoid excessive adjectives and empty phrases 4. Strictly follow the output format specified below +{% if language == "zh" %} +**【严格人称规定】** +- 在描述用户时,必须使用"{{ user_display_name }}"作为人称 +- 绝对禁止使用用户ID(如 {{ user_id }})来称呼用户 +- 绝对禁止在摘要中出现任何形式的UUID或ID字符串 +- 如果需要指代用户,只能使用"{{ user_display_name }}"或相应的代词(他/她/TA) +{% else %} +**【STRICT PRONOUN RULES】** +- When describing the user, you MUST use "{{ user_display_name }}" as the reference +- It is ABSOLUTELY FORBIDDEN to use the user ID (such as {{ user_id }}) to refer to the user +- It is ABSOLUTELY FORBIDDEN to include any form of UUID or ID string in the summary +- If you need to refer to the user, you can ONLY use "{{ user_display_name }}" or appropriate pronouns (he/she/they) +{% endif %} + **Section-Specific Requirements:** {% if language == "zh" %} @@ -103,13 +117,13 @@ Your task is to generate a comprehensive user profile based on the provided enti {% if language == "zh" %} Example Input: -- User ID: user_12345 +- User Display Name: 张三 - Core Entities & Frequency: 产品经理 (15), AI (12), 深圳 (10), 数据分析 (8), 团队协作 (7) - Representative Statement Samples: 我在深圳从事产品经理工作已经5年了 | 我相信好的产品源于对用户需求的深刻理解 | 我喜欢在团队中起到协调作用 | 数据驱动决策是我的工作原则 Example Output: 【基本介绍】 -我是张三,一名充满热情的高级产品经理。在过去的5年里,我专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。我相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 +张三是一名充满热情的高级产品经理,在深圳工作。在过去的5年里,张三专注于AI和数据驱动的产品设计,致力于创造能够真正改善用户生活的产品。张三相信好的产品源于对用户需求的深刻理解和对技术可能性的不断探索。 【性格特点】 性格开朗,善于沟通,注重细节。喜欢在团队中起到协调作用,帮助大家达成共识。面对挑战时保持乐观,相信每个问题都有解决方案。 @@ -121,13 +135,13 @@ Example Output: "让每一个产品决策都充满温度。" {% else %} Example Input: -- User ID: user_12345 +- User Display Name: John - Core Entities & Frequency: Product Manager (15), AI (12), San Francisco (10), Data Analysis (8), Team Collaboration (7) - Representative Statement Samples: I have been working as a product manager in San Francisco for 5 years | I believe good products come from deep understanding of user needs | I enjoy playing a coordinating role in teams | Data-driven decision making is my work principle Example Output: 【Basic Introduction】 -This is a passionate senior product manager based in San Francisco. Over the past 5 years, they have focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. They believe good products stem from deep understanding of user needs and continuous exploration of technological possibilities. +John is a passionate senior product manager based in San Francisco. Over the past 5 years, John has focused on AI and data-driven product design, dedicated to creating products that truly improve users' lives. John believes good products stem from deep understanding of user needs and continuous exploration of technological possibilities. 【Personality Traits】 Outgoing personality with excellent communication skills and attention to detail. Enjoys playing a coordinating role in teams, helping everyone reach consensus. Maintains optimism when facing challenges, believing every problem has a solution. diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index f5f49af0..dba6717d 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -21,31 +21,55 @@ from pydantic import BaseModel, Field T = TypeVar("T") + class RedBearModelConfig(BaseModel): """模型配置基类""" model_name: str provider: str api_key: str base_url: Optional[str] = None + is_omni: bool = False # 是否为 Omni 模型 # 请求超时时间(秒)- 默认120秒以支持复杂的LLM调用,可通过环境变量 LLM_TIMEOUT 配置 timeout: float = Field(default_factory=lambda: float(os.getenv("LLM_TIMEOUT", "120.0"))) # 最大重试次数 - 默认2次以避免过长等待,可通过环境变量 LLM_MAX_RETRIES 配置 max_retries: int = Field(default_factory=lambda: int(os.getenv("LLM_MAX_RETRIES", "2"))) - concurrency: int = 5 # 并发限流 + concurrency: int = 5 # 并发限流 extra_params: Dict[str, Any] = {} + class RedBearModelFactory: """模型工厂类""" - + @classmethod def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() - + # 打印供应商信息用于调试 from app.core.logging_config import get_business_logger logger = get_business_logger() - logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}") + logger.debug(f"获取模型参数 - Provider: {provider}, Model: {config.model_name}, is_omni: {config.is_omni}") + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + import httpx + if not config.base_url: + config.base_url = "https://dashscope.aliyuncs.com/compatible-mode/v1" + timeout_config = httpx.Timeout( + timeout=config.timeout, + connect=60.0, + read=config.timeout, + write=60.0, + pool=10.0, + ) + return { + "model": config.model_name, + "base_url": config.base_url, + "api_key": config.api_key, + "timeout": timeout_config, + "max_retries": config.max_retries, + **config.extra_params + } if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -65,7 +89,7 @@ class RedBearModelFactory: "timeout": timeout_config, "max_retries": config.max_retries, **config.extra_params - } + } elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 @@ -82,7 +106,7 @@ class RedBearModelFactory: # region 从 base_url 或 extra_params 获取 from botocore.config import Config as BotoConfig from app.core.models.bedrock_model_mapper import normalize_bedrock_model_id - + max_pool_connections = int(os.getenv("BEDROCK_MAX_POOL_CONNECTIONS", "50")) max_retries = int(os.getenv("BEDROCK_MAX_RETRIES", "2")) # Configure with increased connection pool @@ -90,16 +114,16 @@ class RedBearModelFactory: max_pool_connections=max_pool_connections, retries={'max_attempts': max_retries, 'mode': 'adaptive'} ) - + # 标准化模型 ID(自动转换简化名称为完整 Bedrock Model ID) model_id = normalize_bedrock_model_id(config.model_name) - + params = { "model_id": model_id, "config": boto_config, **config.extra_params } - + # 解析 API key (格式: access_key_id:secret_access_key) if config.api_key and ":" in config.api_key: access_key_id, secret_access_key = config.api_key.split(":", 1) @@ -107,45 +131,52 @@ class RedBearModelFactory: params["aws_secret_access_key"] = secret_access_key elif config.api_key: params["aws_access_key_id"] = config.api_key - + # 设置 region if config.base_url: params["region_name"] = config.base_url elif "region_name" not in params: params["region_name"] = "us-east-1" # 默认区域 - + return params else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) - + @classmethod def get_rerank_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: """根据提供商获取模型参数""" provider = config.provider.lower() if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: - return { + return { "model": config.model_name, # "base_url": config.base_url, "jina_api_key": config.api_key, **config.extra_params - } + } else: raise BusinessException(f"不支持的提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) -def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType.LLM) -> type[BaseLLM]: + +def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelType.LLM) -> type[BaseLLM]: """根据模型提供商获取对应的模型类""" provider = config.provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + + # dashscope 的 omni 模型使用 OpenAI 兼容模式 + if provider == ModelProvider.DASHSCOPE and config.is_omni: + from langchain_openai import ChatOpenAI + return ChatOpenAI + + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : if type == ModelType.LLM: from langchain_openai import OpenAI - return OpenAI + return OpenAI elif type == ModelType.CHAT: from langchain_openai import ChatOpenAI return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: from langchain_community.chat_models import ChatTongyi return ChatTongyi - elif provider == ModelProvider.OLLAMA: + elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaLLM return OllamaLLM elif provider == ModelProvider.BEDROCK: @@ -155,15 +186,16 @@ def get_provider_llm_class(config:RedBearModelConfig, type: ModelType=ModelType. else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_embedding_class(provider: str) -> type[Embeddings]: """根据模型提供商获取对应的模型类""" provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_openai import OpenAIEmbeddings - return OpenAIEmbeddings + return OpenAIEmbeddings elif provider == ModelProvider.DASHSCOPE: from langchain_community.embeddings import DashScopeEmbeddings - return DashScopeEmbeddings + return DashScopeEmbeddings elif provider == ModelProvider.OLLAMA: from langchain_ollama import OllamaEmbeddings return OllamaEmbeddings @@ -173,14 +205,15 @@ def get_provider_embedding_class(provider: str) -> type[Embeddings]: else: raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) + def get_provider_rerank_class(provider: str): """根据模型提供商获取对应的模型类""" - provider = provider.lower() - if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK] : + provider = provider.lower() + if provider in [ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: from langchain_community.document_compressors import JinaRerank - return JinaRerank - # elif provider == ModelProvider.OLLAMA: + return JinaRerank + # elif provider == ModelProvider.OLLAMA: # from langchain_ollama import OllamaEmbeddings # return OllamaEmbeddings else: - raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) \ No newline at end of file + raise BusinessException(f"不支持的模型提供商: {provider}", code=BizCode.PROVIDER_NOT_SUPPORTED) diff --git a/api/app/core/models/scripts/bedrock_models.yaml b/api/app/core/models/scripts/bedrock_models.yaml index e5b91d1c..2c0ab757 100644 --- a/api/app/core/models/scripts/bedrock_models.yaml +++ b/api/app/core/models/scripts/bedrock_models.yaml @@ -6,6 +6,8 @@ models: description: AI21 Labs大语言模型,completion生成模式,256000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: bedrock @@ -15,6 +17,9 @@ models: description: Amazon Nova大语言模型,支持智能体思考、工具调用、流式工具调用、视觉能力,300000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -28,6 +33,9 @@ models: description: Anthropic Claude大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、文档处理,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -42,6 +50,8 @@ models: description: Cohere大语言模型,支持智能体思考、工具调用、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -54,6 +64,9 @@ models: description: DeepSeek大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -67,6 +80,8 @@ models: description: Meta Llama大语言模型,支持智能体思考、工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -78,6 +93,8 @@ models: description: Mistral AI大语言模型,支持智能体思考、工具调用,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -89,6 +106,8 @@ models: description: OpenAI大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -101,6 +120,8 @@ models: description: Qwen大语言模型,支持智能体思考、工具调用、流式工具调用,32768上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -113,6 +134,8 @@ models: description: amazon.rerank-v1:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -122,6 +145,8 @@ models: description: cohere.rerank-v3-5:0重排序模型,5120上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: bedrock @@ -131,6 +156,9 @@ models: description: amazon.nova-2-multimodal-embeddings-v1:0文本嵌入模型,支持视觉能力,8192上下文窗口 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 文本嵌入模型 - vision @@ -141,6 +169,8 @@ models: description: amazon.titan-embed-text-v1文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -150,6 +180,8 @@ models: description: amazon.titan-embed-text-v2:0文本嵌入模型,8192上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -159,6 +191,8 @@ models: description: Cohere Embed 3 English文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 logo: bedrock @@ -168,6 +202,8 @@ models: description: Cohere Embed 3 Multilingual文本嵌入模型,512上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本嵌入模型 - logo: bedrock + logo: bedrock \ No newline at end of file diff --git a/api/app/core/models/scripts/dashscope_models.yaml b/api/app/core/models/scripts/dashscope_models.yaml index af1c3619..89a16966 100644 --- a/api/app/core/models/scripts/dashscope_models.yaml +++ b/api/app/core/models/scripts/dashscope_models.yaml @@ -6,6 +6,8 @@ models: description: DeepSeek-R1-Distill-Qwen-14B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -16,6 +18,8 @@ models: description: DeepSeek-R1-Distill-Qwen-32B大语言模型,支持智能体思考,32000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -26,6 +30,8 @@ models: description: DeepSeek-R1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -36,6 +42,8 @@ models: description: DeepSeek-V3.1大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -46,6 +54,8 @@ models: description: DeepSeek-V3.2-exp实验版大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -56,6 +66,8 @@ models: description: DeepSeek-V3.2大语言模型,支持智能体思考,131072超大上下文窗口,对话模式,支持丰富生成参数调节 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -66,6 +78,8 @@ models: description: DeepSeek-V3大语言模型,支持智能体思考,64000上下文窗口,对话模式,支持文本与JSON格式输出 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -76,6 +90,8 @@ models: description: farui-plus大语言模型,支持多工具调用、智能体思考、流式工具调用,12288上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +104,8 @@ models: description: GLM-4.7大语言模型,支持多工具调用、智能体思考、流式工具调用,202752超大上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +118,9 @@ models: description: qvq-max-latest大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -112,6 +133,9 @@ models: description: qvq-max大语言模型,支持视觉、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - vision @@ -124,6 +148,8 @@ models: description: qwen-coder-turbo-0919代码专用大语言模型,支持智能体思考,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -135,6 +161,8 @@ models: description: qwen-max-latest大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -147,6 +175,8 @@ models: description: qwen-max-longcontext长上下文大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -159,6 +189,8 @@ models: description: qwen-max大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -171,6 +203,8 @@ models: description: qwen-mt-plus多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -182,6 +216,8 @@ models: description: qwen-mt-turbo轻量化多语言翻译大语言模型,支持智能体思考,16384上下文窗口,对话模式,支持多语种互译与领域翻译适配 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 翻译模型 @@ -193,6 +229,8 @@ models: description: qwen-plus-0112大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -205,6 +243,8 @@ models: description: qwen-plus-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -217,6 +257,8 @@ models: description: qwen-plus-0723大语言模型,支持多工具调用、智能体思考、流式工具调用,32000上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -229,6 +271,8 @@ models: description: qwen-plus-0806大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -241,6 +285,8 @@ models: description: qwen-plus-0919大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -253,6 +299,8 @@ models: description: qwen-plus-1125大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -265,6 +313,8 @@ models: description: qwen-plus-1127大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,支持联网搜索,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -277,6 +327,8 @@ models: description: qwen-plus-1220大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -289,6 +341,10 @@ models: description: qwen-vl-max多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -302,6 +358,10 @@ models: description: qwen-vl-plus-0809多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -315,6 +375,10 @@ models: description: qwen-vl-plus-2025-01-02多模态大模型,支持视觉理解、智能体思考、视频理解,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -328,6 +392,10 @@ models: description: qwen-vl-plus-2025-01-25多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -341,6 +409,10 @@ models: description: qwen-vl-plus-latest多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -354,6 +426,10 @@ models: description: qwen-vl-plus多模态大模型,支持视觉理解、智能体思考、视频理解,131072上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -367,6 +443,8 @@ models: description: qwen2.5-0.5b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,32768上下文窗口,对话模式,未废弃 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -379,6 +457,8 @@ models: description: qwen3-14b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -391,6 +471,8 @@ models: description: qwen3-235b-a22b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -403,6 +485,8 @@ models: description: qwen3-235b-a22b-thinking-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -415,6 +499,8 @@ models: description: qwen3-235b-a22b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -427,6 +513,8 @@ models: description: qwen3-30b-a3b-instruct-2507大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -439,6 +527,8 @@ models: description: qwen3-30b-a3b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -451,6 +541,8 @@ models: description: qwen3-32b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -463,6 +555,8 @@ models: description: qwen3-4b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -475,6 +569,8 @@ models: description: qwen3-8b大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -487,6 +583,8 @@ models: description: qwen3-coder-30b-a3b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -498,6 +596,8 @@ models: description: qwen3-coder-480b-a35b-instruct大语言模型,支持智能体思考,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -509,6 +609,8 @@ models: description: qwen3-coder-plus-2025-09-23大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -520,6 +622,8 @@ models: description: qwen3-coder-plus大语言模型,支持智能体思考,1000000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - 代码模型 @@ -531,6 +635,8 @@ models: description: qwen3-max-2025-09-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -544,6 +650,8 @@ models: description: qwen3-max-2026-01-23大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -557,6 +665,8 @@ models: description: qwen3-max-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -569,6 +679,8 @@ models: description: qwen3-max大语言模型,支持多工具调用、智能体思考、流式工具调用,262144上下文窗口,对话模式,支持联网搜索 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -582,6 +694,8 @@ models: description: qwen3-next-80b-a3b-instruct大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -594,6 +708,8 @@ models: description: qwen3-next-80b-a3b-thinking大语言模型,支持多工具调用、智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -606,6 +722,11 @@ models: description: qwen3-omni-flash-2025-12-01多模态大语言模型,支持视觉、智能体思考、视频、音频能力,65536上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + - audio + is_omni: true tags: - 大语言模型 - 多模态模型 @@ -620,6 +741,10 @@ models: description: qwen3-vl-235b-a22b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -635,6 +760,10 @@ models: description: qwen3-vl-235b-a22b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -650,6 +779,10 @@ models: description: qwen3-vl-30b-a3b-instruct多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -665,6 +798,10 @@ models: description: qwen3-vl-30b-a3b-thinking多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -680,6 +817,10 @@ models: description: qwen3-vl-flash多模态大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉、视频能力,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -695,6 +836,10 @@ models: description: qwen3-vl-plus-2025-09-23多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -708,6 +853,10 @@ models: description: qwen3-vl-plus多模态大语言模型,支持视觉、智能体思考、视频能力,262144上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - video + is_omni: false tags: - 大语言模型 - 多模态模型 @@ -721,6 +870,8 @@ models: description: qwq-32b大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -732,6 +883,8 @@ models: description: qwq-plus-0305大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -743,6 +896,8 @@ models: description: qwq-plus大语言模型,支持智能体思考、流式工具调用,131072上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -754,6 +909,8 @@ models: description: gte-rerank-v2重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -763,6 +920,8 @@ models: description: gte-rerank重排序模型,4000上下文窗口 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 重排序模型 logo: dashscope @@ -772,6 +931,9 @@ models: description: multimodal-embedding-v1多模态嵌入模型,支持视觉能力,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 嵌入模型 - 多模态模型 @@ -783,6 +945,8 @@ models: description: text-embedding-v1文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -793,6 +957,8 @@ models: description: text-embedding-v2文本嵌入模型,2048上下文窗口,最大分块数25 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -803,6 +969,8 @@ models: description: text-embedding-v3文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 @@ -813,7 +981,9 @@ models: description: text-embedding-v4文本嵌入模型,8192上下文窗口,最大分块数10 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 嵌入模型 - 文本嵌入 - logo: dashscope + logo: dashscope \ No newline at end of file diff --git a/api/app/core/models/scripts/loader.py b/api/app/core/models/scripts/loader.py index a14d3268..e4462efa 100644 --- a/api/app/core/models/scripts/loader.py +++ b/api/app/core/models/scripts/loader.py @@ -6,7 +6,7 @@ from typing import Callable import yaml from sqlalchemy.orm import Session -from app.models.models_model import ModelBase, ModelProvider +from app.models.models_model import ModelBase, ModelProvider, ModelConfig def _load_yaml_config(provider: ModelProvider) -> list[dict]: @@ -55,6 +55,15 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) print(f"\n正在加载 {provider.value} 的 {len(models)} 个模型...") for model_data in models: + config_sync_fields = { + "logo": None, + "capability": None, + "is_omni": None, + "name": None, + "provider": None, + "type": None, + "description": None + } try: # 检查模型是否已存在 existing = db.query(ModelBase).filter( @@ -66,6 +75,40 @@ def load_models(db: Session, providers: list[str] = None, silent: bool = False) # 更新现有模型配置 for key, value in model_data.items(): setattr(existing, key, value) + + # 更新绑定了该 model_id 的 ModelConfig 和 ModelApiKey + sync_fields = [k for k in config_sync_fields.keys() if k in model_data] + if sync_fields: + # 批量更新 ModelConfig + update_kwargs = {k: model_data[k] for k in sync_fields} + db.query(ModelConfig).filter(ModelConfig.model_id == existing.id).update( + update_kwargs, + synchronize_session=False + ) + + # 更新 ModelApiKey 的 capability 和 is_omni + if 'capability' in model_data or 'is_omni' in model_data: + from app.models.models_model import ModelApiKey, model_config_api_key_association + api_key_update = {} + if 'capability' in model_data: + api_key_update['capability'] = model_data['capability'] + if 'is_omni' in model_data: + api_key_update['is_omni'] = model_data['is_omni'] + + if api_key_update: + # 查找所有关联的 API Key + api_key_ids = db.query(model_config_api_key_association.c.api_key_id).join( + ModelConfig, + ModelConfig.id == model_config_api_key_association.c.model_config_id + ).filter(ModelConfig.model_id == existing.id).distinct().all() + + if api_key_ids: + api_key_ids = [aid[0] for aid in api_key_ids] + db.query(ModelApiKey).filter(ModelApiKey.id.in_(api_key_ids)).update( + api_key_update, + synchronize_session=False + ) + db.commit() if not silent: print(f"更新成功: {model_data['name']}") diff --git a/api/app/core/models/scripts/openai_models.yaml b/api/app/core/models/scripts/openai_models.yaml index 68c63ee2..7f6d3a51 100644 --- a/api/app/core/models/scripts/openai_models.yaml +++ b/api/app/core/models/scripts/openai_models.yaml @@ -6,12 +6,19 @@ models: description: chatgpt-4o-latest大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + - audio + - video + is_omni: true tags: - 大语言模型 - multi-tool-call - agent-thought - stream-tool-call - vision + - audio + - video logo: openai - name: gpt-3.5-turbo-0125 type: llm @@ -19,6 +26,8 @@ models: description: gpt-3.5-turbo-0125大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -31,6 +40,8 @@ models: description: gpt-3.5-turbo-1106大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -43,6 +54,8 @@ models: description: gpt-3.5-turbo-16k大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -55,6 +68,8 @@ models: description: gpt-3.5-turbo-instruct大语言模型,4096上下文窗口,文本补全模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 logo: openai @@ -64,6 +79,8 @@ models: description: gpt-3.5-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用,16385上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -76,6 +93,8 @@ models: description: gpt-4-0125-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -88,6 +107,8 @@ models: description: gpt-4-1106-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -100,6 +121,9 @@ models: description: gpt-4-turbo-2024-04-09大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -113,6 +137,8 @@ models: description: gpt-4-turbo-preview大语言模型,支持多工具调用、智能体思考、流式工具调用,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -125,6 +151,9 @@ models: description: gpt-4-turbo大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力,128000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -138,6 +167,8 @@ models: description: o1-preview大语言模型,支持智能体思考,128000上下文窗口,对话模式,已废弃 is_deprecated: true is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -148,6 +179,9 @@ models: description: o1大语言模型,支持多工具调用、智能体思考、流式工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - multi-tool-call @@ -162,6 +196,9 @@ models: description: o3-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -176,6 +213,8 @@ models: description: o3-mini-2025-01-31大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -189,6 +228,8 @@ models: description: o3-mini大语言模型,支持智能体思考、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 大语言模型 - agent-thought @@ -202,6 +243,9 @@ models: description: o3-pro-2025-06-10大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -215,6 +259,9 @@ models: description: o3-pro大语言模型,支持智能体思考、工具调用、视觉能力、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -228,6 +275,9 @@ models: description: o3大语言模型,支持智能体思考、视觉能力、工具调用、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -242,6 +292,9 @@ models: description: o4-mini-2025-04-16大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -256,6 +309,9 @@ models: description: o4-mini大语言模型,支持智能体思考、工具调用、视觉能力、流式工具调用、结构化输出,200000上下文窗口,对话模式 is_deprecated: false is_official: true + capability: + - vision + is_omni: false tags: - 大语言模型 - agent-thought @@ -270,6 +326,8 @@ models: description: text-embedding-3-large文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -279,6 +337,8 @@ models: description: text-embedding-3-small文本向量模型,8191上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 logo: openai @@ -288,6 +348,8 @@ models: description: text-embedding-ada-002文本向量模型,8097上下文窗口,最大分块数32 is_deprecated: false is_official: true + capability: [] + is_omni: false tags: - 文本向量模型 - logo: openai + logo: openai \ No newline at end of file diff --git a/api/app/core/workflow/adapters/__init__.py b/api/app/core/workflow/adapters/__init__.py new file mode 100644 index 00000000..141aa4ab --- /dev/null +++ b/api/app/core/workflow/adapters/__init__.py @@ -0,0 +1,8 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:54 +from app.core.workflow.adapters.dify.dify_adapter import DifyAdapter +from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter + +__all__ = ["DifyAdapter", "MemoryBearAdapter"] diff --git a/api/app/core/workflow/adapters/base_adapter.py b/api/app/core/workflow/adapters/base_adapter.py new file mode 100644 index 00000000..49321b89 --- /dev/null +++ b/api/app/core/workflow/adapters/base_adapter.py @@ -0,0 +1,90 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 15:58 +from abc import ABC, abstractmethod +from collections import defaultdict +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, Field + +from app.core.workflow.adapters.errors import ExceptionDefineition +from app.schemas.workflow_schema import ( + EdgeDefinition, + NodeDefinition, + VariableDefinition, + ExecutionConfig, + TriggerConfig +) + + +class PlatformType(StrEnum): + MEMORY_BEAR = "memory_bear" + DIFY = "dify" + COZE = "coze" + + +class PlatformMetadata(BaseModel): + platform_name: str + version: str + support_node_types: list[str] + + +class WorkflowParserResult(BaseModel): + success: bool + platform: PlatformMetadata + execution_config: ExecutionConfig + origin_config: dict[str, Any] + trigger: TriggerConfig | None + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class WorkflowImportResult(BaseModel): + success: bool + temp_id: str | None = Field(..., description="cache id") + workflow_id: str | None = Field(..., description="workflow id") + edges: list[EdgeDefinition] = Field(default_factory=list) + nodes: list[NodeDefinition] = Field(default_factory=list) + variables: list[VariableDefinition] = Field(default_factory=list) + warnings: list[ExceptionDefineition] = Field(default_factory=list) + errors: list[ExceptionDefineition] = Field(default_factory=list) + + +class BasePlatformAdapter(ABC): + def __init__(self, config: dict[str, Any]): + self.config = config + self.nodes: list[NodeDefinition] = [] + self.edges: list[EdgeDefinition] = [] + self.conv_variables: list[VariableDefinition] = [] + + self.errors = [] + self.warnings = [] + + self.branch_node_cache = defaultdict(list) + self.error_branch_node_cache = [] + + self.node_output_map = {} + + @abstractmethod + def get_metadata(self) -> PlatformMetadata: + """get platform metadata""" + pass + + @abstractmethod + def validate_config(self) -> bool: + """platform configuration validate""" + pass + + @abstractmethod + def parse_workflow(self) -> WorkflowParserResult: + """parse platform configuration to local config""" + pass + + @abstractmethod + def map_node_type(self, platform_node_type: str) -> str: + pass diff --git a/api/app/core/workflow/adapters/base_converter.py b/api/app/core/workflow/adapters/base_converter.py new file mode 100644 index 00000000..eebde971 --- /dev/null +++ b/api/app/core/workflow/adapters/base_converter.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 14:32 +from abc import ABC, abstractmethod + +from app.core.workflow.variable.base_variable import DEFAULT_VALUE, VariableType + + +class BaseConverter(ABC): + @staticmethod + def _convert_string(var): + try: + return str(var) + except: + return DEFAULT_VALUE(VariableType.STRING) + + @staticmethod + def _convert_boolean(var): + try: + return bool(var) + except: + return DEFAULT_VALUE(VariableType.BOOLEAN) + + @staticmethod + def _convert_number(var): + try: + return float(var) + except: + return DEFAULT_VALUE(VariableType.NUMBER) + + @staticmethod + def _convert_object(var): + try: + return dict(var) + except: + return DEFAULT_VALUE(VariableType.OBJECT) + + @staticmethod + @abstractmethod + def _convert_file(var): + pass + + @staticmethod + def _convert_array_string(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_STRING) + + @staticmethod + def _convert_array_number(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_NUMBER) + + @staticmethod + def _convert_array_boolean(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_BOOLEAN) + + @staticmethod + def _convert_array_object(var): + try: + return list(var) + except: + return DEFAULT_VALUE(VariableType.ARRAY_OBJECT) + + @staticmethod + @abstractmethod + def _convert_array_file(var): + pass diff --git a/api/app/core/workflow/adapters/dify/__init__.py b/api/app/core/workflow/adapters/dify/__init__.py new file mode 100644 index 00000000..7774dcaa --- /dev/null +++ b/api/app/core/workflow/adapters/dify/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:20 diff --git a/api/app/core/workflow/adapters/dify/converter.py b/api/app/core/workflow/adapters/dify/converter.py new file mode 100644 index 00000000..32d420b5 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/converter.py @@ -0,0 +1,734 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 18:21 +import base64 +import re +from typing import Any +from urllib.parse import quote + +from app.core.workflow.adapters.base_converter import BaseConverter +from app.core.workflow.adapters.errors import UnsupportVariableType, UnknowModelWarning, ExceptionDefineition, \ + ExceptionType +from app.core.workflow.nodes.assigner import AssignerNodeConfig +from app.core.workflow.nodes.assigner.config import AssignmentItem +from app.core.workflow.nodes.base_config import VariableDefinition, BaseNodeConfig +from app.core.workflow.nodes.code import CodeNodeConfig +from app.core.workflow.nodes.code.config import InputVariable, OutputVariable +from app.core.workflow.nodes.configs import StartNodeConfig, LLMNodeConfig +from app.core.workflow.nodes.cycle_graph import LoopNodeConfig, IterationNodeConfig +from app.core.workflow.nodes.cycle_graph.config import ConditionDetail as LoopConditionDetail, ConditionsConfig, \ + CycleVariable +from app.core.workflow.nodes.end import EndNodeConfig +from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, AssignmentOperator, HttpAuthType, \ + HttpContentType, HttpErrorHandle +from app.core.workflow.nodes.http_request import HttpRequestNodeConfig +from app.core.workflow.nodes.http_request.config import HttpAuthConfig, HttpContentTypeConfig, HttpFormData, \ + HttpTimeOutConfig, HttpRetryConfig, HttpErrorDefaultTamplete, HttpErrorHandleConfig +from app.core.workflow.nodes.if_else import IfElseNodeConfig +from app.core.workflow.nodes.if_else.config import ConditionDetail, ConditionBranchConfig +from app.core.workflow.nodes.jinja_render import JinjaRenderNodeConfig +from app.core.workflow.nodes.jinja_render.config import VariablesMappingConfig +from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig +from app.core.workflow.nodes.llm.config import MemoryWindowSetting, MessageConfig +from app.core.workflow.nodes.parameter_extractor import ParameterExtractorNodeConfig +from app.core.workflow.nodes.parameter_extractor.config import ParamsConfig +from app.core.workflow.nodes.question_classifier import QuestionClassifierNodeConfig +from app.core.workflow.nodes.question_classifier.config import ClassifierConfig +from app.core.workflow.nodes.variable_aggregator import VariableAggregatorNodeConfig +from app.core.workflow.variable.base_variable import VariableType, DEFAULT_VALUE + + +class DifyConverter(BaseConverter): + errors: list + warnings: list + branch_node_cache: dict + error_branch_node_cache: list + node_output_map: dict + + def __init__(self): + self.CONFIG_CONVERT_MAP = { + "start": self.convert_start_node_config, + "llm": self.convert_llm_node_config, + "answer": self.convert_end_node_config, + "if-else": self.convert_if_else_node_config, + "loop": self.convert_loop_node_config, + "iteration": self.convert_iteration_node_config, + "assigner": self.convert_assigner_node_config, + "code": self.convert_code_node_config, + "http-request": self.convert_http_node_config, + "template-transform": self.convert_jinja_render_node_config, + "knowledge-retrieval": self.convert_knowledge_node_config, + "parameter-extractor": self.convert_parameter_extractor_node_config, + "question-classifier": self.convert_question_classifier_node_config, + "variable-aggregator": self.convert_variable_aggregator_node_config, + "tool": self.convert_tool_node_config, + "loop-start": lambda x: {}, + "iteration-start": lambda x: {}, + "loop-end": lambda x: {}, + } + + def get_node_convert(self, node_type): + func = self.CONFIG_CONVERT_MAP.get(node_type, lambda x: {}) + return func + + def config_validate( + self, + node_id: str, + node_name: str, + config: type[BaseNodeConfig], + value: dict + ): + try: + return config.model_validate(value) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node_id, + node_name=node_name, + detail=str(e) + )) + return None + + @staticmethod + def is_variable(expression) -> bool: + return bool(re.match(r"\{\{#(.*?)#}}", expression)) + + def process_var_selector(self, var_selector): + if not var_selector: + return "" + selector = var_selector.split('.') + if len(selector) not in [2, 3] and var_selector != "context": + raise Exception(f"invalid variable selector: {var_selector}") + if len(selector) == 3: + selector = selector[1:] + if selector[0] == "conversation": + selector[0] = "conv" + var_selector = ".".join(selector) + mapping = { + "sys.query": "sys.message" + } | self.node_output_map + + var_selector = mapping.get(var_selector, var_selector) + return var_selector + + def _process_list_variable_litearl(self, variable_selector: list) -> str | None: + if not self.process_var_selector(".".join(variable_selector)): + return None + return "{{" + self.process_var_selector(".".join(variable_selector)) + "}}" + + def trans_variable_format(self, content): + pattern = re.compile(r"\{\{#(.*?)#}}") + + def replacer(match: re.Match) -> str: + raw_name = match.group(1) + new_name = self.process_var_selector(raw_name) + return f"{{{{{new_name}}}}}" + + return pattern.sub(replacer, content) + + @staticmethod + def _convert_file(var): + return None + + @staticmethod + def _convert_array_file(var): + return [] + + @staticmethod + def variable_type_map(source_type) -> VariableType | None: + type_map = { + "file": VariableType.FILE, + "paragraph": VariableType.STRING, + "text-input": VariableType.STRING, + "number": VariableType.NUMBER, + "checkbox": VariableType.BOOLEAN, + "file-list": VariableType.ARRAY_FILE, + "select": VariableType.STRING, + "integer": VariableType.NUMBER, + "float": VariableType.NUMBER, + } + var_type = type_map.get(source_type, source_type) + return var_type + + def convert_variable_type(self, target_type: VariableType, origin_value: Any): + if not origin_value: + return DEFAULT_VALUE(target_type) + try: + match target_type: + case VariableType.STRING: + return self._convert_string(origin_value) + case VariableType.NUMBER: + return self._convert_number(origin_value) + case VariableType.BOOLEAN: + return self._convert_boolean(origin_value) + case VariableType.FILE: + return self._convert_file(origin_value) + case VariableType.ARRAY_FILE: + return self._convert_array_file(origin_value) + case _: + return origin_value + except: + raise Exception(f"convert variable failed: {target_type}") + + @staticmethod + def convert_compare_operator(operator): + operator_map = { + "is": ComparisonOperator.EQ, + "is not": ComparisonOperator.NE, + "=": ComparisonOperator.EQ, + "≠": ComparisonOperator.NE, + ">": ComparisonOperator.GT, + "<": ComparisonOperator.LT, + "≥": ComparisonOperator.GE, + "≤": ComparisonOperator.LE, + "not empty": ComparisonOperator.NOT_EMPTY, + "start with": ComparisonOperator.START_WITH, + "end with": ComparisonOperator.END_WITH, + "not contains": ComparisonOperator.NOT_CONTAINS, + "exists": ComparisonOperator.NOT_EMPTY, + "not exists": ComparisonOperator.EMPTY + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_assignment_operator(operator): + operator_map = { + "+=": AssignmentOperator.ADD, + "-=": AssignmentOperator.SUBTRACT, + "*=": AssignmentOperator.MULTIPLY, + "/=": AssignmentOperator.DIVIDE, + "over-write": AssignmentOperator.COVER, + "remove-last": AssignmentOperator.REMOVE_LAST, + "remove-first": AssignmentOperator.REMOVE_FIRST, + "set": AssignmentOperator.ASSIGN, + } + return operator_map.get(operator, operator) + + @staticmethod + def convert_http_auth_type(auth_type): + auth_type_map = { + "no-auth": HttpAuthType.NONE, + "bearer": HttpAuthType.BEARER, + "basic": HttpAuthType.BASIC, + "custom": HttpAuthType.CUSTOM, + } + return auth_type_map.get(auth_type, auth_type) + + @staticmethod + def convert_http_content_type(content_type): + content_type_map = { + "none": HttpContentType.NONE, + "form-data": HttpContentType.FROM_DATA, + "x-www-form-urlencoded": HttpContentType.WWW_FORM, + "json": HttpContentType.JSON, + "raw-text": HttpContentType.RAW, + "binary": HttpContentType.BINARY, + } + return content_type_map.get(content_type, content_type) + + @staticmethod + def convert_http_error_handle_type(handle_type): + handle_type_map = { + "none": HttpErrorHandle.NONE, + "fail-branch": HttpErrorHandle.BRANCH, + "default-value": HttpErrorHandle.DEFAULT, + } + return handle_type_map.get(handle_type, handle_type) + + def convert_start_node_config(self, node: dict) -> dict: + node_data = node["data"] + start_vars = [] + for var in node_data["variables"]: + var_type = self.variable_type_map(var["type"]) + if not var_type: + self.errors.append( + UnsupportVariableType( + scope=node["id"], + name=var["variable"], + var_type=var["type"], + node_id=node["id"], + node_name=node_data["title"] + ) + ) + continue + + if var_type in ["file", "array[file]"]: + self.errors.append( + ExceptionDefineition( + type=ExceptionType.VARIABLE, + node_id=node["id"], + node_name=node_data["title"], + name=var["variable"], + detail=f"Unsupported Variable type for start node: {var_type}" + ) + ) + continue + + var_def = VariableDefinition( + name=var["variable"], + type=var_type, + required=var["required"], + default=self.convert_variable_type( + var_type, var.get("default") + ), + description=var["label"], + max_length=var.get("max_length", 50), + ) + start_vars.append(var_def) + result = StartNodeConfig.model_construct( + variables=start_vars + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], StartNodeConfig, result) + return result + + def convert_question_classifier_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + categories = [] + for category in node_data["classes"]: + self.branch_node_cache[node["id"]].append(category["id"]) + categories.append( + ClassifierConfig.model_construct( + class_name=category["name"], + ) + ) + + result = QuestionClassifierNodeConfig.model_construct( + input_variable=self._process_list_variable_litearl(node_data.get("query_variable_selector")), + user_supplement_prompt=self.trans_variable_format(node_data.get("instructions", "")), + categories=categories, + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], QuestionClassifierNodeConfig, result) + return result + + def convert_llm_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + context = self._process_list_variable_litearl(node_data["context"]["variable_selector"]) + memory = MemoryWindowSetting( + enable=bool(node_data.get("memory")), + enable_window=bool(node_data.get("memory", {}).get("window", {}).get("enabled", False)), + window_size=int(node_data.get("memory", {}).get("window", {}).get("size", 20)) + ) + messages = [] + for message in node_data["prompt_template"]: + messages.append( + MessageConfig( + role=message["role"], + content=self.trans_variable_format(message["text"]) + ) + ) + if memory.enable: + messages.append( + MessageConfig( + role="user", + content=self.trans_variable_format( + node_data["memory"].get("query_prompt_template") or "{{#sys.query#}}" + ) + ) + ) + vision = node_data["vision"]["enabled"] + vision_input = self._process_list_variable_litearl( + node_data["vision"]["configs"]["variable_selector"] + ) if vision else None + result = LLMNodeConfig.model_construct( + model_id=None, + context=context, + memory=memory, + vision=vision, + vision_input=vision_input, + messages=messages + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], LLMNodeConfig, result) + return result + + def convert_end_node_config(self, node: dict) -> dict: + node_data = node["data"] + result = EndNodeConfig.model_construct( + output=self.trans_variable_format(node_data.get("answer", "")), + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], EndNodeConfig, result) + return result + + def convert_if_else_node_config(self, node: dict) -> dict: + node_data = node["data"] + cases = [] + for case in node_data["cases"]: + case_id = case.get("id") or case.get("case_id") + logical_operator = case["logical_operator"] + conditions = [] + for condition in case["conditions"]: + right_value = condition["value"] + condition_detail = ConditionDetail( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left="{{" + self.process_var_selector(".".join(condition["variable_selector"])) + "}}", + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + conditions.append(condition_detail) + cases.append( + ConditionBranchConfig( + logical_operator=logical_operator, + expressions=conditions + ) + ) + self.branch_node_cache[node["id"]].append(case_id) + result = IfElseNodeConfig.model_construct( + cases=cases + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], IfElseNodeConfig, result) + return result + + def convert_loop_node_config(self, node: dict) -> dict: + node_data = node["data"] + logical_operator = node_data["logical_operator"] + conditions = [] + for condition in node_data["break_conditions"]: + right_value = condition["value"] + conditions.append( + LoopConditionDetail.model_construct( + operator=self.convert_compare_operator(condition["comparison_operator"]), + left=self._process_list_variable_litearl(condition["variable_selector"]), + right=self.trans_variable_format( + right_value + ) if isinstance(right_value, str) and self.is_variable(right_value) else self.convert_variable_type( + self.variable_type_map(condition["varType"]), + condition["value"] + ), + input_type=ValueInputType.VARIABLE + if isinstance(right_value, str) and self.is_variable(right_value) else ValueInputType.CONSTANT, + ) + ) + condition_config = ConditionsConfig.model_construct( + logical_operator=logical_operator, + expressions=conditions + ) + loop_variables = [] + for variable in node_data["loop_variables"]: + right_input_type = variable["value_type"] + right_value_type = self.variable_type_map(variable["var_type"]) + if right_input_type == ValueInputType.VARIABLE: + right_value = self._process_list_variable_litearl(variable.get("value", "")) + else: + right_value = self.convert_variable_type(right_value_type, variable.get("value", "")) + loop_variables.append( + CycleVariable( + name=variable["label"], + type=right_value_type, + value=right_value, + input_type=right_input_type + ) + ) + result = LoopNodeConfig.model_construct( + condition=condition_config, + cycle_vars=loop_variables, + max_loop=node_data.get("loop_count", 10) + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], LoopNodeConfig, result) + return result + + def convert_iteration_node_config(self, node: dict) -> dict: + node_data = node["data"] + result = IterationNodeConfig.model_construct( + input=self._process_list_variable_litearl(node_data["iterator_selector"]), + parallel=node_data["is_parallel"], + parallel_count=node_data["parallel_nums"], + output=self._process_list_variable_litearl(node_data["output_selector"]), + output_type=self.variable_type_map(node_data.get("output_type")), + flatten=node_data["flatten_output"], + ).model_dump() + + self.config_validate(node["id"], node["data"]["title"], IterationNodeConfig, result) + return result + + def convert_assigner_node_config(self, node: dict) -> dict: + node_data = node["data"] + assignments = [] + for assignment in node_data["items"]: + if assignment.get("operation") is None or assignment.get("value") is None: + continue + assignments.append( + AssignmentItem( + variable_selector=self._process_list_variable_litearl(assignment["variable_selector"]), + value=self._process_list_variable_litearl( + assignment["value"] + ) if assignment["input_type"] == ValueInputType.VARIABLE else assignment["value"], + operation=self.convert_assignment_operator(assignment["operation"]) + ) + ) + result = AssignerNodeConfig.model_construct( + assignments=assignments + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], AssignerNodeConfig, result) + return result + + def convert_code_node_config(self, node: dict) -> dict: + node_data = node["data"] + input_variables = [] + for input_variable in node_data["variables"]: + input_variables.append( + InputVariable.model_construct( + name=input_variable["variable"], + variable=self._process_list_variable_litearl(input_variable["value_selector"]), + ) + ) + + output_variables = [] + for output_variable in node_data["outputs"]: + output_variables.append( + OutputVariable.model_construct( + name=output_variable, + type=node_data["outputs"][output_variable]["type"], + ) + ) + + code = base64.b64encode(quote(node_data["code"]).encode("utf-8")).decode("utf-8") + + result = CodeNodeConfig.model_construct( + input_variables=input_variables, + language=node_data["code_language"], + output_variables=output_variables, + code=code + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], CodeNodeConfig, result) + return result + + def convert_http_node_config(self, node: dict) -> dict: + node_data = node["data"] + if node_data["authorization"]["type"] != 'no-auth': + auth_type = self.convert_http_auth_type(node_data["authorization"]["config"]["type"]) + auth_config = HttpAuthConfig.model_construct( + auth_type=auth_type, + header=node_data["authorization"]["config"].get("header"), + api_key=node_data["authorization"]["config"].get("api_key"), + ) + else: + auth_config = HttpAuthConfig() + + content_type = self.convert_http_content_type(node_data["body"]["type"]) + if content_type == HttpContentType.FROM_DATA: + body_content = [] + for content in node_data["body"]["data"]: + body_content.append( + HttpFormData( + key=self.trans_variable_format(content["key"]), + type=content["type"], + value=self.trans_variable_format(content["value"]), + ) + ) + elif content_type == HttpContentType.WWW_FORM: + body_content = {} + for content in node_data["body"]["data"]: + body_content[ + self.trans_variable_format(content["key"]) + ] = self.trans_variable_format(content["value"]) + else: + if node_data["body"]["data"]: + body_content = (node_data["body"]["data"][0].get("value") or + self._process_list_variable_litearl(node_data["body"]["data"][0].get("file"))) + else: + body_content = "" + + headers = {} + for header in node_data.get("headers", "").split("\n"): + if not header: + continue + + key_value = header.split(":") + if len(key_value) == 2: + headers[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {header}", + )) + + params = {} + for param in node_data.get("params", "").split("\n"): + if not param: + continue + + key_value = param.split(":") + if len(key_value) == 2: + params[ + self.trans_variable_format(key_value[0]) + ] = self.trans_variable_format(key_value[1]) + else: + self.warnings.append(ExceptionDefineition( + type=ExceptionType.CONFIG, + node_id=node["id"], + node_name=node_data["title"], + detail=f"Invalid header/param - {param}", + )) + + error_handle_type = self.convert_http_error_handle_type( + node_data.get("error_strategy", "none") + ) + default_value = None + if error_handle_type == HttpErrorHandle.DEFAULT: + default_body = "" + default_header = {} + default_status_code = 0 + for var in node_data.get("default_value") or []: + if var["key"] == "body": + default_body = var["value"] + elif var["key"] == "header": + default_header = var["value"] + elif var["key"] == "status_code": + default_status_code = var["value"] + default_value = HttpErrorDefaultTamplete( + body=default_body, + headers=default_header, + status_code=default_status_code, + ) + + self.error_branch_node_cache.append(node['id']) + result = HttpRequestNodeConfig.model_construct( + method=node_data["method"].upper(), + url=node_data["url"], + auth=auth_config, + body=HttpContentTypeConfig.model_construct( + content_type=self.convert_http_content_type(node_data["body"]["type"]), + data=body_content, + ), + headers=headers, + params=params, + verify_ssl=node_data.get("ssl_verify", False), + timeouts=HttpTimeOutConfig.model_construct( + connect_timeout=node_data["timeout"]["max_connect_timeout"] or 5, + read_timeout=node_data["timeout"]["max_read_timeout"] or 5, + write_timeout=node_data["timeout"]["max_write_timeout"] or 5, + ), + retry=HttpRetryConfig.model_construct( + enable=node_data["retry_config"]["retry_enabled"], + max_attempts=node_data["retry_config"]["max_retries"], + retry_interval=node_data["retry_config"]["retry_interval"], + ), + error_handle=HttpErrorHandleConfig.model_construct( + method=error_handle_type, + default=default_value, + ) + ).model_dump() + + self.config_validate(node["id"], node["data"]["title"], HttpRequestNodeConfig, result) + return result + + def convert_jinja_render_node_config(self, node: dict) -> dict: + node_data = node["data"] + mapping = [] + for variable in node_data["variables"]: + mapping.append(VariablesMappingConfig.model_construct( + name=variable["variable"], + value=self._process_list_variable_litearl(variable["value_selector"]) + )) + result = JinjaRenderNodeConfig.model_construct( + template=node_data["template"], + mapping=mapping, + ).model_dump() + self.config_validate(node["id"], node["data"]["title"], JinjaRenderNodeConfig, result) + return result + + def convert_knowledge_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append(ExceptionDefineition( + node_id=node["id"], + node_name=node_data["title"], + type=ExceptionType.CONFIG, + detail=f"Please reconfigure the Knowledge Retrieval node.", + )) + result = KnowledgeRetrievalNodeConfig.model_construct( + query=self._process_list_variable_litearl(node_data["query_variable_selector"]), + ).model_dump() + + self.config_validate(node["id"], node["data"]["title"], KnowledgeRetrievalNodeConfig, result) + return result + + def convert_parameter_extractor_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append( + UnknowModelWarning( + node_id=node["id"], + node_name=node_data["title"], + model_name=node_data["model"].get("name") + ) + ) + params = [] + for param in node_data.get("parameters", []): + params.append( + ParamsConfig.model_construct( + name=param["name"], + desc=param["description"], + required=param["required"], + type=param["type"], + ) + ) + result = ParameterExtractorNodeConfig.model_construct( + text=self._process_list_variable_litearl(node_data["query"]), + params=params, + prompt=node_data.get("instruction") + ).model_dump() + + self.config_validate(node["id"], node["data"]["title"], ParameterExtractorNodeConfig, result) + return result + + def convert_variable_aggregator_node_config(self, node: dict) -> dict: + node_data = node["data"] + advanced_settings = node_data.get("advanced_settings", {}) + group_variables = {} + group_type = {} + if not advanced_settings or not advanced_settings["group_enabled"]: + group_variables = [ + self._process_list_variable_litearl(variable) + for variable in node_data["variables"] + ] + group_type["output"] = node_data["output_type"] + else: + for group in advanced_settings["groups"]: + group_variables[group["group_name"]] = [ + self._process_list_variable_litearl(variable) + for variable in group["variables"] + ] + group_type[group["group_name"]] = group["output_type"] + + result = VariableAggregatorNodeConfig.model_construct( + group=advanced_settings.get("group_enabled", False), + group_variables=group_variables, + group_type=group_type, + ).model_dump() + + self.config_validate(node["id"], node["data"]["title"], VariableAggregatorNodeConfig, result) + + return result + + def convert_tool_node_config(self, node: dict) -> dict: + node_data = node["data"] + self.warnings.append(ExceptionDefineition( + node_id=node["id"], + node_name=node_data["title"], + type=ExceptionType.CONFIG, + detail=f"Please reconfigure the tool node.", + )) + return {} diff --git a/api/app/core/workflow/adapters/dify/dify_adapter.py b/api/app/core/workflow/adapters/dify/dify_adapter.py new file mode 100644 index 00000000..895b3d37 --- /dev/null +++ b/api/app/core/workflow/adapters/dify/dify_adapter.py @@ -0,0 +1,261 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/24 16:05 +from typing import Any + +from app.core.logging_config import get_logger +from app.core.workflow.adapters.base_adapter import ( + BasePlatformAdapter, + PlatformMetadata, + PlatformType, + WorkflowParserResult +) +from app.core.workflow.adapters.dify.converter import DifyConverter +from app.core.workflow.adapters.errors import ExceptionDefineition, ExceptionType +from app.core.workflow.nodes.enums import NodeType +from app.schemas.workflow_schema import ( + NodeDefinition, + EdgeDefinition, + VariableDefinition, + TriggerConfig, + ExecutionConfig +) + +logger = get_logger() + + +class DifyAdapter(BasePlatformAdapter, DifyConverter): + NODE_TYPE_MAPPING = { + "start": NodeType.START, + "llm": NodeType.LLM, + "answer": NodeType.END, + "if-else": NodeType.IF_ELSE, + "loop-start": NodeType.CYCLE_START, + "iteration-start": NodeType.CYCLE_START, + "assigner": NodeType.ASSIGNER, + "loop": NodeType.LOOP, + "iteration": NodeType.ITERATION, + "loop-end": NodeType.BREAK, + "code": NodeType.CODE, + "http-request": NodeType.HTTP_REQUEST, + "template-transform": NodeType.JINJARENDER, + "knowledge-retrieval": NodeType.KNOWLEDGE_RETRIEVAL, + "parameter-extractor": NodeType.PARAMETER_EXTRACTOR, + "question-classifier": NodeType.QUESTION_CLASSIFIER, + "variable-aggregator": NodeType.VAR_AGGREGATOR, + "tool": NodeType.TOOL, + "": NodeType.NOTES + } + + def __init__(self, config: dict[str, Any]): + DifyConverter.__init__(self) + BasePlatformAdapter.__init__(self, config) + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.DIFY, + version="0.5.0", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return self.NODE_TYPE_MAPPING.get(platform_node_type, NodeType.UNKNOWN) + + @property + def origin_nodes(self): + return self.config.get("workflow").get("graph").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("graph").get("edges") + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "data" not in node: + return False + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'kind', 'version', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + if self.config.get("app",{}).get("mode") == "workflow": + self.errors.append(ExceptionDefineition( + type=ExceptionType.PLATFORM, + detail="workflow mode is not supported" + )) + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + self._init_node_output_map() + for node in self.origin_nodes: + node = self._convert_node(node) + if node: + self.nodes.append(node) + nodes_id = [node.id for node in self.nodes] + for edge in self.origin_edges: + source = edge["source"] + target = edge["target"] + if source not in nodes_id or target not in nodes_id: + continue + edge = self._convert_edge(edge) + if edge: + self.edges.append(edge) + # + for variable in self.config.get("workflow").get("conversation_variables"): + con_var = self._convert_variable(variable) + if variable: + self.conv_variables.append(con_var) + # + # for variables in config.get("workflow").get("environment_variables"): + # variable = self._convert_variable(variables) + # conv_variables.append(variable) + + trigger = self._convert_trigger({}) + execution_config = self._convert_execution({}) + + return WorkflowParserResult( + success=not self.errors and not self.warnings, + platform=self.get_metadata(), + execution_config=execution_config, + origin_config=self.config, + trigger=trigger, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors + ) + + def _init_node_output_map(self): + for node in self.origin_nodes: + if self.map_node_type(node["data"]["type"]) == NodeType.LLM: + self.node_output_map[f"{node['id']}.text"] = f"{node['id']}.output" + elif self.map_node_type(node["data"]["type"]) == NodeType.KNOWLEDGE_RETRIEVAL: + self.node_output_map[f"{node['id']}.result"] = f"{node['id']}.output" + + def _convert_cycle_node_position(self, node_id: str, position: dict): + for node in self.origin_nodes: + if node["id"] == node_id: + return { + "x": node["position"]["x"] + position["x"], + "y": node["position"]["y"] + position["y"] + } + self.errors.append( + ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node_id, + detail="parent cycle node not found" + ) + ) + raise Exception("parent cycle node not found") + + def _convert_node(self, node: dict[str, Any]) -> NodeDefinition | None: + node_data = node["data"] + try: + return NodeDefinition( + id=node["id"], + type=self.map_node_type(node_data["type"]), + name=node_data.get("title") or "notes", + cycle=node.get("parentId"), + description=None, + config=self._convert_node_config(node), + position={ + "x": node["position"]["x"], + "y": node["position"]["y"] + } if node.get("parentId") is None else self._convert_cycle_node_position( + node["parentId"], + node["position"] + ), + error_handling=None, + cache=None + ) + except Exception as e: + logger.debug(f"convert node error - {e}", exc_info=True) + + def _convert_node_config(self, node: dict): + node_data = node["data"] + node_type = node_data["type"] + try: + converter = self.get_node_convert(node_type) + if node_type not in self.CONFIG_CONVERT_MAP: + self.errors.append(ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node["id"], + node_name=node["data"]["title"], + detail=f"node type {node_type if node_type else 'notes'} is unsupported", + )) + return converter(node) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.NODE, + node_id=node["id"], + node_name=node["data"]["title"], + detail=f"convert node error - {e}", + )) + raise e + + def _convert_edge(self, edge: dict[str, Any]) -> EdgeDefinition | None: + try: + + source = edge["source"] + target = edge["target"] + label = None + if source in self.branch_node_cache: + case_id = edge["sourceHandle"] + if case_id == "false": + label = f'CASE{len(self.branch_node_cache[source])+1}' + else: + label = f'CASE{self.branch_node_cache[source].index(case_id) + 1}' + if source in self.error_branch_node_cache: + case_id = edge["sourceHandle"] + if case_id == "source": + label = "SUCCESS" + else: + label = "ERROR" + return EdgeDefinition( + id=edge["id"], + source=source, + target=target, + label=label, + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.EDGE, + detail=f"convert edge error - {e}", + )) + logger.debug(f"convert edge error - {e}", exc_info=True) + return None + + def _convert_variable(self, variable) -> VariableDefinition | None: + try: + return VariableDefinition( + name=variable["name"], + default=variable["value"], + type=self.variable_type_map(variable["value_type"]), + description=variable.get("description") + ) + except Exception as e: + self.errors.append(ExceptionDefineition( + type=ExceptionType.VARIABLE, + name=variable.get("name"), + detail=f"convert variable error - {e}", + )) + + def _convert_trigger(self, trigger: dict[str, Any]) -> TriggerConfig | None: + pass + + def _convert_execution(self, execution: dict[str, Any]) -> ExecutionConfig: + return ExecutionConfig() + + diff --git a/api/app/core/workflow/adapters/errors.py b/api/app/core/workflow/adapters/errors.py new file mode 100644 index 00000000..c0340a5e --- /dev/null +++ b/api/app/core/workflow/adapters/errors.py @@ -0,0 +1,75 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:29 +from enum import StrEnum + +from pydantic import BaseModel + + +class ExceptionType(StrEnum): + NODE = "node" + EDGE = "edge" + VARIABLE = "variable" + TRIGGER = "trigger" + EXECUTION = "execution" + CONFIG = "config" + PLATFORM = "platform" + UNKNOWN = "unknown" + + +class ExceptionDefineition(BaseModel): + type: ExceptionType + detail: str + + node_id: str | None = None + node_name: str | None = None + + scope: str | None = None + name: str | None = None + + +class UnknowModelWarning(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id, node_name, model_name): + super().__init__( + detail=f"Please specify the model mapping manually for model: {model_name}", + node_id=node_id, + node_name=node_name + ) + + +class UnknowError(ExceptionDefineition): + type: ExceptionType = ExceptionType.UNKNOWN + + def __init__(self, detail: str, **kwargs): + super().__init__(detail=detail, **kwargs) + + +class UnsupportPlatform(ExceptionDefineition): + type: ExceptionType = ExceptionType.PLATFORM + + def __init__(self, platform: str): + super().__init__(detail=f"Unsupport platform {platform}") + + +class UnsupportVariableType(ExceptionDefineition): + type: ExceptionType = ExceptionType.VARIABLE + + def __init__(self, scope, name, var_type: str, **kwargs): + super().__init__(scope=scope, name=name, detail=f"Unsupport variable type:[{var_type}]", **kwargs) + + +class InvalidConfiguration(ExceptionDefineition): + type: ExceptionType = ExceptionType.CONFIG + + def __init__(self): + super().__init__(detail="Invalid workflow configuration format") + + +class UnsupportNodeType(ExceptionDefineition): + type: ExceptionType = ExceptionType.NODE + + def __init__(self, node_id: str, node_type: str): + super().__init__(node_id=node_id, detail=f"Unsupport node Type {node_type}") diff --git a/api/app/core/workflow/adapters/memory_bear/__init__.py b/api/app/core/workflow/adapters/memory_bear/__init__.py new file mode 100644 index 00000000..f314662f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/__init__.py @@ -0,0 +1,4 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/26 11:30 diff --git a/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py new file mode 100644 index 00000000..0e3f459f --- /dev/null +++ b/api/app/core/workflow/adapters/memory_bear/memory_bear_adapter.py @@ -0,0 +1,76 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:11 +from typing import Any + +from app.core.workflow.adapters.base_adapter import ( + PlatformMetadata, + PlatformType, + BasePlatformAdapter, + WorkflowParserResult +) +from app.schemas.workflow_schema import ExecutionConfig + + +class MemoryBearAdapter(BasePlatformAdapter): + NODE_TYPE_MAPPING = {} + + @property + def origin_nodes(self): + return self.config.get("workflow").get("nodes") + + @property + def origin_edges(self): + return self.config.get("workflow").get("edges") + + @property + def origin_variables(self): + return self.config.get("workflow").get("variables") + + def get_metadata(self) -> PlatformMetadata: + return PlatformMetadata( + platform_name=PlatformType.MEMORY_BEAR, + version="0.2.5", + support_node_types=list(self.NODE_TYPE_MAPPING.keys()) + ) + + def map_node_type(self, platform_node_type) -> str: + return platform_node_type + + @staticmethod + def _valid_nodes(node: dict[str, Any]): + if "type" not in node["data"]: + return False + if "id" not in node or "type" not in node: + return False + return True + + def validate_config(self) -> bool: + require_fields = frozenset({'app', 'workflow'}) + if not all(field in self.config for field in require_fields): + return False + + for node in self.origin_nodes: + if not self._valid_nodes(node): + return False + return True + + def parse_workflow(self) -> WorkflowParserResult: + self.nodes = self.origin_nodes + self.edges = self.origin_edges + self.conv_variables = self.origin_variables + + return WorkflowParserResult( + success=True, + platform=self.get_metadata(), + execution_config=ExecutionConfig(), + origin_config=self.config, + trigger=None, + edges=self.edges, + nodes=self.nodes, + variables=self.conv_variables, + warnings=self.warnings, + errors=self.errors, + + ) diff --git a/api/app/core/workflow/adapters/registry.py b/api/app/core/workflow/adapters/registry.py new file mode 100644 index 00000000..10012676 --- /dev/null +++ b/api/app/core/workflow/adapters/registry.py @@ -0,0 +1,34 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:19 +from typing import Any + +from app.core.workflow.adapters import DifyAdapter, MemoryBearAdapter +from app.core.workflow.adapters.base_adapter import BasePlatformAdapter, PlatformType + + +class PlatformAdapterRegistry: + _adapters: dict[str, type[BasePlatformAdapter]] = {} + + @classmethod + def register(cls, platform: str, adapter: type[BasePlatformAdapter]): + cls._adapters[platform] = adapter + + @classmethod + def get_adapter(cls, platform: str, config: dict[str, Any]) -> BasePlatformAdapter: + if platform not in cls._adapters: + raise ValueError(f"Unsupported platform: {platform}") + return cls._adapters.get(platform)(config) + + @classmethod + def list_platforms(cls) -> list[str]: + return list(cls._adapters.keys()) + + @classmethod + def is_supported(cls, platform: str) -> bool: + return platform in cls._adapters + + +PlatformAdapterRegistry.register(PlatformType.MEMORY_BEAR, MemoryBearAdapter) +PlatformAdapterRegistry.register(PlatformType.DIFY, DifyAdapter) diff --git a/api/app/core/workflow/engine/event_stream_handler.py b/api/app/core/workflow/engine/event_stream_handler.py index 5b7d8de2..dc3cd04d 100644 --- a/api/app/core/workflow/engine/event_stream_handler.py +++ b/api/app/core/workflow/engine/event_stream_handler.py @@ -127,7 +127,7 @@ class EventStreamHandler: yield { "event": "message", "data": { - "chunk": data.get("chunk") + "content": data.get("chunk") } } diff --git a/api/app/core/workflow/engine/graph_builder.py b/api/app/core/workflow/engine/graph_builder.py index 7b5c059c..90668ad9 100644 --- a/api/app/core/workflow/engine/graph_builder.py +++ b/api/app/core/workflow/engine/graph_builder.py @@ -292,6 +292,8 @@ class GraphBuilder: """ for node in self.nodes: node_type = node.get("type") + if node_type == NodeType.NOTES: + continue node_id = node.get("id") cycle_node = node.get("cycle") if cycle_node: @@ -320,7 +322,7 @@ class GraphBuilder: # Used later to determine which branch to take based on the node's output # Assumes node output `node..output` matches the edge's label # For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1' - related_edge[idx]['condition'] = f"node.{node_id}.output == '{related_edge[idx]['label']}'" + related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'" if node_instance: # Wrap node's run method to avoid closure issues diff --git a/api/app/core/workflow/engine/stream_output_coordinator.py b/api/app/core/workflow/engine/stream_output_coordinator.py index 5155a76f..c2885ab0 100644 --- a/api/app/core/workflow/engine/stream_output_coordinator.py +++ b/api/app/core/workflow/engine/stream_output_coordinator.py @@ -13,7 +13,7 @@ from app.core.workflow.engine.variable_pool import VariablePool logger = get_logger(__name__) SCOPE_PATTERN = re.compile( - r"\{\{\s*([a-zA-Z_][a-zA-Z0-9_]*)\.[a-zA-Z0-9_]+\s*}}" + r"\{\{\s*([a-zA-Z0-9_]+)\.[a-zA-Z0-9_]+\s*}}" ) @@ -274,7 +274,7 @@ class StreamOutputCoordinator: yield { "event": "message", "data": { - "chunk": final_chunk + "content": final_chunk } } diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 22be08c8..bc88df19 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -73,7 +73,7 @@ class VariableStruct(BaseModel, Generic[T]): instance: The concrete variable object. The actual Python type is represented by the generic parameter ``T`` (e.g. StringVariable, - NumberVariable, ArrayObject[StringVariable]). + NumberVariable, ArrayVariable[StringVariable]). mut: Whether the variable is mutable. """ @@ -152,6 +152,36 @@ class VariablePool: return None return var_instance + def get_instance( + self, + selector: str, + default: Any = None, + strict: bool = True + ): + """Retrieve a variable instance from the variable pool. + + Args: + selector: + Variable selector as a string variable literal (e.g. "{{ sys.message }}"). + default: + The value to return if the variable does not exist. + strict: + If True, raises KeyError when the variable does not exist. + + Returns: + The variable instance object if it exists; otherwise returns `default`. + + Raises: + KeyError: If strict is True and the variable does not exist. + """ + variable_struct = self._get_variable_struct(selector) + if variable_struct is None: + if strict: + raise KeyError(f"{selector} not exist") + return default + + return variable_struct.instance + def get_value( self, selector: str, @@ -273,38 +303,52 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None - def get_all_system_vars(self) -> dict[str, Any]: + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 Returns: 系统变量字典 """ sys_namespace = self.variables.get("sys", {}) + if literal: + return {k: v.instance.to_literal() for k, v in sys_namespace.items()} return {k: v.instance.get_value() for k, v in sys_namespace.items()} - def get_all_conversation_vars(self) -> dict[str, Any]: + def get_all_conversation_vars(self, literal=False) -> dict[str, Any]: """获取所有会话变量 Returns: 会话变量字典 """ conv_namespace = self.variables.get("conv", {}) + if literal: + return {k: v.instance.to_literal() for k, v in conv_namespace.items()} return {k: v.instance.get_value() for k, v in conv_namespace.items()} - def get_all_node_outputs(self) -> dict[str, Any]: + def get_all_node_outputs(self, literal=False) -> dict[str, Any]: """获取所有节点输出(运行时变量) Returns: 节点输出字典,键为节点 ID """ - runtime_vars = { - namespace: { - k: v.instance.get_value() - for k, v in vars_dict.items() + if literal: + runtime_vars = { + namespace: { + k: v.instance.to_literal() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") + } + else: + runtime_vars = { + namespace: { + k: v.instance.get_value() + for k, v in vars_dict.items() + } + for namespace, vars_dict in self.variables.items() + if namespace not in ("sys", "conv") } - for namespace, vars_dict in self.variables.items() - if namespace not in ("sys", "conv") - } return runtime_vars def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None: diff --git a/api/app/core/workflow/executor.py b/api/app/core/workflow/executor.py index 2b554a60..ff979f2b 100644 --- a/api/app/core/workflow/executor.py +++ b/api/app/core/workflow/executor.py @@ -132,24 +132,24 @@ class WorkflowExecutor: start_time = datetime.datetime.now() - # Build the workflow graph - graph = self.build_graph() - - # Initialize the variable pool with input data - await self.variable_initializer.initialize( - variable_pool=self.variable_pool, - input_data=input_data, - execution_context=self.execution_context - ) - initial_state = self.state_manager.create_initial_state( - workflow_config=self.workflow_config, - input_data=input_data, - execution_context=self.execution_context, - start_node_id=self.start_node_id - ) - # Execute the workflow try: + # Build the workflow graph + graph = self.build_graph() + + # Initialize the variable pool with input data + await self.variable_initializer.initialize( + variable_pool=self.variable_pool, + input_data=input_data, + execution_context=self.execution_context + ) + initial_state = self.state_manager.create_initial_state( + workflow_config=self.workflow_config, + input_data=input_data, + execution_context=self.execution_context, + start_node_id=self.start_node_id + ) + result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config) # Aggregate output from all End nodes @@ -158,24 +158,42 @@ class WorkflowExecutor: full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False) # Append messages for user and assistant - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "assistant", - "content": full_content - } - ] - ) + if input_data.get("files"): + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "user", + "content": input_data.get("files") + }, + { + "role": "assistant", + "content": full_content + } + ] + ) + else: + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "assistant", + "content": full_content + } + ] + ) # Calculate elapsed time end_time = datetime.datetime.now() elapsed_time = (end_time - start_time).total_seconds() logger.info( - f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}s") + f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms") return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content) @@ -231,23 +249,23 @@ class WorkflowExecutor: } } - # Build the workflow graph in streaming mode - graph = self.build_graph(stream=True) - - # Initialize the variable pool and system variables - await self.variable_initializer.initialize( - variable_pool=self.variable_pool, - input_data=input_data, - execution_context=self.execution_context - ) - initial_state = self.state_manager.create_initial_state( - workflow_config=self.workflow_config, - input_data=input_data, - execution_context=self.execution_context, - start_node_id=self.start_node_id - ) - try: + # Build the workflow graph in streaming mode + graph = self.build_graph(stream=True) + + # Initialize the variable pool and system variables + await self.variable_initializer.initialize( + variable_pool=self.variable_pool, + input_data=input_data, + execution_context=self.execution_context + ) + initial_state = self.state_manager.create_initial_state( + workflow_config=self.workflow_config, + input_data=input_data, + execution_context=self.execution_context, + start_node_id=self.start_node_id + ) + full_content = '' self.stream_coordinator.update_scope_activation("sys") @@ -272,7 +290,7 @@ class WorkflowExecutor: event_type = data.get("type", "node_chunk") # "message" or "node_chunk" if event_type == "node_chunk": async for msg_event in self.event_handler.handle_node_chunk_event(data): - full_content += msg_event["data"]["chunk"] + full_content += msg_event["data"]["content"] yield msg_event elif event_type == "node_error": @@ -295,12 +313,12 @@ class WorkflowExecutor: self.graph, self.execution_context.checkpoint_config ): - full_content += msg_event["data"]['chunk'] + full_content += msg_event["data"]['content'] yield msg_event # Flush any remaining chunks async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool): - full_content += msg_event["data"]['chunk'] + full_content += msg_event["data"]['content'] yield msg_event result = graph.get_state(self.execution_context.checkpoint_config).values @@ -308,21 +326,39 @@ class WorkflowExecutor: elapsed_time = (end_time - start_time).total_seconds() # Append messages for user and assistant - result["messages"].extend( - [ - { - "role": "user", - "content": input_data.get("message", '') - }, - { - "role": "assistant", - "content": full_content - } - ] - ) + if input_data.get("files"): + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "user", + "content": input_data.get("files") + }, + { + "role": "assistant", + "content": full_content + } + ] + ) + else: + result["messages"].extend( + [ + { + "role": "user", + "content": input_data.get("message", '') + }, + { + "role": "assistant", + "content": full_content + } + ] + ) logger.info( f"Workflow execution completed (streaming), " - f"elapsed: {elapsed_time:.2f}s, execution_id: {self.execution_context.execution_id}" + f"elapsed: {elapsed_time:.2f}ms, execution_id: {self.execution_context.execution_id}" ) yield { diff --git a/api/app/core/workflow/nodes/agent/node.py b/api/app/core/workflow/nodes/agent/node.py index 98d8bb75..8959e27c 100644 --- a/api/app/core/workflow/nodes/agent/node.py +++ b/api/app/core/workflow/nodes/agent/node.py @@ -14,9 +14,9 @@ from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.variable.base_variable import VariableType -from app.db import get_db +from app.db import get_db_context from app.models import AppRelease -from app.services.draft_run_service import DraftRunService +from app.services.draft_run_service import AgentRunService logger = logging.getLogger(__name__) @@ -39,7 +39,7 @@ class AgentNode(BaseNode): def _output_types(self) -> dict[str, VariableType]: return {"output": VariableType.STRING} - def _prepare_agent(self, variable_pool: VariablePool) -> tuple[DraftRunService, AppRelease, str]: + def _prepare_agent(self, variable_pool: VariablePool) -> tuple[AppRelease, str]: """准备 Agent(公共逻辑) Args: @@ -57,17 +57,17 @@ class AgentNode(BaseNode): if not agent_id: raise ValueError(f"节点 {self.node_id} 缺少 agent_id 配置") - db = next(get_db()) - release = db.query(AppRelease).filter( - AppRelease.id == agent_id - ).first() + with get_db_context() as db: + release = db.query(AppRelease).filter( + AppRelease.id == agent_id + ).first() if not release: raise ValueError(f"Agent 不存在: {agent_id}") - draft_service = DraftRunService(db) + - return draft_service, release, message + return release, message async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]: """非流式执行 @@ -79,19 +79,21 @@ class AgentNode(BaseNode): Returns: 状态更新字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(非流式)") - - # 执行 Agent(非流式) - result = await draft_service.run( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ) + with get_db_context() as db: + draft_service = AgentRunService(db) + + # 执行 Agent(非流式) + result = await draft_service.run( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ) response = result.get("response", "") @@ -118,34 +120,35 @@ class AgentNode(BaseNode): Yields: 流式事件字典 """ - draft_service, release, message = self._prepare_agent(variable_pool) + release, message = self._prepare_agent(variable_pool) logger.info(f"节点 {self.node_id} 开始执行 Agent 调用(流式)") # 累积完整响应 full_response = "" - + with get_db_context() as db: + draft_service = AgentRunService(db) # 执行 Agent(流式) - async for chunk in draft_service.run_stream( - agent_config=release.config, - model_config=None, - message=message, - workspace_id=variable_pool.get_value("sys.workspace_id"), - user_id=state.get("user_id"), - variables=variable_pool.get_all_conversation_vars() - ): - # 提取内容 - content = chunk.get("content", "") - full_response += content - - # 流式返回每个 chunk - yield { - "type": "chunk", - "node_id": self.node_id, - "content": content, - "full_content": full_response, - "meta_data": chunk.get("meta_data", {}) - } + async for chunk in draft_service.run_stream( + agent_config=release.config, + model_config=None, + message=message, + workspace_id=variable_pool.get_value("sys.workspace_id"), + user_id=state.get("user_id"), + variables=variable_pool.get_all_conversation_vars() + ): + # 提取内容 + content = chunk.get("content", "") + full_response += content + + # 流式返回每个 chunk + yield { + "type": "chunk", + "node_id": self.node_id, + "content": content, + "full_content": full_response, + "meta_data": chunk.get("meta_data", {}) + } logger.info(f"节点 {self.node_id} Agent 调用完成,输出长度: {len(full_response)}") diff --git a/api/app/core/workflow/nodes/assigner/node.py b/api/app/core/workflow/nodes/assigner/node.py index be51f81d..4c897d5a 100644 --- a/api/app/core/workflow/nodes/assigner/node.py +++ b/api/app/core/workflow/nodes/assigner/node.py @@ -88,6 +88,8 @@ class AssignerNode(BaseNode): await operator.remove_first() case AssignmentOperator.REMOVE_LAST: await operator.remove_last() + case AssignmentOperator.EXTEND: + await operator.extend() case _: raise ValueError(f"Invalid Operator: {assignment.operation}") logger.info(f"Node {self.node_id}: execution completed") diff --git a/api/app/core/workflow/nodes/base_config.py b/api/app/core/workflow/nodes/base_config.py index 973e120d..4ae89376 100644 --- a/api/app/core/workflow/nodes/base_config.py +++ b/api/app/core/workflow/nodes/base_config.py @@ -85,20 +85,20 @@ class BaseNodeConfig(BaseModel): - tags: 节点标签(用于分类和搜索) """ - name: str | None = Field( - default=None, - description="节点名称(显示名称),如果不设置则使用节点 ID" - ) - - description: str | None = Field( - default=None, - description="节点描述,说明节点的作用" - ) - - tags: list[str] = Field( - default_factory=list, - description="节点标签,用于分类和搜索" - ) + # name: str | None = Field( + # default=None, + # description="节点名称(显示名称),如果不设置则使用节点 ID" + # ) + # + # description: str | None = Field( + # default=None, + # description="节点描述,说明节点的作用" + # ) + # + # tags: list[str] = Field( + # default_factory=list, + # description="节点标签,用于分类和搜索" + # ) class Config: """Pydantic 配置""" diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index a01ffbe3..496454ba 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -1,6 +1,7 @@ import asyncio import logging from abc import ABC, abstractmethod +from datetime import datetime from functools import cached_property from typing import Any, AsyncGenerator @@ -10,8 +11,11 @@ from app.core.config import settings from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.enums import BRANCH_NODES -from app.core.workflow.variable.base_variable import VariableType -from app.services.multimodal_service import PROVIDER_STRATEGIES +from app.core.workflow.variable.base_variable import VariableType, FileObject +from app.db import get_db_read +from app.models import ModelConfig, ModelApiKey, LoadBalanceStrategy +from app.schemas import FileInput +from app.services.multimodal_service import MultimodalService logger = logging.getLogger(__name__) @@ -196,7 +200,7 @@ class BaseNode(ABC): timeout=timeout ) - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 # Extract processed outputs using subclass-defined logic. extracted_output = self._extract_output(business_result) @@ -219,7 +223,7 @@ class BaseNode(ABC): } | self.trans_activate(state) except TimeoutError: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error( f"Node {self.node_id} execution timed out ({timeout} seconds)." ) @@ -230,7 +234,7 @@ class BaseNode(ABC): variable_pool, ) except Exception as e: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error( f"Node {self.node_id} execution failed: {e}", exc_info=True, @@ -307,10 +311,10 @@ class BaseNode(ABC): "done": done }) - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.info(f"Node {self.node_id} streaming execution finished, " - f"time elapsed: {elapsed_time:.2f}s, chunks: {chunk_count}") + f"time elapsed: {elapsed_time:.2f}ms, chunks: {chunk_count}") # Extract processed output (call subclass's _extract_output) extracted_output = self._extract_output(final_result) @@ -337,7 +341,7 @@ class BaseNode(ABC): yield state_update | self.trans_activate(state) except TimeoutError: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error(f"Node {self.node_id} execution timed out ({timeout}s)") error_output = self._wrap_error( f"Node execution timed out ({timeout}s)", @@ -347,7 +351,7 @@ class BaseNode(ABC): ) yield error_output except Exception as e: - elapsed_time = time.time() - start_time + elapsed_time = (time.time() - start_time) * 1000 logger.error(f"Node {self.node_id} execution failed: {e}", exc_info=True) error_output = self._wrap_error(str(e), elapsed_time, state, variable_pool) yield error_output @@ -548,9 +552,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars(), + conv_vars=variable_pool.get_all_conversation_vars(literal=True), + node_outputs=variable_pool.get_all_node_outputs(literal=True), + system_vars=variable_pool.get_all_system_vars(literal=True), strict=strict ) @@ -614,16 +618,45 @@ class BaseNode(ABC): return variable_pool.has(selector) @staticmethod - async def process_message(provider, content, enable_file=False) -> dict | str | None: + async def process_message( + provider: str, + is_omni: bool, + content: str | dict | FileObject, + enable_file=False + ) -> list | str | None: + if isinstance(content, dict): + content = FileObject( + type=content.get("type"), + url=content.get("url"), + transfer_method=content.get("transfer_method"), + origin_file_type=content.get("origin_file_type"), + file_id=content.get("file_id"), + is_file=True + ) if isinstance(content, str): if enable_file: - return {"text": content} + return [{"type": "text", "text": content}] return content - elif isinstance(content, dict): - trans_tool = PROVIDER_STRATEGIES[provider]() - result = await trans_tool.format_image(content["url"]) - return result - raise TypeError('Unexpect input value type') + + elif isinstance(content, FileObject): + if content.content_cache.get(provider): + return content.content_cache[provider] + with get_db_read() as db: + multimodel_service = MultimodalService(db, provider, is_omni=is_omni) + message = await multimodel_service.process_files( + [FileInput.model_construct( + type=content.type, + url=content.url, + transfer_method=content.transfer_method, + file_type=content.origin_file_type, + upload_file_id=content.file_id + )] + ) + if message: + content.content_cache[provider] = message + return message + return None + raise TypeError(f'Unexpect input value type - {type(content)}') @staticmethod def process_model_output(content) -> str: @@ -639,3 +672,12 @@ class BaseNode(ABC): elif isinstance(content, str): return content return result + + @staticmethod + def model_balance(model_config: ModelConfig) -> ModelApiKey: + api_keys = [key for key in model_config.api_keys if key.is_active] + if not api_keys: + raise ValueError("No active API keys available for model") + if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: + return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) + return api_keys[0] diff --git a/api/app/core/workflow/nodes/cycle_graph/iteration.py b/api/app/core/workflow/nodes/cycle_graph/iteration.py index e4026f2d..cf7ac976 100644 --- a/api/app/core/workflow/nodes/cycle_graph/iteration.py +++ b/api/app/core/workflow/nodes/cycle_graph/iteration.py @@ -91,8 +91,8 @@ class IterationRuntime: return loopstate def merge_conv_vars(self): - self.variable_pool.get_all_conversation_vars().update( - self.child_variable_pool.get_all_conversation_vars() + self.variable_pool.variables["conv"].update( + self.child_variable_pool.variables["conv"] ) async def run_task(self, item, idx): diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index cebadfdc..d3ada1ec 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -156,7 +156,7 @@ class LoopRuntime: def merge_conv_vars(self, loopstate): self.variable_pool.variables["conv"].update( - self.child_variable_pool.variables.get("conv", {}) + self.child_variable_pool.variables["conv"] ) loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False) loopstate["node_outputs"][self.node_id] = loop_vars diff --git a/api/app/core/workflow/nodes/cycle_graph/node.py b/api/app/core/workflow/nodes/cycle_graph/node.py index f2912e2c..71e0dbdb 100644 --- a/api/app/core/workflow/nodes/cycle_graph/node.py +++ b/api/app/core/workflow/nodes/cycle_graph/node.py @@ -66,7 +66,7 @@ class CycleGraphNode(BaseNode): if config.flatten: outputs['output'] = config.output_type else: - outputs['output'] = VariableType.ARRAY_STRING + outputs['output'] = VariableType.NESTED_ARRAY else: outputs['output'] = VariableType(f"array[{config.output_type}]") return outputs diff --git a/api/app/core/workflow/nodes/end/config.py b/api/app/core/workflow/nodes/end/config.py index f534dfb5..5c2a6c2a 100644 --- a/api/app/core/workflow/nodes/end/config.py +++ b/api/app/core/workflow/nodes/end/config.py @@ -17,17 +17,17 @@ class EndNodeConfig(BaseNodeConfig): description="输出模板,支持引用前置节点的输出,如:{{ llm_qa.output }}" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="output", - type=VariableType.STRING, - description="工作流的最终输出" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="output", + # type=VariableType.STRING, + # description="工作流的最终输出" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/core/workflow/nodes/enums.py b/api/app/core/workflow/nodes/enums.py index 6ad1c6a8..43ab593b 100644 --- a/api/app/core/workflow/nodes/enums.py +++ b/api/app/core/workflow/nodes/enums.py @@ -24,6 +24,9 @@ class NodeType(StrEnum): MEMORY_READ = "memory-read" MEMORY_WRITE = "memory-write" + UNKNOWN = "unknown" + NOTES = "notes" + BRANCH_NODES = [NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER] @@ -61,6 +64,7 @@ class AssignmentOperator(StrEnum): APPEND = "append" REMOVE_LAST = "remove_last" REMOVE_FIRST = "remove_first" + EXTEND = "extend" class HttpRequestMethod(StrEnum): diff --git a/api/app/core/workflow/nodes/http_request/node.py b/api/app/core/workflow/nodes/http_request/node.py index cdb34b57..e6c00eff 100644 --- a/api/app/core/workflow/nodes/http_request/node.py +++ b/api/app/core/workflow/nodes/http_request/node.py @@ -1,6 +1,7 @@ import asyncio import json import logging +import uuid from typing import Any, Callable, Coroutine import httpx @@ -13,6 +14,7 @@ from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.enums import HttpRequestMethod, HttpErrorHandle, HttpAuthType, HttpContentType from app.core.workflow.nodes.http_request.config import HttpRequestNodeConfig, HttpRequestNodeOutput from app.core.workflow.variable.base_variable import VariableType +from app.core.workflow.variable.variable_objects import FileVariable, ArrayVariable logger = logging.getLogger(__file__) @@ -115,7 +117,7 @@ class HttpRequestNode(BaseNode): params[self._render_template(key, variable_pool)] = self._render_template(value, variable_pool) return params - def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: + async def _build_content(self, variable_pool: VariablePool) -> dict[str, Any]: """ Build HTTP request body arguments for httpx request methods. @@ -135,16 +137,35 @@ class HttpRequestNode(BaseNode): )) case HttpContentType.FROM_DATA: data = {} + content["files"] = {} for item in self.typed_config.body.data: if item.type == "text": - data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, variable_pool) + data[self._render_template(item.key, variable_pool)] = self._render_template(item.value, + variable_pool) elif item.type == "file": - # TODO: File support (Feature) - pass + content["files"][self._render_template(item.key, variable_pool)] = ( + uuid.uuid4().hex, + await variable_pool.get_instance(item.value).get_content() + ) content["data"] = data case HttpContentType.BINARY: - # TODO: File support (Feature) - pass + content["files"] = [] + file_instence = variable_pool.get_instance(self.typed_config.body.data) + if isinstance(file_instence, ArrayVariable): + for v in file_instence.value: + if isinstance(v, FileVariable): + content["files"].append( + ( + "files", (uuid.uuid4().hex, await v.get_content()) + ) + ) + elif isinstance(file_instence, FileVariable): + content["files"].append( + ( + "file", (uuid.uuid4().hex, await file_instence.get_content()) + ) + ) + case HttpContentType.WWW_FORM: content["data"] = json.loads(self._render_template( json.dumps(self.typed_config.body.data), variable_pool @@ -207,7 +228,7 @@ class HttpRequestNode(BaseNode): request_func = self._get_client_method(client) resp = await request_func( url=self._render_template(self.typed_config.url, variable_pool), - **self._build_content(variable_pool) + **(await self._build_content(variable_pool)) ) resp.raise_for_status() logger.info(f"Node {self.node_id}: HTTP request succeeded") @@ -236,5 +257,5 @@ class HttpRequestNode(BaseNode): logger.warning( f"Node {self.node_id}: HTTP request failed, switching to error handling branch" ) - return "ERROR" + return {"output": "ERROR"} raise RuntimeError("http request failed") diff --git a/api/app/core/workflow/nodes/knowledge/config.py b/api/app/core/workflow/nodes/knowledge/config.py index 5475636e..56afe004 100644 --- a/api/app/core/workflow/nodes/knowledge/config.py +++ b/api/app/core/workflow/nodes/knowledge/config.py @@ -40,7 +40,7 @@ class KnowledgeRetrievalNodeConfig(BaseNodeConfig): ) knowledge_bases: list[KnowledgeBaseConfig] = Field( - ..., + default_factory=list, description="Knowledge base config" ) diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 17f55319..696298eb 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -180,6 +180,8 @@ class KnowledgeRetrievalNode(BaseNode): RuntimeError: If no valid knowledge base is found or access is denied. """ self.typed_config = KnowledgeRetrievalNodeConfig(**self.config) + if not self.typed_config.knowledge_bases: + return [] query = self._render_template(self.typed_config.query, variable_pool) with get_db_read() as db: knowledge_bases = self.typed_config.knowledge_bases diff --git a/api/app/core/workflow/nodes/llm/node.py b/api/app/core/workflow/nodes/llm/node.py index fdd5df58..186c204f 100644 --- a/api/app/core/workflow/nodes/llm/node.py +++ b/api/app/core/workflow/nodes/llm/node.py @@ -112,11 +112,12 @@ class LLMNode(BaseNode): raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) # 在 Session 关闭前提取所有需要的数据 - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type # 4. 创建 LLM 实例(使用已提取的数据) @@ -129,7 +130,8 @@ class LLMNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, - extra_params=extra_params + extra_params=extra_params, + is_omni=is_omni ), type=ModelType(model_type) ) @@ -151,39 +153,53 @@ class LLMNode(BaseNode): if role == "system": messages.append({ "role": "system", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["user", "human"]: messages.append({ "role": "user", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) elif role in ["ai", "assistant"]: messages.append({ "role": "assistant", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) else: logger.warning(f"未知的消息角色: {role},默认使用 user") messages.append({ "role": "user", - "content": content + "content": await self.process_message(provider, is_omni, content, self.typed_config.vision) }) if self.typed_config.vision_input and self.typed_config.vision: file_content = [] - files = variable_pool.get_value(self.typed_config.vision_input) - for file in files: - content = await self.process_message(provider, file, self.typed_config.vision) + files = variable_pool.get_instance(self.typed_config.vision_input) + for file in files.value: + content = await self.process_message(provider, is_omni, file.value, self.typed_config.vision) if content: - file_content.append(content) + file_content.extend(content) if messages and messages[-1]["role"] == 'user': - messages[-1]['content'] = [messages[-1]["content"]] + file_content + messages[-1]['content'] = messages[-1]["content"] + file_content else: messages.append({"role": "user", "content": file_content}) if self.typed_config.memory.enable: - messages = messages[:-1] + state["messages"][-self.typed_config.memory.window_size:] + messages[-1:] + history_message = [] + for message in state["messages"][-self.typed_config.memory.window_size:]: + if isinstance(message["content"], list): + file_content = [] + for file in message["content"]: + content = await self.process_message(provider, is_omni, file, self.typed_config.vision) + if content: + file_content.extend(content) + history_message.append( + {"role": message["role"], "content": file_content} + ) + else: + message["content"] = await self.process_message(provider, is_omni, message["content"], self.typed_config.vision) + history_message.append(message) + messages = messages[:-1] + history_message + messages[-1:] self.messages = messages else: # 使用简单的 prompt 格式(向后兼容) diff --git a/api/app/core/workflow/nodes/node_factory.py b/api/app/core/workflow/nodes/node_factory.py index 00120ca0..864e3251 100644 --- a/api/app/core/workflow/nodes/node_factory.py +++ b/api/app/core/workflow/nodes/node_factory.py @@ -123,10 +123,10 @@ class NodeFactory: # 获取节点类 node_class = cls._node_types.get(node_type) if not node_class: - raise ValueError(f"不支持的节点类型: {node_type}") + raise ValueError(f"Unsupported node type: {node_type}") # 创建节点实例 - logger.debug(f"创建节点: {node_config.get('id')} (type={node_type})") + logger.debug(f"create node instance: {node_config.get('id')} (type={node_type})") return node_class(node_config, workflow_config) @classmethod diff --git a/api/app/core/workflow/nodes/parameter_extractor/node.py b/api/app/core/workflow/nodes/parameter_extractor/node.py index 4811c118..700ed85f 100644 --- a/api/app/core/workflow/nodes/parameter_extractor/node.py +++ b/api/app/core/workflow/nodes/parameter_extractor/node.py @@ -95,11 +95,12 @@ class ParameterExtractorNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("Model configuration is missing API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key api_base = api_config.api_base + is_omni = api_config.is_omni model_type = config.type llm = RedBearLLM( @@ -108,6 +109,7 @@ class ParameterExtractorNode(BaseNode): provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/question_classifier/node.py b/api/app/core/workflow/nodes/question_classifier/node.py index e2fd97ae..5cebd886 100644 --- a/api/app/core/workflow/nodes/question_classifier/node.py +++ b/api/app/core/workflow/nodes/question_classifier/node.py @@ -56,11 +56,12 @@ class QuestionClassifierNode(BaseNode): if not config.api_keys or len(config.api_keys) == 0: raise BusinessException("模型配置缺少 API Key", BizCode.INVALID_PARAMETER) - api_config = config.api_keys[0] + api_config = self.model_balance(config) model_name = api_config.model_name provider = api_config.provider api_key = api_config.api_key base_url = api_config.api_base + is_omni = api_config.is_omni model_type = config.type return RedBearLLM( @@ -69,6 +70,7 @@ class QuestionClassifierNode(BaseNode): provider=provider, api_key=api_key, base_url=base_url, + is_omni=is_omni ), type=ModelType(model_type) ) diff --git a/api/app/core/workflow/nodes/start/config.py b/api/app/core/workflow/nodes/start/config.py index 98390bf7..3f795f1e 100644 --- a/api/app/core/workflow/nodes/start/config.py +++ b/api/app/core/workflow/nodes/start/config.py @@ -3,7 +3,6 @@ from pydantic import Field from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition -from app.core.workflow.variable.base_variable import VariableType class StartNodeConfig(BaseNodeConfig): @@ -21,42 +20,42 @@ class StartNodeConfig(BaseNodeConfig): description="自定义输入变量列表,这些变量会作为 Start 节点的输出" ) - # 输出变量定义 - output_variables: list[VariableDefinition] = Field( - default_factory=lambda: [ - VariableDefinition( - name="message", - type=VariableType.STRING, - description="用户输入的消息" - ), - VariableDefinition( - name="conversation_vars", - type=VariableType.OBJECT, - description="会话级变量" - ), - VariableDefinition( - name="execution_id", - type=VariableType.STRING, - description="执行 ID" - ), - VariableDefinition( - name="conversation_id", - type=VariableType.STRING, - description="会话 ID" - ), - VariableDefinition( - name="workspace_id", - type=VariableType.STRING, - description="工作空间 ID" - ), - VariableDefinition( - name="user_id", - type=VariableType.STRING, - description="用户 ID" - ) - ], - description="输出变量定义(自动生成,通常不需要修改)" - ) + # # 输出变量定义 + # output_variables: list[VariableDefinition] = Field( + # default_factory=lambda: [ + # VariableDefinition( + # name="message", + # type=VariableType.STRING, + # description="用户输入的消息" + # ), + # VariableDefinition( + # name="conversation_vars", + # type=VariableType.OBJECT, + # description="会话级变量" + # ), + # VariableDefinition( + # name="execution_id", + # type=VariableType.STRING, + # description="执行 ID" + # ), + # VariableDefinition( + # name="conversation_id", + # type=VariableType.STRING, + # description="会话 ID" + # ), + # VariableDefinition( + # name="workspace_id", + # type=VariableType.STRING, + # description="工作空间 ID" + # ), + # VariableDefinition( + # name="user_id", + # type=VariableType.STRING, + # description="用户 ID" + # ) + # ], + # description="输出变量定义(自动生成,通常不需要修改)" + # ) class Config: json_schema_extra = { diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 26f0c41c..4bc5fc4c 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -12,9 +12,20 @@ class ExpressionEvaluator: # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} - - @staticmethod + + @classmethod + def normalize_template(cls, template: str) -> str: + pattern = re.compile( + r"\{\{\s*(\d+)\.(\w+)\s*}}" + ) + return pattern.sub( + r'{{ node["\1"].\2 }}', + template + ) + + @classmethod def evaluate( + cls, expression: str, conv_vars: dict[str, Any], node_outputs: dict[str, Any], @@ -37,6 +48,7 @@ class ExpressionEvaluator: """ # Remove Jinja2-style brackets if present expression = expression.strip() + expression = cls.normalize_template(expression) pattern = r"\{\{\s*(.*?)\s*\}\}" expression = re.sub(pattern, r"\1", expression).strip() diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 236e0840..424fdf20 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -5,6 +5,7 @@ """ import logging +import re from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined @@ -39,6 +40,16 @@ class TemplateRenderer: autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML ) + @staticmethod + def normalize_template(template: str) -> str: + pattern = re.compile( + r"\{\{\s*(\d+)\.(\w+)\s*}}" + ) + return pattern.sub( + r'{{ node["\1"].\2 }}', + template + ) + def render( self, template: str, @@ -95,7 +106,7 @@ class TemplateRenderer: context.update(conv_vars) context["nodes"] = node_outputs or {} # 旧语法兼容 - + template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) diff --git a/api/app/core/workflow/validator.py b/api/app/core/workflow/validator.py index 47256b75..3b6e9036 100644 --- a/api/app/core/workflow/validator.py +++ b/api/app/core/workflow/validator.py @@ -138,7 +138,7 @@ class WorkflowValidator: errors.append("工作流必须至少有一个 end 节点") # 3. 验证节点 ID 唯一性 - node_ids = [n.get("id") for n in nodes] + node_ids = [n.get("id") for n in nodes if n.get("type") != NodeType.NOTES] if len(node_ids) != len(set(node_ids)): duplicates = [nid for nid in node_ids if node_ids.count(nid) > 1] errors.append(f"节点 ID 必须唯一,重复的 ID: {set(duplicates)}") diff --git a/api/app/core/workflow/variable/base_variable.py b/api/app/core/workflow/variable/base_variable.py index 19cbdc74..dd821ea7 100644 --- a/api/app/core/workflow/variable/base_variable.py +++ b/api/app/core/workflow/variable/base_variable.py @@ -2,7 +2,7 @@ from enum import StrEnum from abc import abstractmethod, ABC from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from app.schemas import FileType @@ -45,7 +45,7 @@ class VariableType(StrEnum): return cls.NUMBER elif isinstance(var, bool): return cls.BOOLEAN - elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('__file')): + elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')): return cls.FILE elif isinstance(var, dict): return cls.OBJECT @@ -109,7 +109,13 @@ def DEFAULT_VALUE(var_type: VariableType) -> Any: class FileObject(BaseModel): type: FileType url: str - __file: bool + transfer_method: str + origin_file_type: str + file_id: str | None + + content_cache: dict = Field(default_factory=dict) + + is_file: bool class BaseVariable(ABC): diff --git a/api/app/core/workflow/variable/variable_objects.py b/api/app/core/workflow/variable/variable_objects.py index 7a39835c..63437fd9 100644 --- a/api/app/core/workflow/variable/variable_objects.py +++ b/api/app/core/workflow/variable/variable_objects.py @@ -1,8 +1,10 @@ from typing import Any, TypeVar, Type, Generic +import httpx from deprecated import deprecated from app.core.workflow.variable.base_variable import BaseVariable, VariableType, FileObject, FileType +from app.core.config import settings T = TypeVar("T", bound=BaseVariable) @@ -61,13 +63,16 @@ class FileVariable(BaseVariable): def valid_value(self, value) -> FileObject: if isinstance(value, dict): - if not value.get("__file"): + if not value.get("is_file"): raise TypeError(f"Value must be a FileObject - {type(value)}:{value}") return FileObject( **{ "type": str(value.get('type')), + "transfer_method": value.get("transfer_method"), "url": value.get('url'), - "__file": True + "file_id": value.get("file_id"), + "origin_file_type": value.get("origin_file_type"), + "is_file": True } ) if isinstance(value, FileObject): @@ -80,8 +85,23 @@ class FileVariable(BaseVariable): def get_value(self) -> Any: return self.value.model_dump() + async def get_content(self): + total_bytes = 0 + chunks = [] -class ArrayObject(BaseVariable, Generic[T]): + async with httpx.AsyncClient() as client: + async with client.stream("GET", self.value.url) as resp: + resp.raise_for_status() + async for chunk in resp.aiter_bytes(8192): + total_bytes += len(chunk) + if total_bytes > settings.MAX_FILE_SIZE: + raise ValueError(f"File too large: {total_bytes} bytes") + chunks.append(chunk) + + return b"".join(chunks) + + +class ArrayVariable(BaseVariable, Generic[T]): type = 'array' def __init__(self, child_type: Type[T], value: list[Any]): @@ -108,7 +128,7 @@ class ArrayObject(BaseVariable, Generic[T]): return [v.get_value() for v in self.value] -class NestedArrayObject(BaseVariable): +class NestedArrayVariable(BaseVariable): type = 'array_nest' def valid_value(self, value: list[T]) -> list[T]: @@ -116,23 +136,23 @@ class NestedArrayObject(BaseVariable): raise TypeError(f"Value must be a list - {type(value)}:{value}") final_value = [] for v in value: - if not isinstance(v, ArrayObject): + if not isinstance(v, list): raise TypeError("All elements must be of type list") - final_value.append(v) + final_value.append(make_array(AnyVariable, v)) return final_value def to_literal(self) -> str: - return "\n".join(["\n".join([item.to_literal() for item in row]) for row in self.value]) + return "\n".join(["\n".join([str(item) for item in row.get_value()]) for row in self.value]) def get_value(self) -> Any: - return [[item.get_value() for item in row] for row in self.value] + return [[item for item in row.get_value()] for row in self.value] @deprecated( reason="Using arbitrary-type values may cause unexpected errors; please switch to strongly-typed values.", category=RuntimeWarning ) -class AnyObject(BaseVariable): +class AnyVariable(BaseVariable): type = 'any' def valid_value(self, value: Any) -> Any: @@ -142,10 +162,10 @@ class AnyObject(BaseVariable): return str(self.value) -def make_array(child_type: Type[T], value: list[Any]) -> ArrayObject[T]: - """简化 ArrayObject 创建,不需要重复写类型""" +def make_array(child_type: Type[T], value: list[Any]) -> ArrayVariable[T]: + """简化 ArrayVariable 创建,不需要重复写类型""" - return ArrayObject(child_type, value) + return ArrayVariable(child_type, value) def create_variable_instance(var_type: VariableType, value: Any) -> T: @@ -168,7 +188,9 @@ def create_variable_instance(var_type: VariableType, value: Any) -> T: return make_array(DictVariable, value) case VariableType.ARRAY_FILE: return make_array(FileVariable, value) + case VariableType.NESTED_ARRAY: + return NestedArrayVariable(value) case VariableType.ANY: - return AnyObject(value) + return AnyVariable(value) case _: raise TypeError(f"Invalid type - {var_type}") diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 3e378f17..23fafcef 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -2,7 +2,7 @@ import datetime import uuid from enum import StrEnum -from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table +from sqlalchemy import Column, String, Boolean, DateTime, Text, ForeignKey, Enum as SQLEnum, UniqueConstraint, Integer, ARRAY, Table, text from sqlalchemy.dialects.postgresql import UUID, JSON from sqlalchemy.orm import relationship from sqlalchemy.sql import func @@ -78,6 +78,9 @@ class ModelConfig(BaseModel): description = Column(String, comment="模型描述") # 模型配置参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") config = Column(JSON, comment="模型配置参数") # - temperature : 控制生成文本的随机性。值越高,输出越随机、越有创造性;值越低,输出越确定、越保守。 # - top_p : 一种替代 temperature 的采样方法,控制模型从概率最高的词中选择的范围。 @@ -118,6 +121,11 @@ class ModelApiKey(BaseModel): api_key = Column(String, nullable=False, comment="API密钥") api_base = Column(String, comment="API基础URL") + # 模型能力参数 + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") + # 配置参数 config = Column(JSON, comment="API Key特定配置") @@ -155,6 +163,9 @@ class ModelBase(Base): tags = Column(ARRAY(String), default=list, nullable=False, comment="模型标签(如['聊天', '创作'])") add_count = Column(Integer, default=0, nullable=False, comment="模型被用户添加的次数") created_at = Column(DateTime, default=datetime.datetime.now, comment="创建时间", server_default=func.now()) + capability = Column(ARRAY(String), default=list, nullable=False, server_default=text("'{}'::varchar[]"), + comment="模型能力列表(如['vision', 'audio', 'video'])") + is_omni = Column(Boolean, default=False, nullable=False, server_default="false", comment="是否为Omni模型(使用特殊API调用)") # 关联关系 configs = relationship("ModelConfig", back_populates="model_base", cascade="all, delete-orphan") diff --git a/api/app/models/ontology_class.py b/api/app/models/ontology_class.py index 528d934e..a8468090 100644 --- a/api/app/models/ontology_class.py +++ b/api/app/models/ontology_class.py @@ -9,7 +9,7 @@ Classes: import datetime import uuid -from sqlalchemy import Column, String, DateTime, Text, ForeignKey +from sqlalchemy import Column, String, DateTime, Text, ForeignKey, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -25,6 +25,9 @@ class OntologyClass(Base): # 类型信息 class_name = Column(String(200), nullable=False, comment="类型名称") class_description = Column(Text, nullable=True, comment="类型描述") + + # 系统默认标识 + is_system_default = Column(Boolean, default=False, nullable=False, comment="是否为系统默认类型") # 外键:关联到本体场景 scene_id = Column(UUID(as_uuid=True), ForeignKey("ontology_scene.scene_id", ondelete="CASCADE"), nullable=False, index=True, comment="所属场景ID") diff --git a/api/app/models/ontology_scene.py b/api/app/models/ontology_scene.py index 350bfdd6..3ce42cad 100644 --- a/api/app/models/ontology_scene.py +++ b/api/app/models/ontology_scene.py @@ -9,7 +9,7 @@ Classes: import datetime import uuid -from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint +from sqlalchemy import Column, String, DateTime, Integer, Text, ForeignKey, UniqueConstraint, Boolean from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import relationship from app.db import Base @@ -28,6 +28,9 @@ class OntologyScene(Base): # 场景信息 scene_name = Column(String(200), nullable=False, comment="场景名称") scene_description = Column(Text, nullable=True, comment="场景描述") + + # 系统默认标识 + is_system_default = Column(Boolean, default=False, nullable=False, index=True, comment="是否为系统默认场景") # 外键:关联到工作空间 workspace_id = Column(UUID(as_uuid=True), ForeignKey("workspaces.id", ondelete="CASCADE"), nullable=False, index=True, comment="所属工作空间ID") diff --git a/api/app/repositories/knowledge_repository.py b/api/app/repositories/knowledge_repository.py index 681d1c10..e3832214 100644 --- a/api/app/repositories/knowledge_repository.py +++ b/api/app/repositories/knowledge_repository.py @@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int except Exception as e: db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}") raise + + +def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和 + """ + db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}") + + try: + from sqlalchemy import func + result = db.query(func.sum(Knowledge.chunk_num)).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id == "Memory" + ).scalar() + + total = result if result is not None else 0 + db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}") + return total + except Exception as e: + db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}") + raise + + +def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int: + """ + 根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量 + """ + db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}") + + try: + count = db.query(Knowledge).filter( + Knowledge.workspace_id == workspace_id, + Knowledge.status == 1, + Knowledge.permission_id != "Memory" + ).count() + + db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}") + return count + except Exception as e: + db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}") + raise + diff --git a/api/app/repositories/memory_config_repository.py b/api/app/repositories/memory_config_repository.py index 2dae51ef..22f13449 100644 --- a/api/app/repositories/memory_config_repository.py +++ b/api/app/repositories/memory_config_repository.py @@ -233,6 +233,7 @@ class MemoryConfigRepository: config_desc=params.config_desc, workspace_id=params.workspace_id, scene_id=params.scene_id, + pruning_scene=params.pruning_scene, llm_id=params.llm_id, embedding_id=params.embedding_id, rerank_id=params.rerank_id, diff --git a/api/app/repositories/model_repository.py b/api/app/repositories/model_repository.py index 2c513e82..f49227d3 100644 --- a/api/app/repositories/model_repository.py +++ b/api/app/repositories/model_repository.py @@ -428,19 +428,17 @@ class ModelConfigRepository: try: # 查询ModelConfig关联的ModelApiKey,筛选出匹配的model_config_id - model_config_ids = db.query(ModelConfig.id).join( - ModelBase, ModelConfig.model_id == ModelBase.id - ).filter( + model_config_ids = db.query(ModelConfig.id).filter( and_( or_( ModelConfig.tenant_id == tenant_id, ModelConfig.is_public ), - ModelBase.provider == provider, + ModelConfig.provider == provider, ModelConfig.is_active, ~ModelConfig.is_composite ) - ).distinct().all() + ).all() db_logger.debug(f"查询成功: 数量={len(model_config_ids)}") return [row[0] for row in model_config_ids] diff --git a/api/app/repositories/ontology_scene_repository.py b/api/app/repositories/ontology_scene_repository.py index 141b5d1c..0b357e41 100644 --- a/api/app/repositories/ontology_scene_repository.py +++ b/api/app/repositories/ontology_scene_repository.py @@ -374,7 +374,7 @@ class OntologySceneRepository: count = self.db.query(OntologyScene).filter( OntologyScene.scene_id == scene_id, - OntologyScene.workspace_id == workspace_id + (OntologyScene.workspace_id == workspace_id) | (OntologyScene.is_system_default == True) ).count() is_owner = count > 0 diff --git a/api/app/schemas/api_key_schema.py b/api/app/schemas/api_key_schema.py index d19cf061..c7ca1e55 100644 --- a/api/app/schemas/api_key_schema.py +++ b/api/app/schemas/api_key_schema.py @@ -15,7 +15,7 @@ class ApiKeyCreate(BaseModel): type: ApiKeyType = Field(..., description="API Key 类型") scopes: List[str] = Field(default_factory=list, description="权限范围列表") resource_id: Optional[uuid.UUID] = Field(None, description="关联资源ID") - rate_limit: Optional[int] = Field(10, ge=1, le=1000, description="QPS限制(请求/秒)") + rate_limit: Optional[int] = Field(100, ge=1, le=1000, description="QPS限制(请求/秒)") daily_request_limit: Optional[int] = Field(10000, description="日请求限制", ge=1) quota_limit: Optional[int] = Field(None, description="配额限制(总请求数)", ge=1) expires_at: Optional[datetime.datetime] = Field(None, description="过期时间") @@ -155,8 +155,7 @@ class ApiKey(BaseModel): return datetime.datetime.now() > self.expires_at @field_serializer('expires_at', 'last_used_at', 'created_at', 'updated_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -171,8 +170,7 @@ class ApiKeyStats(BaseModel): avg_response_time: Optional[float] = Field(None, description="平均响应时间(毫秒)") @field_serializer('last_used_at') - @classmethod - def serialize_datetime(cls, v: Optional[datetime.datetime]) -> Optional[int]: + def serialize_datetime(self, v: Optional[datetime.datetime]) -> Optional[int]: """将datetime转换为时间戳""" return datetime_to_timestamp(v) @@ -219,7 +217,6 @@ class ApiKeyLog(BaseModel): created_at: datetime.datetime @field_serializer('created_at') - @classmethod - def serialize_datetime(cls, v: datetime.datetime) -> int: + def serialize_datetime(self, v: datetime.datetime) -> int: """将datetime转换为时间戳""" return datetime_to_timestamp(v) diff --git a/api/app/schemas/app_schema.py b/api/app/schemas/app_schema.py index 8cf81b92..f073a200 100644 --- a/api/app/schemas/app_schema.py +++ b/api/app/schemas/app_schema.py @@ -5,6 +5,8 @@ from enum import Enum, StrEnum from pydantic import BaseModel, Field, ConfigDict, field_serializer, field_validator +from app.schemas.workflow_schema import WorkflowConfigCreate + # ---------- Multimodal File Support ---------- @@ -19,8 +21,14 @@ class FileType(StrEnum): def trans(cls, value: str) -> 'FileType': if value.startswith("image"): return cls.IMAGE - # TODO: other file type support - raise RuntimeError("Unsupport file type") + elif value.startswith("document"): + return cls.DOCUMENT + elif value.startswith("audio"): + return cls.AUDIO + elif value.startswith("video"): + return cls.VIDEO + else: + raise RuntimeError("Unsupport file type") class TransferMethod(str, Enum): @@ -35,6 +43,12 @@ class FileInput(BaseModel): transfer_method: TransferMethod = Field(..., description="传输方式: local_file/remote_url") upload_file_id: Optional[uuid.UUID] = Field(None, description="已上传文件ID(local_file时必填)") url: Optional[str] = Field(None, description="远程URL(remote_url时必填)") + file_type: Optional[str] = Field(None, description="具体文件格式(如image/jpg、audio/wav、document/docx、video/mp4)") + + def __init__(self, **data): + if "type" in data: + data['file_type'] = data['type'] + super().__init__(**data) @field_validator("type", mode="before") @classmethod @@ -196,6 +210,8 @@ class AppCreate(BaseModel): # only for type=multi_agent multi_agent_config: Optional[Dict[str, Any]] = None + workflow_config: Optional[WorkflowConfigCreate] = None + class AppUpdate(BaseModel): name: Optional[str] = None @@ -429,7 +445,7 @@ class AppChatRequest(BaseModel): user_id: Optional[str] = Field(default=None, description="用户ID(用于会话管理)") variables: Optional[Dict[str, Any]] = Field(default=None, description="自定义变量参数值") stream: bool = Field(default=False, description="是否流式返回") - files: Optional[List[FileInput]] = Field(default=None, description="附件列表(支持多文件)") + files: List[FileInput] = Field(default_factory=list, description="附件列表(支持多文件)") class DraftRunRequest(BaseModel): diff --git a/api/app/schemas/chunk_schema.py b/api/app/schemas/chunk_schema.py index cef9b9cb..ce8f70f2 100644 --- a/api/app/schemas/chunk_schema.py +++ b/api/app/schemas/chunk_schema.py @@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel): class ChunkRetrieve(BaseModel): query: str kb_ids: list[uuid.UUID] + file_names_filter: list[str] | None = Field(None) similarity_threshold: float | None = Field(None) vector_similarity_weight: float | None = Field(None) top_k: int | None = Field(None) diff --git a/api/app/schemas/conversation_schema.py b/api/app/schemas/conversation_schema.py index 0fcbc718..13766ef6 100644 --- a/api/app/schemas/conversation_schema.py +++ b/api/app/schemas/conversation_schema.py @@ -86,6 +86,7 @@ class ChatResponse(BaseModel): """聊天响应(非流式)""" conversation_id: uuid.UUID message: str + message_id: str usage: Optional[Dict[str, Any]] = None elapsed_time: Optional[float] = None diff --git a/api/app/schemas/memory_config_schema.py b/api/app/schemas/memory_config_schema.py index 0b63844b..0c359d70 100644 --- a/api/app/schemas/memory_config_schema.py +++ b/api/app/schemas/memory_config_schema.py @@ -417,6 +417,7 @@ class MemoryConfig: # Ontology scene association scene_id: Optional[UUID] = None + ontology_classes: Optional[list] = field(default=None) def __post_init__(self): """Validate configuration after initialization.""" diff --git a/api/app/schemas/memory_storage_schema.py b/api/app/schemas/memory_storage_schema.py index 776d2783..046b79e7 100644 --- a/api/app/schemas/memory_storage_schema.py +++ b/api/app/schemas/memory_storage_schema.py @@ -232,14 +232,15 @@ class ConfigParamsCreate(BaseModel): # 创建配置参数模型(仅 body, # 本体场景关联(可选) scene_id: Optional[uuid.UUID] = Field(None, description="本体场景ID(UUID),关联ontology_scene表") + # 语义剪枝场景(由 service 层根据 scene_id 自动推导,值为关联场景的 scene_name,前端无需传入) + pruning_scene: Optional[str] = Field(None, description="语义剪枝场景,由 scene_id 对应的 scene_name 自动填充") + # 模型配置字段(可选,用于手动指定或自动填充) llm_id: Optional[str] = Field(None, description="LLM模型配置ID") embedding_id: Optional[str] = Field(None, description="嵌入模型配置ID") rerank_id: Optional[str] = Field(None, description="重排序模型配置ID") reflection_model_id: Optional[str] = Field(None, description="反思模型ID,默认与llm_id一致") emotion_model_id: Optional[str] = Field(None, description="情绪分析模型ID,默认与llm_id一致") - - class ConfigParamsDelete(BaseModel): # 删除配置参数模型(请求体) model_config = ConfigDict(populate_by_name=True, extra="forbid") # config_name: str = Field("配置名称", description="配置名称(字符串)") @@ -274,8 +275,8 @@ class ConfigUpdateExtracted(BaseModel): # 更新记忆萃取引擎配置参数 # 剪枝配置:与 runtime.json 中 pruning 段对应 pruning_enabled: Optional[bool] = Field(None, description="是否启动智能语义剪枝") - pruning_scene: Optional[Literal["education", "online_service", "outbound"]] = Field( - None, description="智能剪枝场景:education/online_service/outbound" + pruning_scene: Optional[str] = Field( + None, description="智能剪枝场景:education/online_service/outbound 或本体工程自定义场景" ) pruning_threshold: Optional[float] = Field( None, ge=0.0, le=0.9, description="智能语义剪枝阈值(0-0.9)" diff --git a/api/app/schemas/model_schema.py b/api/app/schemas/model_schema.py index 0c0bbeed..4f3878ce 100644 --- a/api/app/schemas/model_schema.py +++ b/api/app/schemas/model_schema.py @@ -21,6 +21,9 @@ class ModelConfigBase(BaseModel): is_active: bool = Field(True, description="是否激活") is_public: bool = Field(False, description="是否公开") load_balance_strategy: Optional[str] = Field(LoadBalanceStrategy.NONE.value, description="负载均衡策略") + capability: List[str] = Field(default_factory=list, description="模型能力列表") + is_omni: bool = Field(False, description="是否为Omni模型") + model_id: Optional[uuid.UUID] = Field(None, description="基础模型ID") class ApiKeyCreateNested(BaseModel): @@ -30,6 +33,8 @@ class ApiKeyCreateNested(BaseModel): provider: Optional[str] = Field(None, description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") priority: str = Field("1", description="优先级", max_length=10) @@ -63,6 +68,8 @@ class ModelConfigUpdate(BaseModel): config: Optional[Dict[str, Any]] = Field(None, description="模型配置参数") is_active: Optional[bool] = Field(None, description="是否激活") is_public: Optional[bool] = Field(None, description="是否公开") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelConfig(ModelConfigBase): @@ -95,6 +102,8 @@ class ModelApiKeyCreateByProvider(BaseModel): api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) description: Optional[str] = Field(None, description="备注") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -108,6 +117,8 @@ class ModelApiKeyBase(BaseModel): provider: ModelProvider = Field(..., description="API Key提供商") api_key: str = Field(..., description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field({}, description="API Key特定配置") is_active: bool = Field(True, description="是否激活") priority: str = Field("1", description="优先级", max_length=10) @@ -124,6 +135,8 @@ class ModelApiKeyUpdate(BaseModel): provider: Optional[ModelProvider] = Field(None, description="API Key提供商") api_key: Optional[str] = Field(None, description="API密钥", max_length=500) api_base: Optional[str] = Field(None, description="API基础URL", max_length=500) + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") config: Optional[Dict[str, Any]] = Field(None, description="API Key特定配置") is_active: Optional[bool] = Field(None, description="是否激活") priority: Optional[str] = Field(None, description="优先级", max_length=10) @@ -270,6 +283,8 @@ class ModelBaseCreate(BaseModel): description: Optional[str] = Field(None, description="模型描述") is_official: bool = Field(True, description="是否供应商官方模型") tags: List[str] = Field(default_factory=list, description="模型标签") + capability: List[str] = Field(default_factory=list, description="模型能力列表(如['vision', 'audio', 'video'])") + is_omni: bool = Field(False, description="是否为Omni模型") class ModelBaseUpdate(BaseModel): @@ -282,6 +297,8 @@ class ModelBaseUpdate(BaseModel): is_deprecated: Optional[bool] = Field(None, description="是否弃用") is_official: Optional[bool] = Field(None, description="是否供应商官方模型") tags: Optional[List[str]] = Field(None, description="模型标签") + capability: Optional[List[str]] = Field(None, description="模型能力列表") + is_omni: Optional[bool] = Field(None, description="是否为Omni模型") class ModelBase(BaseModel): @@ -298,6 +315,8 @@ class ModelBase(BaseModel): is_official: bool tags: List[str] add_count: int + capability: List[str] = [] + is_omni: bool = False class ModelBaseQuery(BaseModel): diff --git a/api/app/schemas/multi_agent_schema.py b/api/app/schemas/multi_agent_schema.py index 8fba2929..3573e87c 100644 --- a/api/app/schemas/multi_agent_schema.py +++ b/api/app/schemas/multi_agent_schema.py @@ -64,14 +64,14 @@ class ExecutionConfig(BaseModel): class MultiAgentConfigCreate(BaseModel): """创建多 Agent 配置""" master_agent_id: uuid.UUID = Field(..., description="主 Agent ID") - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") orchestration_mode: str = Field( default="collaboration", pattern="^(collaboration|supervisor)$", description="协作模式:collaboration(协作)| supervisor(监督)" ) sub_agents: List[SubAgentConfig] = Field(..., description="子 Agent 列表") - routing_rules: Optional[List[RoutingRule]] = Field(None, description="路由规则") + routing_rules: Optional[List[RoutingRule]] = Field(default=None, description="路由规则") execution_config: ExecutionConfig = Field(default_factory=ExecutionConfig, description="执行配置") aggregation_strategy: str = Field( default="merge", @@ -83,7 +83,7 @@ class MultiAgentConfigCreate(BaseModel): class MultiAgentConfigUpdate(BaseModel): """更新多 Agent 配置""" master_agent_id: Optional[uuid.UUID] = None - master_agent_name: Optional[str] = Field(None, max_length=100, description="主 Agent 名称") + master_agent_name: Optional[str] = Field(default=None, max_length=100, description="主 Agent 名称") default_model_config_id: Optional[uuid.UUID] = Field(None, description="默认模型配置ID") model_parameters: Optional[ModelParameters] = Field( None, diff --git a/api/app/schemas/ontology_schemas.py b/api/app/schemas/ontology_schemas.py index 88ecd712..905e65fe 100644 --- a/api/app/schemas/ontology_schemas.py +++ b/api/app/schemas/ontology_schemas.py @@ -241,6 +241,7 @@ class SceneResponse(BaseModel): created_at: datetime.datetime = Field(..., description="创建时间(毫秒时间戳)") updated_at: datetime.datetime = Field(..., description="更新时间(毫秒时间戳)") classes_count: int = Field(0, description="类型数量") + is_system_default: bool = Field(False, description="是否为系统默认场景") @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): @@ -462,6 +463,7 @@ class ClassListResponse(BaseModel): scene_id: UUID = Field(..., description="所属场景ID") scene_name: str = Field(..., description="场景名称") scene_description: Optional[str] = Field(None, description="场景描述") + is_system_default: bool = Field(False, description="是否为系统默认场景") items: List[ClassResponse] = Field(..., description="类型列表") diff --git a/api/app/schemas/workflow_schema.py b/api/app/schemas/workflow_schema.py index bdef825e..e580833f 100644 --- a/api/app/schemas/workflow_schema.py +++ b/api/app/schemas/workflow_schema.py @@ -18,7 +18,10 @@ class NodeConfig(BaseModel): class NodeDefinition(BaseModel): """节点定义""" id: str = Field(..., description="节点唯一标识") - type: str = Field(..., description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code") + type: str = Field( + ..., + description="节点类型: start, end, llm, agent, tool, condition, loop, transform, human, code" + ) name: str | None = Field(None, description="节点名称") cycle: str | None = Field(None, description="父循环节点id") description: str | None = Field(None, description="节点描述") @@ -30,12 +33,12 @@ class NodeDefinition(BaseModel): class EdgeDefinition(BaseModel): """边定义""" - id: str | None = Field(None, description="边唯一标识(可选)") + id: str | None = Field(default=None, description="边唯一标识(可选)") source: str = Field(..., description="源节点 ID") target: str = Field(..., description="目标节点 ID") - type: str | None = Field(None, description="边类型: normal, error") - condition: str | None = Field(None, description="条件表达式(条件边)") - label: str | None = Field(None, description="边标签") + type: str | None = Field(default=None, description="边类型: normal, error") + condition: str | None = Field(default=None, description="条件表达式(条件边)") + label: str | None = Field(default=None, description="边标签") class VariableDefinition(BaseModel): @@ -44,7 +47,7 @@ class VariableDefinition(BaseModel): type: str = Field(default="string", description="变量类型: string, number, boolean, object, array") required: bool = Field(default=False, description="是否必填") default: Any = Field(None, description="默认值") - description: str | None = Field(None, description="变量描述") + description: str | None = Field(default=None, description="变量描述") class ExecutionConfig(BaseModel): @@ -61,6 +64,13 @@ class TriggerConfig(BaseModel): config: dict[str, Any] = Field(default_factory=dict, description="触发器配置") +class WorkflowImportSave(BaseModel): + """工作流导入请求""" + temp_id: str + name: str + description: str | None = Field(default=None) + + # ==================== 工作流配置 ==================== class WorkflowConfigCreate(BaseModel): @@ -84,7 +94,7 @@ class WorkflowConfigUpdate(BaseModel): class WorkflowConfig(BaseModel): """工作流配置输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID app_id: uuid.UUID nodes: list[dict[str, Any]] @@ -95,11 +105,11 @@ class WorkflowConfig(BaseModel): is_active: bool created_at: datetime.datetime updated_at: datetime.datetime - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("updated_at", when_used="json") def _serialize_updated_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -123,7 +133,8 @@ class WorkflowExecutionResponse(BaseModel): output_data: dict[str, Any] | None = Field(None, description="所有节点的详细输出数据") error_message: str | None = Field(None, description="错误信息") elapsed_time: float | None = Field(None, description="耗时(秒)") - token_usage: dict[str, Any] | None = Field(None, description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") + token_usage: dict[str, Any] | None = Field(None, + description="Token 使用情况 {prompt_tokens, completion_tokens, total_tokens}") class WorkflowExecutionStreamChunk(BaseModel): @@ -136,7 +147,7 @@ class WorkflowExecutionStreamChunk(BaseModel): class WorkflowExecution(BaseModel): """工作流执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID workflow_config_id: uuid.UUID app_id: uuid.UUID @@ -156,15 +167,15 @@ class WorkflowExecution(BaseModel): token_usage: dict[str, Any] | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None @@ -173,7 +184,7 @@ class WorkflowExecution(BaseModel): class WorkflowNodeExecution(BaseModel): """工作流节点执行记录输出""" model_config = ConfigDict(from_attributes=True) - + id: uuid.UUID execution_id: uuid.UUID node_id: str @@ -193,15 +204,15 @@ class WorkflowNodeExecution(BaseModel): cache_key: str | None meta_data: dict[str, Any] created_at: datetime.datetime - + @field_serializer("started_at", when_used="json") def _serialize_started_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("completed_at", when_used="json") def _serialize_completed_at(self, dt: datetime.datetime | None): return int(dt.timestamp() * 1000) if dt else None - + @field_serializer("created_at", when_used="json") def _serialize_created_at(self, dt: datetime.datetime): return int(dt.timestamp() * 1000) if dt else None diff --git a/api/app/services/agent_tools.py b/api/app/services/agent_tools.py index 3ca7bddd..a4768b51 100644 --- a/api/app/services/agent_tools.py +++ b/api/app/services/agent_tools.py @@ -263,8 +263,8 @@ def create_agent_invocation_tool( try: # 9. 调用 Agent - from app.services.draft_run_service import DraftRunService - draft_service = DraftRunService(db) + from app.services.draft_run_service import AgentRunService + draft_service = AgentRunService(db) result = await draft_service.run( agent_config=agent_config, diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 9723121d..f3cdde2a 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,25 +10,24 @@ from sqlalchemy.orm import Session from app.core.agent.agent_middleware import AgentMiddleware from app.core.agent.langchain_agent import LangChainAgent -from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger -from app.db import get_db, get_db_context -from app.models import MultiAgentConfig, AgentConfig, WorkflowConfig -from app.schemas import DraftRunRequest -from app.schemas.app_schema import FileInput -from app.services.tool_service import ToolService -from app.repositories.tool_repository import ToolRepository from app.db import get_db from app.models import MultiAgentConfig, AgentConfig +from app.models import WorkflowConfig +from app.repositories.tool_repository import ToolRepository +from app.schemas import DraftRunRequest +from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService -from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool +from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool, \ + AgentRunService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator -from app.services.workflow_service import WorkflowService from app.services.multimodal_service import MultimodalService +from app.services.tool_service import ToolService +from app.services.workflow_service import WorkflowService logger = get_business_logger() @@ -39,6 +38,8 @@ class AppChatService: def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) + self.agent_service = AgentRunService(db) + self.workflow_service = WorkflowService(db) async def agnet_chat( self, @@ -55,12 +56,10 @@ class AppChatService: files: Optional[List[FileInput]] = None # 新增:多模态文件 ) -> Dict[str, Any]: """聊天(非流式)""" - start_time = time.time() config_id = None - if variables is None: - variables = {} + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id @@ -79,74 +78,20 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - # 从配置中获取启用的工具 - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) - - # 添加长期记忆工具 + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) memory_flag = False - if memory == True: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + if memory: + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -157,6 +102,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -180,7 +126,7 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") @@ -198,7 +144,7 @@ class AppChatService: ) # 保存消息 - self.conversation_service.save_conversation_messages( + message_id = self.conversation_service.save_conversation_messages( conversation_id=conversation_id, user_message=message, assistant_message=result["content"], @@ -217,6 +163,7 @@ class AppChatService: return { "conversation_id": conversation_id, + "message_id": str(message_id), "message": result["content"], "usage": result.get("usage", { "prompt_tokens": 0, @@ -245,10 +192,13 @@ class AppChatService: try: start_time = time.time() config_id = None + message_id = uuid.uuid4() + yield f"event: start\ndata: {json.dumps({ + 'conversation_id': str(conversation_id), + "message_id": str(message_id) + }, ensure_ascii=False)}\n\n" - if variables is None: - variables = {} - + variables = self.agent_service.prepare_variables(variables, config.variables) # 获取模型配置ID model_config_id = config.default_model_config_id api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) @@ -266,73 +216,22 @@ class AppChatService: tools = [] # 获取工具服务 - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) - if hasattr(config, 'tools') and config.tools and isinstance(config.tools, list): - for tool_config in config.tools: - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(config, 'tools') and config.tools and isinstance(config.tools, dict): - web_tools = config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(config, 'skills') and config.skills: - skills = config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - knowledge_retrieval = config.knowledge_retrieval - if knowledge_retrieval: - knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) - kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] - if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) - tools.append(kb_tool) + tools.extend(self.agent_service.load_tools_config(config.tools, web_search, tenant_id)) + skill_tools, skill_prompts = self.agent_service.load_skill_config(config.skills, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id)) # 添加长期记忆工具 memory_flag = False if memory: - memory_config = config.memory - if memory_config.get("enabled") and user_id: - memory_flag = True - memory_tool = create_long_term_memory_tool(memory_config, user_id) - tools.append(memory_tool) + memory_tools, memory_flag = self.agent_service.load_memory_config( + config.memory, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 获取模型参数 model_parameters = config.model_parameters @@ -343,6 +242,7 @@ class AppChatService: api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, @@ -366,13 +266,10 @@ class AppChatService: # 处理多模态文件 processed_files = None if files: - multimodal_service = MultimodalService(self.db) + multimodal_service = MultimodalService(self.db, api_key_obj.provider, is_omni=api_key_obj.is_omni) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件") - # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" - # 流式调用 Agent(支持多模态) full_content = "" total_tokens = 0 @@ -404,6 +301,7 @@ class AppChatService: ) self.conversation_service.add_message( + message_id=message_id, conversation_id=conversation_id, role="assistant", content=full_content, @@ -416,7 +314,7 @@ class AppChatService: ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) # 发送结束事件 - end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} + end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" logger.info( @@ -435,7 +333,7 @@ class AppChatService: except Exception as e: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 - yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" + yield f"event: end\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" async def multi_agent_chat( self, @@ -481,7 +379,7 @@ class AppChatService: content=message ) - self.conversation_service.add_message( + ai_message = self.conversation_service.add_message( conversation_id=conversation_id, role="assistant", content=result.get("message", ""), @@ -489,16 +387,17 @@ class AppChatService: "mode": result.get("mode"), "elapsed_time": result.get("elapsed_time"), "usage": result.get("usage", { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - }) + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + }) } ) return { "conversation_id": conversation_id, "message": result.get("message", ""), + "message_id": str(ai_message.id), "usage": { "prompt_tokens": 0, "completion_tokens": 0, @@ -522,16 +421,14 @@ class AppChatService: """多 Agent 聊天(流式)""" start_time = time.time() - actual_config_id = None - config_id = actual_config_id if variables is None: variables = {} try: - + message_id = uuid.uuid4() # 发送开始事件 - yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id)}, ensure_ascii=False)}\n\n" + yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation_id), "message_id": str(message_id)}, ensure_ascii=False)}\n\n" full_content = "" total_tokens = 0 @@ -539,6 +436,7 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) + # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, @@ -582,6 +480,7 @@ class AppChatService: ) self.conversation_service.add_message( + message_id=message_id, conversation_id=conversation_id, role="assistant", content=full_content, @@ -629,7 +528,6 @@ class AppChatService: user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -637,7 +535,7 @@ class AppChatService: stream=True, user_id=user_id ) - return await workflow_service.run( + return await self.workflow_service.run( app_id=app_id, payload=payload, config=config, @@ -664,7 +562,6 @@ class AppChatService: ) -> AsyncGenerator[dict, None]: """聊天(流式)""" - workflow_service = WorkflowService(self.db) payload = DraftRunRequest( message=message, variables=variables, @@ -673,7 +570,7 @@ class AppChatService: user_id=user_id, files=files ) - async for event in workflow_service.run_stream( + async for event in self.workflow_service.run_stream( app_id=app_id, payload=payload, config=config, diff --git a/api/app/services/app_service.py b/api/app/services/app_service.py index f3c6260a..5a799937 100644 --- a/api/app/services/app_service.py +++ b/api/app/services/app_service.py @@ -232,7 +232,7 @@ class AppService: # 检查主 Agent 的模型配置 multi_agent_config.default_model_config_id = master_agent_release.default_model_config_id - model_api_key = ModelApiKeyService.get_a_api_key(self.db, multi_agent_config.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, multi_agent_config.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(multi_agent_config.default_model_config_id)) @@ -321,6 +321,26 @@ class AppService: self.db.add(agent_cfg) logger.debug("Agent 配置已创建", extra={"app_id": str(app_id)}) + def _create_workflow_config( + self, + app_id: uuid.UUID, + data: app_schema.WorkflowConfigCreate, + now: datetime.datetime + ): + workflow_cfg = WorkflowConfig( + id=uuid.uuid4(), + app_id=app_id, + nodes=[node.model_dump() for node in data.nodes] if data.nodes else [], + edges=[edge.model_dump() for edge in data.edges] if data.edges else [], + variables=[var.model_dump() for var in data.variables] if data.variables else [], + execution_config=data.execution_config.model_dump() if data.execution_config else {}, + triggers=[trigger.model_dump() for trigger in data.triggers] if data.triggers else [], + is_active=True, + created_at=now, + updated_at=now + ) + self.db.add(workflow_cfg) + def _create_multi_agent_config( self, app_id: uuid.UUID, @@ -532,6 +552,9 @@ class AppService: if app.type == "multi_agent" and data.multi_agent_config: self._create_multi_agent_config(app.id, data.multi_agent_config, now) + if app.type == "workflow" and data.workflow_config: + self._create_workflow_config(app.id, data.workflow_config, now) + self.db.commit() self.db.refresh(app) @@ -680,7 +703,7 @@ class AppService: self.db.flush() # 如果是 agent 类型,复制 AgentConfig - if source_app.type == "agent": + if source_app.type == AppType.AGENT: source_config = self.db.query(AgentConfig).filter( AgentConfig.app_id == source_app.id ).first() @@ -702,6 +725,50 @@ class AppService: ) self.db.add(new_config) + elif source_app.type == AppType.WORKFLOW: + source_config = self.db.query(WorkflowConfig).filter( + WorkflowConfig.app_id == source_app.id + ).first() + + if source_config: + new_config = WorkflowConfig( + id=uuid.uuid4(), + app_id=new_app.id, + nodes=source_config.nodes.copy() if source_config.nodes else [], + edges=source_config.edges.copy() if source_config.edges else [], + variables=source_config.variables.copy() if source_config.variables else [], + execution_config=source_config.execution_config.copy() if source_config.execution_config else {}, + triggers=source_config.triggers.copy() if source_config.triggers else [], + is_active=True, + created_at=now, + updated_at=now, + ) + self.db.add(new_config) + + elif source_app.type == AppType.MULTI_AGENT: + source_config = self.db.query(MultiAgentConfig).filter( + MultiAgentConfig.app_id == source_app.id + ).first() + + if source_config: + new_config = MultiAgentConfig( + id=uuid.uuid4(), + app_id=new_app.id, + master_agent_id=source_config.master_agent_id, + master_agent_name=source_config.master_agent_name, + default_model_config_id=source_config.default_model_config_id, + model_parameters=source_config.model_parameters, + orchestration_mode=source_config.orchestration_mode, + sub_agents=source_config.sub_agents.copy() if source_config.sub_agents else [], + routing_rules=source_config.routing_rules.copy() if source_config.routing_rules else None, + execution_config=source_config.execution_config.copy() if source_config.execution_config else {}, + aggregation_strategy=source_config.aggregation_strategy, + is_active=True, + created_at=now, + updated_at=now, + ) + self.db.add(new_config) + self.db.commit() self.db.refresh(new_app) @@ -968,7 +1035,7 @@ class AppService: config = self.db.scalars(stmt).first() try: - config_memory=config.memory + config_memory = config.memory if 'memory_content' in config_memory: config.memory['memory_config_id'] = config.memory.pop('memory_content') except: @@ -1189,9 +1256,9 @@ class AppService: # ==================== 记忆配置提取方法 ==================== def _extract_memory_config_id( - self, - app_type: str, - config: Dict[str, Any] + self, + app_type: str, + config: Dict[str, Any] ) -> Tuple[Optional[uuid.UUID], bool]: """从发布配置中提取 memory_config_id(委托给 MemoryConfigService) @@ -1205,13 +1272,13 @@ class AppService: - is_legacy_int: 是否检测到旧格式 int 数据,需要回退到工作空间默认配置 """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) return service.extract_memory_config_id(app_type, config) def _get_workspace_default_memory_config_id( - self, - workspace_id: uuid.UUID + self, + workspace_id: uuid.UUID ) -> Optional[uuid.UUID]: """获取工作空间的默认记忆配置ID @@ -1222,22 +1289,22 @@ class AppService: Optional[uuid.UUID]: 默认记忆配置ID,如果不存在则返回 None """ from app.services.memory_config_service import MemoryConfigService - + service = MemoryConfigService(self.db) config = service.get_workspace_default_config(workspace_id) - + if not config: logger.warning( f"工作空间没有可用的记忆配置: workspace_id={workspace_id}" ) return None - + return config.config_id def _update_endusers_memory_config( - self, - app_id: uuid.UUID, - memory_config_id: uuid.UUID + self, + app_id: uuid.UUID, + memory_config_id: uuid.UUID ) -> int: """批量更新应用下所有终端用户的 memory_config_id @@ -1249,13 +1316,13 @@ class AppService: int: 更新的终端用户数量 """ from app.repositories.end_user_repository import EndUserRepository - + repo = EndUserRepository(self.db) updated_count = repo.batch_update_memory_config_id( app_id=app_id, memory_config_id=memory_config_id ) - + return updated_count # ==================== 应用发布管理 ==================== @@ -1403,7 +1470,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(app.type, config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1412,7 +1479,7 @@ class AppService: f"发布时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( @@ -1537,7 +1604,7 @@ class AppService: # 提取记忆配置ID并更新终端用户 memory_config_id, is_legacy_int = self._extract_memory_config_id(release.type, release.config) - + # 如果检测到旧格式 int 数据,回退到工作空间默认配置 if is_legacy_int and not memory_config_id: memory_config_id = self._get_workspace_default_memory_config_id(app.workspace_id) @@ -1546,7 +1613,7 @@ class AppService: f"回滚时使用工作空间默认记忆配置(旧数据兼容): app_id={app_id}, " f"workspace_id={app.workspace_id}, memory_config_id={memory_config_id}" ) - + if memory_config_id: updated_count = self._update_endusers_memory_config(app_id, memory_config_id) logger.info( @@ -1768,372 +1835,6 @@ class AppService: return shares - # ==================== 试运行功能 ==================== - - async def draft_run( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ) -> Dict[str, Any]: - """试运行 Agent(使用当前草稿配置) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Returns: - Dict: 包含 AI 回复和元数据的字典 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用试运行服务 - logger.debug( - "准备调用试运行服务", - extra={ - "app_id": str(app_id), - "model": model_config.name, - "has_conversation_id": bool(conversation_id), - "has_variables": bool(variables) - } - ) - - draft_service = DraftRunService(self.db) - result = await draft_service.run( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ) - - logger.debug( - "试运行服务返回结果", - extra={ - "result_type": str(type(result)), - "result_keys": list(result.keys()) if isinstance(result, dict) else "not_dict", - "has_message": "message" in result if isinstance(result, dict) else False, - "has_conversation_id": "conversation_id" in result if isinstance(result, dict) else False - } - ) - - logger.info( - "试运行完成", - extra={ - "app_id": str(app_id), - "elapsed_time": result.get("elapsed_time"), - "model": model_config.name - } - ) - - return result - - async def draft_run_stream( - self, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None - ): - """试运行 Agent(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - conversation_id: 会话ID(用于多轮对话) - user_id: 用户ID(用于会话管理) - variables: 自定义变量参数值 - workspace_id: 工作空间ID(用于权限验证) - - Yields: - str: SSE 格式的事件数据 - - Raises: - ResourceNotFoundException: 当应用不存在时 - BusinessException: 当应用类型不支持或配置缺失时 - """ - from app.services.draft_run_service import DraftRunService - - logger.info("流式试运行 Agent", extra={"app_id": str(app_id), "user_message": message[:50]}) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - - if not agent_cfg: - raise BusinessException("Agent 配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 3. 获取模型配置 - model_config = None - if agent_cfg.default_model_config_id: - from app.models import ModelConfig - model_config = self.db.get(ModelConfig, agent_cfg.default_model_config_id) - - if not model_config: - raise BusinessException("模型配置不存在,无法试运行", BizCode.AGENT_CONFIG_MISSING) - - # 4. 调用流式试运行服务 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_stream( - agent_config=agent_cfg, - model_config=model_config, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables - ): - yield event - - # ==================== 多模型对比试运行 ==================== - - async def draft_run_compare( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ) -> Dict[str, Any]: - """多模型对比试运行 - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - parallel: 是否并行执行 - timeout: 超时时间(秒) - - Returns: - Dict: 对比结果 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models), - "parallel": parallel - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的对比方法 - draft_service = DraftRunService(self.db) - result = await draft_service.run_compare( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ) - - logger.info( - "多模型对比完成", - extra={ - "app_id": str(app_id), - "successful": result["successful_count"], - "failed": result["failed_count"] - } - ) - - return result - - async def draft_run_compare_stream( - self, - *, - app_id: uuid.UUID, - message: str, - models: List[app_schema.ModelCompareItem], - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None, - parallel: bool = True, - timeout: int = 60 - ): - """多模型对比试运行(流式返回) - - Args: - app_id: 应用ID - message: 用户消息 - models: 要对比的模型列表 - conversation_id: 会话ID - user_id: 用户ID - variables: 变量参数 - workspace_id: 工作空间ID - timeout: 超时时间(秒) - - Yields: - str: SSE 格式的事件数据 - """ - from app.models import ModelConfig - from app.services.draft_run_service import DraftRunService - - logger.info( - "多模型对比流式试运行", - extra={ - "app_id": str(app_id), - "model_count": len(models) - } - ) - - # 1. 验证应用 - app = self._get_app_or_404(app_id) - if app.type != "agent": - raise BusinessException("只有 Agent 类型应用支持试运行", BizCode.APP_TYPE_NOT_SUPPORTED) - - # 只读操作,允许访问共享应用 - self._validate_app_accessible(app, workspace_id) - - # 2. 获取 Agent 配置 - stmt = select(AgentConfig).where(AgentConfig.app_id == app_id) - agent_cfg = self.db.scalars(stmt).first() - if not agent_cfg: - raise BusinessException("Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - - # 3. 准备所有模型配置 - model_configs = [] - for model_item in models: - model_config = self.db.get(ModelConfig, model_item.model_config_id) - if not model_config: - raise ResourceNotFoundException("模型配置", str(model_item.model_config_id)) - - # 合并参数:agent配置参数 + 请求覆盖参数 - merged_parameters = { - **(agent_cfg.model_parameters or {}), - **(model_item.model_parameters or {}) - } - - model_configs.append({ - "model_config": model_config, - "parameters": merged_parameters, - "label": model_item.label or model_config.name, - "model_config_id": model_item.model_config_id - }) - - # 4. 调用 DraftRunService 的流式对比方法 - draft_service = DraftRunService(self.db) - async for event in draft_service.run_compare_stream( - agent_config=agent_cfg, - models=model_configs, - message=message, - workspace_id=workspace_id, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - parallel=parallel, - timeout=timeout - ): - yield event - - logger.info( - "多模型对比流式完成", - extra={"app_id": str(app_id)} - ) - - # ==================== 向后兼容的函数接口 ==================== # 保留函数接口以兼容现有代码,但内部使用服务类 @@ -2255,53 +1956,6 @@ def get_apps_by_ids( return service.get_apps_by_ids(app_ids, workspace_id) -# ==================== 向后兼容的函数接口 ==================== - -async def draft_run( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -) -> Dict[str, Any]: - """试运行 Agent(向后兼容接口)""" - service = AppService(db) - return await service.draft_run( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ) - - -async def draft_run_stream( - db: Session, - *, - app_id: uuid.UUID, - message: str, - conversation_id: Optional[str] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - workspace_id: Optional[uuid.UUID] = None -): - """试运行 Agent 流式返回(向后兼容接口)""" - service = AppService(db) - async for event in service.draft_run_stream( - app_id=app_id, - message=message, - conversation_id=conversation_id, - user_id=user_id, - variables=variables, - workspace_id=workspace_id - ): - yield event - - # ==================== 依赖注入函数 ==================== def get_app_service( diff --git a/api/app/services/audio_transcription_service.py b/api/app/services/audio_transcription_service.py new file mode 100644 index 00000000..11d13f38 --- /dev/null +++ b/api/app/services/audio_transcription_service.py @@ -0,0 +1,101 @@ +""" +音频转文本服务 + +支持的服务商: +- DashScope (阿里云通义千问) +- OpenAI Whisper +""" +import httpx + +from app.core.logging_config import get_business_logger + +logger = get_business_logger() + + +class AudioTranscriptionService: + """音频转文本服务""" + + @staticmethod + async def transcribe_dashscope(audio_url: str, api_key: str) -> str: + """ + 使用阿里云通义千问语音识别服务转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: DashScope API Key + + Returns: + str: 转录的文本 + """ + try: + async with httpx.AsyncClient(timeout=60.0) as client: + response = await client.post( + "https://dashscope.aliyuncs.com/api/v1/services/audio/asr/transcription", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "X-DashScope-Async": "enable", + }, + json={ + "model": "paraformer-v2", + "input": { + "file_urls": [audio_url] + }, + "parameters": { + "language_hints": ["zh", "en", "ja", "yue", "ko", "de", "fr", "ru"] + } + } + ) + response.raise_for_status() + result = response.json() + + if result.get("output", {}).get("results"): + text = result["output"]["results"][0].get("transcription_text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + return "[音频转文本失败]" + + except Exception as e: + logger.error(f"DashScope 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" + + @staticmethod + async def transcribe_openai(audio_url: str, api_key: str) -> str: + """ + 使用 OpenAI Whisper 转换音频为文本 + + Args: + audio_url: 音频文件 URL + api_key: OpenAI API Key + + Returns: + str: 转录的文本 + """ + try: + # 下载音频文件 + async with httpx.AsyncClient(timeout=60.0) as client: + audio_response = await client.get(audio_url) + audio_response.raise_for_status() + audio_data = audio_response.content + + # 调用 Whisper API + files = {"file": ("audio.mp3", audio_data, "audio/mpeg")} + data = {"model": "whisper-1"} + + response = await client.post( + "https://api.openai.com/v1/audio/transcriptions", + headers={"Authorization": f"Bearer {api_key}"}, + files=files, + data=data + ) + response.raise_for_status() + result = response.json() + + text = result.get("text", "") + logger.info(f"音频转文本成功: {len(text)} 字符") + return text + + except Exception as e: + logger.error(f"OpenAI Whisper 音频转文本失败: {e}") + return f"[音频转文本失败: {str(e)}]" diff --git a/api/app/services/collaborative_orchestrator.py b/api/app/services/collaborative_orchestrator.py index 00a731de..68181cd1 100644 --- a/api/app/services/collaborative_orchestrator.py +++ b/api/app/services/collaborative_orchestrator.py @@ -445,6 +445,7 @@ class CollaborativeOrchestrator: "provider": api_key_config.provider, "api_key": api_key_config.api_key, "api_base": api_key_config.api_base, + "is_omni": api_key_config.is_omni, "model_parameters": config_data.get("model_parameters", {}), "api_key_id": api_key_config.id } @@ -511,6 +512,7 @@ class CollaborativeOrchestrator: provider=agent_config["provider"], api_key=agent_config["api_key"], base_url=agent_config.get("api_base"), + is_omni=agent_config.get("is_omni", False), extra_params=extra_params ) diff --git a/api/app/services/conversation_service.py b/api/app/services/conversation_service.py index 553aefc4..aff5f533 100644 --- a/api/app/services/conversation_service.py +++ b/api/app/services/conversation_service.py @@ -178,7 +178,8 @@ class ConversationService: conversation_id: uuid.UUID, role: str, content: str, - meta_data: Optional[dict] = None + meta_data: Optional[dict] = None, + message_id: Optional[uuid.UUID] = None, ) -> Message: """ Add a message to a conversation using UnitOfWork. @@ -188,6 +189,7 @@ class ConversationService: role (str): Role of the message sender ('user' or 'assistant'). content (str): Message content. meta_data (Optional[dict]): Optional metadata. + message_id (Optional[uuid.UUID]): Optional custom message UUID. Returns: Message: Newly created Message instance. @@ -198,6 +200,7 @@ class ConversationService: ) message = Message( + id=message_id if message_id else uuid.uuid4(), conversation_id=conversation_id, role=role, content=content, @@ -317,7 +320,7 @@ class ConversationService: content=user_message ) - self.add_message( + ai_message = self.add_message( conversation_id=conversation_id, role="assistant", content=assistant_message, @@ -332,6 +335,7 @@ class ConversationService: "assistant_message_length": len(assistant_message) } ) + return ai_message.id def delete_conversation( self, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index 8977710b..5026bf27 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -17,15 +17,18 @@ from sqlalchemy.orm import Session from app.celery_app import celery_app from app.core.agent.agent_middleware import AgentMiddleware +from app.core.agent.langchain_agent import LangChainAgent from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval +from app.db import get_db_context from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput from app.schemas.prompt_schema import PromptMessageRole, render_prompt_message from app.services import task_service +from app.services.conversation_service import ConversationService from app.services.langchain_tool_server import Search from app.services.memory_agent_service import MemoryAgentService from app.services.model_parameter_merger import ModelParameterMerger @@ -52,8 +55,12 @@ class LongTermMemoryInput(BaseModel): description="经过优化重写的查询问题。请将用户的原始问题重写为更合适的检索形式,包含关键词,上下文和具体描述,注意错词检查并且改写") -def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str, storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None): +def create_long_term_memory_tool( + memory_config: Dict[str, Any], + end_user_id: str, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None +): """创建记忆工具, @@ -61,6 +68,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str memory_config: 记忆配置 end_user_id: 用户ID storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) Returns: 长期记忆工具 @@ -96,9 +104,7 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str """ logger.info(f" 长期记忆工具被调用!question={question}, user={end_user_id}") try: - from app.db import get_db - db = next(get_db()) - try: + with get_db_context() as db: memory_content = asyncio.run( MemoryAgentService().read_memory( end_user_id=end_user_id, @@ -120,9 +126,6 @@ def create_long_term_memory_tool(memory_config: Dict[str, Any], end_user_id: str logger.info(f"读取任务状态:{status}") if memory_content: memory_content = memory_content['answer'] - - finally: - db.close() logger.info(f'用户ID:Agent:{end_user_id}') logger.debug("调用长期记忆 API", extra={"question": question, "end_user_id": end_user_id}) @@ -188,7 +191,9 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): """从知识库中检索相关信息。当用户的问题需要参考知识库、文档或历史记录时,使用此工具进行检索。 Args: - query: 需要检索的问题或关键词 + kb_config: 知识库配置 + kb_ids: 知识库ID列表 + user_id: 用户ID Returns: 检索到的相关知识内容 @@ -232,17 +237,141 @@ def create_knowledge_retrieval_tool(kb_config, kb_ids, user_id): return knowledge_retrieval_tool -class DraftRunService: - """试运行服务类""" +class AgentRunService: + """Agent运行服务类""" def __init__(self, db: Session): - """初始化试运行服务 + """Agent运行服务 Args: db: 数据库会话 """ self.db = db + @staticmethod + def prepare_variables( + input_vars: dict | None, + variables_config: dict + ) -> dict: + input_vars = input_vars or {} + for variable in variables_config: + if variable.get("required") and variable.get("name") not in input_vars: + raise ValueError(f"The required parameter '{variable.get('name')}' was not provided") + return input_vars + + def load_tools_config(self, tools_config, web_search, tenant_id) -> list: + """加载工具配置""" + if not tools_config: + return [] + tools = [] + tool_service = ToolService(self.db) + + if tools_config and isinstance(tools_config, list): + for tool_config in tools_config: + if tool_config.get("enabled", False): + # 根据工具名称查找工具实例 + tool_instance = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + if tool_instance: + if tool_instance.name == "baidu_search_tool" and not web_search: + continue + # 转换为LangChain工具 + langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) + tools.append(langchain_tool) + elif tools_config and isinstance(tools_config, dict): + web_search_choice = tools_config.get("web_search", {}) + web_search_enable = web_search_choice.get("enabled", False) + if web_search and web_search_enable: + search_tool = create_web_search_tool({}) + tools.append(search_tool) + + logger.debug( + "已添加网络搜索工具", + extra={ + "tool_count": len(tools) + } + ) + return tools + + def load_skill_config( + self, + skills_config: dict | None, + message: str, tenant_id + ) -> tuple[list, str]: + if not skills_config: + return [], "" + + tools = [] + skill_prompts = "" + skill_enable = skills_config.get("enabled", False) + if skill_enable: + middleware = AgentMiddleware(skills=skills_config) + skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) + tools.extend(skill_tools) + logger.debug(f"已加载 {len(skill_tools)} 个技能工具") + + if skill_configs: + tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, + tool_to_skill_map) + logger.debug(f"过滤后剩余 {len(tools)} 个工具") + skill_prompts = AgentMiddleware.get_active_prompts( + activated_skill_ids, skill_configs + ) + + return tools, skill_prompts + + def load_knowledge_retrieval_config( + self, + knowledge_retrieval_config: dict | None, + user_id + ) -> list: + if not knowledge_retrieval_config: + return [] + + tools = [] + knowledge_bases = knowledge_retrieval_config.get("knowledge_bases", []) + kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) + if kb_ids: + # 创建知识库检索工具 + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval_config, kb_ids, user_id) + tools.append(kb_tool) + + logger.debug( + "已添加知识库检索工具", + extra={ + "kb_ids": kb_ids, + "tool_count": len(tools) + } + ) + return tools + + def load_memory_config( + self, + memory_config: dict | None, + user_id, + storage_type, + user_rag_memory_id + ) -> tuple[list, bool]: + """加载长期记忆配置""" + if not memory_config: + return [], False + + tools = [] + if memory_config.get("enabled"): + if user_id: + # 创建长期记忆工具 + memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.append(memory_tool) + + logger.debug( + "已添加长期记忆工具", + extra={ + "user_id": user_id, + "tool_count": len(tools) + } + ) + return tools, bool(memory_config.get("enabled")) + async def run( self, *, @@ -270,19 +399,21 @@ class DraftRunService: conversation_id: 会话ID(用于多轮对话) user_id: 用户ID variables: 自定义变量参数值 + storage_type: 存储类型(可选) + user_rag_memory_id: 用户RAG记忆ID(可选) + web_search: 是否启用网络搜索(默认True) + memory: 是否启用长期记忆(默认True) + sub_agent: 是否为子代理调用(默认False) + files: 多模态文件列表(可选) Returns: Dict: 包含 AI 回复和元数据的字典 """ - memory_flag = False - - print('===========', storage_type) - - print(user_id) - if variables == None: variables = {} - from app.core.agent.langchain_agent import LangChainAgent - start_time = time.time() + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory try: # 1. 获取 API Key 配置 @@ -302,112 +433,40 @@ class DraftRunService: agent_config=agent_config ) - items_params = variables + if sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} + system_prompt = render_prompt_message( - agent_config.system_prompt, # 修正拼写错误 + agent_config.system_prompt, PromptMessageRole.USER, - items_params + variables ) # 3. 处理系统提示词(支持变量替换) system_prompt = system_prompt.get_text_content() or "你是一个专业的AI助手" - print('系统提示词:', system_prompt) # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - if hasattr(agent_config, 'tools') and agent_config.tools: - for tool_config in agent_config.tools: - print("+" * 50) - print(f"agent_config:{agent_config}") - print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) - - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) - + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config( + memory_config, user_id, storage_type, user_rag_memory_id + ) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -415,6 +474,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -431,7 +491,7 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, max_history=agent_config.memory.get("max_history", 10) @@ -442,7 +502,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -481,7 +541,7 @@ class DraftRunService: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.get("api_key_id")) # 9. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -556,16 +616,21 @@ class DraftRunService: Yields: str: SSE 格式的事件数据 """ - memory_flag = False - if variables == None: variables = {} - - from app.core.agent.langchain_agent import LangChainAgent + tools_config: dict | list | None = agent_config.tools + skills_config: dict | None = agent_config.skills + knowledge_retrieval_config: dict | None = agent_config.knowledge_retrieval + memory_config: dict | None = agent_config.memory start_time = time.time() try: # 1. 获取 API Key 配置 api_key_config = await self._get_api_key(model_config.id) + if not sub_agent: + variables = self.prepare_variables(variables, agent_config.variables) + else: + # FIXME: subagent input valid + variables = variables or {} # 2. 合并模型参数 effective_params = ModelParameterMerger.get_effective_parameters( @@ -587,95 +652,22 @@ class DraftRunService: # 4. 准备工具列表 tools = [] - tool_service = ToolService(self.db) tenant_id = ToolRepository.get_tenant_id_by_workspace_id(self.db, str(workspace_id)) # 从配置中获取启用的工具 - if hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, list): - for tool_config in agent_config.tools: - # print("+"*50) - # print(f"agent_config:{agent_config}") - # print(f"tool_config:{tool_config}") - if tool_config.get("enabled", False): - # 根据工具名称查找工具实例 - tool_instance = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) - if tool_instance: - if tool_instance.name == "baidu_search_tool" and not web_search: - continue - # 转换为LangChain工具 - langchain_tool = tool_instance.to_langchain_tool(tool_config.get("operation", None)) - tools.append(langchain_tool) - elif hasattr(agent_config, 'tools') and agent_config.tools and isinstance(agent_config.tools, dict): - web_tools = agent_config.tools - web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled", False) - if web_search: - if web_search_enable: - search_tool = create_web_search_tool({}) - tools.append(search_tool) + tools.extend(self.load_tools_config(tools_config, web_search, tenant_id)) + skill_tools, skill_prompts = self.load_skill_config(skills_config, message, tenant_id) + tools.extend(skill_tools) + if skill_prompts: + system_prompt = f"{system_prompt}\n\n{skill_prompts}" + tools.extend(self.load_knowledge_retrieval_config(knowledge_retrieval_config, user_id)) - logger.debug( - "已添加网络搜索工具", - extra={ - "tool_count": len(tools) - } - ) - - # 加载技能关联的工具 - if hasattr(agent_config, 'skills') and agent_config.skills: - skills = agent_config.skills - skill_enable = skills.get("enabled", False) - if skill_enable: - middleware = AgentMiddleware(skills=skills) - skill_tools, skill_configs, tool_to_skill_map = middleware.load_skill_tools(self.db, tenant_id) - tools.extend(skill_tools) - logger.debug(f"已加载 {len(skill_tools)} 个技能工具") - - # 应用动态过滤 - if skill_configs: - tools, activated_skill_ids = middleware.filter_tools(tools, message, skill_configs, - tool_to_skill_map) - logger.debug(f"过滤后剩余 {len(tools)} 个工具") - active_prompts = AgentMiddleware.get_active_prompts( - activated_skill_ids, skill_configs - ) - system_prompt = f"{system_prompt}\n\n{active_prompts}" - - # 添加知识库检索工具 - if agent_config.knowledge_retrieval: - kb_config = agent_config.knowledge_retrieval - knowledge_bases = kb_config.get("knowledge_bases", []) - kb_ids = bool(knowledge_bases and knowledge_bases[0].get("kb_id")) - if kb_ids: - # 创建知识库检索工具 - kb_tool = create_knowledge_retrieval_tool(kb_config, kb_ids, user_id) - tools.append(kb_tool) - - logger.debug( - "已添加知识库检索工具", - extra={ - "kb_ids": kb_ids, - "tool_count": len(tools) - } - ) # 添加长期记忆工具 + memory_flag = False if memory: - if agent_config.memory and agent_config.memory.get("enabled"): - memory_flag = True - memory_config = agent_config.memory - if user_id: - # 创建长期记忆工具 - memory_tool = create_long_term_memory_tool(memory_config, user_id, storage_type, - user_rag_memory_id) - tools.append(memory_tool) - - logger.debug( - "已添加长期记忆工具", - extra={ - "user_id": user_id, - "tool_count": len(tools) - } - ) + memory_tools, memory_flag = self.load_memory_config(memory_config, user_id, storage_type, + user_rag_memory_id) + tools.extend(memory_tools) # 4. 创建 LangChain Agent agent = LangChainAgent( @@ -683,6 +675,7 @@ class DraftRunService: api_key=api_key_config["api_key"], provider=api_key_config.get("provider", "openai"), api_base=api_key_config.get("api_base"), + is_omni=api_key_config.get("is_omni", False), temperature=effective_params.get("temperature", 0.7), max_tokens=effective_params.get("max_tokens", 2000), system_prompt=system_prompt, @@ -700,10 +693,10 @@ class DraftRunService: # 6. 加载历史消息 history = [] - if agent_config.memory and agent_config.memory.get("enabled"): + if memory_config and memory_config.get("enabled"): history = await self._load_conversation_history( conversation_id=conversation_id, - max_history=agent_config.memory.get("max_history", 10) + max_history=memory_config.get("max_history", 10) ) # 6. 处理多模态文件 @@ -711,7 +704,7 @@ class DraftRunService: if files: # 获取 provider 信息 provider = api_key_config.get("provider", "openai") - multimodal_service = MultimodalService(self.db, provider=provider) + multimodal_service = MultimodalService(self.db, provider=provider, is_omni=api_key_config.get("is_omni", False)) processed_files = await multimodal_service.process_files(files) logger.info(f"处理了 {len(processed_files)} 个文件,provider={provider}") @@ -761,7 +754,7 @@ class DraftRunService: }) # 10. 保存会话消息 - if not sub_agent and agent_config.memory and agent_config.memory.get("enabled"): + if not sub_agent and memory_config and memory_config.get("enabled"): await self._save_conversation_message( conversation_id=conversation_id, user_message=message, @@ -809,7 +802,7 @@ class DraftRunService: """ return f"event: {event}\ndata: {json.dumps(data, ensure_ascii=False)}\n\n" - async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict[str, str]: + async def _get_api_key(self, model_config_id: uuid.UUID) -> Dict: """获取模型的 API Key Args: @@ -846,7 +839,8 @@ class DraftRunService: "provider": api_key.provider, "api_key": api_key.api_key, "api_base": api_key.api_base, - "api_key_id": api_key.id + "api_key_id": api_key.id, + "is_omni": api_key.is_omni } async def _ensure_conversation( @@ -966,7 +960,6 @@ class DraftRunService: List[Dict]: 历史消息列表 """ try: - from app.services.conversation_service import ConversationService conversation_service = ConversationService(self.db) history = conversation_service.get_conversation_history( @@ -1486,6 +1479,15 @@ class DraftRunService: "conversation_id": returned_conversation_id, "content": chunk })) + + if event_type == "error" and event_data: + await event_queue.put(self._format_sse_event("model_error", { + "model_index": idx, + "model_config_id": model_config_id, + "label": model_label, + "conversation_id": returned_conversation_id, + "error": event_data.get("error", "未知错误") + })) except Exception as e: logger.warning(f"解析流式事件失败: {e}") finally: @@ -1670,41 +1672,3 @@ class DraftRunService: "total_time": sum(r.get("elapsed_time", 0) for r in results) } ) - - -async def draft_run( - db: Session, - *, - agent_config: AgentConfig, - model_config: ModelConfig, - message: str, - user_id: Optional[str] = None, - kb_ids: Optional[List[str]] = None, - similarity_threshold: float = 0.7, - top_k: int = 3 -) -> Dict[str, Any]: - """试运行 Agent(便捷函数) - - Args: - db: 数据库会话 - agent_config: Agent 配置 - model_config: 模型配置 - message: 用户消息 - user_id: 用户ID - kb_ids: 知识库ID列表 - similarity_threshold: 相似度阈值 - top_k: 检索返回的文档数量 - - Returns: - Dict: 包含 AI 回复和元数据的字典 - """ - service = DraftRunService(db) - return await service.run( - agent_config=agent_config, - model_config=model_config, - message=message, - user_id=user_id, - kb_ids=kb_ids, - similarity_threshold=similarity_threshold, - top_k=top_k - ) diff --git a/api/app/services/handoffs_service.py b/api/app/services/handoffs_service.py index e490eea4..8418fe31 100644 --- a/api/app/services/handoffs_service.py +++ b/api/app/services/handoffs_service.py @@ -544,6 +544,7 @@ def convert_multi_agent_config_to_handoffs( provider=model_api_key.provider, api_key=model_api_key.api_key, base_url=model_api_key.api_base, + is_omni=model_api_key.is_omni, extra_params={ "temperature": 0.7, "max_tokens": 2000, diff --git a/api/app/services/langchain_tool_server.py b/api/app/services/langchain_tool_server.py index f44e4cdc..2c151956 100644 --- a/api/app/services/langchain_tool_server.py +++ b/api/app/services/langchain_tool_server.py @@ -9,6 +9,8 @@ load_dotenv() # 读取web_search环境变量 web_search_value = os.getenv('web_search') + + def Search(query): url = "https://qianfan.baidubce.com/v2/ai_search/chat/completions" api_key = web_search_value @@ -18,23 +20,24 @@ def Search(query): "role": "user", "content": query } - ], #搜索输入 - "edition":"standard", #搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 - "search_source": "baidu_search_v2", #使用的搜索引擎版本 - "resource_type_filter": [{"type": "web","top_k": 20}], #支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 + ], # 搜索输入 + "edition": "standard", # 搜索版本。默认为standard。可选值:standard:完整版本。lite:标准版本,对召回规模和精排条数简化后的版本,时延表现更好,效果略弱于完整版。 + "search_source": "baidu_search_v2", # 使用的搜索引擎版本 + "resource_type_filter": [{"type": "web", "top_k": 20}], + # 支持设置网页、视频、图片、阿拉丁搜索模态,网页top_k最大取值为50,视频top_k最大为10,图片top_k最大为30,阿拉丁top_k最大为5 "search_filter": { "range": { "page_time": { - "gte": "now-1w/d", #时间查询参数,大于或等于 - "lt": "now/d", #时间查询参数,小于 - "gt": "", #时间查询参数,大于 - "lte": "" #时间查询参数,小于或等于 + "gte": "now-1w/d", # 时间查询参数,大于或等于 + "lt": "now/d", # 时间查询参数,小于 + "gt": "", # 时间查询参数,大于 + "lte": "" # 时间查询参数,小于或等于 } } }, - "block_websites":["tieba.baidu.com"], #需要屏蔽的站点列表 - "search_recency_filter":"week", #根据网页发布时间进行筛选,可填值为:week,month,semiyear,year - "enable_full_content":True #是否输出网页完整原文 + "block_websites": ["tieba.baidu.com"], # 需要屏蔽的站点列表 + "search_recency_filter": "week", # 根据网页发布时间进行筛选,可填值为:week,month,semiyear,year + "enable_full_content": True # 是否输出网页完整原文 }, ensure_ascii=False) headers = { 'Content-Type': 'application/json', @@ -42,10 +45,10 @@ def Search(query): } response = requests.request("POST", url, headers=headers, data=payload.encode("utf-8")).json() - content=[] + content = [] for i in response['references']: - title=i['title'] - snippet=i['snippet'] - content.append(title+';'+snippet) - content='。'.join(content) - return content \ No newline at end of file + title = i['title'] + snippet = i['snippet'] + content.append(title + ';' + snippet) + content = '。'.join(content) + return content diff --git a/api/app/services/llm_router.py b/api/app/services/llm_router.py index e56ad5aa..02895d6b 100644 --- a/api/app/services/llm_router.py +++ b/api/app/services/llm_router.py @@ -414,6 +414,7 @@ class LLMRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.3, max_tokens=500 ) diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index 3cf3ecc3..b0f43b51 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -392,6 +392,7 @@ class MasterAgentRouter: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, extra_params = extra_params ) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 1f3667a6..f272c541 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -13,7 +13,6 @@ TODO: Refactor get_end_user_connected_config """ import json import os -import re import time import uuid from typing import Any, AsyncGenerator, Dict, List, Optional @@ -35,12 +34,10 @@ from app.core.memory.agent.utils.messages_tools import ( reorder_output_results, ) from app.core.memory.agent.utils.type_classifier import status_typle -from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数 -from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags +from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType -from app.repositories.memory_short_repository import ShortTermMemoryRepository from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_agent_schema import Write_UserInput from app.schemas.memory_config_schema import ConfigurationError @@ -69,7 +66,8 @@ class MemoryAgentService: logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=True, + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, duration=duration, details={"message_length": len(message)}) return context else: @@ -88,8 +86,6 @@ class MemoryAgentService: raise ValueError(f"写入失败: {messages}") - - def extract_tool_call_info(self, event: Dict) -> bool: """Extract tool call information from event""" last_message = event["messages"][-1] @@ -271,7 +267,8 @@ class MemoryAgentService: logger.info("Log streaming completed, cleaning up resources") # LogStreamer uses context manager for file handling, so cleanup is automatic - async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID]|int, db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: + async def write_memory(self, end_user_id: str, messages: list[dict], config_id: Optional[uuid.UUID] | int, + db: Session, storage_type: str, user_rag_memory_id: str, language: str = "zh") -> str: """ Process write operation with config_id @@ -300,7 +297,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -331,7 +329,8 @@ class MemoryAgentService: # Log failed operation if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -351,9 +350,9 @@ class MemoryAgentService: langchain_messages.append(HumanMessage(content=msg['content'])) elif msg['role'] == 'assistant': langchain_messages.append(AIMessage(content=msg['content'])) - print(100*'-') + print(100 * '-') print(langchain_messages) - print(100*'-') + print(100 * '-') # 初始状态 - 包含所有必要字段 initial_state = { "messages": langchain_messages, @@ -375,29 +374,28 @@ class MemoryAgentService: contents = massages.get('write_result') # Convert messages back to string for logging message_text = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) - return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, contents) + return self.writer_messages_deal(massagesstatus, start_time, end_user_id, config_id, message_text, + contents) except Exception as e: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) if audit_logger: duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, success=False, duration=duration, error=error_msg) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) - - - async def read_memory( - self, - end_user_id: str, - message: str, - history: List[Dict], - search_switch: str, - config_id: Optional[uuid.UUID]|int, - db: Session, - storage_type: str, - user_rag_memory_id: str) -> Dict: + self, + end_user_id: str, + message: str, + history: List[Dict], + search_switch: str, + config_id: Optional[uuid.UUID] | int, + db: Session, + storage_type: str, + user_rag_memory_id: str) -> Dict: """ Process read operation with config_id @@ -425,7 +423,7 @@ class MemoryAgentService: import time start_time = time.time() - ori_message= message + ori_message = message # Resolve config_id and workspace_id # Always get workspace_id from end_user for fallback, even if config_id is provided @@ -437,7 +435,8 @@ class MemoryAgentService: config_id = connected_config.get("memory_config_id") logger.info(f"Resolved config from end_user: config_id={config_id}, workspace_id={workspace_id}") if config_id is None and workspace_id is None: - raise ValueError(f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") + raise ValueError( + f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.") except Exception as e: if "No memory configuration found" in str(e): raise # Re-raise our specific error @@ -454,7 +453,6 @@ class MemoryAgentService: except ImportError: audit_logger = None - config_load_start = time.time() try: # Use a separate database session to avoid transaction failures @@ -562,34 +560,35 @@ class MemoryAgentService: from app.repositories.memory_short_repository import ( ShortTermMemoryRepository, ) - + retrieved_content = [] repo = ShortTermMemoryRepository(db) - + if str(search_switch) != "2": for intermediate in _intermediate_outputs: logger.debug(f"处理中间结果: {intermediate}") intermediate_type = intermediate.get('type', '') - + if intermediate_type == "search_result": query = intermediate.get('query', '') raw_results = intermediate.get('raw_results', {}) try: reranked_results = raw_results.get('reranked_results', []) - statements = [statement['statement'] for statement in reranked_results.get('statements', [])] + statements = [statement['statement'] for statement in + reranked_results.get('statements', [])] except Exception: statements = [] - + # 去重 statements = list(set(statements)) - + if query and statements: retrieved_content.append({query: statements}) - + # 如果 retrieved_content 为空,设置为空字符串 if retrieved_content == []: retrieved_content = '' - + # 只有当回答不是"信息不足"且不是快速检索时才保存 if '信息不足,无法回答。' != str(summary) and str(search_switch).strip() != "2": # 使用 upsert 方法 @@ -602,15 +601,17 @@ class MemoryAgentService: ) logger.info(f"成功保存短期记忆: end_user_id={end_user_id}, search_switch={search_switch}") else: - logger.debug(f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") - + logger.debug( + f"跳过保存短期记忆: summary={summary[:50] if summary else 'None'}, search_switch={search_switch}") + except Exception as save_error: # 保存失败不应该影响主流程,只记录错误 logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) # Log successful operation total_time = time.time() - start_time - logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") + logger.info( + f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") if audit_logger: duration = time.time() - start_time audit_logger.log_operation( @@ -641,7 +642,6 @@ class MemoryAgentService: ) raise ValueError(error_msg) - def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: """ Get standardized message list from user input. @@ -657,41 +657,43 @@ class MemoryAgentService: """ from app.core.logging_config import get_api_logger logger = get_api_logger() - + if len(user_input.messages) == 0: logger.error("Validation failed: Message list cannot be empty") raise ValueError("Message list cannot be empty") - + for idx, msg in enumerate(user_input.messages): if not isinstance(msg, dict): logger.error(f"Validation failed: Message {idx} is not a dict: {type(msg)}") - raise ValueError(f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") - + raise ValueError( + f"Message format error: Message must be a dictionary. Error message index: {idx}, type: {type(msg)}") + if 'role' not in msg: logger.error(f"Validation failed: Message {idx} missing 'role' field: {msg}") raise ValueError(f"Message format error: Message must contain 'role' field. Error message index: {idx}") - + if 'content' not in msg: logger.error(f"Validation failed: Message {idx} missing 'content' field: {msg}") - raise ValueError(f"Message format error: Message must contain 'content' field. Error message index: {idx}") - + raise ValueError( + f"Message format error: Message must contain 'content' field. Error message index: {idx}") + if msg['role'] not in ['user', 'assistant']: logger.error(f"Validation failed: Message {idx} invalid role: {msg['role']}") raise ValueError(f"Role must be 'user' or 'assistant', got: {msg['role']}. Message index: {idx}") - + if not msg['content'] or not msg['content'].strip(): logger.error(f"Validation failed: Message {idx} content is empty") raise ValueError(f"Message content cannot be empty. Message index: {idx}, role: {msg['role']}") - + logger.info(f"Validation successful: Structured message list, count: {len(user_input.messages)}") return user_input.messages async def classify_message_type( - self, - message: str, - config_id: UUID, - db: Session, - workspace_id: Optional[UUID] = None + self, + message: str, + config_id: UUID, + db: Session, + workspace_id: Optional[UUID] = None ) -> Dict: """ Determine the type of user message (read or write) @@ -719,14 +721,15 @@ class MemoryAgentService: status = await status_typle(message, memory_config.llm_model_id) logger.debug(f"Message type: {status}") return status + async def generate_summary_from_retrieve( - self, - end_user_id: str, - retrieve_info: str, - history: List[Dict], - query: str, - config_id: str, - db: Session + self, + end_user_id: str, + retrieve_info: str, + history: List[Dict], + query: str, + config_id: str, + db: Session ) -> str: """ 基于检索信息、历史对话和查询生成最终答案 @@ -761,9 +764,9 @@ class MemoryAgentService: if config_id is None: raise ValueError(f"Unable to determine memory configuration for end_user {end_user_id}: {e}") # If config_id was provided, continue without workspace_id fallback - + logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") - + try: # 加载配置 config_service = MemoryConfigService(db) @@ -772,7 +775,7 @@ class MemoryAgentService: workspace_id=workspace_id, service_name="MemoryAgentService" ) - + # 导入必要的模块 from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( summary_llm, @@ -780,13 +783,13 @@ class MemoryAgentService: from app.core.memory.agent.models.summary_models import ( RetrieveSummaryResponse, ) - + # 构建状态对象 state = { "data": query, "memory_config": memory_config } - + # 直接调用 summary_llm 函数 answer = await summary_llm( state=state, @@ -797,21 +800,20 @@ class MemoryAgentService: response_model=RetrieveSummaryResponse, search_mode="1" ) - + logger.info(f"Successfully generated summary: {answer[:100] if answer else 'None'}...") return answer if answer else "信息不足,无法回答。" - + except Exception as e: logger.error(f"生成摘要失败: {str(e)}", exc_info=True) return "信息不足,无法回答。" - async def get_knowledge_type_stats( - self, - end_user_id: Optional[str] = None, - only_active: bool = True, - current_workspace_id: Optional[uuid.UUID] = None, - db: Session = None + self, + db: Session, + end_user_id: Optional[str] = None, + only_active: bool = True, + current_workspace_id: Optional[uuid.UUID] = None ) -> Dict[str, Any]: """ 统计知识库类型分布,包含: @@ -837,11 +839,6 @@ class MemoryAgentService: # 1. 统计 PostgreSQL 中的知识库类型 try: - if db is None: - from app.db import get_db - db_gen = get_db() - db = next(db_gen) - # 初始化所有标准类型为 0 for kb_type in KnowledgeType: result[kb_type.value] = 0 @@ -881,53 +878,50 @@ class MemoryAgentService: # 3. 计算知识库类型总和(不包括 memory) result["total"] = ( - result.get("General", 0) + - result.get("Web", 0) + - result.get("Third-party", 0) + - result.get("Folder", 0) + result.get("General", 0) + + result.get("Web", 0) + + result.get("Third-party", 0) + + result.get("Folder", 0) ) return result - - async def get_hot_memory_tags_by_user( - self, - end_user_id: Optional[str] = None, - limit: int = 20 + async def get_interest_distribution_by_user( + self, + end_user_id: Optional[str] = None, + limit: int = 5, + language: str = "zh" ) -> List[Dict[str, Any]]: """ - 获取指定用户的热门记忆标签 + 获取指定用户的兴趣分布标签。 + + 与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等), + 过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。 参数: - - end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段 + - end_user_id: 用户ID(必填) - limit: 返回标签数量限制 + - language: 输出语言("zh" 中文, "en" 英文) 返回格式: [ - {"name": "标签名", "frequency": 频次}, + {"name": "兴趣活动名", "frequency": 频次}, ... ] - - 注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译 """ try: - # by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度) - tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False) - payload = [] - for tag, freq in tags: - payload.append({"name": tag, "frequency": freq}) - return payload + tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language) + return [{"name": tag, "frequency": freq} for tag, freq in tags] except Exception as e: - logger.error(f"热门记忆标签查询失败: {e}") - raise Exception(f"热门记忆标签查询失败: {e}") - + logger.error(f"兴趣分布标签查询失败: {e}") + raise Exception(f"兴趣分布标签查询失败: {e}") async def get_user_profile( - self, - end_user_id: Optional[str] = None, - current_user_id: Optional[str] = None, - llm_id: Optional[str] = None, - db: Session = None + self, + end_user_id: Optional[str] = None, + current_user_id: Optional[str] = None, + llm_id: Optional[str] = None, + db: Session = None ) -> Dict[str, Any]: """ 获取用户详情,包含: @@ -1017,7 +1011,8 @@ class MemoryAgentService: # 定义标签提取的结构 class UserTags(BaseModel): - tags: list[str] = Field(..., description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") + tags: list[str] = Field(..., + description="3个描述用户特征的标签,如:产品设计师、旅行爱好者、摄影发烧友") messages = [ { @@ -1160,7 +1155,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ValueError: 当终端用户不存在或应用未发布时 """ import json as json_module - import uuid from sqlalchemy import select @@ -1192,14 +1186,14 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An # 3. 兼容旧数据:如果 memory_config_id 为空,从 AppRelease.config 获取并回填 memory_config_id_to_use = end_user.memory_config_id - + # 如果已有 memory_config_id,直接使用 # 如果新创建enduser,enduser.memory_config_id 必定为none # 那么使用从release中获取memory_config_id为预期行为,并且回填到 # end_user.memory_config_id if not memory_config_id_to_use: logger.info(f"end_user.memory_config_id is None, migrating from AppRelease.config") - + # 获取最新发布版本 stmt = ( select(AppRelease) @@ -1208,10 +1202,10 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An ) # TODO: change to current_release_id latest_release = db.scalars(stmt).first() - + if latest_release: config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1219,22 +1213,22 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") config = {} - + # 使用 MemoryConfigService 的提取方法 memory_config_service = MemoryConfigService(db) legacy_config_id, is_legacy_int = memory_config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 验证提取的 config_id 是否存在于数据库中 from app.models.memory_config_model import MemoryConfig as MemoryConfigModel existing_config = db.get(MemoryConfigModel, legacy_config_id) - + if existing_config: memory_config_id_to_use = legacy_config_id - + # 回填到 end_user 表(lazy update) end_user.memory_config_id = memory_config_id_to_use db.commit() @@ -1268,7 +1262,8 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An "workspace_id": str(app.workspace_id) } - logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") + logger.info( + f"Successfully retrieved connected config: memory_config_id={memory_config_id}, workspace_id={app.workspace_id}") return result @@ -1312,7 +1307,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 1. 批量查询所有 end_user 及其 app_id 和 memory_config_id end_users = db.query(EndUser).filter(EndUser.id.in_(end_user_ids)).all() - + # 创建映射 - 保留 EndUser 对象引用以便回填 end_user_map = {str(eu.id): eu for eu in end_users} user_data = {str(eu.id): {"app_id": eu.app_id, "memory_config_id": eu.memory_config_id} for eu in end_users} @@ -1336,15 +1331,15 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 3. 对于没有 memory_config_id 的用户,尝试从 AppRelease.config 提取 users_needing_migration = [ - (end_user_id, data["app_id"]) - for end_user_id, data in user_data.items() + (end_user_id, data["app_id"]) + for end_user_id, data in user_data.items() if not data["memory_config_id"] ] - + if users_needing_migration: # 批量获取相关应用的最新发布版本 migration_app_ids = list(set(app_id for _, app_id in users_needing_migration)) - + # 查询每个应用的最新活跃发布版本 app_latest_releases = {} for app_id in migration_app_ids: @@ -1357,18 +1352,18 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) latest_release = db.scalars(stmt).first() if latest_release: app_latest_releases[app_id] = latest_release - + # 为每个需要迁移的用户提取 memory_config_id config_service = MemoryConfigService(db) users_to_backfill = [] # [(end_user, memory_config_id), ...] - + for end_user_id, app_id in users_needing_migration: latest_release = app_latest_releases.get(app_id) if not latest_release: continue - + config = latest_release.config or {} - + # 如果 config 是字符串,解析为字典 if isinstance(config, str): try: @@ -1376,21 +1371,21 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) except json_module.JSONDecodeError: logger.warning(f"Failed to parse config JSON for release {latest_release.id}") continue - + # 使用 MemoryConfigService 的提取方法 app = app_map.get(app_id) if not app: continue - + legacy_config_id, is_legacy_int = config_service.extract_memory_config_id( app_type=app.type, config=config ) - + if legacy_config_id: # 更新 user_data 中的 memory_config_id user_data[end_user_id]["memory_config_id"] = legacy_config_id - + # 记录需要回填的用户(稍后验证配置存在后再回填) end_user = end_user_map.get(end_user_id) if end_user: @@ -1399,7 +1394,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) logger.info( f"Legacy int config detected for end_user {end_user_id}, will use workspace default" ) - + # 验证提取的 config_id 是否存在于数据库中 if users_to_backfill: config_ids_to_validate = list(set(cid for _, cid in users_to_backfill)) @@ -1407,17 +1402,17 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) MemoryConfig.config_id.in_(config_ids_to_validate) ).all() valid_config_ids = {mc.config_id for mc in existing_configs} - + # 只回填存在的配置 valid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid in valid_config_ids ] invalid_backfills = [ - (eu, cid) for eu, cid in users_to_backfill + (eu, cid) for eu, cid in users_to_backfill if cid not in valid_config_ids ] - + if invalid_backfills: invalid_ids = [str(cid) for _, cid in invalid_backfills] logger.warning( @@ -1426,7 +1421,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 清除 user_data 中无效的 config_id for eu, cid in invalid_backfills: user_data[str(eu.id)]["memory_config_id"] = None - + # 批量回填 end_user.memory_config_id if valid_backfills: for end_user, memory_config_id in valid_backfills: @@ -1437,7 +1432,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 4. 收集需要查询的 memory_config_id 和需要回退的 workspace_id direct_config_ids = [] workspace_fallback_users = [] # [(end_user_id, workspace_id), ...] - + for end_user_id, data in user_data.items(): if data["memory_config_id"]: direct_config_ids.append(data["memory_config_id"]) @@ -1455,7 +1450,7 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 6. 获取工作空间默认配置(需要逐个查询,因为 get_workspace_default_config 有复杂逻辑) workspace_default_configs = {} unique_workspace_ids = list(set(ws_id for _, ws_id in workspace_fallback_users)) - + if unique_workspace_ids: config_service = MemoryConfigService(db) for workspace_id in unique_workspace_ids: @@ -1466,11 +1461,11 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) # 7. 构建最终结果 for end_user_id, data in user_data.items(): memory_config = None - + # 优先使用 end_user 直接分配的配置 if data["memory_config_id"]: memory_config = config_id_to_config.get(data["memory_config_id"]) - + # 回退到工作空间默认配置 if not memory_config: workspace_id = app_to_workspace.get(data["app_id"]) @@ -1486,4 +1481,4 @@ def get_end_users_connected_configs_batch(end_user_ids: List[str], db: Session) result[end_user_id] = {"memory_config_id": None, "memory_config_name": None} logger.info(f"Successfully retrieved {len(result)} connected configs") - return result \ No newline at end of file + return result diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index a8c39a5a..f86fbed8 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -140,9 +140,11 @@ class MemoryAPIService: try: # Delegate to MemoryAgentService + # Convert string message to list[dict] format expected by MemoryAgentService + messages = message if isinstance(message, list) else [{"role": "user", "content": message}] result = await MemoryAgentService().write_memory( end_user_id=end_user_id, - messages=message, + messages=messages, config_id=config_id, db=self.db, storage_type=storage_type, @@ -151,9 +153,18 @@ class MemoryAPIService: logger.info(f"Memory write successful for end_user: {end_user_id}") + # result may be a string "success" or a dict with a "status" key + # Preserve the full dict so callers don't silently lose extra fields + # (e.g. error codes, metadata) returned by MemoryAgentService. + if isinstance(result, dict): + return { + **result, + "status": result.get("status", "unknown"), + "end_user_id": end_user_id, + } return { - "status": "success" if result == "success" else result, - "end_user_id": end_user_id + "status": result if isinstance(result, str) else "success", + "end_user_id": end_user_id, } except ConfigurationError as e: diff --git a/api/app/services/memory_config_service.py b/api/app/services/memory_config_service.py index ccfd5482..00757f8c 100644 --- a/api/app/services/memory_config_service.py +++ b/api/app/services/memory_config_service.py @@ -107,6 +107,40 @@ def _validate_config_id(config_id, db: Session = None): ) +# 专门场景的内置 key 集合,直接从 SceneConfigRegistry 派生,避免重复维护 +# 使用懒加载函数避免模块级循环导入 +def _get_builtin_pruning_scenes() -> set: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene_config import SceneConfigRegistry + return set(SceneConfigRegistry.get_all_scenes()) + + +def _load_ontology_classes(db: Session, scene_id, pruning_scene: Optional[str]) -> Optional[list]: + """当 pruning_scene 不是内置场景时,从 ontology_class 表加载类型名称列表。 + + Args: + db: 数据库会话 + scene_id: 本体场景 UUID + pruning_scene: 语义剪枝场景名称 + + Returns: + class_name 字符串列表,或 None(内置场景 / 无数据时) + """ + if not scene_id: + return None + # 内置场景走 SceneConfigRegistry,不需要注入类型列表 + if pruning_scene in _get_builtin_pruning_scenes(): + return None + try: + from app.repositories.ontology_class_repository import OntologyClassRepository + repo = OntologyClassRepository(db) + classes = repo.get_classes_by_scene(scene_id) + names = [c.class_name for c in classes if c.class_name] + return names if names else None + except Exception as e: + logger.warning(f"Failed to load ontology classes for scene_id={scene_id}: {e}") + return None + + class MemoryConfigService: """ Centralized service for memory configuration loading and validation. @@ -359,6 +393,7 @@ class MemoryConfigService: pruning_threshold=float(memory_config.pruning_threshold) if memory_config.pruning_threshold is not None else 0.5, # Ontology scene association scene_id=memory_config.scene_id, + ontology_classes=_load_ontology_classes(self.db, memory_config.scene_id, memory_config.pruning_scene), ) elapsed_ms = (time.time() - start_time) * 1000 diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 8d6071cc..05aed57e 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -390,19 +390,59 @@ def get_rag_total_kb( current_user: User ) -> int: """ - 根据当前用户所在的workspace_id查询konwledges表所有不同id的数量 + 根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量 """ workspace_id = current_user.current_workspace_id - business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}") + business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}") try: - total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id) + total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id) business_logger.info(f"成功获取RAG总知识库数: {total_kb}") return total_kb except Exception as e: business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}") raise + +def get_rag_user_kb_total_chunk( + db: Session, + current_user: User +) -> int: + """ + 根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。 + 与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。 + """ + workspace_id = current_user.current_workspace_id + business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}") + + try: + from app.models.document_model import Document + from app.models.end_user_model import EndUser + from app.models.app_model import App + from sqlalchemy import func + + # 通过 App 关联取该 workspace 下所有 end_user_id + end_user_ids = [ + str(eid) for (eid,) in db.query(EndUser.id) + .join(App, EndUser.app_id == App.id) + .filter(App.workspace_id == workspace_id) + .all() + ] + if not end_user_ids: + return 0 + + file_names = [f"{uid}.txt" for uid in end_user_ids] + result = db.query(func.sum(Document.chunk_num)).filter( + Document.file_name.in_(file_names) + ).scalar() + + total_chunk = int(result or 0) + business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}") + return total_chunk + except Exception as e: + business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}") + raise + def get_current_user_total_chunk( end_user_id: str, db: Session, diff --git a/api/app/services/memory_konwledges_server.py b/api/app/services/memory_konwledges_server.py index 420f7ca1..b8961d33 100644 --- a/api/app/services/memory_konwledges_server.py +++ b/api/app/services/memory_konwledges_server.py @@ -1,45 +1,42 @@ # 修改 memory_konwledges_server.py 文件 -import asyncio import os -import re import uuid from pathlib import Path from typing import Optional -from pydantic import BaseModel, Field +from fastapi import HTTPException, status +from pydantic import BaseModel +from sqlalchemy.orm import Session +from app.celery_app import celery_app +from app.core.config import settings +from app.core.logging_config import get_api_logger from app.core.rag.models.chunk import DocumentChunk from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.response_utils import success -from app.db import get_db -from app.schemas import file_schema, document_schema -from fastapi import APIRouter, Depends, HTTPException, status, File, UploadFile, Query +from app.db import get_db_context from app.models.document_model import Document -import uuid -from sqlalchemy.orm import Session -from fastapi import HTTPException, status - -from app.core.config import settings from app.models.user_model import User +from app.schemas import file_schema, document_schema from app.schemas.file_schema import CustomTextFileCreate from app.services import document_service, file_service, knowledge_service -from app.celery_app import celery_app -from app.core.logging_config import get_api_logger -from app.schemas.file_schema import CustomTextFileCreate -from app.db import get_db + # 创建一个简单的用户类用于测试 api_logger = get_api_logger() + class ChunkCreate(BaseModel): content: str + + class SimpleUser: def __init__(self, user_id: str): # 确保ID是UUID类型 self.id = user_id self.username = user_id -'''解析''' + async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user: User): """ 解析指定文档 @@ -120,7 +117,7 @@ async def parse_document_by_id(document_id: uuid.UUID, db: Session, current_user api_logger.error(f"文档解析失败: document_id={document_id} - {str(e)}") raise -'''获取块ID''' + async def get_document_chunks( kb_id: uuid.UUID, document_id: uuid.UUID, @@ -198,7 +195,7 @@ async def get_document_chunks( return success(data=result, msg="文档块列表查询成功") -'''查找文档ID''' + def find_document_id_by_kb_and_filename( db: Session, kb_id: str, @@ -231,7 +228,7 @@ def find_document_id_by_kb_and_filename( except Exception as e: return None -'''获取知识库ID''' + def find_documents_by_kb_id( db: Session, kb_id: str, @@ -268,18 +265,14 @@ def find_documents_by_kb_id( except Exception as e: return [] -''''上传文件''' + async def memory_konwledges_up( kb_id: str, parent_id: str, create_data: file_schema.CustomTextFileCreate, - db: Session = Depends(get_db), - current_user: SimpleUser = None, # 修改为SimpleUser + db: Session, + current_user: SimpleUser, ): - # 如果没有提供current_user,则创建一个默认的 - if current_user is None: - current_user = SimpleUser("5d27df0b-7eec-4fa6-9f8b-0f9b7e852f60") - content_bytes = create_data.content.encode('utf-8') file_size = len(content_bytes) print(f"file size: {file_size} byte") @@ -350,8 +343,6 @@ async def memory_konwledges_up( return success(data=document_schema.Document.model_validate(db_document), msg="custom text upload successful") -'''添加新块''' - async def create_document_chunk( kb_id: uuid.UUID, @@ -417,7 +408,7 @@ async def create_document_chunk( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"查询文档块失败: {error_msg}" ) - + sort_id = sort_id + 1 # 5. 创建文档块 @@ -450,6 +441,7 @@ async def create_document_chunk( return success(data=chunk, msg="文档块创建成功") + async def write_rag(end_user_id, message, user_rag_memory_id): """ 将消息写入 RAG 知识库 @@ -483,15 +475,12 @@ async def write_rag(end_user_id, message, user_rag_memory_id): detail=f"知识库ID格式无效: {user_rag_memory_id}" ) - db_gen = get_db() - db = next(db_gen) - - try: + with get_db_context() as db: create_data = CustomTextFileCreate(title=end_user_id, content=message) current_user = SimpleUser(user_rag_memory_id) # 检查文档是否已存在 document = find_document_id_by_kb_and_filename(db=db, kb_id=user_rag_memory_id, file_name=f"{end_user_id}.txt") - print('======',document) + print('======', document) api_logger.info(f"查找文档结果: document_id={document}") if document is not None: # 文档已存在,直接添加新块 @@ -528,6 +517,3 @@ async def write_rag(end_user_id, message, user_rag_memory_id): else: api_logger.error(f"创建文档后无法找到文档ID: end_user_id={end_user_id}") return result - finally: - # 确保数据库会话被关闭 - db.close() \ No newline at end of file diff --git a/api/app/services/memory_short_service.py b/api/app/services/memory_short_service.py index fa3870f0..fa9623e0 100644 --- a/api/app/services/memory_short_service.py +++ b/api/app/services/memory_short_service.py @@ -1,22 +1,37 @@ +from typing import Dict, List + +from sqlalchemy.orm import Session from app.core.logging_config import get_api_logger -from app.db import get_db -from app.repositories.memory_short_repository import LongTermMemoryRepository -from app.repositories.memory_short_repository import ShortTermMemoryRepository - +from app.repositories.memory_short_repository import ( + LongTermMemoryRepository, + ShortTermMemoryRepository, +) api_logger = get_api_logger() -db=next(get_db()) + + class ShortService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for short-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.short_repo = ShortTermMemoryRepository(db) self.end_user_id = end_user_id - def get_short_databasets(self): + def get_short_databasets(self) -> List[Dict]: + """Retrieve the latest short-term memory entries for the user. + + Returns: + List[Dict]: List of memory dicts with retrieval, message, and answer keys. + """ short_memories = self.short_repo.get_latest_by_user_id(self.end_user_id, 3) short_result = [] for memory in short_memories: - deep_expanded = {} # Create a new dictionary for each memory + deep_expanded = {} messages = memory.messages aimessages = memory.aimessages retrieved_content = memory.retrieved_content or [] @@ -27,23 +42,41 @@ class ShortService: for item in retrieved_content: if isinstance(item, dict): for key, values in item.items(): - retrieval_source.append({"query": key, "retrieval": values,"source":"上下文记忆"}) + retrieval_source.append({"query": key, "retrieval": values, "source": "上下文记忆"}) deep_expanded['retrieval'] = retrieval_source - deep_expanded['message'] = messages # 修正拼写错误 + deep_expanded['message'] = messages deep_expanded['answer'] = aimessages short_result.append(deep_expanded) return short_result - def get_short_count(self): + + def get_short_count(self) -> int: + """Count total short-term memory entries for the user. + + Returns: + int: Number of short-term memory records. + """ short_count = self.short_repo.count_by_user_id(self.end_user_id) return short_count + class LongService: - def __init__(self, end_user_id): + def __init__(self, end_user_id: str, db: Session) -> None: + """Service for long-term memory queries. + + Args: + end_user_id: The end user identifier to query memories for. + db: SQLAlchemy database session (caller-managed lifecycle). + """ self.long_repo = LongTermMemoryRepository(db) self.end_user_id = end_user_id - def get_long_databasets(self): - # 获取长期记忆数据 + + def get_long_databasets(self) -> List[Dict]: + """Retrieve long-term memory retrieval data for the user. + + Returns: + List[Dict]: List of dicts with query and retrieval keys. + """ long_memories = self.long_repo.get_by_user_id(self.end_user_id, 1) long_result = [] diff --git a/api/app/services/memory_storage_service.py b/api/app/services/memory_storage_service.py index 1083f750..6e7c1ad4 100644 --- a/api/app/services/memory_storage_service.py +++ b/api/app/services/memory_storage_service.py @@ -115,6 +115,17 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) # --- Create --- def create(self, params: ConfigParamsCreate) -> Dict[str, Any]: # 创建配置参数(仅名称与描述) + # 业务层检查同一工作空间下是否已存在同名配置 + if params.workspace_id and params.config_name: + from app.models.memory_config_model import MemoryConfig + existing = ( + self.db.query(MemoryConfig) + .filter_by(workspace_id=params.workspace_id, config_name=params.config_name) + .first() + ) + if existing: + raise ValueError(f"DUPLICATE_CONFIG_NAME:{params.config_name}") + # 如果workspace_id存在且模型字段未全部指定,则自动获取 if params.workspace_id and not all([params.llm_id, params.embedding_id, params.rerank_id]): configs = self._get_workspace_configs(params.workspace_id) @@ -135,6 +146,10 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) if not params.emotion_model_id: params.emotion_model_id = params.llm_id + # 根据关联的本体场景推导 pruning_scene(语义剪枝场景与本体工程场景保持一致) + if params.scene_id and not getattr(params, 'pruning_scene', None): + params.pruning_scene = self._resolve_pruning_scene_from_scene_id(params.scene_id) + config = MemoryConfigRepository.create(self.db, params) self.db.commit() return {"affected": 1, "config_id": config.config_id} @@ -150,6 +165,23 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) finally: db_session.close() + def _resolve_pruning_scene_from_scene_id(self, scene_id) -> Optional[str]: + """根据本体场景ID获取对应的 scene_name,作为语义剪枝场景值 + + Args: + scene_id: 本体场景UUID + + Returns: + scene_name 字符串,查询失败时返回 None + """ + try: + from app.models.ontology_scene import OntologyScene + scene = self.db.query(OntologyScene).filter_by(scene_id=scene_id).first() + return scene.scene_name if scene else None + except Exception as e: + logger.warning(f"_resolve_pruning_scene_from_scene_id failed for scene_id={scene_id}: {e}", exc_info=True) + return None + # --- Delete --- def delete(self, key: ConfigParamsDelete) -> Dict[str, Any]: # 删除配置参数(按配置ID) success = MemoryConfigRepository.delete(self.db, key.config_id) @@ -185,6 +217,19 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) def get_all(self, workspace_id = None) -> List[Dict[str, Any]]: # 获取所有配置参数 results = MemoryConfigRepository.get_all(self.db, workspace_id) + # 检查并修正 pruning_scene 与 scene_name 不一致的记录 + needs_commit = False + for config, scene_name in results: + if scene_name and config.pruning_scene != scene_name: + logger.info( + f"修正 pruning_scene: config_id={config.config_id} " + f"'{config.pruning_scene}' -> '{scene_name}'" + ) + config.pruning_scene = scene_name + needs_commit = True + if needs_commit: + self.db.commit() + # 将 ORM 对象转换为字典列表 data_list = [] for config, scene_name in results: @@ -211,6 +256,7 @@ class DataConfigService: # 数据配置服务类(PostgreSQL) "apply_id": config.apply_id, "scene_id": str(config.scene_id) if config.scene_id else None, "scene_name": scene_name, # 新增:场景名称 + "is_system_default": config.is_default, # 是否为系统默认配置 "llm_id": config.llm_id, "embedding_id": config.embedding_id, "rerank_id": config.rerank_id, @@ -737,8 +783,37 @@ async def analytics_hot_memory_tags( await connector.close() -async def analytics_recent_activity_stats() -> Dict[str, Any]: - stats, _msg = get_recent_activity_stats() +async def analytics_recent_activity_stats(workspace_id: Optional[str] = None) -> Dict[str, Any]: + """获取最近记忆提取活动统计。 + + 优先从 Redis 缓存读取(按 workspace_id),缓存不存在时降级到日志文件解析。 + + Args: + workspace_id: 工作空间ID,用于从 Redis 读取对应缓存 + + Returns: + 包含 total、stats、latest_relative、source 的统计字典 + """ + stats = None + source = "log" + + # 优先从 Redis 读取 + if workspace_id: + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + cached = await ActivityStatsCache.get_activity_stats(workspace_id) + if cached: + stats = cached.get("stats", {}) + source = "redis" + logger.info(f"[ANALYTICS] 从 Redis 读取活动统计: workspace_id={workspace_id}") + except Exception as e: + logger.warning(f"[ANALYTICS] 读取 Redis 活动统计失败,降级到日志: {e}") + + # 降级:从日志文件解析 + if stats is None: + stats, _msg = get_recent_activity_stats() + source = "log" + total = ( stats.get("chunk_count", 0) + stats.get("statements_count", 0) @@ -746,26 +821,29 @@ async def analytics_recent_activity_stats() -> Dict[str, Any]: + stats.get("triplet_relations_count", 0) + stats.get("temporal_count", 0) ) - # 精简:仅提供“最新一次活动多久前” - latest_relative = None - try: - info = stats.get("log_path", "") - idx = info.rfind("最新:") - if idx != -1: - latest_path = info[idx + 3 :].strip() - if latest_path and os.path.exists(latest_path): - import time - diff = max(0.0, time.time() - os.path.getmtime(latest_path)) - m = int(diff // 60) - if m < 1: - latest_relative = "刚刚" - elif m < 60: - latest_relative = "一会前" - else: - latest_relative = "较早前" - except Exception: - pass - data = {"total": total, "stats": stats, "latest_relative": latest_relative} + # 计算"最新一次活动多久前"(仅日志来源时有效) + latest_relative = None + if source == "log": + try: + info = stats.get("log_path", "") + idx = info.rfind("最新:") + if idx != -1: + latest_path = info[idx + 3:].strip() + if latest_path and os.path.exists(latest_path): + import time + diff = max(0.0, time.time() - os.path.getmtime(latest_path)) + m = int(diff // 60) + if m < 1: + latest_relative = "刚刚" + elif m < 60: + latest_relative = "一会前" + else: + latest_relative = "较早前" + except Exception: + pass + + data = {"total": total, "stats": stats, "latest_relative": latest_relative, "source": source} return data + diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index aa8cfbac..a7398504 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -90,7 +90,8 @@ class ModelConfigService: api_key: str, api_base: Optional[str] = None, model_type: str = "llm", - test_message: str = "Hello" + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -102,6 +103,7 @@ class ModelConfigService: api_base: API基础URL model_type: 模型类型 (llm/chat/embedding/rerank) test_message: 测试消息 + is_omni: 是否为Omni模型 Returns: Dict: 验证结果 @@ -119,6 +121,7 @@ class ModelConfigService: provider=provider, api_key=api_key, base_url=api_base, + is_omni=is_omni, temperature=0.7, max_tokens=100 ) @@ -257,8 +260,9 @@ class ModelConfigService: provider=model_data.provider, api_key=api_key_data.api_key, api_base=api_key_data.api_base, - model_type=model_data.type, # 传递模型类型 - test_message="Hello" + model_type=model_data.type, + test_message="Hello", + is_omni=model_data.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -279,6 +283,9 @@ class ModelConfigService: for api_key_data in api_key_datas: api_key_data.model_name = model_data.name api_key_data.provider = model_data.provider + # 同步capability和is_omni + api_key_data.capability = model_data.capability + api_key_data.is_omni = model_data.is_omni api_key_create_schema = ModelApiKeyCreate( model_config_ids=[model.id], **api_key_data.model_dump() @@ -473,6 +480,9 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: continue + + data.is_omni = model_config.is_omni + data.capability = model_config.capability # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name @@ -497,6 +507,8 @@ class ModelApiKeyService: existing_key.config = data.config existing_key.priority = data.priority existing_key.model_name = model_name + existing_key.capability = data.capability + existing_key.is_omni = data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -513,7 +525,8 @@ class ModelApiKeyService: api_key=data.api_key, api_base=data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=data.is_omni ) if not validation_result["valid"]: # 记录验证失败的模型,但不抛出异常 @@ -528,6 +541,8 @@ class ModelApiKeyService: provider=data.provider, api_key=data.api_key, api_base=data.api_base, + capability=data.capability, + is_omni=data.is_omni, config=data.config, is_active=data.is_active, priority=data.priority @@ -550,6 +565,10 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) + if api_key_data.is_omni is None: + api_key_data.is_omni = model_config.is_omni + if api_key_data.capability is None: + api_key_data.capability = model_config.capability # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( @@ -572,6 +591,8 @@ class ModelApiKeyService: existing_key.config = api_key_data.config existing_key.priority = api_key_data.priority existing_key.model_name = api_key_data.model_name + existing_key.capability = api_key_data.capability + existing_key.is_omni = api_key_data.is_omni # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: @@ -589,7 +610,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key, api_base=api_key_data.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=api_key_data.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -620,7 +642,8 @@ class ModelApiKeyService: api_key=api_key_data.api_key or existing_api_key.api_key, api_base=api_key_data.api_base or existing_api_key.api_base, model_type=model_config.type, - test_message="Hello" + test_message="Hello", + is_omni=model_config.is_omni ) if not validation_result["valid"]: raise BusinessException( @@ -755,6 +778,9 @@ class ModelBaseService: "type": model_base.type, "logo": model_base.logo, "description": model_base.description, + "capability": model_base.capability, + "is_omni": model_base.is_omni, + "is_active": False, "is_composite": False } model_config = ModelConfigRepository.create(db, model_config_data) diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index d1aa46d1..f42ee95a 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -123,11 +123,14 @@ class MultiAgentOrchestrator: user_id: 用户 ID variables: 变量参数 use_llm_routing: 是否使用 LLM 路由 + web_search: 是否启用网络搜索 + memory: 是否启用记忆功能 + storage_type: 存储类型 + user_rag_memory_id: 用户 RAG 记忆 ID Yields: SSE 格式的事件流 """ - import json start_time = time.time() @@ -200,7 +203,8 @@ class MultiAgentOrchestrator: except Exception as e: logger.error( "多 Agent 任务执行失败(流式)", - extra={"error": str(e), "mode": self._normalized_mode} + extra={"error": str(e), "mode": self._normalized_mode}, + exc_info=True ) # 发送错误事件 yield self._format_sse_event("error", { @@ -1267,7 +1271,7 @@ class MultiAgentOrchestrator: Yields: SSE 格式的事件流 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1278,7 +1282,7 @@ class MultiAgentOrchestrator: ) # 流式执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) async for event in draft_service.run_stream( agent_config=agent_config, model_config=model_config, @@ -1320,7 +1324,7 @@ class MultiAgentOrchestrator: Returns: 执行结果 """ - from app.services.draft_run_service import DraftRunService + from app.services.draft_run_service import AgentRunService # 获取模型配置 model_config = self.db.get(ModelConfig, agent_config.default_model_config_id) @@ -1331,7 +1335,7 @@ class MultiAgentOrchestrator: ) # 执行 Agent - draft_service = DraftRunService(self.db) + draft_service = AgentRunService(self.db) result = await draft_service.run( agent_config=agent_config, model_config=model_config, @@ -1633,6 +1637,7 @@ class MultiAgentOrchestrator: self.memory = config_data.get("memory") self.variables = config_data.get("variables", []) self.tools = config_data.get("tools", {}) + self.skills = config_data.get("skills", {}) self.default_model_config_id = release.default_model_config_id return AgentConfigProxy(release, app, config_data) @@ -2593,6 +2598,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, # 整合任务使用中等温度 max_tokens=2000 ) @@ -2758,6 +2764,7 @@ class MultiAgentOrchestrator: provider=api_key_config.provider, api_key=api_key_config.api_key, base_url=api_key_config.api_base, + is_omni=api_key_config.is_omni, temperature=0.7, max_tokens=2000, extra_params={"streaming": True} # 启用流式输出 diff --git a/api/app/services/multi_agent_service.py b/api/app/services/multi_agent_service.py index c52814ed..751099d5 100644 --- a/api/app/services/multi_agent_service.py +++ b/api/app/services/multi_agent_service.py @@ -267,7 +267,7 @@ class MultiAgentService: # 2. 验证模型配置(如果提供了) if data.default_model_config_id: - model_api_key = ModelApiKeyService.get_a_api_key(self.db, data.default_model_config_id) + model_api_key = ModelApiKeyService.get_available_api_key(self.db, data.default_model_config_id) if not model_api_key: raise ResourceNotFoundException("模型配置", str(data.default_model_config_id)) diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index bfb23a56..9b06c287 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -9,47 +9,100 @@ - OpenAI: 支持 URL 和 base64 格式 """ import uuid -from typing import List, Dict, Any, Optional, Protocol +import httpx +import base64 +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod from sqlalchemy.orm import Session +from docx import Document +import io +import PyPDF2 from app.core.logging_config import get_business_logger from app.core.exceptions import BusinessException from app.core.error_codes import BizCode from app.schemas.app_schema import FileInput, FileType, TransferMethod -from app.models.generic_file_model import GenericFile +from app.models.file_metadata_model import FileMetadata +from app.core.config import settings +from app.services.audio_transcription_service import AudioTranscriptionService logger = get_business_logger() -class ImageFormatStrategy(Protocol): - """图片格式策略接口""" +class MultimodalFormatStrategy(ABC): + """多模态格式策略基类""" + + @abstractmethod + async def format_image(self, url: str) -> Dict[str, Any]: + """格式化图片""" + pass + + @abstractmethod + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """格式化文档""" + pass + + @abstractmethod + async def format_audio(self, file_type: str, url: str) -> Dict[str, Any]: + """格式化音频""" + pass + + @abstractmethod + async def format_video(self, url: str) -> Dict[str, Any]: + """格式化视频""" + pass + + +class DashScopeFormatStrategy(MultimodalFormatStrategy): + """通义千问策略""" async def format_image(self, url: str) -> Dict[str, Any]: - """将图片 URL 转换为特定 provider 的格式""" - ... - - -class DashScopeImageStrategy: - """通义千问图片格式策略""" - - async def format_image(self, url: str) -> Dict[str, Any]: - """通义千问格式: {"type": "image", "image": "url"}""" + """通义千问图片格式:{"type": "image", "image": "url"}""" return { "type": "image", "image": url } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """通义千问文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } -class BedrockImageStrategy: - """Bedrock/Anthropic 图片格式策略""" + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + 通义千问音频格式 + - 原生支持: qwen-audio 系列 + - 其他模型: 需要转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + # 通义千问音频格式:{"type": "audio", "audio": "url"} + return { + "type": "audio", + "audio": url + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """通义千问视频格式(qwen-vl 系列原生支持)""" + return { + "type": "video", + "video": url + } + + +class BedrockFormatStrategy(MultimodalFormatStrategy): + """Bedrock/Anthropic 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """ Bedrock/Anthropic 格式: base64 编码 {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} """ - import httpx - import base64 from mimetypes import guess_type logger.info(f"下载并编码图片: {url}") @@ -84,9 +137,46 @@ class BedrockImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """Bedrock/Anthropic 文档格式(需要 base64 编码)""" + # Bedrock 文档需要 base64 编码 + text_bytes = text.encode('utf-8') + base64_text = base64.b64encode(text_bytes).decode('utf-8') -class OpenAIImageStrategy: - """OpenAI 图片格式策略""" + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "text/plain", + "data": base64_text + } + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + Bedrock/Anthropic 音频格式 + 不支持原生音频,必须转录为文本 + """ + if transcription: + return { + "type": "text", + "text": f"[音频转录]\n{transcription}" + } + return { + "type": "text", + "text": "[音频文件:Bedrock 不支持原生音频,请启用音频转文本功能]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """Bedrock/Anthropic 视频格式""" + return { + "type": "text", + "text": f"" + } + + +class OpenAIFormatStrategy(MultimodalFormatStrategy): + """OpenAI 策略""" async def format_image(self, url: str) -> Dict[str, Any]: """OpenAI 格式: {"type": "image_url", "image_url": {"url": "..."}}""" @@ -97,29 +187,97 @@ class OpenAIImageStrategy: } } + async def format_document(self, file_name: str, text: str) -> Dict[str, Any]: + """OpenAI 文档格式""" + return { + "type": "text", + "text": f"\n{text}\n" + } + + async def format_audio(self, file_type: str, url: str, transcription: Optional[str] = None) -> Dict[str, Any]: + """ + OpenAI 音频格式 + - gpt-4o-audio 系列支持原生音频(需要 base64 编码) + - 其他模型使用转录文本 + """ + if transcription: + return { + "type": "text", + "text": f"" + } + + # OpenAI 音频需要 base64 编码 + try: + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(url) + response.raise_for_status() + audio_data = response.content + base64_audio = base64.b64encode(audio_data).decode('utf-8') + # 1. 优先从 file_type (MIME) 取扩展名 + file_ext = file_type.split('/')[-1] if file_type and '/' in file_type else None + # 2. 从响应头 content-type 取 + if not file_ext: + ct = response.headers.get("content-type", "") + file_ext = ct.split('/')[-1].split(';')[0].strip() if '/' in ct else None + # 3. 从 URL 路径取扩展名 + if not file_ext: + file_ext = url.split('?')[0].rsplit('.', 1)[-1].lower() or None + # 4. 默认 wav + # supported_ext = {"wav", "mp3", "mp4", "ogg", "flac", "webm", "m4a", "wave", "x-m4a"} + file_ext = "wav" if not file_ext else file_ext + + return { + "type": "input_audio", + "input_audio": { + "data": f"data:;base64,{base64_audio}", + "format": file_ext + } + } + except Exception as e: + logger.error(f"下载音频失败: {e}") + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def format_video(self, url: str) -> Dict[str, Any]: + """OpenAI 视频格式""" + return { + "type": "video_url", + "video_url": { + "url": url + } + } + # Provider 到策略的映射 PROVIDER_STRATEGIES = { - "dashscope": DashScopeImageStrategy, - "bedrock": BedrockImageStrategy, - "anthropic": BedrockImageStrategy, - "openai": OpenAIImageStrategy, + "dashscope": DashScopeFormatStrategy, + "bedrock": BedrockFormatStrategy, + "anthropic": BedrockFormatStrategy, + "openai": OpenAIFormatStrategy, } class MultimodalService: """多模态文件处理服务""" - def __init__(self, db: Session, provider: str = "dashscope"): + def __init__(self, db: Session, provider: str = "dashscope", api_key: Optional[str] = None, enable_audio_transcription: bool = False, is_omni: bool = False): """ 初始化多模态服务 Args: db: 数据库会话 - provider: 模型提供商(dashscope, bedrock, anthropic 等) + provider: 模型提供商(dashscope, bedrock, anthropic, openai 等) + api_key: API 密钥(用于音频转文本) + enable_audio_transcription: 是否启用音频转文本 + is_omni: 是否为 Omni 模型(dashscope 的 omni 模型需要使用 OpenAI 兼容格式) """ self.db = db self.provider = provider.lower() + self.api_key = api_key + self.enable_audio_transcription = enable_audio_transcription + self.is_omni = is_omni async def process_files( self, @@ -137,20 +295,32 @@ class MultimodalService: if not files: return [] + # 获取对应的策略 + # dashscope 的 omni 模型使用 OpenAI 兼容格式 + if self.provider == "dashscope" and self.is_omni: + strategy_class = OpenAIFormatStrategy + else: + strategy_class = PROVIDER_STRATEGIES.get(self.provider) + if not strategy_class: + logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略") + strategy_class = DashScopeFormatStrategy + + strategy = strategy_class() + result = [] for idx, file in enumerate(files): try: if file.type == FileType.IMAGE: - content = await self._process_image(file) + content = await self._process_image(file, strategy) result.append(content) elif file.type == FileType.DOCUMENT: - content = await self._process_document(file) + content = await self._process_document(file, strategy) result.append(content) elif file.type == FileType.AUDIO: - content = await self._process_audio(file) + content = await self._process_audio(file, strategy) result.append(content) elif file.type == FileType.VIDEO: - content = await self._process_video(file) + content = await self._process_video(file, strategy) result.append(content) else: logger.warning(f"不支持的文件类型: {file.type}") @@ -172,55 +342,29 @@ class MultimodalService: logger.info(f"成功处理 {len(result)}/{len(files)} 个文件,provider={self.provider}") return result - async def _process_image(self, file: FileInput) -> Dict[str, Any]: + async def _process_image(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理图片文件 Args: file: 图片文件输入 + strategy: 格式化策略 Returns: - Dict: 根据 provider 返回不同格式 - - Anthropic/Bedrock: {"type": "image", "source": {"type": "base64", "media_type": "...", "data": "..."}} - - 通义千问: {"type": "image", "image": "url"} + Dict: 根据 provider 返回不同格式的图片内容 """ - url = await self.get_file_url(file) - - logger.debug(f"处理图片: {url}, provider={self.provider}") - - # 根据 provider 返回不同格式 - if self.provider in ["bedrock", "anthropic"]: - # Anthropic/Bedrock 只支持 base64 格式,需要下载并转换 - try: - logger.info(f"开始下载并编码图片: {url}") - base64_data, media_type = await self._download_and_encode_image(url) - result = { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": base64_data[:100] + "..." # 只记录前100个字符 - } - } - logger.info(f"图片编码完成: media_type={media_type}, data_length={len(base64_data)}") - # 返回完整数据 - result["source"]["data"] = base64_data - return result - except Exception as e: - logger.error(f"下载并编码图片失败: {e}", exc_info=True) - # 返回错误提示 - return { - "type": "text", - "text": f"[图片加载失败: {str(e)}]" - } - else: - # 通义千问等其他格式支持 URL + try: + url = await self.get_file_url(file) + return await strategy.format_image(url) + except Exception as e: + logger.error(f"处理图片失败: {e}", exc_info=True) return { - "type": "image", - "image": url + "type": "text", + "text": f"[图片处理失败: {str(e)}]" } - async def _download_and_encode_image(self, url: str) -> tuple[str, str]: + @staticmethod + async def _download_and_encode_image(url: str) -> tuple[str, str]: """ 下载图片并转换为 base64 @@ -230,8 +374,6 @@ class MultimodalService: Returns: tuple: (base64_data, media_type) """ - import httpx - import base64 from mimetypes import guess_type # 下载图片 @@ -258,15 +400,16 @@ class MultimodalService: return base64_data, media_type - async def _process_document(self, file: FileInput) -> Dict[str, Any]: + async def _process_document(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理文档文件(PDF、Word 等) Args: file: 文档文件输入 + strategy: 格式化策略 Returns: - Dict: text 格式的内容(包含提取的文本) + Dict: 根据 provider 返回不同格式的文档内容 """ if file.transfer_method == TransferMethod.REMOTE_URL: # 远程文档暂不支持提取 @@ -277,48 +420,68 @@ class MultimodalService: else: # 本地文件,提取文本内容 text = await self._extract_document_text(file.upload_file_id) - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file.upload_file_id + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file.upload_file_id ).first() - file_name = generic_file.file_name if generic_file else "unknown" + file_name = file_metadata.file_name if file_metadata else "unknown" - return { - "type": "text", - "text": f"\n{text}\n" - } + # 使用策略格式化文档 + return await strategy.format_document(file_name, text) - async def _process_audio(self, file: FileInput) -> Dict[str, Any]: + async def _process_audio(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理音频文件 Args: file: 音频文件输入 + strategy: 格式化策略 Returns: - Dict: 音频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的音频内容 """ - # TODO: 实现音频转文字功能 - return { - "type": "text", - "text": "[音频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) - async def _process_video(self, file: FileInput) -> Dict[str, Any]: + # 如果启用音频转文本且有 API Key + transcription = None + if self.enable_audio_transcription and self.api_key: + logger.info(f"开始音频转文本: {url}") + if self.provider == "dashscope": + transcription = await AudioTranscriptionService.transcribe_dashscope(url, self.api_key) + elif self.provider == "openai": + transcription = await AudioTranscriptionService.transcribe_openai(url, self.api_key) + else: + logger.warning(f"Provider {self.provider} 不支持音频转文本") + + return await strategy.format_audio(file.file_type, url, transcription) + except Exception as e: + logger.error(f"处理音频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[音频处理失败: {str(e)}]" + } + + async def _process_video(self, file: FileInput, strategy) -> Dict[str, Any]: """ 处理视频文件 Args: file: 视频文件输入 + strategy: 格式化策略 Returns: - Dict: 视频内容(暂时返回占位符) + Dict: 根据 provider 返回不同格式的视频内容 """ - # TODO: 实现视频处理功能 - return { - "type": "text", - "text": "[视频文件,暂不支持处理]" - } + try: + url = await self.get_file_url(file) + return await strategy.format_video(url) + except Exception as e: + logger.error(f"处理视频失败: {e}", exc_info=True) + return { + "type": "text", + "text": f"[视频处理失败: {str(e)}]" + } async def get_file_url(self, file: FileInput) -> str: """ @@ -336,26 +499,22 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return file.url else: - # 本地文件,通过 file_storage 系统获取永久访问 URL - from app.models.file_metadata_model import FileMetadata - from app.core.config import settings - file_id = file.upload_file_id print("="*50) print("file_id",file_id) - + # 查询 FileMetadata file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file_id, FileMetadata.status == "completed" ).first() - + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - + # 返回永久URL server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" @@ -370,58 +529,79 @@ class MultimodalService: Returns: str: 提取的文本内容 """ - generic_file = self.db.query(GenericFile).filter( - GenericFile.id == file_id, - GenericFile.status == "active" + file_metadata = self.db.query(FileMetadata).filter( + FileMetadata.id == file_id, + FileMetadata.status == "completed" ).first() - if not generic_file: + if not file_metadata: raise BusinessException( f"文件不存在或已删除: {file_id}", BizCode.NOT_FOUND ) - # TODO: 根据文件类型提取文本 - # - PDF: 使用 PyPDF2 或 pdfplumber - # - Word: 使用 python-docx - # - TXT/MD: 直接读取 - - file_ext = generic_file.file_ext.lower() + file_ext = file_metadata.file_ext.lower() + server_url = settings.FILE_LOCAL_SERVER_URL + file_url = f"{server_url}/storage/permanent/{file_id}" if file_ext in ['.txt', '.md', '.markdown']: - return await self._read_text_file(generic_file.storage_path) + return await self._read_text_file(file_url) elif file_ext == '.pdf': - return await self._extract_pdf_text(generic_file.storage_path) + return await self._extract_pdf_text(file_url) elif file_ext in ['.doc', '.docx']: - return await self._extract_word_text(generic_file.storage_path) + return await self._extract_word_text(file_url) else: return f"[不支持的文档格式: {file_ext}]" - async def _read_text_file(self, storage_path: str) -> str: + @staticmethod + async def _read_text_file(file_url: str) -> str: """读取纯文本文件""" try: - with open(storage_path, 'r', encoding='utf-8') as f: - return f.read() + # 下载文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + return response.text except Exception as e: logger.error(f"读取文本文件失败: {e}") return f"[文件读取失败: {str(e)}]" - async def _extract_pdf_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_pdf_text(file_url: str) -> str: """提取 PDF 文本""" try: - # TODO: 实现 PDF 文本提取 - # import PyPDF2 或 pdfplumber - return "[PDF 文本提取功能待实现]" + # 下载 PDF 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + pdf_data = response.content + + # 使用 BytesIO 读取 PDF + text_parts = [] + pdf_file = io.BytesIO(pdf_data) + pdf_reader = PyPDF2.PdfReader(pdf_file) + for page in pdf_reader.pages: + text_parts.append(page.extract_text()) + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 PDF 文本失败: {e}") return f"[PDF 提取失败: {str(e)}]" - async def _extract_word_text(self, storage_path: str) -> str: + @staticmethod + async def _extract_word_text(file_url: str) -> str: """提取 Word 文档文本""" try: - # TODO: 实现 Word 文本提取 - # import docx - return "[Word 文本提取功能待实现]" + # 下载 Word 文件 + async with httpx.AsyncClient(timeout=30.0) as client: + response = await client.get(file_url) + response.raise_for_status() + word_data = response.content + + # 使用 BytesIO 读取 Word 文档 + word_file = io.BytesIO(word_data) + doc = Document(word_file) + text_parts = [paragraph.text for paragraph in doc.paragraphs] + return '\n'.join(text_parts) except Exception as e: logger.error(f"提取 Word 文本失败: {e}") return f"[Word 提取失败: {str(e)}]" diff --git a/api/app/services/pilot_run_service.py b/api/app/services/pilot_run_service.py index 34b8867e..5d00d8a5 100644 --- a/api/app/services/pilot_run_service.py +++ b/api/app/services/pilot_run_service.py @@ -101,34 +101,141 @@ async def run_pilot_extraction( ) if progress_callback: - await progress_callback("text_preprocessing", "开始预处理文本...") + await progress_callback("text_preprocessing", "开始预处理文本(语义剪枝 + 语义分块)...") + # ========== 步骤 2.1: 语义剪枝 ========== + pruned_dialogs = [dialog] + deleted_messages = [] # 记录被删除的消息 + pruning_stats = None # 保存剪枝统计信息,用于最终汇总 + + if memory_config.pruning_enabled: + try: + from app.core.memory.storage_services.extraction_engine.data_preprocessing.data_pruning import ( + SemanticPruner, + ) + from app.core.memory.models.config_models import PruningConfig + + # 构建剪枝配置 + pruning_config_dict = { + "pruning_switch": memory_config.pruning_enabled, + "pruning_scene": memory_config.pruning_scene, + "pruning_threshold": memory_config.pruning_threshold, + "llm_model_id": str(memory_config.llm_model_id), + } + config = PruningConfig(**pruning_config_dict) + + logger.info(f"[PILOT_RUN] 开始语义剪枝: scene={config.pruning_scene}, threshold={config.pruning_threshold}") + + # 记录剪枝前的消息(用于对比) + original_messages = [{"role": msg.role, "content": msg.msg} for msg in dialog.context.msgs] + original_msg_count = len(original_messages) + + # 执行剪枝 + pruner = SemanticPruner(config=config, llm_client=llm_client) + pruned_dialogs = await pruner.prune_dataset([dialog]) + + # 计算剪枝结果并找出被删除的消息 + if pruned_dialogs and pruned_dialogs[0].context: + remaining_messages = [{"role": msg.role, "content": msg.msg} for msg in pruned_dialogs[0].context.msgs] + remaining_msg_count = len(remaining_messages) + deleted_msg_count = original_msg_count - remaining_msg_count + + # 找出被删除的消息(基于索引精确匹配) + # 为剩余消息创建带索引的列表,用于精确追踪 + remaining_with_index = [] + remaining_idx = 0 + for orig_idx, orig_msg in enumerate(original_messages): + if remaining_idx < len(remaining_messages) and \ + orig_msg["role"] == remaining_messages[remaining_idx]["role"] and \ + orig_msg["content"] == remaining_messages[remaining_idx]["content"]: + remaining_with_index.append(orig_idx) + remaining_idx += 1 + + # 找出未在保留列表中的消息索引 + deleted_messages = [ + {"index": idx, "role": msg["role"], "content": msg["content"]} + for idx, msg in enumerate(original_messages) + if idx not in remaining_with_index + ] + + # 保存剪枝统计信息(用于最终汇总,只保留deleted_count) + pruning_stats = { + "enabled": True, + "scene": config.pruning_scene, + "threshold": config.pruning_threshold, + "deleted_count": deleted_msg_count, + } + + # 输出剪枝结果(显示删除的消息详情) + pruning_result = { + "type": "pruning", + "deleted_messages": deleted_messages, + } + + logger.info( + f"[PILOT_RUN] 语义剪枝完成: 原始{original_msg_count}条 -> " + f"保留{remaining_msg_count}条 (删除{deleted_msg_count}条)" + ) + + if progress_callback: + await progress_callback("text_preprocessing_result", "语义剪枝完成", pruning_result) + else: + logger.warning("[PILOT_RUN] 剪枝后对话为空,使用原始对话") + pruned_dialogs = [dialog] + + except Exception as e: + logger.error(f"[PILOT_RUN] 语义剪枝失败,使用原始对话: {e}", exc_info=True) + pruned_dialogs = [dialog] + if progress_callback: + error_result = { + "type": "pruning", + "error": str(e), + "fallback": "使用原始对话" + } + await progress_callback("text_preprocessing_result", "语义剪枝失败", error_result) + else: + logger.info("[PILOT_RUN] 语义剪枝已关闭,跳过") + pruning_stats = { + "enabled": False, + } + + # ========== 步骤 2.2: 语义分块 ========== chunked_dialogs = await get_chunked_dialogs_from_preprocessed( - data=[dialog], + data=pruned_dialogs, chunker_strategy=memory_config.chunker_strategy, llm_client=llm_client, ) - logger.info(f"Processed dialogue text: {len(messages)} messages") + + remaining_msg_count = len(pruned_dialogs[0].context.msgs) if pruned_dialogs and pruned_dialogs[0].context else 0 + logger.info(f"Processed dialogue text: {remaining_msg_count} messages after pruning") # 进度回调:输出每个分块的结果 if progress_callback: for dlg in chunked_dialogs: - for i, chunk in enumerate(dlg.chunks): - chunk_result = { - "chunk_index": i + 1, - "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, - "full_length": len(chunk.content), - "dialog_id": dlg.id, - "chunker_strategy": memory_config.chunker_strategy, - } - await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + if hasattr(dlg, 'chunks') and dlg.chunks: + for i, chunk in enumerate(dlg.chunks): + chunk_result = { + "type": "chunking", + "chunk_index": i + 1, + "content": chunk.content[:200] + "..." if len(chunk.content) > 200 else chunk.content, + "full_length": len(chunk.content), + "dialog_id": dlg.id, + "chunker_strategy": memory_config.chunker_strategy, + } + await progress_callback("text_preprocessing_result", f"分块 {i + 1} 处理完成", chunk_result) + # 构建预处理完成总结(包含剪枝统计) preprocessing_summary = { - "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs), + "total_chunks": sum(len(dlg.chunks) for dlg in chunked_dialogs if hasattr(dlg, 'chunks') and dlg.chunks), "total_dialogs": len(chunked_dialogs), "chunker_strategy": memory_config.chunker_strategy, } - await progress_callback("text_preprocessing_complete", "预处理文本完成", preprocessing_summary) + + # 添加剪枝统计信息 + if pruning_stats: + preprocessing_summary["pruning"] = pruning_stats + + await progress_callback("text_preprocessing_complete", "预处理文本完成(剪枝 + 分块)", preprocessing_summary) log_time("Data Loading & Chunking", time.time() - step_start, log_file) @@ -219,6 +326,25 @@ async def run_pilot_extraction( logger.info("Pilot run completed: Skipping Neo4j save") + # 将提取统计写入 Redis,按 workspace_id 存储 + try: + from app.cache.memory.activity_stats_cache import ActivityStatsCache + + stats_to_cache = { + "chunk_count": len(chunk_nodes) if chunk_nodes else 0, + "statements_count": len(statement_nodes) if statement_nodes else 0, + "triplet_entities_count": len(entity_nodes) if entity_nodes else 0, + "triplet_relations_count": len(entity_edges) if entity_edges else 0, + "temporal_count": 0, # temporal 数据在日志中,此处暂置0 + } + await ActivityStatsCache.set_activity_stats( + workspace_id=str(memory_config.workspace_id), + stats=stats_to_cache, + ) + logger.info(f"[PILOT_RUN] 活动统计已写入 Redis: workspace_id={memory_config.workspace_id}") + except Exception as cache_err: + logger.warning(f"[PILOT_RUN] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True) + except Exception as e: logger.error(f"Pilot run failed: {e}", exc_info=True) raise diff --git a/api/app/services/prompt_optimizer_service.py b/api/app/services/prompt_optimizer_service.py index 99edcc0e..184220a8 100644 --- a/api/app/services/prompt_optimizer_service.py +++ b/api/app/services/prompt_optimizer_service.py @@ -184,7 +184,8 @@ class PromptOptimizerService: model_name=api_config.model_name, provider=api_config.provider, api_key=api_config.api_key, - base_url=api_config.api_base + base_url=api_config.api_base, + is_omni=api_config.is_omni ), type=ModelType(model_config.type)) try: prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt') diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index c7b81999..0d659832 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -21,63 +21,64 @@ from app.repositories import knowledge_repository import json from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task + logger = get_business_logger() class SharedChatService: """基于分享链接的聊天服务""" - + def __init__(self, db: Session): self.db = db self.conversation_service = ConversationService(db) self.share_service = ReleaseShareService(db) - - def _get_release_by_share_token( - self, - share_token: str, - password: Optional[str] = None + + def get_release_by_share_token( + self, + share_token: str, + password: Optional[str] = None ) -> tuple[ReleaseShare, AppRelease]: """通过 share_token 获取发布版本""" # 获取分享配置 share = self.share_service.repo.get_by_share_token(share_token) if not share: raise ResourceNotFoundException("分享链接", share_token) - + # 验证分享是否启用 if not share.is_enabled: raise BusinessException("该分享链接已被禁用", BizCode.SHARE_DISABLED) - + # 验证密码 if share.require_password: if not password: raise BusinessException("需要提供访问密码", BizCode.PASSWORD_REQUIRED) - + if not self.share_service.verify_password(share_token, password): raise BusinessException("访问密码错误", BizCode.INVALID_PASSWORD) - + # 获取发布版本 release = self.db.get(AppRelease, share.release_id) if not release: raise ResourceNotFoundException("发布版本", str(share.release_id)) - + # 更新访问统计 try: self.share_service.repo.increment_view_count(share.id) except Exception as e: logger.warning(f"更新访问统计失败: {str(e)}") - + return share, release - + def create_or_get_conversation( - self, - share_token: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - password: Optional[str] = None + self, + share_token: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + password: Optional[str] = None ) -> Conversation: """创建或获取会话""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 如果提供了 conversation_id,尝试获取现有会话 if conversation_id: try: @@ -85,18 +86,18 @@ class SharedChatService: conversation_id=conversation_id, workspace_id=release.app.workspace_id ) - + # 验证会话是否属于该应用 if conversation.app_id != release.app_id: raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - + return conversation except ResourceNotFoundException: logger.warning( "会话不存在,将创建新会话", extra={"conversation_id": str(conversation_id)} ) - + # 创建新会话(使用发布版本的配置) conversation = self.conversation_service.create_conversation( app_id=release.app_id, @@ -105,7 +106,7 @@ class SharedChatService: is_draft=False, # 分享链接使用发布版本 config_snapshot=release.config ) - + logger.info( "为分享链接创建新会话", extra={ @@ -114,25 +115,25 @@ class SharedChatService: "release_id": str(release.id) } ) - + return conversation - + async def chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, ) -> Dict[str, Any]: """聊天(非流式)""" actual_config_id = None - config_id=actual_config_id + config_id = actual_config_id from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.services.model_parameter_merger import ModelParameterMerger @@ -140,32 +141,30 @@ class SharedChatService: from sqlalchemy import select from app.models import ModelApiKey - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取 Agent 配置 config = release.config or {} - # 获取模型配置ID model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - + # 获取模型配置 from app.models import ModelConfig model_config = self.db.get(ModelConfig, model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_config_id)) - + # 获取 API Key # stmt = ( # select(ModelApiKey).join( @@ -184,7 +183,7 @@ class SharedChatService: api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -192,7 +191,7 @@ class SharedChatService: user_id=user_id, password=password ) - + # 处理系统提示词(支持变量替换) system_prompt = config.get("system_prompt", "你是一个专业的AI助手") if variables: @@ -202,31 +201,31 @@ class SharedChatService: variables ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - + # 准备工具列表 tools = [] - + # 添加知识库检索工具 knowledge_retrieval = config.get("knowledge_retrieval") if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) tools.append(kb_tool) # 添加长期记忆工具 - memory_flag=False + memory_flag = False if memory: memory_config = config.get("memory", {}) if memory_config.get("enabled") and user_id: - memory_flag=True + memory_flag = True memory_tool = create_long_term_memory_tool(memory_config, user_id) tools.append(memory_tool) - web_tools=config.get("tools") + web_tools = config.get("tools") web_search_choice = web_tools.get("web_search", {}) - web_search_enable = web_search_choice.get("enabled",False) + web_search_enable = web_search_choice.get("enabled", False) if web_search: if web_search_enable: search_tool = create_web_search_tool({}) @@ -238,26 +237,27 @@ class SharedChatService: "tool_count": len(tools) } ) - + # 获取模型参数 model_parameters = config.get("model_parameters", {}) - + # 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_obj.model_name, api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, ) - + # 加载历史消息 history = [] - memory_config={"enabled":True,'max_history':10} + memory_config = {"enabled": True, 'max_history': 10} if memory_config.get("enabled"): messages = self.conversation_service.get_messages( conversation_id=conversation.id, @@ -267,7 +267,7 @@ class SharedChatService: {"role": msg.role, "content": msg.content} for msg in messages ] - + # 调用 Agent result = await agent.chat( message=message, @@ -279,7 +279,7 @@ class SharedChatService: config_id=config_id, memory_flag=memory_flag ) - + # 保存消息 self.conversation_service.save_conversation_messages( conversation_id=conversation.id, @@ -298,7 +298,7 @@ class SharedChatService: # role="user", # content=message # ) - + # self.conversation_service.add_message( # conversation_id=conversation.id, # role="assistant", @@ -308,12 +308,11 @@ class SharedChatService: # "usage": result.get("usage", {}) # } # ) - + elapsed_time = time.time() - start_time ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id) - return { "conversation_id": conversation.id, "message": result["content"], @@ -324,19 +323,19 @@ class SharedChatService: }), "elapsed_time": elapsed_time } - + async def chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, - storage_type:Optional[str] = None, - user_rag_memory_id: Optional[str] = None, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, + storage_type: Optional[str] = None, + user_rag_memory_id: Optional[str] = None, ) -> AsyncGenerator[str, None]: """聊天(流式)""" from app.core.agent.langchain_agent import LangChainAgent @@ -345,36 +344,35 @@ class SharedChatService: from sqlalchemy import select from app.models import ModelApiKey import json - - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + start_time = time.time() + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} # 兼容新旧字段名:使用 memory_config_id memory_config = {"enabled": memory, "memory_config_id": "17", "max_history": 10} - + try: # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取 Agent 配置 config = release.config or {} agent_config_data = config.get("agent_config", {}) - + # 获取模型配置ID model_config_id = release.default_model_config_id if not model_config_id: raise BusinessException("发布版本未配置模型", BizCode.AGENT_CONFIG_MISSING) - + # 获取模型配置 from app.models import ModelConfig model_config = self.db.get(ModelConfig, model_config_id) if not model_config: raise ResourceNotFoundException("模型配置", str(model_config_id)) - + # 获取 API Key # stmt = ( # select(ModelApiKey).join( @@ -393,7 +391,7 @@ class SharedChatService: api_key_obj = ModelApiKeyService.get_available_api_key(self.db, model_config_id) if not api_key_obj: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -401,7 +399,7 @@ class SharedChatService: user_id=user_id, password=password ) - + # 处理系统提示词(支持变量替换) system_prompt = config.get("system_prompt", "你是一个专业的AI助手") if variables: @@ -411,21 +409,21 @@ class SharedChatService: variables ) system_prompt = system_prompt_rendered.get_text_content() or system_prompt - + # 准备工具列表 tools = [] - + # 添加知识库检索工具 knowledge_retrieval = config.get("knowledge_retrieval") if knowledge_retrieval: knowledge_bases = knowledge_retrieval.get("knowledge_bases", []) kb_ids = [kb.get("kb_id") for kb in knowledge_bases if kb.get("kb_id")] if kb_ids: - kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids,user_id) + kb_tool = create_knowledge_retrieval_tool(knowledge_retrieval, kb_ids, user_id) tools.append(kb_tool) - + # 添加长期记忆工具 - memory_flag=False + memory_flag = False if memory: memory_config = config.get("memory", {}) if memory_config.get("enabled") and user_id: @@ -450,20 +448,21 @@ class SharedChatService: # 获取模型参数 model_parameters = config.get("model_parameters", {}) - + # 创建 LangChain Agent agent = LangChainAgent( model_name=api_key_obj.model_name, api_key=api_key_obj.api_key, provider=api_key_obj.provider, api_base=api_key_obj.api_base, + is_omni=api_key_obj.is_omni, temperature=model_parameters.get("temperature", 0.7), max_tokens=model_parameters.get("max_tokens", 2000), system_prompt=system_prompt, tools=tools, streaming=True ) - + # 加载历史消息 history = [] memory_config = {"enabled": True, 'max_history': 10} @@ -476,22 +475,22 @@ class SharedChatService: {"role": msg.role, "content": msg.content} for msg in messages ] - + # 发送开始事件 yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - + # 流式调用 Agent full_content = "" total_tokens = 0 async for chunk in agent.chat_stream( - message=message, - history=history, - context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag + message=message, + history=history, + context=None, + end_user_id=user_id, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id, + config_id=config_id, + memory_flag=memory_flag ): if isinstance(chunk, int): total_tokens = chunk @@ -499,16 +498,16 @@ class SharedChatService: full_content += chunk # 发送消息块事件 yield f"event: message\ndata: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -524,7 +523,7 @@ class SharedChatService: # 发送结束事件 end_data = {"elapsed_time": elapsed_time, "message_length": len(full_content)} yield f"event: end\ndata: {json.dumps(end_data, ensure_ascii=False)}\n\n" - + logger.info( "流式聊天完成", extra={ @@ -533,7 +532,7 @@ class SharedChatService: "message_length": len(full_content) } ) - + except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("流式聊天被中断") @@ -542,39 +541,39 @@ class SharedChatService: logger.error(f"流式聊天失败: {str(e)}", exc_info=True) # 发送错误事件 yield f"event: error\ndata: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" - + def get_conversation_messages( - self, - share_token: str, - conversation_id: uuid.UUID, - password: Optional[str] = None + self, + share_token: str, + conversation_id: uuid.UUID, + password: Optional[str] = None ) -> Conversation: """获取会话消息""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取会话 conversation = self.conversation_service.get_conversation( conversation_id=conversation_id, workspace_id=release.app.workspace_id ) - + # 验证会话是否属于该应用 if conversation.app_id != release.app_id: raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION) - + return conversation - + def list_conversations( - self, - share_token: str, - user_id: Optional[str] = None, - password: Optional[str] = None, - page: int = 1, - pagesize: int = 20 + self, + share_token: str, + user_id: Optional[str] = None, + password: Optional[str] = None, + page: int = 1, + pagesize: int = 20 ) -> tuple[list[Conversation], int]: """列出会话""" - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + conversations, total = self.conversation_service.list_conversations( app_id=release.app_id, workspace_id=release.app.workspace_id, @@ -583,19 +582,19 @@ class SharedChatService: page=page, pagesize=pagesize ) - + return conversations, total - + async def multi_agent_chat( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None ) -> Dict[str, Any]: @@ -603,18 +602,16 @@ class SharedChatService: from app.services.multi_agent_service import MultiAgentService from app.models import MultiAgentConfig - - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -622,19 +619,19 @@ class SharedChatService: user_id=user_id, password=password ) - + # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.is_active.is_(True) ).first() - + if not multi_agent_config: raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 构建多 Agent 运行请求 from app.schemas.multi_agent_schema import MultiAgentRunRequest - + multi_agent_request = MultiAgentRunRequest( message=message, conversation_id=conversation.id, @@ -644,23 +641,23 @@ class SharedChatService: web_search=web_search, memory=memory ) - + # 使用多 Agent 服务执行 multi_agent_service = MultiAgentService(self.db) result = await multi_agent_service.run( app_id=release.app_id, request=multi_agent_request ) - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -672,8 +669,6 @@ class SharedChatService: } ) - - return { "conversation_id": conversation.id, "message": result.get("message", ""), @@ -684,34 +679,33 @@ class SharedChatService: }, "elapsed_time": elapsed_time } - + async def multi_agent_chat_stream( - self, - share_token: str, - message: str, - conversation_id: Optional[uuid.UUID] = None, - user_id: Optional[str] = None, - variables: Optional[Dict[str, Any]] = None, - password: Optional[str] = None, - web_search: bool = False, - memory: bool = True, + self, + share_token: str, + message: str, + conversation_id: Optional[uuid.UUID] = None, + user_id: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + password: Optional[str] = None, + web_search: bool = False, + memory: bool = True, storage_type: Optional[str] = None, - user_rag_memory_id:Optional[str] = None + user_rag_memory_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """多 Agent 聊天(流式)""" - start_time = time.time() - actual_config_id=None - config_id=actual_config_id - + actual_config_id = None + config_id = actual_config_id + if variables is None: variables = {} - + try: # 获取发布版本和配置 - share, release = self._get_release_by_share_token(share_token, password) - + share, release = self.get_release_by_share_token(share_token, password) + # 获取或创建会话 conversation = self.create_or_get_conversation( share_token=share_token, @@ -719,28 +713,28 @@ class SharedChatService: user_id=user_id, password=password ) - + # 获取多 Agent 配置 multi_agent_config = self.db.query(MultiAgentConfig).filter( MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.is_active.is_(True) ).first() - + if not multi_agent_config: raise BusinessException("多 Agent 配置不存在", BizCode.AGENT_CONFIG_MISSING) - + # 获取 storage_type 和 user_rag_memory_id workspace_id = release.app.workspace_id storage_type = 'neo4j' # 默认值 user_rag_memory_id = '' - + try: # 获取工作空间的存储类型(不需要用户权限检查,因为是公开分享) from app.models import Workspace workspace = self.db.get(Workspace, workspace_id) if workspace and workspace.storage_type: storage_type = workspace.storage_type - + # 获取 USER_RAG_MERORY 知识库 ID knowledge = knowledge_repository.get_knowledge_by_name( db=self.db, @@ -751,13 +745,13 @@ class SharedChatService: user_rag_memory_id = str(knowledge.id) except Exception as e: logger.warning(f"获取 storage_type 或 user_rag_memory_id 失败,使用默认值: {str(e)}") - + # 发送开始事件 yield f"event: start\ndata: {json.dumps({'conversation_id': str(conversation.id)}, ensure_ascii=False)}\n\n" - + # 构建多 Agent 运行请求 from app.schemas.multi_agent_schema import MultiAgentRunRequest - + multi_agent_request = MultiAgentRunRequest( message=message, conversation_id=conversation.id, @@ -767,20 +761,20 @@ class SharedChatService: web_search=web_search, memory=memory ) - + # 使用多 Agent 服务流式执行 multi_agent_service = MultiAgentService(self.db) full_content = "" - + async for event in multi_agent_service.run_stream( - app_id=release.app_id, - request=multi_agent_request, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id + app_id=release.app_id, + request=multi_agent_request, + storage_type=storage_type, + user_rag_memory_id=user_rag_memory_id ): # 直接转发事件 yield event - + # 尝试提取内容(用于保存) if "data:" in event: try: @@ -790,16 +784,16 @@ class SharedChatService: full_content += data["content"] except: pass - + elapsed_time = time.time() - start_time - + # 保存消息 self.conversation_service.add_message( conversation_id=conversation.id, role="user", content=message ) - + self.conversation_service.add_message( conversation_id=conversation.id, role="assistant", @@ -808,7 +802,7 @@ class SharedChatService: "elapsed_time": elapsed_time } ) - + logger.info( "多 Agent 流式聊天完成", extra={ @@ -818,7 +812,6 @@ class SharedChatService: } ) - except (GeneratorExit, asyncio.CancelledError): # 生成器被关闭或任务被取消,正常退出 logger.debug("多 Agent 流式聊天被中断") diff --git a/api/app/services/skill_service.py b/api/app/services/skill_service.py index 5eb80795..0b7de6cf 100644 --- a/api/app/services/skill_service.py +++ b/api/app/services/skill_service.py @@ -121,7 +121,7 @@ class SkillService: if skill and skill.is_active: # 加载技能关联的工具 for tool_config in skill.tools: - tool = tool_service._get_tool_instance(tool_config.get("tool_id", ""), tenant_id) + tool = tool_service.get_tool_instance(tool_config.get("tool_id", ""), tenant_id) if tool: langchain_tool = tool.to_langchain_tool(tool_config.get("operation", None)) tools.append(langchain_tool) diff --git a/api/app/services/tool_service.py b/api/app/services/tool_service.py index 2bb96e53..f6e2ccce 100644 --- a/api/app/services/tool_service.py +++ b/api/app/services/tool_service.py @@ -8,6 +8,8 @@ from datetime import datetime from sqlalchemy.orm import Session +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException from app.core.tools.mcp import MCPToolManager, SimpleMCPClient from app.repositories.tool_repository import ( ToolRepository, BuiltinToolRepository, CustomToolRepository, @@ -79,6 +81,18 @@ class ToolService: config = self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) return self._config_to_info(config) if config else None + def _check_name_duplicate(self, name: str, tool_type: ToolType, tenant_id: uuid.UUID, exclude_id: Optional[uuid.UUID] = None): + """检查工具名称是否重复""" + query = self.db.query(ToolConfig).filter( + ToolConfig.name == name, + ToolConfig.tool_type == tool_type.value, + ToolConfig.tenant_id == tenant_id + ) + if exclude_id: + query = query.filter(ToolConfig.id != exclude_id) + if query.first(): + raise BusinessException(f"工具名称 '{name}' 已存在", BizCode.DUPLICATE_NAME) + def create_tool( self, name: str, @@ -92,6 +106,7 @@ class ToolService: """创建工具""" if tool_type == ToolType.BUILTIN: raise ValueError("内置工具不允许创建") + self._check_name_duplicate(name, tool_type, tenant_id) try: # 创建基础配置 @@ -141,6 +156,7 @@ class ToolService: raise ValueError("内置工具不允许修改名称、描述和图标") try: if name: + self._check_name_duplicate(name, config_obj.tool_type, tenant_id, exclude_id=config_obj.id) config_obj.name = name if description: config_obj.description = description @@ -209,7 +225,7 @@ class ToolService: try: # 获取工具实例 - tool = self._get_tool_instance(tool_id, tenant_id) + tool = self.get_tool_instance(tool_id, tenant_id) if not tool: return ToolResult.error_result( error=f"工具不存在: {tool_id}", @@ -335,7 +351,7 @@ class ToolService: return [] # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return [] @@ -792,7 +808,7 @@ class ToolService: """获取工具配置""" return self.tool_repo.find_by_id_and_tenant(self.db, uuid.UUID(tool_id), tenant_id) - def _get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: + def get_tool_instance(self, tool_id: str, tenant_id: uuid.UUID) -> Optional[BaseTool]: """获取工具实例""" if tool_id in self._tool_cache: return self._tool_cache[tool_id] @@ -1416,7 +1432,7 @@ class ToolService: """测试内置工具连接""" try: # 获取工具实例 - tool_instance = self._get_tool_instance(str(config.id), config.tenant_id) + tool_instance = self.get_tool_instance(str(config.id), config.tenant_id) if not tool_instance: return {"success": False, "message": "无法创建工具实例"} diff --git a/api/app/services/user_memory_service.py b/api/app/services/user_memory_service.py index 80413c12..8bacc112 100644 --- a/api/app/services/user_memory_service.py +++ b/api/app/services/user_memory_service.py @@ -10,6 +10,9 @@ from collections import Counter from datetime import datetime from typing import Any, Dict, List, Optional, Tuple +from pydantic import BaseModel, Field +from sqlalchemy.orm import Session + from app.core.logging_config import get_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context @@ -18,13 +21,10 @@ from app.repositories.end_user_repository import EndUserRepository from app.repositories.neo4j.cypher_queries import Graph_Node_query from app.repositories.neo4j.neo4j_connector import Neo4jConnector from app.schemas.memory_episodic_schema import EmotionSubject, EmotionType, type_mapping -from app.services.implicit_memory_service import ImplicitMemoryService -from app.services.memory_base_service import MemoryBaseService, MemoryTransService +from app.services.memory_base_service import MemoryBaseService from app.services.memory_config_service import MemoryConfigService from app.services.memory_perceptual_service import MemoryPerceptualService from app.services.memory_short_service import ShortService -from pydantic import BaseModel, Field -from sqlalchemy.orm import Session logger = get_logger(__name__) @@ -1035,9 +1035,10 @@ async def analytics_memory_insight_report(end_user_id: Optional[str] = None, lan "growth_trajectory": str # 成长轨迹 } """ - from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt - from app.core.language_utils import validate_language import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_memory_insight_prompt # 验证语言参数 language = validate_language(language) @@ -1161,13 +1162,32 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st "one_sentence": str } """ - from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt - from app.core.language_utils import validate_language import re + + from app.core.language_utils import validate_language + from app.core.memory.utils.prompt.prompt_utils import render_user_summary_prompt + from app.repositories.end_user_repository import EndUserRepository # 验证语言参数 language = validate_language(language) + # 获取用户的 other_name 字段 + user_display_name = "该用户" if language == "zh" else "the user" + if end_user_id: + try: + # 获取数据库会话并查询用户信息 + with get_db_context() as db: + repo = EndUserRepository(db) + end_user = repo.get_by_id(uuid.UUID(end_user_id)) + if end_user and end_user.other_name: + user_display_name = end_user.other_name + logger.info(f"使用 other_name 作为用户显示名称: {user_display_name}") + else: + logger.info(f"用户 {end_user_id} 的 other_name 为空,使用默认称呼: {user_display_name}") + + except Exception as e: + logger.warning(f"获取用户 other_name 失败,使用默认称呼: {str(e)}") + # 创建 UserSummaryHelper 实例 user_summary_tool = UserSummaryHelper(end_user_id or os.getenv("SELECTED_end_user_id", "group_123")) @@ -1184,7 +1204,8 @@ async def analytics_user_summary(end_user_id: Optional[str] = None, language: st user_id=user_summary_tool.user_id, entities=", ".join(entity_lines) if entity_lines else "(空)" if language == "zh" else "(empty)", statements=" | ".join(statement_samples) if statement_samples else "(空)" if language == "zh" else "(empty)", - language=language + language=language, + user_display_name=user_display_name ) messages = [ @@ -1435,7 +1456,7 @@ async def analytics_memory_types( short_term_count = 0 if end_user_id: try: - short_term_service = ShortService(end_user_id) + short_term_service = ShortService(end_user_id, db) short_term_data = short_term_service.get_short_databasets() # 统计 short_term 数组的长度 if short_term_data: @@ -1449,8 +1470,10 @@ async def analytics_memory_types( forgetting_threshold = 0.3 # 默认值 if end_user_id: try: + from app.core.memory.storage_services.forgetting_engine.config_utils import ( + load_actr_config_from_db, + ) from app.services.memory_agent_service import get_end_user_connected_config - from app.core.memory.storage_services.forgetting_engine.config_utils import load_actr_config_from_db # 获取用户关联的 config_id connected_config = get_end_user_connected_config(end_user_id, db) diff --git a/api/app/services/workflow_import_service.py b/api/app/services/workflow_import_service.py new file mode 100644 index 00000000..2b36c5ea --- /dev/null +++ b/api/app/services/workflow_import_service.py @@ -0,0 +1,102 @@ +# -*- coding: UTF-8 -*- +# Author: Eternity +# @Email: 1533512157@qq.com +# @Time : 2026/2/25 14:39 +import json +import uuid +from typing import Any + +from sqlalchemy.orm import Session + +from app.aioRedis import aio_redis_set, aio_redis_get +from app.core.config import settings +from app.core.exceptions import BusinessException +from app.core.workflow.adapters.base_adapter import WorkflowImportResult, WorkflowParserResult +from app.core.workflow.adapters.errors import UnsupportPlatform, InvalidConfiguration +from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.schemas import AppCreate +from app.schemas.workflow_schema import WorkflowConfigCreate +from app.services.app_service import AppService +from app.services.workflow_service import WorkflowService + + +class WorkflowImportService: + def __init__(self, db: Session): + self.db = db + self.registry = PlatformAdapterRegistry + self.cache_timeout = settings.WORKFLOW_IMPORT_CACHE_TIMEOUT + + self.app_service = AppService(db) + self.workflow_service = WorkflowService(db) + + async def flush_config(self, temp_id: str, config: WorkflowParserResult): + config_cache = await aio_redis_get(temp_id) + if not config_cache: + raise BusinessException("Workflow configuration has expired. Please re-upload it.") + await aio_redis_set(temp_id, config.model_dump_json(), expire=self.cache_timeout) + + async def upload_config( + self, + platform: str, + config: dict[str, Any], + ): + + if not self.registry.is_supported(platform): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[UnsupportPlatform(platform=platform)] + ) + + adapter = self.registry.get_adapter(platform, config) + + if not adapter.validate_config(): + return WorkflowImportResult( + success=False, + temp_id=None, + workflow_id=None, + errors=[InvalidConfiguration()] + adapter.errors + ) + + workflow_config = adapter.parse_workflow() + temp_id = uuid.uuid4().hex + await aio_redis_set(temp_id, workflow_config.model_dump(), expire=self.cache_timeout) + return WorkflowImportResult( + success=True, + temp_id=temp_id, + workflow_id=None, + edges=workflow_config.edges, + nodes=workflow_config.nodes, + variables=workflow_config.variables, + warnings=workflow_config.warnings, + errors=workflow_config.errors + ) + + async def save_workflow( + self, + user_id: uuid.UUID, + workspace_id: uuid.UUID, + temp_id: str, + name: str, + description: str | None, + ): + config = await aio_redis_get(temp_id) + if config is None: + raise BusinessException("Configuration import timed out. Please try again.") + config = json.loads(config) + app = self.app_service.create_app( + user_id=user_id, + workspace_id=workspace_id, + data=AppCreate( + name=name, + description=description, + type="workflow", + workflow_config=WorkflowConfigCreate( + nodes=config["nodes"], + edges=config["edges"], + variables=config["variables"] + ) + ) + ) + return app diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index d06a05d7..eaf78b90 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -6,20 +6,26 @@ import logging import uuid from typing import Any, Annotated, Optional +import yaml from fastapi import Depends from sqlalchemy.orm import Session from app.core.error_codes import BizCode from app.core.exceptions import BusinessException +from app.core.workflow.adapters.registry import PlatformAdapterRegistry +from app.core.workflow.executor import execute_workflow, execute_workflow_stream +from app.core.workflow.nodes.enums import NodeType from app.core.workflow.validator import validate_workflow_config +from app.core.workflow.variable.base_variable import FileObject from app.db import get_db +from app.models import App from app.models.workflow_model import WorkflowConfig, WorkflowExecution from app.repositories.workflow_repository import ( WorkflowConfigRepository, WorkflowExecutionRepository, WorkflowNodeExecutionRepository ) -from app.schemas import DraftRunRequest +from app.schemas import DraftRunRequest, FileInput, FileType from app.services.conversation_service import ConversationService from app.services.multi_agent_service import convert_uuids_to_str from app.services.multimodal_service import MultimodalService @@ -38,6 +44,8 @@ class WorkflowService: self.conversation_service = ConversationService(db) self.multimodal_service = MultimodalService(db) + self.registry = PlatformAdapterRegistry + # ==================== 配置管理 ==================== def create_workflow_config( @@ -200,6 +208,32 @@ class WorkflowService: logger.info(f"删除工作流配置成功: app_id={app_id}, config_id={config.id}") return True + def export_workflow_dsl(self, app_id: uuid.UUID): + config = self.get_workflow_config(app_id) + if not config: + raise BusinessException( + code=BizCode.NOT_FOUND, + message=f"工作流配置不存在: app_id={app_id}" + ) + + app: App = config.app + dsl_info = { + "app": { + "name": app.name, + "description": app.description, + "icon": app.icon, + "icon_type": app.icon_type + }, + "workflow": { + "variables": config.variables, + "edges": config.edges, + "nodes": config.nodes, + "execution_config": config.execution_config, + "triggers": config.triggers + } + } + return yaml.dump(dsl_info, default_flow_style=False, allow_unicode=True) + def check_config(self, app_id: uuid.UUID) -> WorkflowConfig: """检查工作流配置的完整性 @@ -413,6 +447,95 @@ class WorkflowService: "success_rate": completed / total if total > 0 else 0 } + async def _handle_file_input(self, files: list[FileInput]): + if not files: + return [] + + files_struct = [] + for file in files: + files_struct.append( + FileObject( + type=file.type, + url=await self.multimodal_service.get_file_url(file), + transfer_method=file.transfer_method, + file_id=str(file.upload_file_id), + origin_file_type=file.file_type, + is_file=True + ).model_dump() + ) + return files_struct + + @staticmethod + def _map_public_event(event: dict) -> dict | None: + """ + Map internal workflow events to public-facing event formats. + + Purpose: + - Hide internal execution details + - Expose a stable and simplified public event schema + - Filter out non-public events + - Maintain backward compatibility when possible + + Args: + event (dict): Internal event object, e.g.: + { + "event": "workflow_start", + "data": {...} + } + + Returns: + dict | None: + - Returns the mapped public event + - Returns None if the event should not be exposed + """ + event_type = event.get("event") + payload = event.get("data") + match event_type: + case "workflow_start": + return { + "event": "start", + "data": { + "conversation_id": payload.get("conversation_id"), + "message_id": payload.get("message_id") + } + } + case "workflow_end": + return { + "event": "end", + "data": { + "elapsed_time": payload.get("elapsed_time"), + "message_length": len(payload.get("output", "")), + "error": payload.get("error", "") + } + } + case "node_start" | "node_end" | "node_error" | "cycle_item": + return None + case _: + return event + + def _emit(self, public: bool, internal_event: dict): + """ + Unified event emission entry. + + Args: + public (bool): + - True -> Emit mapped public event + - False -> Emit raw internal event + + internal_event (dict): + The original internal event object + + Returns: + dict | None: + - The mapped event + - Or None if the event is filtered out + """ + if public: + mapped = self._map_public_event(internal_event) + else: + mapped = internal_event + return mapped + # ==================== 工作流执行 ==================== async def run( @@ -447,10 +570,11 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id, - "files": [file.model_dump(mode='json') for file in payload.files] - } + input_data = { + "message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, + "files": [file.model_dump(mode='json') for file in payload.files] + } # 转换 conversation_id 为 UUID conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None @@ -474,23 +598,10 @@ class WorkflowService: "execution_config": config.execution_config } - # 4. 获取工作空间 ID(从 app 获取) - - # 5. 执行工作流 - from app.core.workflow.executor import execute_workflow - try: - files = [] - if payload.files: - for file in payload.files: - files.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } - ) + files = await self._handle_file_input(payload.files) input_data["files"] = files + message_id = uuid.uuid4() # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") @@ -515,24 +626,45 @@ class WorkflowService: workspace_id=str(workspace_id), user_id=payload.user_id ) - # 更新执行结果 if result.get("status") == "completed": token_usage = result.get("token_usage", {}) or {} + + final_messages = result.get("messages", [])[init_message_length:] + human_message = "" + assistant_message = "" + for message in final_messages: + if message["role"] == "user": + if isinstance(message["content"], str): + human_message += message["content"] + elif isinstance(message["content"], list): + for file in message["content"]: + if file.get("type") == FileType.IMAGE: + human_message += f"![image]({file.get('url', '')})" + else: + human_message += f"[{file.get('type')}]({file.get('url', '')})" + if message["role"] == "assistant": + assistant_message = message["content"] + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="user", + content=human_message, + meta_data=None + ) + self.conversation_service.add_message( + message_id=message_id, + conversation_id=conversation_id_uuid, + role="assistant", + content=assistant_message, + meta_data={"usage": token_usage} + ) self.update_execution_status( execution.execution_id, "completed", output_data=result, token_usage=token_usage.get("total_tokens", None) ) - final_messages = result.get("messages", [])[init_message_length:] - for message in final_messages: - self.conversation_service.add_message( - conversation_id=conversation_id_uuid, - role=message["role"], - content=message["content"], - meta_data=None if message["role"] == "user" else {"usage": token_usage} - ) + logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") else: @@ -541,6 +673,8 @@ class WorkflowService: "failed", error_message=result.get("error") ) + logger.error(f"Workflow Run Failed, execution_id: {execution.execution_id}," + f" error: {result.get('error')}") # 返回增强的响应结构 return { @@ -549,6 +683,8 @@ class WorkflowService: # "variables": result.get("variables"), # "messages": result.get("messages"), "output": result.get("output"), # 最终输出(字符串) + "message": result.get("output"), # 最终输出(字符串) + "message_id": str(message_id), # "output_data": result.get("node_outputs", {}), # 所有节点输出(详细数据) "conversation_id": result.get("conversation_id"), # 所有节点输出(详细数据)payload., # 会话 ID "error_message": result.get("error"), @@ -568,41 +704,6 @@ class WorkflowService: message=f"工作流执行失败: {str(e)}" ) - @staticmethod - def _map_public_event(event: dict) -> dict | None: - event_type = event.get("event") - payload = event.get("data") - match event_type: - case "workflow_start": - return { - "event": "start", - "data": { - "conversation_id": payload.get("conversation_id"), - } - } - case "workflow_end": - return { - "event": "end", - "data": { - "elapsed_time": payload.get("elapsed_time"), - "message_length": len(payload.get("output", "")) - } - } - case "node_start" | "node_end" | "node_error" | "cycle_item": - return None - case _: - return event - - def _emit(self, public: bool, internal_event: dict): - """ - decide - """ - if public: - mapped = self._map_public_event(internal_event) - else: - mapped = internal_event - return mapped - async def run_stream( self, app_id: uuid.UUID, @@ -637,10 +738,11 @@ class WorkflowService: message=f"工作流配置不存在: app_id={app_id}" ) - input_data = {"message": payload.message, "variables": payload.variables, - "conversation_id": payload.conversation_id, - "files": [file.model_dump(mode='json') for file in payload.files] - } + input_data = { + "message": payload.message, "variables": payload.variables, + "conversation_id": payload.conversation_id, + "files": [file.model_dump(mode='json') for file in payload.files] + } # 转换 conversation_id 为 UUID conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None @@ -665,16 +767,7 @@ class WorkflowService: } try: - files = [] - if payload.files: - for file in payload.files: - files.append( - { - "type": file.type, - "url": await self.multimodal_service.get_file_url(file), - "__file": True - } - ) + files = await self._handle_file_input(payload.files) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) @@ -689,8 +782,7 @@ class WorkflowService: input_data["conv_messages"] = last_state.get("messages") or [] break init_message_length = len(input_data.get("conv_messages", [])) - from app.core.workflow.executor import execute_workflow_stream - + message_id = uuid.uuid4() async for event in execute_workflow_stream( workflow_config=workflow_config_dict, input_data=input_data, @@ -699,24 +791,43 @@ class WorkflowService: user_id=payload.user_id, ): if event.get("event") == "workflow_end": - status = event.get("data", {}).get("status") token_usage = event.get("data", {}).get("token_usage", {}) or {} if status == "completed": + final_messages = event.get("data", {}).get("messages", [])[init_message_length:] + human_message = "" + assistant_message = "" + for message in final_messages: + if message["role"] == "user": + if isinstance(message["content"], str): + human_message += message["content"] + elif isinstance(message["content"], list): + for file in message["content"]: + if file.get("type") == FileType.IMAGE: + human_message += f"![image]({file.get('url', '')})" + else: + human_message += f"[{file.get('type')}]({file.get('url', '')})" + if message["role"] == "assistant": + assistant_message = message["content"] + self.conversation_service.add_message( + conversation_id=conversation_id_uuid, + role="user", + content=human_message, + meta_data=None + ) + self.conversation_service.add_message( + message_id=message_id, + conversation_id=conversation_id_uuid, + role="assistant", + content=assistant_message, + meta_data={"usage": token_usage} + ) self.update_execution_status( execution.execution_id, "completed", output_data=event.get("data"), token_usage=token_usage.get("total_tokens", None) ) - final_messages = event.get("data", {}).get("messages", [])[init_message_length:] - for message in final_messages: - self.conversation_service.add_message( - conversation_id=conversation_id_uuid, - role=message["role"], - content=message["content"], - meta_data=None if message["role"] == "user" else {"usage": token_usage} - ) logger.info(f"Workflow Run Success, " f"execution_id: {execution.execution_id}, message count: {len(final_messages)}") elif status == "failed": @@ -727,6 +838,8 @@ class WorkflowService: ) else: logger.error(f"unexpect workflow run status, status: {status}") + elif event.get("event") == "workflow_start": + event["data"]["message_id"] = str(message_id) event = self._emit(public, event) if event: yield event @@ -747,36 +860,13 @@ class WorkflowService: } } - def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]: - """清理事件数据,移除不可序列化的对象 - - Args: - event: 原始事件数据 - - Returns: - 可序列化的事件数据 - """ - from langchain_core.messages import BaseMessage - - def clean_value(value): - """递归清理值""" - if isinstance(value, BaseMessage): - # 将 Message 对象转换为字典 - return { - "type": value.__class__.__name__, - "content": value.content, - } - elif isinstance(value, dict): - return {k: clean_value(v) for k, v in value.items()} - elif isinstance(value, list): - return [clean_value(item) for item in value] - elif isinstance(value, (str, int, float, bool, type(None))): - return value - else: - # 其他不可序列化的对象转换为字符串 - return str(value) - - return clean_value(event) + @staticmethod + def get_start_node_variables(config: dict) -> list: + nodes = config.get("nodes", []) + for node in nodes: + if node.get("type") == NodeType.START: + return node.get("config", {}).get("variables", []) + raise BusinessException("workflow config error - start node not found") # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/workspace_service.py b/api/app/services/workspace_service.py index 6f102695..74880410 100644 --- a/api/app/services/workspace_service.py +++ b/api/app/services/workspace_service.py @@ -30,6 +30,7 @@ from app.schemas.workspace_schema import ( WorkspaceModelsUpdate, WorkspaceUpdate, ) +from app.config.default_ontology_initializer import DefaultOntologyInitializer # 获取业务逻辑专用日志器 business_logger = get_business_logger() @@ -106,6 +107,7 @@ def get_user_workspaces(db: Session, user: User) -> List[Workspace]: for workspace in workspaces: if workspace.storage_type == 'neo4j': _ensure_default_memory_config(db, workspace) + _ensure_default_ontology_scenes(db, workspace) business_logger.info(f"用户 {user.username} 的工作空间数量: {len(workspaces)}") return workspaces @@ -129,7 +131,7 @@ def _create_workspace_only( raise def create_workspace( - db: Session, workspace: WorkspaceCreate, user: User + db: Session, workspace: WorkspaceCreate, user: User, language: str = "zh" ) -> Workspace: business_logger.info( f"创建工作空间: {workspace.name}, 创建者: {user.username}, " @@ -145,10 +147,71 @@ def create_workspace( db=db, workspace=workspace, tenant_id=user.tenant_id ) business_logger.info(f"工作空间创建成功: {db_workspace.name} (ID: {db_workspace.id}), 创建者: {user.username}") - db.commit() + db.flush() # 使用 flush 而不是 commit,获取 ID 但不提交事务 db.refresh(db_workspace) + # Initialize default ontology scenes for the workspace (先创建本体场景) + default_scene_id = None + default_scene_name = None + try: + initializer = DefaultOntologyInitializer(db) + success, error_msg = initializer.initialize_default_scenes( + db_workspace.id, language=language + ) + + if success: + business_logger.info( + f"为工作空间 {db_workspace.id} 创建默认本体场景成功 (language={language})" + ) + + # 获取默认场景ID,优先使用"在线教育"场景,如果不存在则使用"情感陪伴"场景 + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.config.default_ontology_config import ( + ONLINE_EDUCATION_SCENE, + EMOTIONAL_COMPANION_SCENE, + get_scene_name + ) + + scene_repo = OntologySceneRepository(db) + + # 优先尝试获取教育场景 + education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language) + education_scene = scene_repo.get_by_name(education_scene_name, db_workspace.id) + + if education_scene: + default_scene_id = education_scene.scene_id + default_scene_name = education_scene.scene_name + business_logger.info( + f"获取到教育场景ID用于默认记忆配置: {default_scene_id} (scene_name={education_scene_name})" + ) + else: + # 如果教育场景不存在,尝试获取情感陪伴场景 + companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language) + companion_scene = scene_repo.get_by_name(companion_scene_name, db_workspace.id) + + if companion_scene: + default_scene_id = companion_scene.scene_id + default_scene_name = companion_scene.scene_name + business_logger.info( + f"教育场景不存在,使用情感陪伴场景ID用于默认记忆配置: {default_scene_id} (scene_name={companion_scene_name})" + ) + else: + business_logger.warning( + f"未找到任何默认场景 (education={education_scene_name}, companion={companion_scene_name})" + ) + else: + business_logger.warning( + f"为工作空间 {db_workspace.id} 创建默认本体场景失败: {error_msg} (language={language})" + ) + except Exception as ontology_error: + business_logger.error( + f"为工作空间 {db_workspace.id} 创建默认本体场景异常: {str(ontology_error)} (language={language})" + ) + # Don't fail workspace creation if default ontology initialization fails + # The workspace can still function without default ontology scenes + # Create default memory config for the workspace (only for neo4j storage types) + # 将默认场景ID(教育场景或情感陪伴场景)关联到记忆配置 if workspace.storage_type == 'neo4j': try: _create_default_memory_config( @@ -158,9 +221,11 @@ def create_workspace( llm_id=llm, embedding_id=embedding, rerank_id=rerank, + scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景) + pruning_scene_name=default_scene_name, # 传入场景名称作为语义剪枝场景值 ) business_logger.info( - f"为工作空间 {db_workspace.id} 创建默认记忆配置成功" + f"为工作空间 {db_workspace.id} 创建默认记忆配置成功 (scene_id={default_scene_id})" ) except Exception as mc_error: business_logger.error( @@ -209,7 +274,6 @@ def create_workspace( db=db, knowledge=knowledge_data ) - db.commit() business_logger.info( f"为工作空间 {db_workspace.id} 自动创建知识库成功: " f"{db_knowledge.name} (ID: {db_knowledge.id})" @@ -224,6 +288,12 @@ def create_workspace( BizCode.INTERNAL_ERROR ) + # 统一提交所有更改 + db.commit() + business_logger.info( + f"工作空间 {db_workspace.id} 及相关资源创建完成并已提交" + ) + return db_workspace except Exception as e: @@ -919,6 +989,43 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: f"Workspace {workspace.id} missing default memory config, creating one" ) + # 尝试获取默认场景ID,优先教育场景,其次情感陪伴场景 + default_scene_id = None + try: + from app.repositories.ontology_scene_repository import OntologySceneRepository + from app.config.default_ontology_config import ( + ONLINE_EDUCATION_SCENE, + EMOTIONAL_COMPANION_SCENE, + get_scene_name + ) + + scene_repo = OntologySceneRepository(db) + # 尝试中文和英文场景名称 + for language in ["zh", "en"]: + # 优先尝试教育场景 + education_scene_name = get_scene_name(ONLINE_EDUCATION_SCENE, language) + education_scene = scene_repo.get_by_name(education_scene_name, workspace.id) + if education_scene: + default_scene_id = education_scene.scene_id + business_logger.info( + f"找到教育场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={education_scene_name}" + ) + break + + # 如果教育场景不存在,尝试情感陪伴场景 + companion_scene_name = get_scene_name(EMOTIONAL_COMPANION_SCENE, language) + companion_scene = scene_repo.get_by_name(companion_scene_name, workspace.id) + if companion_scene: + default_scene_id = companion_scene.scene_id + business_logger.info( + f"教育场景不存在,找到情感陪伴场景用于默认记忆配置: scene_id={default_scene_id}, scene_name={companion_scene_name}" + ) + break + except Exception as scene_error: + business_logger.warning( + f"获取默认场景失败,将创建不关联场景的记忆配置: {str(scene_error)}" + ) + try: _create_default_memory_config( db=db, @@ -927,6 +1034,7 @@ def _ensure_default_memory_config(db: Session, workspace: Workspace) -> None: llm_id=uuid.UUID(workspace.llm) if workspace.llm else None, embedding_id=uuid.UUID(workspace.embedding) if workspace.embedding else None, rerank_id=uuid.UUID(workspace.rerank) if workspace.rerank else None, + scene_id=default_scene_id, # 传入默认场景ID(优先教育场景,其次情感陪伴场景) ) except Exception as e: business_logger.error( @@ -1001,6 +1109,52 @@ def _fill_workspace_configs_model_defaults( ) +def _ensure_default_ontology_scenes(db: Session, workspace: Workspace) -> None: + """Ensure a workspace has default ontology scenes, creating them if missing. + + Checks whether any is_system_default scene exists for the workspace. + If not, runs the DefaultOntologyInitializer to create them. + + Args: + db: Database session + workspace: The workspace to check + """ + from app.models.ontology_scene import OntologyScene + + # 幂等检查:是否已存在系统默认场景 + existing = db.query(OntologyScene).filter( + OntologyScene.workspace_id == workspace.id, + OntologyScene.is_system_default.is_(True) + ).first() + + if existing: + return + + business_logger.info( + f"Workspace {workspace.id} missing default ontology scenes, creating them" + ) + + try: + initializer = DefaultOntologyInitializer(db) + success, error_msg = initializer.initialize_default_scenes( + workspace.id, language="zh" + ) + if success: + db.commit() + business_logger.info( + f"为工作空间 {workspace.id} 补建默认本体场景成功" + ) + else: + business_logger.warning( + f"为工作空间 {workspace.id} 补建默认本体场景失败: {error_msg}" + ) + except Exception as e: + db.rollback() + business_logger.error( + f"为工作空间 {workspace.id} 补建默认本体场景异常: {str(e)}" + ) + + def _create_default_memory_config( db: Session, workspace_id: uuid.UUID, @@ -1008,6 +1162,8 @@ def _create_default_memory_config( llm_id: Optional[uuid.UUID] = None, embedding_id: Optional[uuid.UUID] = None, rerank_id: Optional[uuid.UUID] = None, + scene_id: Optional[uuid.UUID] = None, + pruning_scene_name: Optional[str] = None, ) -> None: """Create a default memory config for a newly created workspace. @@ -1018,6 +1174,8 @@ def _create_default_memory_config( llm_id: Optional LLM model ID embedding_id: Optional embedding model ID rerank_id: Optional rerank model ID + scene_id: Optional ontology scene ID (默认关联教育场景) + pruning_scene_name: Optional pruning scene name,取自 ontology_scene.scene_name """ from app.models.memory_config_model import MemoryConfig @@ -1031,12 +1189,14 @@ def _create_default_memory_config( llm_id=str(llm_id) if llm_id else None, embedding_id=str(embedding_id) if embedding_id else None, rerank_id=str(rerank_id) if rerank_id else None, + scene_id=scene_id, # 关联本体场景ID(默认为"在线教育"场景) + pruning_scene=pruning_scene_name, # 语义剪枝场景直接使用 scene_name state=True, # Active by default is_default=True, # Mark as workspace default ) db.add(default_config) - db.commit() + db.flush() # 使用 flush 而不是 commit,让调用者统一提交 business_logger.info( "Created default memory config for workspace", @@ -1044,5 +1204,6 @@ def _create_default_memory_config( "workspace_id": str(workspace_id), "config_id": str(config_id), "config_name": default_config.config_name, + "scene_id": str(scene_id) if scene_id else None, } ) diff --git a/api/app/tasks.py b/api/app/tasks.py index ae533489..a6ebbb8e 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,16 +1,16 @@ import asyncio -from concurrent.futures import ThreadPoolExecutor import json import os import re +import shutil import time import uuid -from uuid import UUID +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from math import ceil from pathlib import Path -import shutil from typing import Any, Dict, List, Optional +from uuid import UUID import redis import requests @@ -38,7 +38,7 @@ from app.db import get_db, get_db_context from app.models.document_model import Document from app.models.file_model import File from app.models.knowledge_model import Knowledge -from app.schemas import file_schema, document_schema +from app.schemas import document_schema, file_schema from app.services.memory_agent_service import MemoryAgentService from app.utils.config_utils import resolve_config_id @@ -67,8 +67,9 @@ def parse_document(file_path: str, document_id: uuid.UUID): Document parsing, vectorization, and storage """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_document = None @@ -256,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID): progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n" return result - try: + def sync_task(): trio.run( lambda: _run( row=task, @@ -271,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID): with_community=with_community, ) ) + try: + with ThreadPoolExecutor(max_workers=1) as executor: + future = executor.submit(sync_task) + future.result() # Blocks until the task completes except Exception as e: progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n" progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)" @@ -297,8 +302,9 @@ def build_graphrag_for_kb(kb_id: uuid.UUID): build knowledge graph """ # Force re-importing Trio in child processes (to avoid inheriting the state of the parent process) - import trio import importlib + + import trio importlib.reload(trio) db = next(get_db()) # Manually call the generator db_documents = None @@ -932,24 +938,18 @@ def read_message_task(self, end_user_id: str, message: str, history: List[Dict[s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: service = MemoryAgentService() return await service.read_memory(end_user_id, message, history, search_switch, actual_config_id, db, storage_type, user_rag_memory_id) - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1049,19 +1049,15 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s if actual_config_id is None: try: from app.services.memory_agent_service import get_end_user_connected_config - db = next(get_db()) - try: + with get_db_context() as db: connected_config = get_end_user_connected_config(end_user_id, db) actual_config_id = connected_config.get("memory_config_id") - finally: - db.close() except Exception: # Log but continue - will fail later with proper error pass async def _run() -> str: - db = next(get_db()) - try: + with get_db_context() as db: logger.info( f"[CELERY WRITE] Executing MemoryAgentService.write_memory with config_id={actual_config_id} (type: {type(actual_config_id).__name__}), language={language}") service = MemoryAgentService() @@ -1069,11 +1065,6 @@ def write_message_task(self, end_user_id: str, message: list[dict], config_id: s user_rag_memory_id, language) logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result - except Exception as e: - logger.error(f"[CELERY WRITE] Write failed: {e}", exc_info=True) - raise - finally: - db.close() try: # 使用 nest_asyncio 来避免事件循环冲突 @@ -1328,9 +1319,9 @@ def write_all_workspaces_memory_task(self) -> Dict[str, Any]: async def _run() -> Dict[str, Any]: from app.core.logging_config import get_api_logger - from app.models.workspace_model import Workspace from app.models.app_model import App from app.models.end_user_model import EndUser + from app.models.workspace_model import Workspace from app.repositories.memory_increment_repository import write_memory_increment from app.services.memory_storage_service import search_all diff --git a/api/app/version_info.json b/api/app/version_info.json index 7d82eabc..bbaffc17 100644 --- a/api/app/version_info.json +++ b/api/app/version_info.json @@ -1,4 +1,36 @@ { + "v0.2.6": { + "introduction": { + "codeName": "听剑", + "releaseDate": "2026-3-6", + "upgradePosition": "🐻 多模态交互全面升级,记忆剪枝与工作流迁移双线并进,锋芒初露,兼收并蓄", + "coreUpgrades": [ + "1. 工作流与应用框架
* 工作流导入适配(Dify):支持 Dify 工作流定义无缝迁移
* 字段字数限制与校验规则:可配置字符限制与产品级校验
* 应用复制(Agent、工作流、集群):一键复制完整应用配置
* 对话变量(调试+分享):支持有状态多轮交互
* Chat 接口流式输出 message_id:流式响应包含消息追踪标识", + "2. 多模态与交互 💬
* 音频输入与输出:应用支持音频模态
* 文件类型输入支持:扩展支持语音、文件、视频上传", + "3. 模型与智能 🧠
* 模型视觉与 Omni 区分:精确区分视觉与 Omni 模型能力
* 教育记忆与陪伴玩具场景预设:垂直领域本体配置开箱即用
* 本体配置默认标识:支持基线配置标记
* 记忆配置默认标识:自动应用默认记忆设置", + "4. 记忆智能 🔬
* 记忆剪枝模块:智能裁剪冗余低价值记忆
* RAG 快速检索集成记忆:深度思考与正常回复双模式检索", + "5. 稳健性与缺陷修复 🔧
* 模型管理:修复自定义模型 API Key 批量配置错误
* 知识库管理:修复非源文档下载原始内容接口错误,更新分享停用提示文案
* 用户记忆:优化档案提取准确性(姓名、职业、兴趣分布)
* 长期记忆:修复情景记忆卡片重复和用户归属错误
* 工作空间首页:修复知识库数量、应用数量、总记忆容量、API 调用次数、知识库类型分布等数据不一致问题
* 基础设施:修正 Celery 环境变量配置,修复数据库连接池 idle-in-transaction 泄漏", + "
", + "v0.2.6 标志着 MemoryBear 在多模态交互、跨平台工作流迁移和智能记忆管理方面的重要突破。下一版本将聚焦 A2A 协议支持实现多智能体协作、多模态记忆能力扩展至语音与视觉领域,以及应用导入导出功能支持跨环境便携部署。", + "MemoryBear,让记忆有熊力 🐻✨" + ] + }, + "introduction_en": { + "codeName": "TingJian", + "releaseDate": "2026-3-6", + "upgradePosition": "🐻 Full multimodal interaction upgrade with memory pruning and workflow migration — sharpened edge, broader reach", + "coreUpgrades": [ + "1. Workflow & Application Framework
* Workflow Import Adaptation (Dify): Seamless Dify workflow migration
* Field Character Limits & Validation: Configurable limits with product-defined rules
* Application Cloning (Agent, Workflow, Cluster): One-click full config duplication
* Conversation Variables (Debug + Share): Stateful multi-turn interactions
* Streaming message_id in Chat API: Message tracking in streaming responses", + "2. Multimodal & Interaction 💬
* Audio Input & Output: Audio modality support for applications
* File Type Input Support: Voice, file, and video upload support", + "3. Model & Intelligence 🧠
* Model Vision & Omni Differentiation: Precise capability routing
* Education Memory & Companion Toy Presets: Domain-specific ontology configs
* Ontology Default Identifier: Baseline configuration flagging
* Memory Configuration Default Identifier: Auto-apply default settings", + "4. Memory Intelligence 🔬
* Memory Pruning Module: Intelligent trimming of redundant memories
* RAG Quick Retrieval with Memory: Deep think and normal reply dual-mode retrieval", + "5. Robustness & Bug Fixes 🔧
* Model Management: Fixed custom model API key batch configuration error
* Knowledge Base: Fixed download original content API error for non-source documents, updated share disable prompt text
* User Memory: Improved profile extraction accuracy (name, occupation, interests)
* Long-Term Memory: Fixed duplicate episodic memory cards and wrong user attribution
* Dashboard: Fixed data inconsistencies in knowledge count, app count, memory capacity, API calls, and knowledge type distribution
* Infrastructure: Corrected Celery environment variables, fixed database connection pool idle-in-transaction leak", + "
", + "v0.2.6 marks a significant milestone for MemoryBear in multimodal interaction, cross-platform workflow migration, and intelligent memory management. The next release will focus on A2A protocol support for multi-agent collaboration, multimodal memory extending extraction to voice and visual domains, and application import/export for portable cross-environment deployment.", + "MemoryBear, Memory with Bear Power 🐻✨" + ] + } + }, "v0.2.5": { "introduction": { "codeName": "行云", diff --git a/api/env.example b/api/env.example index e8074f82..bd7f3dae 100644 --- a/api/env.example +++ b/api/env.example @@ -29,10 +29,10 @@ REDIS_DB= REDIS_PASSWORD=password #celery -BROKER_URL= -RESULT_BACKEND= -CELERY_BROKER= -CELERY_BACKEND= +# NOTE: 不要使用 BROKER_URL / RESULT_BACKEND / CELERY_BROKER / CELERY_BACKEND, +# 这些名称会被 Celery CLI 劫持,详见 docs/celery-env-bug-report.md +REDIS_DB_CELERY_BROKER= +REDIS_DB_CELERY_BACKEND= # Memory Cache Regeneration Configuration # Interval in hours for regenerating memory insight and user summary cache @@ -139,7 +139,7 @@ SMTP_USER= SMTP_PASSWORD= # 本体类型融合配置 (记得写入env_example) -GENERAL_ONTOLOGY_FILES=General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 +GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔 ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导) MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长 CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中 diff --git a/api/migrations/versions/4bf27c66ae63_202602281918.py b/api/migrations/versions/4bf27c66ae63_202602281918.py new file mode 100644 index 00000000..78b13435 --- /dev/null +++ b/api/migrations/versions/4bf27c66ae63_202602281918.py @@ -0,0 +1,44 @@ +"""202602281918 + +Revision ID: 4bf27c66ae63 +Revises: 7672d8f0f939 +Create Date: 2026-02-28 19:18:38.332468 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '4bf27c66ae63' +down_revision: Union[str, None] = '7672d8f0f939' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add columns as nullable first + op.add_column('ontology_class', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认类型')) + op.add_column('ontology_scene', sa.Column('is_system_default', sa.Boolean(), nullable=True, comment='是否为系统默认场景')) + + # Set default value for existing rows + op.execute("UPDATE ontology_class SET is_system_default = false WHERE is_system_default IS NULL") + op.execute("UPDATE ontology_scene SET is_system_default = false WHERE is_system_default IS NULL") + + # Now make columns NOT NULL + op.alter_column('ontology_class', 'is_system_default', nullable=False) + op.alter_column('ontology_scene', 'is_system_default', nullable=False) + + op.create_index(op.f('ix_ontology_scene_is_system_default'), 'ontology_scene', ['is_system_default'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index(op.f('ix_ontology_scene_is_system_default'), table_name='ontology_scene') + op.drop_column('ontology_scene', 'is_system_default') + op.drop_column('ontology_class', 'is_system_default') + # ### end Alembic commands ### diff --git a/api/migrations/versions/6a4641cf192b_202603051440.py b/api/migrations/versions/6a4641cf192b_202603051440.py new file mode 100644 index 00000000..0322c9e2 --- /dev/null +++ b/api/migrations/versions/6a4641cf192b_202603051440.py @@ -0,0 +1,43 @@ +"""202603051440 + +Revision ID: 6a4641cf192b +Revises: b4af97639217 +Create Date: 2026-03-05 14:41:03.371557 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = '6a4641cf192b' +down_revision: Union[str, None] = 'b4af97639217' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('implicit_emotions_storage', + sa.Column('id', sa.UUID(), nullable=False, comment='主键ID'), + sa.Column('end_user_id', sa.String(length=255), nullable=False, comment='终端用户ID'), + sa.Column('implicit_profile', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='隐性记忆用户画像数据'), + sa.Column('emotion_suggestions', postgresql.JSONB(astext_type=sa.Text()), nullable=True, comment='情绪个性化建议数据'), + sa.Column('created_at', sa.DateTime(), nullable=False, comment='创建时间'), + sa.Column('updated_at', sa.DateTime(), nullable=False, comment='更新时间'), + sa.Column('implicit_generated_at', sa.DateTime(), nullable=True, comment='隐性记忆画像生成时间'), + sa.Column('emotion_generated_at', sa.DateTime(), nullable=True, comment='情绪建议生成时间'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('end_user_id') + ) + op.create_index('idx_updated_at', 'implicit_emotions_storage', ['updated_at'], unique=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('idx_updated_at', table_name='implicit_emotions_storage') + op.drop_table('implicit_emotions_storage') + # ### end Alembic commands ### diff --git a/api/migrations/versions/b4af97639217_202603051033.py b/api/migrations/versions/b4af97639217_202603051033.py new file mode 100644 index 00000000..ddeae41c --- /dev/null +++ b/api/migrations/versions/b4af97639217_202603051033.py @@ -0,0 +1,63 @@ +"""202603051033 + +Revision ID: b4af97639217 +Revises: 4bf27c66ae63 +Create Date: 2026-03-05 10:36:06.282227 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'b4af97639217' +down_revision: Union[str, None] = '4bf27c66ae63' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + # Add columns as nullable first to avoid table locks + op.add_column('model_api_keys', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_api_keys', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_bases', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_bases', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + op.add_column('model_configs', sa.Column('capability', sa.ARRAY(sa.String()), nullable=True, comment="模型能力列表(如['vision', 'audio', 'video'])")) + op.add_column('model_configs', sa.Column('is_omni', sa.Boolean(), nullable=True, comment='是否为Omni模型(使用特殊API调用)')) + + # Update existing rows with default values + op.execute("UPDATE model_api_keys SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_api_keys SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_bases SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_bases SET is_omni = false WHERE is_omni IS NULL") + + op.execute("UPDATE model_configs SET capability = '{}' WHERE capability IS NULL") + op.execute("UPDATE model_configs SET is_omni = false WHERE is_omni IS NULL") + + # Now make columns NOT NULL + op.alter_column('model_api_keys', 'capability', nullable=False) + op.alter_column('model_api_keys', 'is_omni', nullable=False) + + op.alter_column('model_bases', 'capability', nullable=False) + op.alter_column('model_bases', 'is_omni', nullable=False) + + op.alter_column('model_configs', 'capability', nullable=False) + op.alter_column('model_configs', 'is_omni', nullable=False) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('model_configs', 'is_omni') + op.drop_column('model_configs', 'capability') + op.drop_column('model_bases', 'is_omni') + op.drop_column('model_bases', 'capability') + op.drop_column('model_api_keys', 'is_omni') + op.drop_column('model_api_keys', 'capability') + # ### end Alembic commands ### diff --git a/web/.gitignore b/web/.gitignore index 9608e0b9..89a253b3 100644 --- a/web/.gitignore +++ b/web/.gitignore @@ -23,4 +23,4 @@ dist-ssr *.sln *.sw? vite.config.js -package-lock.json +package-lock.json \ No newline at end of file diff --git a/web/i18n-comparison-report.md b/web/i18n-comparison-report.md index 07e59b22..542c38f0 100644 --- a/web/i18n-comparison-report.md +++ b/web/i18n-comparison-report.md @@ -1,205 +1,195 @@ -# i18n 中英文对比报告 +# Memory Bear 前端项目 - 中英文国际化对比报告 -## 📊 统计概览 +生成时间: 2024 -- **中文键总数**: 1136 -- **英文键总数**: 1052 -- **中文缺失**: 27 个键 -- **英文缺失**: 111 个键 +## 📊 概览统计 + +### 文件信息 +- **中文文件**: `src/i18n/zh.ts` +- **英文文件**: `src/i18n/en.ts` + +### 模块统计 +| 模块名称 | 中文键数 | 英文键数 | 状态 | +|---------|---------|---------|------| +| translation | ✅ | ✅ | 完整 | + +## 🔍 详细对比分析 + +### 1. 主要模块对比 + +#### 1.1 基础信息 (title, memoryBear) +- ✅ **完全匹配** +- 中文: "记忆熊.AI" +- 英文: "Memory Bear.AI" + +#### 1.2 首页模块 (index) +- ✅ **完全匹配** - 包含所有子键 + +#### 1.3 版本信息 (version) +- ✅ **完全匹配** + +#### 1.4 快速操作 (quickActions) +- ✅ **完全匹配** - 包含所有功能入口 + +#### 1.5 引导模块 (guide) +- ✅ **完全匹配** + +#### 1.6 首页引导 (indexTour) +- ✅ **完全匹配** + +#### 1.7 菜单模块 (menu) +- ✅ **完全匹配** - 包含所有导航项 + +#### 1.8 仪表盘 (dashboard) +- ✅ **完全匹配** - 包含所有统计指标 + +#### 1.9 表格 (table) +- ✅ **完全匹配** + +#### 1.10 头部 (header) +- ✅ **完全匹配** + +#### 1.11 语言 (language) +- ✅ **完全匹配** + +#### 1.12 用户管理 (user) +- ✅ **完全匹配** - 包含所有用户相关功能 + +#### 1.13 时区 (timezones) +- ✅ **完全匹配** - 包含全球主要时区 + +#### 1.14 通用 (common) +- ✅ **完全匹配** - 包含所有通用操作和提示 + +#### 1.15 模型管理 (model) +- ✅ **完全匹配** + +#### 1.16 新模型管理 (modelNew) +- ✅ **完全匹配** + +#### 1.17 知识库 (knowledgeBase) +- ✅ **完全匹配** - 包含所有知识库功能 +- 包含知识图谱相关配置 + +#### 1.18 API (api) +- ✅ **完全匹配** + +#### 1.19 记忆管理 (memory) +- ✅ **完全匹配** + +#### 1.20 成员管理 (member) +- ✅ **完全匹配** + +#### 1.21 记忆摘要 (memorySummary) +- ✅ **完全匹配** + +#### 1.22 遗忘引擎 (forgettingEngine) +- ✅ **完全匹配** + +#### 1.23 应用管理 (application) +- ✅ **完全匹配** - 包含所有应用配置功能 +- 包含工作流、Agent配置等 + +#### 1.24 用户记忆 (userMemory) +- ✅ **完全匹配** - 包含所有记忆类型 + +#### 1.25 空间管理 (space) +- ✅ **完全匹配** + +#### 1.26 记忆萃取引擎 (memoryExtractionEngine) +- ✅ **完全匹配** - 包含所有配置参数 + +#### 1.27 记忆对话 (memoryConversation) +- ✅ **完全匹配** + +#### 1.28 登录 (login) +- ✅ **完全匹配** + +#### 1.29 空状态 (empty) +- ✅ **完全匹配** + +#### 1.30 API密钥 (apiKey) +- ✅ **完全匹配** + +#### 1.31 工具管理 (tool) +- ✅ **完全匹配** - 包含MCP服务、内置工具、自定义工具 + +#### 1.32 工作流 (workflow) +- ✅ **完全匹配** - 包含所有节点配置 + +#### 1.33 情感引擎 (emotionEngine) +- ✅ **完全匹配** + +#### 1.34 情感详情 (statementDetail) +- ✅ **完全匹配** + +#### 1.35 反思引擎 (reflectionEngine) +- ✅ **完全匹配** + +#### 1.36 定价 (pricing) +- ✅ **完全匹配** - 包含所有套餐信息 + +#### 1.37 遗忘详情 (forgetDetail) +- ✅ **完全匹配** + +#### 1.38 情景记忆详情 (episodicDetail) +- ✅ **完全匹配** + +#### 1.39 内隐记忆详情 (implicitDetail) +- ✅ **完全匹配** + +#### 1.40 短期记忆详情 (shortTermDetail) +- ✅ **完全匹配** + +#### 1.41 感知记忆详情 (perceptualDetail) +- ✅ **完全匹配** + +#### 1.42 外显记忆详情 (explicitDetail) +- ✅ **完全匹配** + +#### 1.43 工作记忆详情 (workingDetail) +- ✅ **完全匹配** + +#### 1.44 本体工程 (ontology) +- ✅ **完全匹配** + +#### 1.45 提示词工程 (prompt) +- ✅ **完全匹配** + +#### 1.46 技能库 (skills) +- ✅ **完全匹配** + +## ✅ 结论 + +### 整体评估 +- **状态**: 🟢 完全同步 +- **中英文键值对**: 完全匹配 +- **结构一致性**: 100% + +### 优点 +1. ✅ 所有模块的中英文翻译完整 +2. ✅ 键名结构完全一致 +3. ✅ 嵌套层级对应准确 +4. ✅ 特殊字符和变量占位符使用正确 +5. ✅ 时区、语言等枚举值完整 + +### 建议 +1. 定期检查新增功能的国际化覆盖 +2. 建议添加自动化测试确保中英文键值对同步 +3. 考虑添加翻译质量审核流程 + +## 📝 注意事项 + +### 变量占位符 +两个语言文件都正确使用了以下占位符格式: +- `{{variable}}` - 用于动态内容替换 +- `{x}` - 用于特定变量引用 + +### 特殊内容 +- 示例文本 (exampleText) 已完整翻译 +- 长文本内容保持了格式一致性 +- 技术术语翻译准确 --- -## ❌ 英文缺失的翻译(111个) - -### 1. Application 模块 (3个) -- `application.cluster` - 集群 -- `application.clusterDesc` - 创建Agent集群 -- `application.fullAmount` - 全量 - -### 2. Role 角色管理模块 (15个) -- `role.roleManagement` - 角色管理 -- `role.roleId` - 角色ID -- `role.roleName` - 角色名称 -- `role.roleCode` - 角色编码 -- `role.description` - 角色描述 -- `role.status` - 状态 -- `role.enabled` - 已启用 -- `role.disabled` - 已停用 -- `role.createTime` - 创建时间 -- `role.createRole` - 新建角色 -- `role.editRole` - 编辑角色 -- `role.roleTemplate` - 角色模板 -- `role.emptyTemplate` - 空模板 -- `role.adminTemplate` - 管理员模板 -- `role.userTemplate` - 用户模板 -- `role.confirmDelete` - 确定要删除这个角色吗? -- `role.createSuccess` - 角色创建成功 -- `role.updateSuccess` - 角色更新成功 -- `role.deleteSuccess` - 角色删除成功 -- `role.createFailed` - 角色创建失败 -- `role.updateFailed` - 角色更新失败 -- `role.deleteFailed` - 角色删除失败 - -### 3. Tenant 租户管理模块 (20个) -- `tenant.tenantId` - 租户ID -- `tenant.tenantName` - 租户名称 -- `tenant.contactPerson` - 联系人 -- `tenant.contactInfo` - 联系方式 -- `tenant.status` - 状态 -- `tenant.enabled` - 启用 -- `tenant.disabled` - 禁用 -- `tenant.expiryDate` - 到期时间 -- `tenant.createTenant` - 新增租户 -- `tenant.editTenant` - 编辑租户 -- `tenant.searchPlaceholder` - 搜索租户ID、名称、联系人或联系方式 -- `tenant.confirmDelete` - 确定要删除该租户吗? -- `tenant.confirmBatchDelete` - 确定要批量删除选中的租户吗? -- `tenant.fetchFailed` - 获取租户数据失败 -- `tenant.batchEnableSuccess` - 批量启用成功 -- `tenant.batchEnableFailed` - 批量启用失败 -- `tenant.batchDisableSuccess` - 批量停用成功 -- `tenant.batchDisableFailed` - 批量停用失败 -- `tenant.exportSuccess` - 导出成功 -- `tenant.batchDeleteSuccess` - 批量删除成功 -- `tenant.batchDeleteFailed` - 批量删除失败 -- `tenant.saveFailed` - 保存失败 -- `tenant.batchImport` - 批量导入 - -### 4. User 用户管理模块 (13个) -- `user.tenantName` - 所属租户 -- `user.password` - 密码 -- `user.expiryDate` - 有效期 -- `user.expiryDateDue` - 有效期至 -- `user.batchImport` - 批量导入 -- `user.batchImportUser` - 批量导入用户 -- `user.downloadTemplate` - 下载导入模板 -- `user.templateDownloadSuccess` - 模板下载成功 -- `user.startImport` - 开始导入 -- `user.batchImportSuccess` - 批量导入成功 -- `user.importFailed` - 导入失败,请检查文件格式 -- `user.noFileSelected` - 请选择要导入的文件 -- `user.onlyXlsxOrCsv` - 只能上传 .xlsx 或 .csv 格式的文件 -- `user.reselect` - 重新选择 -- `user.noFileSelectedTip` - 未选择任何文件 -- `user.downloadTemplateTip` - 请下载模板,填写用户信息后上传。 - -### 5. Product 产品管理模块 (13个) -- `product.applicationManagement` - 应用管理 -- `product.createApplication` - 创建应用 -- `product.applicationName` - 应用名称 -- `product.applicationIcon` - 应用图标 -- `product.applicationNameRequired` - 请输入应用名称 -- `product.associationStatus` - 关联状态 -- `product.associated` - 已关联 -- `product.notAssociated` - 未关联 -- `product.unassociate` - 解除关联 -- `product.unassociateSuccess` - 解除关联成功 -- `product.unassociateFailed` - 解除关联失败 -- `product.viewKey` - 查看KEY -- `product.viewStats` - 查看统计 -- `product.disableSuccess` - 停用成功 -- `product.enableSuccess` - 启用成功 -- `product.operationFailed` - 操作失败 - -### 6. 其他模块 (47个) -- `count` - 计数: {{count}} -- `increment` - 增加 -- `decrement` - 减少 -- `reset` - 重置 -- `switchLanguage` - 切换语言 -- `home.title` - 首页 -- `home.welcome` - 欢迎使用我们的带单页路由的 React 应用! -- `home.counterCard` - 计数器演示 -- `home.aboutCard` - 关于我们 -- `home.workflowCard` - 工作流编辑器 -- `home.websocketDemoCard` - WebSocket 演示 -- `home.sseDemoCard` - SSE演示 -- `workflow.title` - 工作流编辑器 -- `workflow.description` - 拖拽节点创建连接,构建您的工作流程。点击节点可进行配置。 -- `workflow.addNode` - 添加节点 -- `workflow.deleteNode` - 删除选中 -- `workflow.saveWorkflow` - 保存工作流 -- `workflow.startNode` - 触发节点 -- `workflow.conditionNode` - 条件判断 -- `workflow.actionNode` - 执行动作 -- `workflow.endNode` - 结束节点 -- `workflow.newNode` - 新节点 -- `workflow.node` - 节点 -- `workflow.nodesCreated` - 已创建节点 -- `workflow.loadingNodes` - 正在加载节点 {{progress}}% -- `workflow.loadingFailed` - 加载节点失败 -- `workflow.create5kNodes` - 创建5000节点 -- `workflow.create10kNodes` - 创建10000节点 -- `notFound.title` - 页面未找到 -- `notFound.description` - 请求的页面不存在。 -- `notFound.backToHome` - 返回首页 - ---- - -## ✅ 中文缺失的翻译(27个) - -### 1. Common 通用模块 (1个) -- `common.operateSuccess` - Operation successful - -### 2. KnowledgeBase 知识库模块 (3个) -- `knowledgeBase.models` - Model -- `knowledgeBase.owner` - Owner -- `knowledgeBase.operation` - Operation - -### 3. Application 应用模块 (15个) -- `application.multi_agent` - Cluster -- `application.multi_agentDesc` - Create an Agent Cluster -- `application.current` - Current -- `application.versionName` - Version Name -- `application.versionNameTip` - Version number format: v[major version number].[next version number].[revision number] (e.g. v1.3.0) -- `application.agentName` - Agent Name -- `application.roleType` - Role Type -- `application.coordinator` - Coordinator -- `application.analyzer` - Analyzer -- `application.executor` - Executor -- `application.reviewer` - Reviewer -- `application.updateSubAgent` - Update Sub Agent -- `application.subAgentMaxLength` - Sub Agent maximum {{maxLength}} -- `application.capabilities` - Capabilities - -### 4. Space 空间模块 (5个) -- `space.storageType` - Storage Type -- `space.rag` - RAG storage -- `space.ragDesc` - Based on vector retrieval, suitable for document Q&A and semantic search -- `space.neo4j` - Graph storage -- `space.neo4jDesc` - Based on knowledge graph, suitable for relational reasoning and path query - -### 5. MemoryExtractionEngine 记忆提取引擎模块 (4个) -- `memoryExtractionEngine.coreEntitiesAfterDedup` - Core entities after deduplication -- `memoryExtractionEngine.extractRelationalTriples` - Extracted relational triples (partial) -- `memoryExtractionEngine.extractRelationalTriplesDesc` - There are a total of {{count}} segments with clear semantic boundaries -- `memoryExtractionEngine.theEffectOfEntityDisambiguationLLMDriven` - The effect of entity disambiguation (LLM driven) - ---- - -## 🎯 建议 - -### 优先级 1 - 核心功能模块(需要立即补充) -1. **Role 角色管理** - 完整模块缺失(15个键) -2. **Tenant 租户管理** - 完整模块缺失(20个键) -3. **Product 产品管理** - 完整模块缺失(13个键) -4. **User 用户管理扩展** - 批量导入功能缺失(13个键) - -### 优先级 2 - 功能增强(建议补充) -1. **Application 应用模块** - 多代理相关功能(15个键) -2. **Space 空间模块** - 存储类型配置(5个键) -3. **MemoryExtractionEngine** - 实体去重相关(4个键) - -### 优先级 3 - 演示/测试功能(可选) -1. **Home/Workflow/NotFound** - 演示页面(30个键) -2. **通用计数器功能** - 测试功能(5个键) - ---- - -## 📝 下一步行动 - -1. **补充英文翻译**: 优先补充 Role、Tenant、Product、User 模块的英文翻译 -2. **补充中文翻译**: 补充 Application、Space、MemoryExtractionEngine 模块的中文翻译 -3. **清理无用翻译**: 如果 Home/Workflow 等演示功能不再使用,可以考虑从中文文件中移除 -4. **建立翻译规范**: 建议建立翻译键的命名规范和审查流程,避免未来出现遗漏 - +**报告生成完成** ✨ diff --git a/web/src/api/application.ts b/web/src/api/application.ts index 244f3503..c769dd91 100644 --- a/web/src/api/application.ts +++ b/web/src/api/application.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 13:59:45 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 13:59:45 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 12:08:42 */ import { request } from '@/utils/request' import type { ApplicationModalData } from '@/views/ApplicationManagement/types' @@ -120,3 +120,19 @@ export const copyApplication = (app_id: string, new_name: string) => { export const getAppStatistics = (app_id: string, data: { start_date: number; end_date: number; }) => { return request.get(`/apps/${app_id}/statistics`, data) } +// Upload workflow and analyze compatibility +export const importWorkflow = (formData: FormData) => { + return request.uploadFile(`/apps/workflow/import`, formData) +} +// Complete workflow import +export const completeImportWorkflow = (data: { temp_id: string; name?: string; description?: string }) => { + return request.post(`/apps/workflow/import/save`, data) +} +// Get experience config +export const getExperienceConfig = (share_token: string) => { + return request.get(`/public/share/config`, {}, { + headers: { + 'Authorization': `Bearer ${localStorage.getItem(`shareToken_${share_token}`)}` + } + }) +} \ No newline at end of file diff --git a/web/src/api/knowledgeBase.ts b/web/src/api/knowledgeBase.ts index 1353067e..60ed2403 100644 --- a/web/src/api/knowledgeBase.ts +++ b/web/src/api/knowledgeBase.ts @@ -154,7 +154,7 @@ export const uploadFile = async (data: FormData, options?: UploadFileOptions) => // 下载文件 export const downloadFile = async (fileId: string, fileName?: string) => { const token = cookieUtils.get('authToken'); - const url = `${apiPrefix}/files/${fileId}`; + const url = `/api/files/${fileId}`; try { const response = await fetch(url, { diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index cb917ec1..2c840c9a 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-03 14:58:32 + * @Last Modified time: 2026-03-04 10:58:41 */ import { request } from '@/utils/request' import type { @@ -98,8 +98,8 @@ export const getMemorySearchEdges = (end_user_id: string) => { return request.get(`/memory-storage/analytics/graph_data`, { end_user_id }) } // User Memory - User interest distribution -export const getHotMemoryTagsByUser = (end_user_id: string) => { - return request.get(`/memory/analytics/hot_memory_tags/by_user`, { end_user_id }) +export const getInterestDistributionByUser = (end_user_id: string) => { + return request.get(`/memory/analytics/interest_distribution/by_user`, { end_user_id }) } // User Memory - Total memory count export const getTotalMemoryCountByUser = (end_user_id: string) => { diff --git a/web/src/assets/images/workflow/unknown.svg b/web/src/assets/images/workflow/unknown.svg new file mode 100644 index 00000000..4c8198dd --- /dev/null +++ b/web/src/assets/images/workflow/unknown.svg @@ -0,0 +1,26 @@ + + + 未知节点 + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/AudioRecorder/index.tsx b/web/src/components/AudioRecorder/index.tsx index f6a030b4..d31746f6 100644 --- a/web/src/components/AudioRecorder/index.tsx +++ b/web/src/components/AudioRecorder/index.tsx @@ -1,16 +1,21 @@ import { type FC, useRef, useState } from 'react' import RecordRTC from 'recordrtc' -import { fileUpload } from '@/api/fileStorage' +import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' +import { request } from '@/utils/request' interface AudioRecorderProps { - onRecordingComplete?: (file: { file_id: string; file_key: string; }, blob: Blob) => void - className?: string + onRecordingComplete?: (file: { file_id: string; file_key: string; url: string; type?: string; }, blob?: Blob) => void + className?: string; + action?: string; + requestConfig?: Record; } const AudioRecorder: FC = ({ onRecordingComplete, className = '', + action = fileUploadUrlWithoutApiPrefix, + requestConfig = {} }) => { const [isRecording, setIsRecording] = useState(false) const recorderRef = useRef(null) @@ -33,11 +38,17 @@ const AudioRecorder: FC = ({ if (recorderRef.current) { recorderRef.current.stopRecording(() => { const blob = recorderRef.current!.getBlob() + const url = recorderRef.current!.toURL() const formData = new FormData() formData.append('file', blob, `recording_${Date.now()}.webm`) - fileUpload(formData) + request + .uploadFile(action, formData, requestConfig) .then(res => { - onRecordingComplete?.(res as { file_id: string; file_key: string; }, blob) + onRecordingComplete?.({ + ...(res as { file_id: string; file_key: string }), + type: blob.type, + url + }, blob) recorderRef.current?.destroy() recorderRef.current = null }) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index 32e6ae23..c1f5223c 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -27,12 +27,45 @@ const ChatContent: FC = ({ }) => { // Scroll container reference for controlling auto-scroll to bottom const scrollContainerRef = useRef<(HTMLDivElement | null)>(null) + const prevDataLengthRef = useRef(data.length); + const isScrolledToBottomRef = useRef(true); // Track if user is scrolled to bottom + + // Track scroll position to determine if user is at bottom + useEffect(() => { + const handleScroll = () => { + if (scrollContainerRef.current) { + const { scrollTop, scrollHeight, clientHeight } = scrollContainerRef.current; + // Consider user is at bottom if within 20px of the bottom + isScrolledToBottomRef.current = scrollHeight - scrollTop - clientHeight < 20; + } + }; + + const container = scrollContainerRef.current; + if (container) { + container.addEventListener('scroll', handleScroll); + // Initial check + handleScroll(); + } + + return () => { + if (container) { + container.removeEventListener('scroll', handleScroll); + } + }; + }, []); // Auto-scroll to bottom when data changes to show latest messages + // When data array length remains unchanged, if data is updated and user manually scrolled up, don't auto-scroll to bottom + // When data array length changes, auto-scroll to bottom + // If already scrolled to bottom, will auto-scroll to bottom useEffect(() => { setTimeout(() => { if (scrollContainerRef.current) { - scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight; + // Auto-scroll if data length changed OR user is currently at bottom + if (data.length !== prevDataLengthRef.current || isScrolledToBottomRef.current) { + scrollContainerRef.current.scrollTop = scrollContainerRef.current.scrollHeight; + } + prevDataLengthRef.current = data.length; } }, 0); }, [data]) diff --git a/web/src/components/Chat/ChatInput.tsx b/web/src/components/Chat/ChatInput.tsx index c155bb22..508b0d0c 100644 --- a/web/src/components/Chat/ChatInput.tsx +++ b/web/src/components/Chat/ChatInput.tsx @@ -2,10 +2,11 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:14 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 12:13:52 + * @Last Modified time: 2026-03-06 13:36:20 */ import { type FC, useEffect, useMemo } from 'react' import { Flex, Input, Form } from 'antd' + import SendIcon from '@/assets/images/conversation/send.svg' import SendDisabledIcon from '@/assets/images/conversation/sendDisabled.svg' import LoadingIcon from '@/assets/images/conversation/loading.svg' @@ -49,13 +50,17 @@ const ChatInput: FC = ({ const handleDelete = (file: any) => { - fileChange?.(fileList?.filter(item => item.uid !== file.uid) || []) + fileChange?.(fileList?.filter(item => { + return item.thumbUrl && file.thumbUrl ? item.thumbUrl !== file.thumbUrl + : item.url && file.url ? item.url !== file.url + : item.uid !== file.uid + }) || []) } // Convert file object to preview URL const previewFileList = useMemo(() => { return fileList?.map(file => ({ ...file, - url: file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : file.thumbUrl) + url: file.thumbUrl || file.url || (file.originFileObj ? URL.createObjectURL(file.originFileObj) : undefined) })) || [] }, [fileList]) @@ -71,7 +76,7 @@ const ChatInput: FC = ({ {previewFileList.map((file) => { if (file.type.includes('image')) { return ( -
+
{file.name}
= ({
) } + if (file.type.includes('video')) { + return ( +
+
+ ) + } + if (file.type.includes('audio')) { + return ( +
+
+ ) + } return ( -
- {(file.type.includes('word') || file.type.includes('wordprocessingml.document')) &&
+ {(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) &&
} {(file.type.includes('pdf')) &&
[]; files?: any[]; + error?: string; } /** diff --git a/web/src/components/SearchInput/index.tsx b/web/src/components/SearchInput/index.tsx index 32a64310..476c2cbb 100644 --- a/web/src/components/SearchInput/index.tsx +++ b/web/src/components/SearchInput/index.tsx @@ -41,6 +41,8 @@ interface SearchInputProps { className?: string; /** Input size */ size?: InputProps['size'] + /** Maximum length of the input value */ + maxLength?: number; } /** Search input component with debounce and throttle support */ diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 352fc4b6..ad9680d3 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -453,6 +453,11 @@ export const en = { prevStep: 'Previous Step', exportSuccess: 'Export successful', recommend: 'Recommend', + default: 'Default', + logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`, + imageSquareRequired: 'Please upload a square image', + nameInvalid: 'Name cannot start or end with a space', + notAllSpaces: 'Cannot be all spaces', }, model: { searchPlaceholder: 'search model…', @@ -542,7 +547,7 @@ export const en = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", }, modelNew: { group: 'Model Group', @@ -600,7 +605,13 @@ export const en = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + + is_vision: 'Vision Support', + is_omni: 'Omni Support', + vision: 'Vision', + audio: 'Audio', + video: 'Video', }, knowledgeBase: { home: 'Home', @@ -608,6 +619,7 @@ export const en = { preview: 'Preview', pleaseUploadFileFirst: 'Please upload file first', shareSuccess: 'Share successfully', + stopShareSuccess: 'Sharing is off. Access denied. ', shareFailed: 'Share failed', allModels: 'All Models', knowledgeBaseInfo: 'Knowledge base information', @@ -1341,12 +1353,24 @@ export const en = { dynamicMatchSkill: 'Dynamic Match Skill', executeTask: 'Execute Task', + import: 'Import Application', + importWorkflow: 'Third-Party Workflow', + importThirdParty: 'Import Workflow', + platform: 'Source Platform', upload: 'Upload & Parse', complex: 'Compatibility Analysis', - node: 'Node Mapping', - configCheck: 'Configuration Validation', sureInfo: 'Information Confirmation', completed: 'Import Completed', + baseInfo: 'Basic Information', + workflowName: 'Workflow Name', + fileName: 'File Name', + fileSize: 'File Size', + importSuccess: 'Import Success', + importSuccessDesc: 'Workflow imported successfully, you can view and manage it in the application management', + gotoList: 'Return to Application List', + gotoDetail: 'View Details', + dify: 'Dify', + pleaseUploadFile: 'Please upload workflow file', }, userMemory: { userMemory: 'User Memory', @@ -1549,7 +1573,7 @@ export const en = { intelligentSemanticPruningFunction: 'Intelligent Semantic Pruning Function', intelligentSemanticPruningFunctionDesc: 'Whether to activate intelligent semantic pruning (true/false).', intelligentSemanticPruningScene: 'Intelligent Semantic Pruning Scene', - intelligentSemanticPruningSceneDesc: 'Select intelligent semantic pruning scene (education, online_service, outbound).', + intelligentSemanticPruningSceneDesc: 'Semantic pruning scenarios are consistent with ontology engineering scenarios', intelligentSemanticPruningThreshold: 'Intelligent Semantic Pruning Threshold', intelligentSemanticPruningThresholdDesc: 'Set intelligent semantic pruning threshold (0-0.9).', reflectionEngine: 'Self-Reflexion Engine', @@ -1632,6 +1656,10 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re scene_type_distribution: 'Scene Type Distribution', general_type_distribution: 'General Type Distribution', unmatched: 'Unmatched', + disagreementCase: 'Disagreement Case', + Pruned: 'Pruned', + pruning: 'Pruning', + pruning_desc: 'Text pruning {{count}} fragments' }, memoryConversation: { searchPlaceholder: 'Enter user ID...', @@ -1665,8 +1693,11 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re uploadFile: 'Upload File', fileType: 'File Type', image: 'Image', + video: 'Video', + audio: 'Audio', fileUrl: 'File URL', - addRemoteFile: 'Add Remote File' + addRemoteFile: 'Add Remote File', + variableConfig: 'Variable Configuration', }, login: { title: 'Red Bear Memory Science', @@ -1754,6 +1785,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re mcp: 'MCP Services', inner: 'Built-in Tools', custom: 'Custom Tools', + market: 'Tool Market', mcpSearchPlaceholder: 'Search MCP Services...', innerSearchPlaceholder: 'Search Tools...', customSearchPlaceholder: 'Search Custom Tools...', @@ -1927,7 +1959,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re path: 'Path', viewDetail: 'View Details', textLink: 'Test Connection', - noResult: 'Processing results will be displayed here' + noResult: 'Processing results will be displayed here', + serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces', + requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore', }, workflow: { coreNode: 'Core Nodes', @@ -1973,6 +2007,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re evolutionAndGovernance: 'Evolution & Governance', self_optimization: 'Self Optimization', process_evolution: 'Process Evolution', + unknown: 'Unknown Node', clickToConfigure: 'Click to configure node parameters', nodeProperties: 'Node Properties', @@ -2173,7 +2208,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re save: 'Save', export: 'Export', variableConfig: 'Variable Configuration', - variableRequired: 'Required', + variableRequired: 'Required, please configure variable values', addMessage: 'Add Message', answerDesc: 'Reply', addNode: 'Add Node', @@ -2582,6 +2617,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re updated_at: 'Updated At', entityTypes: 'Entity Types', + classSearchPlaceholder: 'Search types', addClass: 'Add Type', class_name: 'Type Name', class_description: 'Type Definition', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 92f0710c..c4d2df71 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -96,7 +96,7 @@ export const zh = { createMemorySummary: '创建记忆摘要', memoryManagement: '记忆管理', spaceManagement: '空间管理', - memoryExtractionEngine: '记忆提取引擎', + memoryExtractionEngine: '记忆萃取引擎', forgettingEngine: '遗忘引擎', apiKeyManagement: 'API KEY管理', knowledgePrivate: '详情', @@ -122,6 +122,7 @@ export const zh = { preview: '预览', pleaseUploadFileFirst: '请先上传文件', shareSuccess: '分享成功', + stopShareSuccess: '已取消分享,对方将无法访问该知识库', shareFailed: '分享失败', allModels: '所有模型', knowledgeBaseInfo: '知识库信息', @@ -736,12 +737,24 @@ export const zh = { dynamicMatchSkill: '动态匹配技能', executeTask: '执行任务', + import: '导入应用', + importWorkflow: '第三方工作流', + importThirdParty: '导入工作流', + platform: '来源平台', upload: '上传与解析', complex: '兼容性分析', - node: '节点映射', - configCheck: '配置校验', sureInfo: '信息确认', completed: '完成导入', + baseInfo: '基本信息', + workflowName: '工作流名称', + fileName: '文件名称', + fileSize: '文件大小', + importSuccess: '导入成功', + importSuccessDesc: '您的工作流已成功导入,可以在应用管理中查看和管理', + gotoList: '返回应用列表', + gotoDetail: '查看详情', + dify: 'Dify', + pleaseUploadFile: '请上传工作流文件', }, table: { totalRecords: '共 {{total}} 条记录' @@ -1020,6 +1033,11 @@ export const zh = { prevStep: '上一步', exportSuccess: '导出成功', recommend: '推荐', + default: '默认', + logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`, + imageSquareRequired: '请上传正方形比例图片', + nameInvalid: '不能是空格开头或结尾', + notAllSpaces: '不能是纯空格', }, model: { searchPlaceholder: '搜索模型…', @@ -1167,7 +1185,13 @@ export const zh = { ollama: "Ollama", xinference: "Xinference", gpustack: "Gpustack", - bedrock: "Bedrock" + bedrock: "Bedrock", + + is_vision: '支持视觉', + is_omni: '支持全模态', + vision: '视觉', + audio: '音频', + video: '视频', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', @@ -1259,7 +1283,7 @@ export const zh = { createConfiguration: '创建配置', editConfiguration: '编辑配置', desc: '描述', - memoryExtractionEngine: '记忆提取引擎', + memoryExtractionEngine: '记忆萃取引擎', forgottenEngine: '遗忘引擎', active: '活跃', inactive: '不活跃', @@ -1547,7 +1571,7 @@ export const zh = { intelligentSemanticPruningFunction: '智能语义修剪功能', intelligentSemanticPruningFunctionDesc: '是否激活智能语义修剪(true/false)。', intelligentSemanticPruningScene: '智能语义修剪场景', - intelligentSemanticPruningSceneDesc: '选择智能语义修剪场景(education、online_service、outbound)。', + intelligentSemanticPruningSceneDesc: '语义剪枝场景与本体工程场景一致', intelligentSemanticPruningThreshold: '智能语义修剪阈值', intelligentSemanticPruningThresholdDesc: '设置智能语义修剪阈值(0-0.9)。', reflectionEngine: '自我反思引擎', @@ -1628,6 +1652,10 @@ export const zh = { scene_type_distribution: '场景类型', general_type_distribution: '通用类型', unmatched: '未匹配', + disagreementCase: '不一致案例', + Pruned: '已剪枝', + pruning: '剪枝', + pruning_desc: '文本剪枝{{count}}个片段' }, memoryConversation: { chatEmpty:'有什么我可以帮您的吗?', @@ -1661,8 +1689,11 @@ export const zh = { uploadFile: '上传文件', fileType: '文件类型', image: '图片', + video: '视频', + audio: '音频', fileUrl: '文件链接', - addRemoteFile: '添加远程文件' + addRemoteFile: '添加远程文件', + variableConfig: '变量配置', }, login: { title: '红熊记忆科学', @@ -1750,6 +1781,7 @@ export const zh = { mcp: 'MCP 服务', inner: '内置工具', custom: '自定义工具', + market: '工具市场', mcpSearchPlaceholder: '搜索MCP服务...', innerSearchPlaceholder: '搜索工具...', customSearchPlaceholder: '搜索自定义工具...', @@ -1923,7 +1955,9 @@ export const zh = { path: '路径', viewDetail: '查看详情', textLink: '测试连接', - noResult: '处理结果将显示在这里' + noResult: '处理结果将显示在这里', + serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格', + requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾', }, workflow: { coreNode: '核心节点', @@ -1969,6 +2003,7 @@ export const zh = { evolutionAndGovernance: '演化与治理', self_optimization: '自我优化', process_evolution: '流程演化', + unknown: '未知节点', clickToConfigure: '点击配置节点参数', nodeProperties: '节点属性', @@ -2156,6 +2191,9 @@ export const zh = { output_variables: '输出变量', refreshTip: '同步函数签名至代码', }, + unknown: { + replaceNodeType: '替换节点' + }, name: '键', type: '类型', value: '值', @@ -2169,7 +2207,7 @@ export const zh = { save: '保存', export: '导出', variableConfig: '变量配置', - variableRequired: '必填', + variableRequired: '必填,请配置变量值', addMessage: '添加消息', answerDesc: '回复', addNode: '添加节点', @@ -2187,7 +2225,8 @@ export const zh = { iteration: '迭代', input_cycle_vars: '初始循环变量', output_cycle_vars: '最终循环变量', - } + }, + sureReplace: '确认替换', }, emotionEngine: { emotionEngineConfig: '情感引擎配置', @@ -2578,6 +2617,7 @@ export const zh = { updated_at: '更新时间', entityTypes: '实体类型', + classSearchPlaceholder: '搜索类型', addClass: '添加类型', class_name: '类型名称', class_description: '类型定义', diff --git a/web/src/utils/request.ts b/web/src/utils/request.ts index 3c3e8fa2..3f81d4ab 100644 --- a/web/src/utils/request.ts +++ b/web/src/utils/request.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 16:35:15 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 16:35:15 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-06 10:39:00 */ /** * HTTP Request Utility Module @@ -183,7 +183,9 @@ service.interceptors.response.use( msg = msg || i18n.t('common.serverError'); break; default: - if (!msg && Array.isArray(error.response?.data?.detail)) { + if (['SYSTEM_DEFAULT_SCENE_CANNOT_DELETE', 'SYSTEM_DEFAULT_CLASS_CANNOT_DELETE', 'SYSTEM_DEFAULT_SCENE_CANNOT_UPDATE'].includes(msg)) { + msg = i18n.t(`common.${msg}`) + } else if (!msg && Array.isArray(error.response?.data?.detail)) { msg = error.response?.data?.detail?.map((item: { msg: string }) => item.msg).join(';') } else { msg = msg || i18n.t('common.unknownError'); @@ -354,12 +356,11 @@ export const request = { * Get parent domain for cookie setting * @returns Parent domain or IP address */ +const isIp = (hostname: string) => /^\d+\.\d+\.\d+\.\d+$/.test(hostname) + const getParentDomain = () => { const hostname = window.location.hostname - // Check if it's an IP address - if (/^\d+\.\d+\.\d+\.\d+$/.test(hostname)) { - return hostname - } + if (isIp(hostname)) return hostname const parts = hostname.split('.') return parts.length > 2 ? `.${parts.slice(-2).join('.')}` : hostname } @@ -369,7 +370,10 @@ const getParentDomain = () => { */ export const cookieUtils = { set: (name: string, value: string, domain = getParentDomain()) => { - document.cookie = `${name}=${value}; domain=${domain}; path=/; secure; samesite=strict` + const ip = isIp(window.location.hostname) + const domainPart = ip ? '' : `; domain=${domain}` + const securePart = window.location.protocol === 'https:' ? '; secure' : '' + document.cookie = `${name}=${value}${domainPart}; path=/${securePart}; samesite=strict` }, get: (name: string) => { const value = `; ${document.cookie}` diff --git a/web/src/utils/stream.ts b/web/src/utils/stream.ts index b637e76a..846af9f7 100644 --- a/web/src/utils/stream.ts +++ b/web/src/utils/stream.ts @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-02 16:35:43 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-02 16:35:43 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-04 18:19:24 */ /** * Server-Sent Events (SSE) Stream Utility Module @@ -176,17 +176,17 @@ export const handleSSE = async (url: string, data: any, onMessage?: (data: SSEMe case 500: case 502: const errorData = await response.json(); - errorData.error || i18n.t('common.serviceUpgrading'); - message.warning(errorData.error || i18n.t('common.serviceUpgrading')); - return; + let errorInfo = errorData.error || i18n.t('common.serviceUpgrading') + message.warning(errorInfo); + throw errorInfo; case 400: const error = await response.json(); message.warning(error.error); - throw error || 'Bad Request'; + throw error.error || 'Bad Request'; case 504: const errorJson = await response.json(); message.warning(errorJson.error || i18n.t('common.serverError')); - return; + throw errorData.error; case 401: if (url?.includes('/public')) { return message.warning(i18n.t('common.publicApiCannotRefreshToken')); diff --git a/web/src/utils/validator.ts b/web/src/utils/validator.ts new file mode 100644 index 00000000..650266ab --- /dev/null +++ b/web/src/utils/validator.ts @@ -0,0 +1,49 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-03-02 13:46:53 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-02 14:38:33 + */ +/** + * Form validation utilities + */ + +interface UploadFile { + originFileObj: Blob; + [key: string]: unknown; +} + +/** + * Validate if uploaded image is square (width === height) + * @param errorMessage - Error message to display when validation fails + * @returns Ant Design form validator + */ +export const validateSquareImage = (errorMessage: string = 'Image must be square') => { + return (_: unknown, value: UploadFile | UploadFile[] | undefined) => { + if (!value || (Array.isArray(value) && value.length === 0)) { + return Promise.resolve(); + } + + const file = Array.isArray(value) ? value[0] : value; + + if (file?.originFileObj) { + return new Promise((resolve, reject) => { + const img = new Image(); + img.onload = () => { + if (img.width === img.height) { + resolve(); + } else { + reject(new Error(errorMessage)); + } + }; + img.onerror = () => reject(new Error('Failed to load image')); + img.src = URL.createObjectURL(file.originFileObj); + }); + } + + return Promise.resolve(); + }; +}; + +// - Cannot start or end with a space +export const stringRegExp = /^(?!\s).*(?(({ form.validateFields() .then((values) => { const { memory, rag, expires_at, ...rest } = values - let scopes = [] + const scopes = [] if (memory) { scopes.push('memory') @@ -130,7 +131,11 @@ const ApiKeyModal = forwardRef(({ @@ -138,6 +143,7 @@ const ApiKeyModal = forwardRef(({ diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 2ece4b6e..237c3373 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-25 18:11:49 + * @Last Modified time: 2026-03-03 14:24:34 */ import { type FC, type ReactNode, useEffect, useRef, useState, forwardRef, useImperativeHandle } from 'react'; import clsx from 'clsx' @@ -169,12 +169,16 @@ const Agent = forwardRef((_props, ref) => { getApplicationConfig(id as string).then(res => { const response = res as Config const { skills, variables } = response - let allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] - let allTools = Array.isArray(response.tools) ? response.tools : [] + const allSkills = Array.isArray(skills?.skill_ids) ? skills?.skill_ids.map(vo => ({ id: vo })) : [] + const allTools = Array.isArray(response.tools) ? response.tools : [] const memoryContent = response.memory?.memory_config_id const parsedMemoryContent = memoryContent === null || memoryContent === '' ? undefined : !isNaN(Number(memoryContent)) ? Number(memoryContent) : memoryContent + const variableList = variables?.map((item, index) => ({ + ...item, + index + })) || [] form.setFieldsValue({ ...response, tools: allTools, @@ -185,9 +189,10 @@ const Agent = forwardRef((_props, ref) => { skills: { ...skills, skill_ids: allSkills - } + }, + variables: [...variableList] }) - updateVariableList([...variables]) + updateVariableList([...variableList]) setData({ ...response, tools: allTools @@ -398,6 +403,9 @@ const Agent = forwardRef((_props, ref) => { const handleSaveChatVariable = (values: Variable[]) => { setChatVariables(values) } + useEffect(() => { + setChatVariables(values?.variables || []) + }, [values?.variables]) console.log('values', values) return ( <> @@ -431,7 +439,11 @@ const Agent = forwardRef((_props, ref) => {
- + ((_props, ref) => { chatList={chatList} updateChatList={setChatList} handleSave={handleSave} + chatVariables={chatVariables} /> diff --git a/web/src/views/ApplicationConfig/Api.tsx b/web/src/views/ApplicationConfig/Api.tsx index c4b0fefb..22cec3e8 100644 --- a/web/src/views/ApplicationConfig/Api.tsx +++ b/web/src/views/ApplicationConfig/Api.tsx @@ -29,7 +29,7 @@ const Api: FC<{ application: Application | null }> = ({ application }) => { const { t } = useTranslation(); const activeMethods = ['POST']; const { message, modal } = App.useApp() - const copyContent = window.location.origin + '/v1/chat' + const copyContent = window.location.origin + '/v1/app/chat' const apiKeyModalRef = useRef(null); const apiKeyConfigModalRef = useRef(null); const [apiKeyList, setApiKeyList] = useState([]) diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 794489c6..17af7613 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:27:39 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 17:40:15 + * @Last Modified time: 2026-03-05 17:03:46 */ /** * Chat debugging component for application testing @@ -13,7 +13,7 @@ import { type FC, useEffect, useState, useRef } from 'react'; import { useTranslation } from 'react-i18next'; import clsx from 'clsx' -import { Flex, Dropdown, type MenuProps } from 'antd' +import { Flex, Dropdown, type MenuProps, App, Divider } from 'antd' import ChatIcon from '@/assets/images/application/chat.png' import DebuggingEmpty from '@/assets/images/application/debuggingEmpty.png' @@ -25,9 +25,10 @@ import type { ChatItem } from '@/components/Chat/types' import { type SSEMessage } from '@/utils/stream' import ChatInput from '@/components/Chat/ChatInput' import UploadFiles from '@/views/Conversation/components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import UploadFileListModal from '@/views/Conversation/components/UploadFileListModal' import type { UploadFileListModalRef } from '@/views/Conversation/types' +import type { Variable } from './VariableList/types' /** * Component props @@ -43,14 +44,16 @@ interface ChatProps { handleSave: (flag?: boolean) => Promise; /** Source type: multi-agent cluster or single agent */ source?: 'multi_agent' | 'agent'; + chatVariables?: Variable[]; // Add chatVariables prop } /** * Chat debugging component * Allows testing application with different model configurations side-by-side */ -const Chat: FC = ({ chatList, data, updateChatList, handleSave, source = 'agent' }) => { +const Chat: FC = ({ chatList, data, updateChatList, handleSave, source = 'agent', chatVariables }) => { const { t } = useTranslation(); + const { message: messageApi } = App.useApp() const [loading, setLoading] = useState(false) const [isCluster, setIsCluster] = useState(source === 'multi_agent') const [conversationId, setConversationId] = useState(null) @@ -85,7 +88,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc content: '', created_at: Date.now(), }; - + if (isCluster) { updateChatList(prev => prev.map(item => ({ ...item, @@ -131,7 +134,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update assistant message when error occurs */ - const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { + const updateErrorAssistantMessage = (message_length: number, model_config_id?: string) => { if (message_length > 0 || !model_config_id) return updateChatList(prev => { @@ -168,6 +171,29 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc .then(() => { const message = msg if (!message?.trim()) return + // Validate required variables before sending + let isCanSend = true + const params: Record = {} + if (chatVariables && chatVariables.length > 0) { + const needRequired: string[] = [] + chatVariables.forEach(vo => { + params[vo.name] = vo.value + + if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { + isCanSend = false + needRequired.push(vo.name) + } + }) + + if (needRequired.length) { + messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) + } + } + if (!isCanSend) { + setLoading(false) + setCompareLoading(false) + return + } addUserMessage(message, fileList) setMessage(message) @@ -214,12 +240,20 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc model_parameters: item.model_parameters, conversation_id: item.conversation_id })), - variables: {}, + variables: params, "parallel": true, "stream": true, "timeout": 60, }, handleStreamMessage) - .finally(() => setLoading(false)); + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -264,7 +298,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }) } /** Update cluster message when error occurs */ - const updateClusterErrorAssistantMessage = (message_length: number) => { + const updateClusterErrorAssistantMessage = (message_length: number) => { if (message_length > 0) return updateChatList(prev => { @@ -307,7 +341,7 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc data.map(item => { const { conversation_id, content, message_length } = item.data as { conversation_id: string, content: string, message_length: number }; - switch(item.event) { + switch (item.event) { case 'start': if (conversation_id && conversationId !== conversation_id) { setConversationId(conversation_id); @@ -330,27 +364,35 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc }; setTimeout(() => { - draftRun( - data.app_id, - { - message, - conversation_id: conversationId, - stream: true, - files: fileList.map(file => { - if (file.url) { - return file - } else { - return { - type: file.type, - transfer_method: 'local_file', - upload_file_id: file.response.data.file_id - } + draftRun( + data.app_id, + { + message, + conversation_id: conversationId, + stream: true, + files: fileList.map(file => { + if (file.url) { + return file + } else { + return { + type: file.type, + transfer_method: 'local_file', + upload_file_id: file.response.data.file_id } - }), - }, - handleStreamMessage - ) - .finally(() => setLoading(false)) + } + }), + }, + handleStreamMessage + ) + .catch(() => { + setLoading(false) + setCompareLoading(false) + updateClusterErrorAssistantMessage(0) + }) + .finally(() => { + setLoading(false) + setCompareLoading(false) + }) }, 0) }) .catch(() => { @@ -369,12 +411,17 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc const fileChange = (file?: any) => { setFileList([...fileList, file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + setFileList([...fileList, { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { - switch(key) { + switch (key) { case 'define': uploadFileListModalRef.current?.handleOpen() break @@ -391,99 +438,98 @@ const Chat: FC = ({ chatList, data, updateChatList, handleSave, sourc return (
{chatList.length === 0 - ? - : <> -
- {chatList.map((chat, index) => ( -
1, - })}> - {chat.label && -
-
-
{chat.label}
-
handleDelete(index)} - >
+ : <> +
+ {chatList.map((chat, index) => ( +
1, + })}> + {chat.label && +
+
+
{chat.label}
+
handleDelete(index)} + >
+
-
- } - } - data={chat.list || []} - streamLoading={compareLoading} - labelPosition="top" - labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} - errorDesc={t('application.ReplyException')} - /> -
- ))} -
-
- - - - - ) - }, - ], - onClick: handleShowUpload + } + -
-
+ contentClassNames={{ + 'rb:max-w-[400px]!': chatList.length === 1, + 'rb:max-w-[260px]!': chatList.length === 2, + 'rb:max-w-[150px]!': chatList.length === 3, + 'rb:max-w-[108px]!': chatList.length === 4, + }} + empty={} + data={chat.list || []} + streamLoading={compareLoading} + labelPosition="top" + labelFormat={(item) => item.role === 'user' ? t('application.you') : chat.label} + errorDesc={t('application.ReplyException')} + /> +
+ ))} +
+
+ + + + + ) + }, + ], + onClick: handleShowUpload + }} + > +
+
+
+ + + +
- {/* - - - */} - -
-
- + +
+ } = ({ label: t(`application.${key}`), })) } - /** - * Format dropdown menu items - */ - const formatMenuItems = () => { - const items = ['edit', 'copy', 'export', 'delete'].map(key => ({ - key, - icon: , - label: t(`common.${key}`), - })) - return { - items, - onClick: handleClick - } - } /** * Handle menu item click */ @@ -106,6 +93,8 @@ const ConfigHeader: FC = ({ copyModalRef.current?.handleOpen() break; case 'export': + console.log('export', workflowRef?.current?.config) + exportToYaml(workflowRef?.current?.config, application?.name ?`${application?.name}.yml`: undefined) break; case 'delete': handleDelete() @@ -160,6 +149,19 @@ const ConfigHeader: FC = ({ const addvariable = () => { workflowRef?.current?.addVariable() } + /** + * Format dropdown menu items + */ + const formatMenuItems = useMemo(() => { + const items = (application?.type === 'workflow' ? ['edit', 'copy', 'export', 'delete'] : ['edit', 'copy', 'delete']).map(key => ({ + key, + icon: , + label: t(`common.${key}`), + })) + return items + }, [t, handleClick, application]) + + console.log('formatMenuItems', formatMenuItems) return ( <>
@@ -170,7 +172,7 @@ const ConfigHeader: FC = ({
{application?.name}
diff --git a/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx b/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx index b98c1aa4..f26441fd 100644 --- a/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx +++ b/web/src/views/ApplicationConfig/components/ReleaseShareModal.tsx @@ -1,8 +1,8 @@ /* * @Author: ZhaoYing * @Date: 2026-02-03 16:28:46 - * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-03 16:28:46 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-03 14:03:44 */ /** * Release Share Modal @@ -79,7 +79,7 @@ const ReleaseShareModal = forwardRef{t('application.shareVersion')} {version?.version}} + title={<>{t('application.shareVersion')} ({version?.version_name && version.version_name[0].toLocaleLowerCase() === 'v' ? version.version_name : version?.version_name ? `v${version.version_name}` : `v${version?.version}`})} open={visible} onCancel={handleClose} footer={false} diff --git a/web/src/views/ApplicationConfig/types.ts b/web/src/views/ApplicationConfig/types.ts index 2d09f739..36d40a40 100644 --- a/web/src/views/ApplicationConfig/types.ts +++ b/web/src/views/ApplicationConfig/types.ts @@ -2,13 +2,13 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:49 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-05 10:31:10 + * @Last Modified time: 2026-02-28 16:40:30 */ import type { KnowledgeConfig } from './components/Knowledge/types' import type { Variable } from './components/VariableList/types' import type { ToolOption } from './components/ToolList/types' import type { ChatItem } from '@/components/Chat/types' -import type { GraphRef } from '@/views/Workflow/types'; +import type { GraphRef, WorkflowConfig } from '@/views/Workflow/types'; import type { ApiKey } from '@/views/ApiKeyManagement/types' import type { SkillConfigForm } from './components/Skill/types' @@ -155,6 +155,7 @@ export interface WorkflowRef { graphRef: GraphRef; /** Add variable */ addVariable: () => void; + config: WorkflowConfig | null; } /** diff --git a/web/src/views/ApplicationManagement/components/ApplicationModal.tsx b/web/src/views/ApplicationManagement/components/ApplicationModal.tsx index 2039701c..877f535c 100644 --- a/web/src/views/ApplicationManagement/components/ApplicationModal.tsx +++ b/web/src/views/ApplicationManagement/components/ApplicationModal.tsx @@ -21,6 +21,7 @@ import WorkflowIcon from '@/assets/images/application/workflow.svg' import type { ApplicationModalData, ApplicationModalRef, Application } from '../types' import RbModal from '@/components/RbModal' import { addApplication, updateApplication } from '@/api/application' +import { stringRegExp } from '@/utils/validator'; const FormItem = Form.Item; @@ -131,13 +132,18 @@ const ApplicationModal = forwardRef( diff --git a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx index 2f2f56b2..e1353843 100644 --- a/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx +++ b/web/src/views/ApplicationManagement/components/UploadWorkflowModal.tsx @@ -1,114 +1,366 @@ +/* + * @Author: ZhaoYing + * @Date: 2026-02-28 14:08:14 + * @Last Modified by: ZhaoYing + * @Last Modified time: 2026-03-06 12:05:46 + */ +/** + * UploadWorkflowModal Component + * + * This component provides a modal for uploading workflow files with a multi-step process: + * 1. Upload - Select platform and file + * 2. Complex - Show warnings and errors if any + * 3. SureInfo - Confirm and edit workflow information + * 4. Completed - Show success message and options + */ import { forwardRef, useImperativeHandle, useState, useMemo } from 'react'; -import { Form, Select, Steps, Flex, Alert, Row, Col, Statistic, Input, Button } from 'antd'; +import { Form, Select, Steps, Flex, Alert, Input, Button, Result, message } from 'antd'; import { useTranslation } from 'react-i18next'; -import type { UploadWorkflowModalData, UploadWorkflowModalRef } from '../types' +import type { UploadWorkflowModalData, UploadData, UploadWorkflowModalRef } from '../types' import RbModal from '@/components/RbModal' import UploadFiles from '@/components/Upload/UploadFiles' -import { fileUploadUrl } from '@/api/fileStorage' -import RbCard from '@/components/RbCard/Card' +import { importWorkflow, completeImportWorkflow } from '@/api/application' +/** + * Props for UploadWorkflowModal component + */ interface UploadWorkflowModalProps { + /** Function to refresh the parent component after workflow import */ refresh: () => void; } + +/** + * Steps definition for the upload process + */ const steps = [ - 'upload', - 'complex', - 'node', - 'configCheck', - 'sureInfo', - 'completed' + 'upload', // Step 1: File upload + 'complex', // Step 2: Error/warning display + 'sureInfo', // Step 3: Information confirmation + 'completed' // Step 4: Success message ] + +/** + * UploadWorkflowModal component + * + * @param {UploadWorkflowModalProps} props - Component props + * @param {React.Ref} ref - Ref for imperative methods + */ const UploadWorkflowModal = forwardRef(({ refresh }, ref) => { const { t } = useTranslation(); - const [visible, setVisible] = useState(false); - const [form] = Form.useForm(); - const [loading, setLoading] = useState(false) - const [current, setCurrent] = useState(5); + + // State management + const [visible, setVisible] = useState(false); // Modal visibility + const [form] = Form.useForm(); // Form instance + const [loading, setLoading] = useState(false); // Loading state + const [current, setCurrent] = useState(0); // Current step + const [data, setData] = useState(null); // Upload response data + const [firstFormData, setFirstFormData] = useState(null); // First step form data + const [appId, setAppId] = useState(null); // Imported application ID - // 封装取消方法,添加关闭弹窗逻辑 + /** + * Handle modal close + * Resets all states and form fields + */ const handleClose = () => { setVisible(false); form.resetFields(); - setLoading(false) + setData(null); + setCurrent(0); + setFirstFormData(null); + setAppId(null); + setLoading(false); }; + /** + * Handle modal open + * Resets form fields and shows modal + */ const handleOpen = () => { form.resetFields(); setVisible(true); }; - // 封装保存方法,添加提交逻辑 + + /** + * Handle save/submit action + * Processes different logic based on current step + */ const handleSave = () => { + const values = form.getFieldsValue(); + switch(current) { - case 0: - setCurrent(1) + case 0: // Step 1: Upload file + if (!values.file || values.file.length === 0) { + message.warning(t('application.pleaseUploadFile')); + return; + } + const formData = new FormData(); + setFirstFormData(values); + formData.append('platform', values.platform); + formData.append('file', values.file[0]); + + setLoading(true) + // Call import workflow API + importWorkflow(formData) + .then(res => { + const response = res as UploadData; + const { errors, warnings } = response; + setData(response); + + // Navigate to error/warning step if any, otherwise go to confirmation + if (errors.length || warnings.length) { + setCurrent(1); + } else { + setCurrent(2); + // Pre-fill form with file information + const fileNameSplit = values.file[0].name.split('.') + form.setFieldsValue({ + name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'), + platform: values.platform, + fileName: values.file[0].name, + fileSize: values.file[0].size, + }); + } + }) + .finally(() => setLoading(false)); break; - case 1: - setCurrent(2) + case 1: // Step 2: Error/warning display + if (firstFormData) { + const { file, platform } = firstFormData; + const fileNameSplit = firstFormData.file[0].name.split('.') + // Pre-fill form with file information + form.setFieldsValue({ + name: fileNameSplit.slice(0, fileNameSplit.length - 1).join('.'), + platform: platform, + fileName: file[0].name, + fileSize: file[0].size, + }); + } + setCurrent(2); break; - case 2: - setCurrent(3) - break; - case 3: - setCurrent(4) - break; - case 4: - setCurrent(5) - break; - case 5: + case 2: // Step 3: Confirm information + if (data) { + setLoading(true); + // Complete import workflow + completeImportWorkflow({ + temp_id: data.temp_id, + name: values.name, + description: values.description, + }) + .then((res) => { + const response = res as { id: string }; + setCurrent(3); + setAppId(response.id); + }) + .finally(() => setLoading(false)); + } break; default: - setCurrent(prev => prev + 1) + setCurrent(prev => prev + 1); break; } - // form - // .validateFields() - // .then(() => { - // }) - // .catch((err) => { - // console.log('err', err) - // }); - } + }; - // 暴露给父组件的方法 + // Expose methods to parent component via ref useImperativeHandle(ref, () => ({ handleOpen, handleClose })); + /** + * Handle navigation to previous step + * Adjusts step based on whether there were errors/warnings + */ const handleLastStep = () => { - setCurrent(prev => prev - 1) - } - const handleJump = (type: string) => { - switch(type) { - case 'detail': - break; - default: - break; + let newStep = current - 1; + // If no errors or warnings, skip the error/warning step + if (!data?.warnings?.length && !data?.errors?.length) { + newStep = current - 2; } - } + // Reset form if not going back to error/warning step + if (newStep === 0) { + form.setFieldsValue(firstFormData || {}) + } else if (newStep !== 1) { + form.resetFields(); + } + setCurrent(newStep); + }; + + /** + * Handle navigation after successful import + * @param {string} type - Navigation type ('detail' or 'list') + */ + const handleJump = (type: string) => { + handleClose(); + refresh(); + setTimeout(() => { + switch (type) { + case 'detail': + // Open application detail page in new tab + window.open(`/#/application/config/${appId}`, '_blank'); + break; + } + }, 100) + }; + + /** + * Generate modal footer based on current step + */ const getFooter = useMemo(() => { switch(current) { - case 0: + case 0: // Step 1: Upload return [ , + + ]; + case 3: // Step 4: Completed + return null; + default: // Steps 1-2 + return [ + , + , - ] - case 5: - return [ - , - ] - default: - return [ - , - , - - ] - } - }, [current]) - - return ( - -
- ({ title: t(`application.${key}`) }))} + ]} /> -
- {current === 0 && -
- - -
- - - - } - {current === 3 && - - - - } - {current === 4 && - -
{t('application.baseInfo')}
- - - - - source - - - fileName - - - fileSize - - - - - -
{t('application.importStatistic')}
- - {['complex', 'nodes', 'task'].map(key => ( - - - - ))} - - - } - {current === 5 && - -
导入成功
-
您的工作流已成功导入,可以在应用管理中查看和管理
-
} ); diff --git a/web/src/views/ApplicationManagement/index.tsx b/web/src/views/ApplicationManagement/index.tsx index 74dcef05..055c0c8f 100644 --- a/web/src/views/ApplicationManagement/index.tsx +++ b/web/src/views/ApplicationManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:34:12 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 13:52:22 + * @Last Modified time: 2026-03-02 17:48:51 */ /** * Application Management Page @@ -12,7 +12,7 @@ import React, { useState, useRef, useEffect } from 'react'; import { useTranslation } from 'react-i18next'; -import { Button, Row, Col, App, Select, Space } from 'antd'; +import { Button, Row, Col, App, Select, Space, Dropdown } from 'antd'; import clsx from 'clsx'; import { DeleteOutlined } from '@ant-design/icons'; import { useSearchParams } from 'react-router-dom' @@ -83,9 +83,16 @@ const ApplicationManagement: React.FC = () => { setQuery(prev => ({...prev, type: value})) } - // const handleImport = () => { - // uploadWorkflowModalRef.current?.handleOpen() - // } + const handleImport = () => { + uploadWorkflowModalRef.current?.handleOpen() + } + const handleClick = ({ key }: { key: string } ) => { + switch (key) { + case 'thirdParty': + handleImport() + break; + } + } return ( <> @@ -111,9 +118,16 @@ const ApplicationManagement: React.FC = () => { - {/* */} + + + diff --git a/web/src/views/ApplicationManagement/types.ts b/web/src/views/ApplicationManagement/types.ts index ccc4f114..696b828a 100644 --- a/web/src/views/ApplicationManagement/types.ts +++ b/web/src/views/ApplicationManagement/types.ts @@ -2,12 +2,12 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:34:15 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 11:08:37 + * @Last Modified time: 2026-02-28 16:16:03 */ /** * Type definitions for Application Management */ - +import type { WorkflowConfig } from '@/views/Workflow/types'; /** * Search query parameters */ @@ -174,9 +174,63 @@ export interface ApiExtensionModalRef { handleOpen: () => void; } - +/** + * Upload workflow modal form data + */ export interface UploadWorkflowModalData { + /** Platform type (e.g., 'dify') */ + platform: string; + /** Array of uploaded files */ + file: any[]; + /** Optional workflow name */ + name?: string; + /** Optional original file name */ + fileName?: string; + /** Optional file size in bytes */ + fileSize?: number; + /** Optional workflow description */ + description?: string; } + +/** + * Complex item for errors and warnings + */ +interface ComplexItem { + /** Error/warning type */ + type: string; + /** Detailed error/warning message */ + detail: string; + /** Node identifier where the error/warning occurred */ + node_id: string; + /** Node name where the error/warning occurred */ + node_name: string; + /** Optional scope of the error/warning */ + scope: string | null; + /** Optional name associated with the error/warning */ + name: string | null; +} + +/** + * Upload data response + * @extends WorkflowConfig + */ +export interface UploadData extends WorkflowConfig { + /** Whether the upload was successful */ + success: boolean; + /** Temporary identifier for the uploaded workflow */ + temp_id: string; + /** Optional workflow identifier if already exists */ + workflow_id?: string; + /** Array of error items */ + errors: ComplexItem[]; + /** Array of warning items */ + warnings: ComplexItem[]; +} + +/** + * Upload workflow modal ref interface + */ export interface UploadWorkflowModalRef { + /** Open the upload workflow modal */ handleOpen: () => void; } \ No newline at end of file diff --git a/web/src/views/Conversation/components/FileUpload.tsx b/web/src/views/Conversation/components/FileUpload.tsx index 70ee9cf2..b4f11b1b 100644 --- a/web/src/views/Conversation/components/FileUpload.tsx +++ b/web/src/views/Conversation/components/FileUpload.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:42 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-11 11:32:48 + * @Last Modified time: 2026-03-06 12:20:43 */ /** * File Upload Component @@ -25,6 +25,7 @@ import { Upload, Progress, App } from 'antd'; import type { UploadProps, UploadFile } from 'antd'; import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface'; import { useTranslation } from 'react-i18next'; + import { request } from '@/utils/request' import { fileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' @@ -56,27 +57,36 @@ interface UploadFilesProps extends Omit { /** Custom file removal callback */ onRemove?: (file: UploadFile) => boolean | void | Promise; } + +const transform_file_type = { + 'text/plain': 'document/text', + 'text/markdown': 'document/markdown', + 'text/x-markdown': 'document/x-markdown', + + 'application/pdf': 'document/pdf', + + 'application/msword': 'document/doc', + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': 'document/docx', + + 'application/vnd.ms-powerpoint': 'document/ppt', + 'application/vnd.openxmlformats-officedocument.presentationml.presentation': 'document/pptx', +} // Mapping of file extensions to MIME types const ALL_FILE_TYPE: { [key: string]: string; } = { - // txt: 'text/plain', + txt: 'text/plain', + md: 'text/markdown', + xmd: 'text/x-markdown', + pdf: 'application/pdf', doc: 'application/msword', docx: 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', - - xls: 'application/vnd.ms-excel', - xlsx: 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet', - csv: 'text/csv', ppt: 'application/vnd.ms-powerpoint', pptx: 'application/vnd.openxmlformats-officedocument.presentationml.presentation', - - // md: 'text/markdown', - // htm: 'text/html', - // html: 'text/html', - // json: 'application/json', + jpg: 'image/jpeg', jpeg: 'image/jpeg', png: 'image/png', @@ -84,6 +94,23 @@ const ALL_FILE_TYPE: { bmp: 'image/bmp', webp: 'image/webp', svg: 'image/svg+xml', + + mp4: 'video/mp4', + mov: 'video/quicktime', + avi: 'video/x-msvideo', + mkv: 'video/x-matroska', + webm: 'video/webm', + flv: 'video/x-flv', + wmv: 'video/x-ms-wmv', + + mp3: 'audio/mpeg', + wav: 'audio/wav', + ogg: 'audio/ogg', + aac: 'audio/aac', + flac: 'audio/flac', + m4a: 'audio/mp4', + wma: 'audio/x-ms-wma', + xm4a: 'audio/x-m4a', } export interface UploadFilesRef { /** Current file list */ @@ -178,6 +205,11 @@ const UploadFiles = forwardRef(({ * Handles upload state changes */ const handleChange: UploadProps['onChange'] = ({ fileList: newFileList }) => { + newFileList.map(file => { + const type = (file.type && transform_file_type[file.type as keyof typeof transform_file_type]) || file.type || 'document' + file.type = type + file.thumbUrl = file.thumbUrl || URL.createObjectURL(file.originFileObj as Blob) + }) setFileList(newFileList); if (onChange) { onChange(maxCount === 1 ? newFileList[newFileList.length - 1] : newFileList); diff --git a/web/src/views/Conversation/components/UploadFileListModal.tsx b/web/src/views/Conversation/components/UploadFileListModal.tsx index c5110701..a43b9dd4 100644 --- a/web/src/views/Conversation/components/UploadFileListModal.tsx +++ b/web/src/views/Conversation/components/UploadFileListModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-06 21:09:47 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-09 10:17:54 + * @Last Modified time: 2026-03-04 17:47:09 */ /** * Upload File List Modal Component @@ -104,7 +104,9 @@ const UploadFileListModal = forwardRef diff --git a/web/src/views/Conversation/index.tsx b/web/src/views/Conversation/index.tsx index 825ea834..8a67b3ae 100644 --- a/web/src/views/Conversation/index.tsx +++ b/web/src/views/Conversation/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:58:03 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-10 17:41:05 + * @Last Modified time: 2026-03-04 12:10:44 */ /** * Conversation Page @@ -14,11 +14,12 @@ import { type FC, useState, useEffect, useRef } from 'react' import { useParams, useLocation } from 'react-router-dom' import { useTranslation } from 'react-i18next' import InfiniteScroll from 'react-infinite-scroll-component'; -import { Flex, Skeleton, Form, Dropdown, type MenuProps } from 'antd' +import { Flex, Skeleton, Form, Dropdown, type MenuProps, App, Divider } from 'antd' +import { SettingOutlined } from '@ant-design/icons' import clsx from 'clsx' import dayjs from 'dayjs' -import { getConversationHistory, sendConversation, getConversationDetail, getShareToken } from '@/api/application' +import { getConversationHistory, sendConversation, getConversationDetail, getShareToken, getExperienceConfig } from '@/api/application' import type { HistoryItem, QueryParams, UploadFileListModalRef } from './types' import Empty from '@/components/Empty' import { formatDateTime } from '@/utils/format'; @@ -34,15 +35,19 @@ import OnlineCheckedIcon from '@/assets/images/conversation/onlineChecked.svg' import MemoryFunctionCheckedIcon from '@/assets/images/conversation/memoryFunctionChecked.svg' import { type SSEMessage } from '@/utils/stream' import UploadFiles from './components/FileUpload' -// import AudioRecorder from '@/components/AudioRecorder' +import AudioRecorder from '@/components/AudioRecorder' import { shareFileUploadUrlWithoutApiPrefix } from '@/api/fileStorage' import UploadFileListModal from './components/UploadFileListModal' +import type { VariableConfigModalRef } from '@/views/Workflow/types' +import type { Variable } from '@/views/Workflow/components/Properties/VariableList/types' +import VariableConfigModal from '@/views/Workflow/components/Chat/VariableConfigModal'; /** * Conversation component for shared applications */ const Conversation: FC = () => { const { t } = useTranslation() + const { message: messageApi } = App.useApp() const { token } = useParams() const location = useLocation() const searchParams = new URLSearchParams(location.search) @@ -64,6 +69,22 @@ const Conversation: FC = () => { const queryValues = Form.useWatch([], form) const uploadFileListModalRef = useRef(null) + + const variableConfigModalRef = useRef(null) + const [variables, setVariables] = useState([]) // Workflow input variables + + /** + * Opens the variable configuration modal + */ + const handleEditVariables = () => { + variableConfigModalRef.current?.handleOpen(variables) + } + /** + * Saves updated variable values from the modal + */ + const handleSave = (values: Variable[]) => { + setVariables([...values]) + } useEffect(() => { const shareToken = localStorage.getItem(`shareToken_${token}`) setShareToken(shareToken) @@ -81,6 +102,17 @@ const Conversation: FC = () => { getHistory() } }, [token, shareToken, page, hasMore, historyList]) + useEffect(() => { + if (shareToken && token) { + getExperienceConfig(token) + .then(res => { + const response = res as { variables: Variable[] } + setVariables(response.variables || []) + }) + } else { + setChatList([]) + } + }, [shareToken, token]) /** Group conversation history by date */ const groupHistoryByDate = (items: HistoryItem[]): Record => { @@ -191,12 +223,35 @@ const Conversation: FC = () => { }) } + const isNeedVariableConfig = variables.some(vo => vo.required && (vo.value === null || vo.value === undefined || vo.value === '')) + /** Send message and handle streaming response */ const handleSend = () => { if (!token || !shareToken) { return } const { files = [], ...rest } = queryValues || {} + // Validate required variables before sending + let isCanSend = true + const params: Record = {} + if (variables.length > 0) { + const needRequired: string[] = [] + variables.forEach(vo => { + params[vo.name] = vo.value ?? vo.defaultValue + + if (vo.required && (params[vo.name] === null || params[vo.name] === undefined || params[vo.name] === '')) { + isCanSend = false + needRequired.push(vo.name) + } + }) + + if (needRequired.length) { + messageApi.error(`${needRequired.join(',')} ${t('workflow.variableRequired')}`) + } + } + if (!isCanSend) { + return + } setLoading(true) setStreamLoading(true) addUserMessage(message, files) @@ -212,8 +267,8 @@ const Conversation: FC = () => { currentConversationId = newId break case 'message': - const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; } - updateAssistantMessage(content ?? chunk) + const { content, conversation_id: curId } = item.data as { content: string; conversation_id: string; } + updateAssistantMessage(content) if (curId) { currentConversationId = curId; @@ -247,19 +302,30 @@ const Conversation: FC = () => { upload_file_id: file.response.data.file_id } } - }) + }), + variables: params }, handleStreamMessage, shareToken) + .catch(() => { + setLoading(false) + setStreamLoading(false) + }) .finally(() => { setLoading(false) + setStreamLoading(false) }) } const fileChange = (file?: any) => { form.setFieldValue('files', [...(queryValues.files || []), file]) } - // const handleRecordingComplete = async (file: any) => { - // console.log('file', file) - // } + const handleRecordingComplete = async (file: any) => { + form.setFieldValue('files', [...(queryValues.files || []), { + uid: file.file_id, + response: { data: file }, + thumbUrl: file.url, + type: file.type + }]) + } const handleShowUpload: MenuProps['onClick'] = ({ key }) => { switch(key) { @@ -273,6 +339,7 @@ const Conversation: FC = () => { form.setFieldValue('files', [...(queryValues.files || []), ...fileList]) } const updateFileList = (fileList?: any[]) => { + console.log('fileList', fileList) form.setFieldValue('files', [...(fileList || [])]) } @@ -327,7 +394,7 @@ const Conversation: FC = () => {
} - contentClassName="rb:h-[calc(100%-180px)]" + contentClassName={!queryValues?.files?.length ? "rb:h-[calc(100%-144px)]" : "rb:h-[calc(100%-208px)]"} data={chatList} streamLoading={streamLoading} loading={loading} @@ -349,13 +416,12 @@ const Conversation: FC = () => { key: 'upload', label: ( ) }, @@ -384,11 +450,34 @@ const Conversation: FC = () => { {t(`memoryConversation.memory`)} + {variables.length > 0 && ( + +
+ + {t(`memoryConversation.variableConfig`)} +
+
+ )} - {/* - + + - */} +
@@ -399,6 +488,11 @@ const Conversation: FC = () => { ref={uploadFileListModalRef} refresh={addFileList} /> + ) } diff --git a/web/src/views/Conversation/types.ts b/web/src/views/Conversation/types.ts index deb14d1f..cc074c1b 100644 --- a/web/src/views/Conversation/types.ts +++ b/web/src/views/Conversation/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:57:46 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-02-06 21:11:19 + * @Last Modified time: 2026-03-03 13:46:55 */ /** * Type definitions for Conversation @@ -51,6 +51,7 @@ export interface QueryParams { /** Current conversation ID */ conversation_id?: string | null; files?: any[]; + variables?: Record; } export interface UploadFileListModalRef { diff --git a/web/src/views/KnowledgeBase/components/CreateModal.tsx b/web/src/views/KnowledgeBase/components/CreateModal.tsx index 76640058..35eb52d0 100644 --- a/web/src/views/KnowledgeBase/components/CreateModal.tsx +++ b/web/src/views/KnowledgeBase/components/CreateModal.tsx @@ -15,6 +15,7 @@ import { } from '@/api/knowledgeBase' import RbModal from '@/components/RbModal' import SliderInput from '@/components/SliderInput' +import { stringRegExp } from '@/utils/validator' const { TextArea } = Input; const { confirm } = Modal @@ -519,12 +520,16 @@ const CreateModal = forwardRef(({ )} - +