Merge branch 'develop' into feature/model_zy

This commit is contained in:
zhaoying
2026-03-05 10:32:02 +08:00
36 changed files with 1182 additions and 179 deletions

View File

@@ -3,9 +3,10 @@ Cache 缓存模块
提供各种缓存功能的统一入口
"""
from .memory import EmotionMemoryCache, ImplicitMemoryCache
from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

View File

@@ -5,8 +5,10 @@ Memory 缓存模块
"""
from .emotion_memory import EmotionMemoryCache
from .implicit_memory import ImplicitMemoryCache
from .interest_memory import InterestMemoryCache
__all__ = [
"EmotionMemoryCache",
"ImplicitMemoryCache",
"InterestMemoryCache",
]

122
api/app/cache/memory/interest_memory.py vendored Normal file
View File

@@ -0,0 +1,122 @@
"""
Interest Distribution Cache
兴趣分布缓存模块
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
"""
import json
import logging
from typing import Optional, List, Dict, Any
from datetime import datetime
from app.aioRedis import aio_redis
logger = logging.getLogger(__name__)
# 缓存过期时间24小时
INTEREST_CACHE_EXPIRE = 86400
class InterestMemoryCache:
"""兴趣分布缓存类"""
PREFIX = "cache:memory:interest_distribution"
@classmethod
def _get_key(cls, end_user_id: str, language: str) -> str:
"""生成 Redis key
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
完整的 Redis key
"""
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
@classmethod
async def set_interest_distribution(
cls,
end_user_id: str,
language: str,
data: List[Dict[str, Any]],
expire: int = INTEREST_CACHE_EXPIRE,
) -> bool:
"""设置用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
expire: 过期时间默认24小时
Returns:
是否设置成功
"""
try:
key = cls._get_key(end_user_id, language)
payload = {
"data": data,
"generated_at": datetime.now().isoformat(),
"cached": True,
}
value = json.dumps(payload, ensure_ascii=False)
await aio_redis.set(key, value, ex=expire)
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}")
return True
except Exception as e:
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
return False
@classmethod
async def get_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> Optional[List[Dict[str, Any]]]:
"""获取用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
兴趣分布列表,缓存不存在或已过期返回 None
"""
try:
key = cls._get_key(end_user_id, language)
value = await aio_redis.get(key)
if value:
payload = json.loads(value)
logger.info(f"命中兴趣分布缓存: {key}")
return payload.get("data")
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
return None
except Exception as e:
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
return None
@classmethod
async def delete_interest_distribution(
cls,
end_user_id: str,
language: str,
) -> bool:
"""删除用户兴趣分布缓存
Args:
end_user_id: 用户ID
language: 语言类型
Returns:
是否删除成功
"""
try:
key = cls._get_key(end_user_id, language)
result = await aio_redis.delete(key)
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
return result > 0
except Exception as e:
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
return False

View File

@@ -1,6 +1,7 @@
import os
import platform
from datetime import timedelta
from celery.schedules import crontab
from urllib.parse import quote
from celery import Celery
@@ -43,8 +44,8 @@ celery_app.conf.update(
task_ignore_result=False,
# 超时设置
task_time_limit=1800, # 30分钟硬超时
task_soft_time_limit=1500, # 25分钟软超时
task_time_limit=3600, # 60分钟硬超时
task_soft_time_limit=3000, # 50分钟软超时
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
@@ -90,11 +91,10 @@ celery_app.conf.update(
celery_app.autodiscover_tasks(['app'])
# Celery Beat schedule for periodic tasks
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
# 这个30秒的设计不合理
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
#构建定时任务配置
beat_schedule_config = {

View File

@@ -441,14 +441,14 @@ async def retrieve_chunks(
# 1 participle search, 2 semantic search, 3 hybrid search
match retrieve_data.retrieve_type:
case chunk_schema.RetrieveType.PARTICIPLE:
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case chunk_schema.RetrieveType.SEMANTIC:
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
return success(data=rs, msg="retrieval successful")
case _:
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
# Efficient deduplication
seen_ids = set()
unique_rs = []

View File

@@ -1,5 +1,6 @@
from typing import List, Optional
from app.cache.memory.interest_memory import InterestMemoryCache
from app.celery_app import celery_app
from app.core.error_codes import BizCode
from app.core.language_utils import get_language_from_header
@@ -661,34 +662,56 @@ async def get_knowledge_type_stats_api(
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
async def get_hot_memory_tags_by_user_api(
end_user_id: Optional[str] = Query(None, description="用户ID可选"),
limit: int = Query(20, description="返回标签数量限制"),
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
async def get_interest_distribution_by_user_api(
end_user_id: str = Query(..., description="用户ID必填"),
limit: int = Query(5, le=5, description="返回兴趣标签数量限制最多5个"),
language_type: str = Header(default=None, alias="X-Language-Type"),
current_user: User = Depends(get_current_user),
db: Session=Depends(get_db),
db: Session = Depends(get_db),
):
"""
获取指定用户的热门记忆标签
获取指定用户的兴趣分布标签
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
返回格式:
[
{"name": "标签", "frequency": 频次},
{"name": "兴趣活动", "frequency": 频次},
...
]
"""
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
language = get_language_from_header(language_type)
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
try:
result = await memory_agent_service.get_hot_memory_tags_by_user(
# 优先读取缓存
cached = await InterestMemoryCache.get_interest_distribution(
end_user_id=end_user_id,
limit=limit
language=language,
)
return success(data=result, msg="获取热门记忆标签成功")
if cached is not None:
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
return success(data=cached, msg="获取兴趣分布标签成功")
# 缓存未命中,调用模型生成
result = await memory_agent_service.get_interest_distribution_by_user(
end_user_id=end_user_id,
limit=limit,
language=language
)
# 写入缓存24小时过期
await InterestMemoryCache.set_interest_distribution(
end_user_id=end_user_id,
language=language,
data=result,
)
return success(data=result, msg="获取兴趣分布标签成功")
except Exception as e:
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
api_logger.error(f"Interest distribution by user failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
@router.get("/analytics/user_profile", response_model=ApiResponse)

View File

@@ -606,8 +606,8 @@ async def dashboard_data(
# 获取RAG相关数据
try:
# total_memory: 使用 total_chunkchunk数
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
# total_memory: 只统计用户知识库permission_id='Memory')的chunk数
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
rag_data["total_memory"] = total_chunk
# total_app: 统计当前空间下的所有app数量

View File

@@ -249,6 +249,7 @@ async def chat(
app_id=app.id,
workspace_id=workspace_id,
release_id=app.current_release.id,
public=True
):
event_type = event.get("event", "message")
event_data = event.get("data", {})

View File

@@ -39,7 +39,7 @@ async def write_memory_api_service(
Stores memory content for the specified end user using the Memory API Service.
"""
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)

View File

@@ -1,9 +1,10 @@
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Annotated, Any, Dict, Optional
from dotenv import load_dotenv
from pydantic import Field, TypeAdapter
load_dotenv()
@@ -200,12 +201,25 @@ class Settings:
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
# Memory Cache Regeneration Configuration
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
# Celery Beat Schedule Configuration (定时任务执行频率)
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
# Memory Module Configuration (internal)
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
@@ -230,7 +244,7 @@ class Settings:
# General Ontology Type Configuration
# ========================================================================
# 通用本体文件路径列表(逗号分隔)
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl")
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
# 是否启用通用本体类型功能
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"

View File

@@ -1,9 +1,12 @@
import asyncio
import json
import logging
import os
from typing import List, Tuple
from app.core.config import settings
logger = logging.getLogger(__name__)
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
"""用于接收LLM筛选后的核心标签列表的模型。"""
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
class InterestTags(BaseModel):
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
"""
使用LLM筛选标签列表仅保留具有代表性的核心名词。
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
return structured_response.meaningful_tags
except Exception as e:
print(f"LLM筛选过程中发生错误: {e}")
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
# 在LLM失败时返回原始标签确保流程继续
return tags
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
"""
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
Args:
tags: 原始标签列表
end_user_id: 用户ID用于获取LLM配置
Returns:
筛选后的兴趣活动标签列表
"""
try:
with get_db_context() as db:
from app.services.memory_agent_service import (
get_end_user_connected_config,
)
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if not config_id and not workspace_id:
raise ValueError(
f"No memory_config_id found for end_user_id: {end_user_id}."
)
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
if not memory_config.llm_model_id:
raise ValueError(
f"No llm_model_id found in memory config {config_id}."
)
factory = MemoryClientFactory(db)
llm_client = factory.get_llm_client(memory_config.llm_model_id)
tag_list_str = ", ".join(tags)
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
messages = [
{
"role": "user",
"content": rendered_prompt
}
]
structured_response = await llm_client.response_structured(
messages=messages,
response_model=InterestTags
)
return structured_response.interest_tags
except Exception as e:
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
return tags
async def get_raw_tags_from_db(
connector: Neo4jConnector,
end_user_id: str,
@@ -183,3 +254,56 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
finally:
# 确保关闭连接
await connector.close()
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
"""
获取用户的兴趣分布标签。
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
Args:
end_user_id: 必需参数。如果by_user=False则为end_user_id如果by_user=True则为user_id
limit: 最终返回的标签数量限制默认10
by_user: 是否按user_id查询默认False按end_user_id查询
Raises:
ValueError: 如果end_user_id未提供或为空
"""
if not end_user_id or not end_user_id.strip():
raise ValueError(
"end_user_id is required. Please provide a valid end_user_id or user_id."
)
connector = Neo4jConnector()
try:
# 查询更多原始标签给LLM提供充足上下文
query_limit = 40
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
if not raw_tags_with_freq:
return []
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
# 使用兴趣活动专用prompt进行筛选支持语义推断出新标签
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
# 构建最终标签列表:
# - 原始标签中存在的,保留原始频率
# - LLM推断出的新标签不在原始列表中赋予默认频率1
final_tags = []
seen = set()
for tag in interest_tag_names:
if tag in seen:
continue
seen.add(tag)
freq = raw_freq_map.get(tag, 1)
final_tags.append((tag, freq))
# 按频率降序排列
final_tags.sort(key=lambda x: x[1], reverse=True)
return final_tags[:limit]
finally:
await connector.close()

View File

@@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt(
})
return rendered_prompt
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
"""
Renders the interest filter prompt using the interest_filter.jinja2 template.
Args:
tag_list: Comma-separated string of raw tags to filter
language: Output language ("zh" for Chinese, "en" for English)
Returns:
Rendered prompt content as string
"""
template = prompt_env.get_template("interest_filter.jinja2")
rendered_prompt = template.render(tag_list=tag_list, language=language)
log_prompt_rendering('interest filter', rendered_prompt)
return rendered_prompt

View File

@@ -0,0 +1,67 @@
{% if language == "zh" %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
- '川味可视化', '川菜' → '烹饪'
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
- '吉他', '指弹', '琴谱' → '吉他'
- '跑步', '5公里', '跑鞋' → '跑步'
- '瑜伽垫', '瑜伽课' → '瑜伽'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
- If multiple tags all point to '攀岩', output '攀岩' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
**Output format**: Return a list of concise interest activity names in Chinese.
**Example**:
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
{% else %}
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
**Step 1 - Infer the underlying interest from each tag**:
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
Examples of inference:
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
- 'morning meditation streak', 'mind-body peak' → 'meditation'
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
- 'open source project', 'data visualization tool', 'Python' → 'programming'
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
- 'running', '5km', 'running shoes' → 'running'
**Step 2 - Consolidate and deduplicate**:
- Merge tags that point to the same interest into one representative label
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
**Step 3 - Filter out non-interest tags**:
Remove tags that do NOT suggest any hobby or interest:
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
**Output format**: Return a list of concise interest activity names in English.
**Example**:
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
{% endif %}

View File

@@ -127,7 +127,7 @@ class EventStreamHandler:
yield {
"event": "message",
"data": {
"chunk": data.get("chunk")
"content": data.get("chunk")
}
}

View File

@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
yield {
"event": "message",
"data": {
"chunk": final_chunk
"content": final_chunk
}
}

View File

@@ -272,7 +272,7 @@ class WorkflowExecutor:
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
if event_type == "node_chunk":
async for msg_event in self.event_handler.handle_node_chunk_event(data):
full_content += msg_event["data"]["chunk"]
full_content += msg_event["data"]["content"]
yield msg_event
elif event_type == "node_error":
@@ -295,12 +295,12 @@ class WorkflowExecutor:
self.graph,
self.execution_context.checkpoint_config
):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
# Flush any remaining chunks
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
full_content += msg_event["data"]['chunk']
full_content += msg_event["data"]['content']
yield msg_event
result = graph.get_state(self.execution_context.checkpoint_config).values

View File

@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
except Exception as e:
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
raise
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中permission_id='Memory'用户知识库的chunk_num总和
"""
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
try:
from sqlalchemy import func
result = db.query(func.sum(Knowledge.chunk_num)).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id == "Memory"
).scalar()
total = result if result is not None else 0
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
return total
except Exception as e:
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
raise
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
"""
根据workspace_id查询knowledges表中排除用户知识库permission_id!='Memory')的数量
"""
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
try:
count = db.query(Knowledge).filter(
Knowledge.workspace_id == workspace_id,
Knowledge.status == 1,
Knowledge.permission_id != "Memory"
).count()
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
return count
except Exception as e:
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
raise

View File

@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
class ChunkRetrieve(BaseModel):
query: str
kb_ids: list[uuid.UUID]
file_names_filter: list[str] | None = Field(None)
similarity_threshold: float | None = Field(None)
vector_similarity_weight: float | None = Field(None)
top_k: int | None = Field(None)

View File

@@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import (
)
from app.core.memory.agent.utils.type_classifier import status_typle
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
from app.db import get_db_context
from app.models.knowledge_model import Knowledge, KnowledgeType
@@ -890,36 +890,36 @@ class MemoryAgentService:
return result
async def get_hot_memory_tags_by_user(
async def get_interest_distribution_by_user(
self,
end_user_id: Optional[str] = None,
limit: int = 20
limit: int = 5,
language: str = "zh"
) -> List[Dict[str, Any]]:
"""
获取指定用户的热门记忆标签
获取指定用户的兴趣分布标签
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
参数:
- end_user_id: 用户ID可选对应Neo4j中的end_user_id字段
- end_user_id: 用户ID必填)
- limit: 返回标签数量限制
- language: 输出语言("zh" 中文, "en" 英文)
返回格式:
[
{"name": "标签", "frequency": 频次},
{"name": "兴趣活动", "frequency": 频次},
...
]
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
"""
try:
# by_user=False 表示按 end_user_id 查询在Neo4j中end_user_id就是用户维度
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
payload = []
for tag, freq in tags:
payload.append({"name": tag, "frequency": freq})
return payload
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
return [{"name": tag, "frequency": freq} for tag, freq in tags]
except Exception as e:
logger.error(f"热门记忆标签查询失败: {e}")
raise Exception(f"热门记忆标签查询失败: {e}")
logger.error(f"兴趣分布标签查询失败: {e}")
raise Exception(f"兴趣分布标签查询失败: {e}")
async def get_user_profile(

View File

@@ -140,9 +140,11 @@ class MemoryAPIService:
try:
# Delegate to MemoryAgentService
# Convert string message to list[dict] format expected by MemoryAgentService
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
result = await MemoryAgentService().write_memory(
end_user_id=end_user_id,
messages=message,
messages=messages,
config_id=config_id,
db=self.db,
storage_type=storage_type,
@@ -151,9 +153,18 @@ class MemoryAPIService:
logger.info(f"Memory write successful for end_user: {end_user_id}")
# result may be a string "success" or a dict with a "status" key
# Preserve the full dict so callers don't silently lose extra fields
# (e.g. error codes, metadata) returned by MemoryAgentService.
if isinstance(result, dict):
return {
**result,
"status": result.get("status", "unknown"),
"end_user_id": end_user_id,
}
return {
"status": "success" if result == "success" else result,
"end_user_id": end_user_id
"status": result if isinstance(result, str) else "success",
"end_user_id": end_user_id,
}
except ConfigurationError as e:

View File

@@ -390,19 +390,59 @@ def get_rag_total_kb(
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库permission_id!='Memory'的数量
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
return total_kb
except Exception as e:
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_rag_user_kb_total_chunk(
db: Session,
current_user: User
) -> int:
"""
根据当前用户所在的workspace_id从documents表统计所有用户知识库的chunk总数。
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
"""
workspace_id = current_user.current_workspace_id
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from app.models.end_user_model import EndUser
from app.models.app_model import App
from sqlalchemy import func
# 通过 App 关联取该 workspace 下所有 end_user_id
end_user_ids = [
str(eid) for (eid,) in db.query(EndUser.id)
.join(App, EndUser.app_id == App.id)
.filter(App.workspace_id == workspace_id)
.all()
]
if not end_user_ids:
return 0
file_names = [f"{uid}.txt" for uid in end_user_ids]
result = db.query(func.sum(Document.chunk_num)).filter(
Document.file_name.in_(file_names)
).scalar()
total_chunk = int(result or 0)
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
return total_chunk
except Exception as e:
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
raise
def get_current_user_total_chunk(
end_user_id: str,
db: Session,

View File

@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
from app.core.workflow.nodes.enums import NodeType
from app.core.workflow.validator import validate_workflow_config
from app.db import get_db
@@ -23,7 +24,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest
from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
@@ -445,6 +446,91 @@ class WorkflowService:
"success_rate": completed / total if total > 0 else 0
}
async def _handle_file_input(self, files: list[FileInput]):
if not files:
return []
files_struct = []
for file in files:
files_struct.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
return files_struct
@staticmethod
def _map_public_event(event: dict) -> dict | None:
"""
Map internal workflow events to public-facing event formats.
Purpose:
- Hide internal execution details
- Expose a stable and simplified public event schema
- Filter out non-public events
- Maintain backward compatibility when possible
Args:
event (dict): Internal event object, e.g.:
{
"event": "workflow_start",
"data": {...}
}
Returns:
dict | None:
- Returns the mapped public event
- Returns None if the event should not be exposed
"""
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
Unified event emission entry.
Args:
public (bool):
- True -> Emit mapped public event
- False -> Emit raw internal event
internal_event (dict):
The original internal event object
Returns:
dict | None:
- The mapped event
- Or None if the event is filtered out
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
# ==================== 工作流执行 ====================
async def run(
@@ -479,10 +565,11 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
input_data = {
"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
# 转换 conversation_id 为 UUID
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
@@ -506,22 +593,8 @@ class WorkflowService:
"execution_config": config.execution_config
}
# 4. 获取工作空间 ID从 app 获取)
# 5. 执行工作流
from app.core.workflow.executor import execute_workflow
try:
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
files = await self._handle_file_input(payload.files)
input_data["files"] = files
# 更新状态为运行中
self.update_execution_status(execution.execution_id, "running")
@@ -601,42 +674,6 @@ class WorkflowService:
message=f"工作流执行失败: {str(e)}"
)
@staticmethod
def _map_public_event(event: dict) -> dict | None:
event_type = event.get("event")
payload = event.get("data")
match event_type:
case "workflow_start":
return {
"event": "start",
"data": {
"conversation_id": payload.get("conversation_id"),
}
}
case "workflow_end":
return {
"event": "end",
"data": {
"elapsed_time": payload.get("elapsed_time"),
"message_length": len(payload.get("output", "")),
"error": payload.get("error", "")
}
}
case "node_start" | "node_end" | "node_error" | "cycle_item":
return None
case _:
return event
def _emit(self, public: bool, internal_event: dict):
"""
decide
"""
if public:
mapped = self._map_public_event(internal_event)
else:
mapped = internal_event
return mapped
async def run_stream(
self,
app_id: uuid.UUID,
@@ -671,10 +708,11 @@ class WorkflowService:
message=f"工作流配置不存在: app_id={app_id}"
)
input_data = {"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
input_data = {
"message": payload.message, "variables": payload.variables,
"conversation_id": payload.conversation_id,
"files": [file.model_dump(mode='json') for file in payload.files]
}
# 转换 conversation_id 为 UUID
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
@@ -699,16 +737,7 @@ class WorkflowService:
}
try:
files = []
if payload.files:
for file in payload.files:
files.append(
{
"type": file.type,
"url": await self.multimodal_service.get_file_url(file),
"__file": True
}
)
files = await self._handle_file_input(payload.files)
input_data["files"] = files
self.update_execution_status(execution.execution_id, "running")
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
@@ -723,7 +752,6 @@ class WorkflowService:
input_data["conv_messages"] = last_state.get("messages") or []
break
init_message_length = len(input_data.get("conv_messages", []))
from app.core.workflow.executor import execute_workflow_stream
async for event in execute_workflow_stream(
workflow_config=workflow_config_dict,
@@ -789,37 +817,6 @@ class WorkflowService:
return node.get("config", {}).get("variables", [])
raise BusinessException("workflow config error - start node not found")
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
"""清理事件数据,移除不可序列化的对象
Args:
event: 原始事件数据
Returns:
可序列化的事件数据
"""
from langchain_core.messages import BaseMessage
def clean_value(value):
"""递归清理值"""
if isinstance(value, BaseMessage):
# 将 Message 对象转换为字典
return {
"type": value.__class__.__name__,
"content": value.content,
}
elif isinstance(value, dict):
return {k: clean_value(v) for k, v in value.items()}
elif isinstance(value, list):
return [clean_value(item) for item in value]
elif isinstance(value, (str, int, float, bool, type(None))):
return value
else:
# 其他不可序列化的对象转换为字符串
return str(value)
return clean_value(event)
# ==================== 依赖注入函数 ====================

View File

@@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
return result
try:
def sync_task():
trio.run(
lambda: _run(
row=task,
@@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
with_community=with_community,
)
)
try:
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(sync_task)
future.result() # Blocks until the task completes
except Exception as e:
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"

View File

@@ -139,7 +139,7 @@ SMTP_USER=
SMTP_PASSWORD=
# 本体类型融合配置 (记得写入env_example)
GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中

View File

@@ -456,6 +456,7 @@ export const en = {
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
imageSquareRequired: 'Please upload a square image',
nameInvalid: 'Name cannot start or end with a space',
notAllSpaces: 'Cannot be all spaces',
},
model: {
searchPlaceholder: 'search model…',
@@ -1782,6 +1783,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
mcp: 'MCP Services',
inner: 'Built-in Tools',
custom: 'Custom Tools',
market: 'Tool Market',
mcpSearchPlaceholder: 'Search MCP Services...',
innerSearchPlaceholder: 'Search Tools...',
customSearchPlaceholder: 'Search Custom Tools...',
@@ -1955,7 +1957,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
path: 'Path',
viewDetail: 'View Details',
textLink: 'Test Connection',
noResult: 'Processing results will be displayed here'
noResult: 'Processing results will be displayed here',
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
},
workflow: {
coreNode: 'Core Nodes',

View File

@@ -1036,6 +1036,7 @@ export const zh = {
logoTip: `支持图片格式JPG、PNG\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
imageSquareRequired: '请上传正方形比例图片',
nameInvalid: '不能是空格开头或结尾',
notAllSpaces: '不能是纯空格',
},
model: {
searchPlaceholder: '搜索模型…',
@@ -1779,6 +1780,7 @@ export const zh = {
mcp: 'MCP 服务',
inner: '内置工具',
custom: '自定义工具',
market: '工具市场',
mcpSearchPlaceholder: '搜索MCP服务...',
innerSearchPlaceholder: '搜索工具...',
customSearchPlaceholder: '搜索自定义工具...',
@@ -1952,7 +1954,9 @@ export const zh = {
path: '路径',
viewDetail: '查看详情',
textLink: '测试连接',
noResult: '处理结果将显示在这里'
noResult: '处理结果将显示在这里',
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
},
workflow: {
coreNode: '核心节点',

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 16:58:03
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-03 13:46:22
* @Last Modified time: 2026-03-04 12:10:44
*/
/**
* Conversation Page
@@ -267,8 +267,8 @@ const Conversation: FC = () => {
currentConversationId = newId
break
case 'message':
const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; }
updateAssistantMessage(content ?? chunk)
const { content, conversation_id: curId } = item.data as { content: string; conversation_id: string; }
updateAssistantMessage(content)
if (curId) {
currentConversationId = curId;

View File

@@ -185,6 +185,7 @@ const OntologyClassExtractModal = forwardRef<OntologyClassExtractModalRef, Ontol
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 2000 },
{ pattern: /^(?!\s*$).+$/, message: t('common.notAllSpaces') },
]}
>
<Input.TextArea placeholder={t('ontology.scenarioPlaceholder')} />

View File

@@ -0,0 +1,315 @@
import React, { useState, useRef, type ReactNode } from 'react';
import { Input, Button, Spin, App } from 'antd';
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
import { useTranslation } from 'react-i18next';
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
interface MarketSource {
id: string;
name: string;
category: string;
icon: string;
url: string;
desc: string;
apiKey: string;
connected: boolean;
mcpCount: number;
}
interface MarketMcp {
id: string;
name: string;
provider: string;
type: string;
desc: string;
downloads?: string;
stars?: string;
icon: string;
configTemplate: any;
}
interface MarketCategory {
id: string;
name: string;
icon: string;
}
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
const { t } = useTranslation();
const { message } = App.useApp();
const [loading, setLoading] = useState(false);
const [selectedSource, setSelectedSource] = useState<string | null>(null);
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
const [marketSources, setMarketSources] = useState<MarketSource[]>([
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 },
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 },
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 },
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 },
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
]);
const [categories] = useState<MarketCategory[]>([
{ id: 'official', name: '官方/综合', icon: '🌐' },
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
{ id: 'community', name: '社区/垂直', icon: '👥' }
]);
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
'github-mcp': [
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
]
});
const [searchKeyword, setSearchKeyword] = useState('');
const handleSelectSource = (sourceId: string) => {
setSelectedSource(sourceId);
};
const handleRefresh = (sourceId: string) => {
setLoading(true);
setTimeout(() => {
// 模拟刷新数据
const source = marketSources.find(s => s.id === sourceId);
if (source) {
message.success(`${source.name} 列表已刷新`);
}
setLoading(false);
}, 600);
};
const handleOpenConfig = (sourceId: string) => {
const source = marketSources.find(s => s.id === sourceId);
if (source) {
marketConfigModalRef.current?.handleOpen(source);
}
};
const handleConnect = (sourceId: string, apiKey: string) => {
// 更新市场源状态
setMarketSources(prev => prev.map(source => {
if (source.id === sourceId) {
return {
...source,
apiKey,
connected: true
};
}
return source;
}));
// 模拟获取MCP列表
setTimeout(() => {
const source = marketSources.find(s => s.id === sourceId);
if (source && !mcpCache[sourceId]) {
// 生成模拟数据
const mockData: MarketMcp[] = [
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
];
setMcpCache(prev => ({
...prev,
[sourceId]: mockData
}));
}
message.success(`已连接 ${source?.name}`);
}, 800);
};
const renderSourceDetail = () => {
if (!selectedSource) {
return (
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
<div className="rb:text-6xl rb:mb-4">🏪</div>
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2"> MCP </h3>
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md"> MCP </p>
</div>
);
}
const source = marketSources.find(s => s.id === selectedSource);
if (!source) return null;
const mcpList = mcpCache[selectedSource] || [];
const filteredList = mcpList.filter(mcp =>
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase())
);
return (
<>
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6">
<div className="rb:flex rb:gap-4">
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0">
{source.icon}
</div>
<div className="rb:flex-1">
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2>
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p>
</div>
</div>
<div className="rb:flex rb:gap-3">
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
</Button>
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
</Button>
</div>
</div>
<div className="rb:mt-6">
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
MCP <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
</h3>
<div className="rb:flex rb:gap-3 rb:items-center">
{source.connected && (
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
</Button>
)}
{mcpList.length > 0 && (
<Input
prefix={<SearchOutlined />}
placeholder="搜索服务..."
value={searchKeyword}
onChange={(e) => setSearchKeyword(e.target.value)}
style={{ width: 200 }}
/>
)}
</div>
</div>
{mcpList.length > 0 ? (
<Spin spinning={loading}>
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4">
{filteredList.map(mcp => (
<div
key={mcp.id}
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
>
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg">
{mcp.icon}
</div>
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${
mcp.type === 'Hosted'
? 'rb:bg-blue-50 rb:text-blue-700'
: 'rb:bg-gray-100 rb:text-gray-600'
}`}>
{mcp.type}
</span>
</div>
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3>
{mcp.provider && (
<div className="rb:mb-2">
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span>
</div>
)}
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p>
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
{mcp.downloads && (
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
<GlobalOutlined /> {mcp.downloads}
</span>
)}
{mcp.stars && (
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
{mcp.stars}
</span>
)}
</div>
<div className="rb:flex rb:justify-end">
<Button type="primary" size="small">
+
</Button>
</div>
</div>
))}
</div>
</Spin>
) : (
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
</h4>
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
</p>
{!source.connected && (
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
</Button>
)}
</div>
)}
</div>
</>
);
};
return (
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]">
{/* 左侧市场源列表 */}
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0">
<div className="rb:p-4 rb:border-b rb:border-gray-200">
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP </span>
</div>
{categories.map(cat => (
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0">
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase">
<span className="rb:text-sm">{cat.icon}</span>
<span>{cat.name}</span>
</div>
<div className="rb:px-2 rb:py-1">
{marketSources
.filter(s => s.category === cat.id)
.map(source => (
<div
key={source.id}
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${
selectedSource === source.id
? 'rb:bg-blue-50 rb:text-blue-600'
: 'hover:rb:bg-gray-50'
}`}
onClick={() => handleSelectSource(source.id)}
>
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span>
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
{source.name}
</span>
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full">
{source.mcpCount}
</span>
{source.connected && (
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1"></span>
)}
</div>
))}
</div>
</div>
))}
</div>
{/* 右侧内容区 */}
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden">
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
{renderSourceDetail()}
</div>
</div>
{/* 配置弹窗 */}
<MarketConfigModal
ref={marketConfigModalRef}
onConnect={handleConnect}
/>
</div>
);
};
export default Market;

View File

@@ -6,6 +6,7 @@ import type { CustomToolItem, CustomToolModalRef, ToolItem } from '../types'
import RbModal from '@/components/RbModal';
import { parseSchema, addTool, updateTool } from '@/api/tools';
import Table from '@/components/Table';
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
interface CustomToolModalProps {
@@ -134,7 +135,11 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
<Form.Item
name="name"
label={t('tool.name')}
rules={[{ required: true, message: t('common.enterNamePlaceholder') }]}
rules={[
{ required: true, message: t('tool.enterNamePlaceholder') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('tool.enterNamePlaceholder')} />
</Form.Item>

View File

@@ -0,0 +1,173 @@
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, Input, Button, App, Space } from 'antd';
import { useTranslation } from 'react-i18next';
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
import RbModal from '@/components/RbModal';
const FormItem = Form.Item;
interface MarketSource {
id: string;
name: string;
icon: string;
url: string;
desc: string;
apiKey: string;
connected: boolean;
}
interface MarketConfigModalProps {
onConnect: (sourceId: string, apiKey: string) => void;
}
export interface MarketConfigModalRef {
handleOpen: (source: MarketSource) => void;
handleClose: () => void;
}
const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProps>(({
onConnect
}, ref) => {
const { t } = useTranslation();
const { message } = App.useApp();
const [visible, setVisible] = useState(false);
const [form] = Form.useForm();
const [loading, setLoading] = useState(false);
const [currentSource, setCurrentSource] = useState<MarketSource | null>(null);
const [showApiKey, setShowApiKey] = useState(false);
const handleClose = () => {
setVisible(false);
form.resetFields();
setLoading(false);
setCurrentSource(null);
setShowApiKey(false);
};
const handleOpen = (source: MarketSource) => {
setCurrentSource(source);
form.setFieldsValue({
url: source.url,
apiKey: source.apiKey,
});
setVisible(true);
};
const handleSave = () => {
form
.validateFields()
.then((values) => {
if (!currentSource) return;
setLoading(true);
// 模拟连接延迟
setTimeout(() => {
onConnect(currentSource.id, values.apiKey || '');
message.success(`正在连接 ${currentSource.name}...`);
setLoading(false);
handleClose();
}, 500);
})
.catch((err) => {
console.log('表单验证失败:', err);
});
};
const handleCopyUrl = () => {
if (currentSource?.url) {
navigator.clipboard.writeText(currentSource.url).then(() => {
message.success(t('common.copySuccess'));
});
}
};
useImperativeHandle(ref, () => ({
handleOpen,
handleClose
}));
if (!currentSource) return null;
return (
<RbModal
title={`配置 ${currentSource.name}`}
open={visible}
onCancel={handleClose}
okText="保存并连接"
onOk={handleSave}
confirmLoading={loading}
width={600}
>
<div>
{/* 市场源信息头部 */}
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0">
{currentSource.icon}
</div>
<div className="rb:flex-1">
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p>
</div>
</div>
<Form
form={form}
layout="vertical"
>
{/* 市场地址 */}
<FormItem
name="url"
label="市场地址"
>
<Space.Compact style={{ width: '100%' }}>
<Input
readOnly
placeholder="市场地址"
/>
<Button
icon={<CopyOutlined />}
onClick={handleCopyUrl}
>
</Button>
</Space.Compact>
</FormItem>
{/* API Key */}
<FormItem
name="apiKey"
label={
<span>
API Key <span className="rb:text-gray-400 rb:font-normal">()</span>
</span>
}
extra="部分市场需要 API Key 才能获取完整的服务列表"
>
<Space.Compact style={{ width: '100%' }}>
<Input
type={showApiKey ? 'text' : 'password'}
placeholder="输入 API Key 以获取更多服务"
autoComplete="off"
/>
<Button
icon={showApiKey ? <EyeInvisibleOutlined /> : <EyeOutlined />}
onClick={() => setShowApiKey(!showApiKey)}
/>
</Space.Compact>
</FormItem>
{/* 连接状态 */}
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
<span className="rb:text-gray-600"></span>
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
{currentSource.connected ? '● 已连接' : '○ 未连接'}
</span>
</div>
</Form>
</div>
</RbModal>
);
});
export default MarketConfigModal;

View File

@@ -9,6 +9,7 @@ import RequestHeaderModal from './RequestHeaderModal';
import Table from '@/components/Table';
import { addTool, updateTool, testConnection } from '@/api/tools'
import type { McpServiceModalRef } from '../types'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -168,14 +169,22 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
name={['config', "server_url"]}
label={t('tool.serviceEndpoint')}
extra={t('tool.serviceEndpointExtra')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 500 },
{ pattern: /^https?:\/\/\S+$/, message: t('tool.serverUrlInvalid') },
]}
>
<Input placeholder={t('tool.serviceEndpointPlaceholder')} />
</FormItem>
<Form.Item
name="name"
label={t('tool.name')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ max: 50 },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
]}
>
<Input placeholder={t('tool.namePlaceholder')} />
</Form.Item>
@@ -201,6 +210,7 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
<FormItem
name="description"
label={t('tool.description')}
rules={[{ max: 500 }]}
>
<Input.TextArea rows={3} placeholder={t('common.inputPlaceholder', { title: t('tool.description') })}/>
</FormItem>

View File

@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
import type { RequestHeader, RequestHeaderModalRef } from './McpServiceModal'
import RbModal from '@/components/RbModal'
import { stringRegExp } from '@/utils/validator';
const FormItem = Form.Item;
@@ -82,7 +83,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
<FormItem
name="key"
label={t('tool.requestHeaderName')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ pattern: /^[a-zA-Z0-9][a-zA-Z0-9\-_]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$/, message: t('tool.requestHeaderKeyInvalid') },
{ max: 100 }
]}
>
<Input placeholder={t('common.enter')} />
</FormItem>
@@ -90,7 +95,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
<FormItem
name="value"
label={t('tool.requestHeaderValue')}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ pattern: stringRegExp, message: t('common.nameInvalid') },
{ max: 2000 }
]}
>
<Input placeholder={t('common.enter',)} />
</FormItem>

View File

@@ -1,3 +1,11 @@
/*
* @Description:
* @Version: 0.0.1
* @Author: yujiangping
* @Date: 2026-01-05 17:22:23
* @LastEditors: yujiangping
* @LastEditTime: 2026-03-04 15:12:48
*/
import React, { useState } from 'react';
import { Tabs } from 'antd';
import { useTranslation } from 'react-i18next';
@@ -5,9 +13,10 @@ import { useTranslation } from 'react-i18next';
import Mcp from './Mcp';
import Inner from './Inner';
import Custom from './Custom';
import Market from './Market';
import Tag from '@/components/Tag'
const tabKeys = ['mcp', 'inner', 'custom']
const tabKeys = ['mcp', 'inner', 'custom', 'market']
const ToolManagement: React.FC = () => {
const { t } = useTranslation();
const [activeTab, setActiveTab] = useState('mcp');
@@ -45,6 +54,7 @@ const ToolManagement: React.FC = () => {
{activeTab === 'mcp' && <Mcp getStatusTag={getStatusTag} />}
{activeTab === 'inner' && <Inner getStatusTag={getStatusTag} />}
{activeTab === 'custom' && <Custom getStatusTag={getStatusTag} />}
{/* {activeTab === 'market' && <Market getStatusTag={getStatusTag} />} */}
</div>
);
};

View File

@@ -6,6 +6,7 @@ import {
getShortTerm,
} from '@/api/memory'
import Empty from '@/components/Empty'
import Markdown from '@/components/Markdown'
interface ShortTermItem {
retrieval: Array<{ query: string; retrieval: string[]; }>;
@@ -85,7 +86,9 @@ const ShortTermDetail: FC = () => {
))}
<div>
<div className="rb:font-medium rb:leading-5 rb:mb-1">{t('shortTermDetail.answer')}</div>
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">{vo.answer}</div>
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">
<Markdown content={vo.answer} />
</div>
</div>
</Space>
</div>
@@ -103,7 +106,9 @@ const ShortTermDetail: FC = () => {
: data.long_term?.map((vo, voIdx) => (
<div key={voIdx} className="rb:leading-5 rb:shadow-[inset_3px_0px_0px_0px_#155EEF] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:px-6 rb:py-3">
<div className="rb:mb-1 rb:font-medium rb:leading-5.5">{vo.query}</div>
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">{vo.retrieval}</div>
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">
<Markdown content={vo.retrieval} />
</div>
</div>
))
}

View File

@@ -174,8 +174,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
*/
const handleStreamMessage = (data: SSEMessage[]) => {
data.forEach(item => {
const { chunk, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
chunk: string;
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
content: string;
conversation_id: string | null;
cycle_id: string;
cycle_idx: number;
@@ -202,7 +202,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
if (lastIndex >= 0) {
newList[lastIndex] = {
...newList[lastIndex],
content: newList[lastIndex].content + chunk
content: newList[lastIndex].content + content
}
}
return newList