Merge branch 'develop' into feature/model_zy
This commit is contained in:
3
api/app/cache/__init__.py
vendored
3
api/app/cache/__init__.py
vendored
@@ -3,9 +3,10 @@ Cache 缓存模块
|
||||
|
||||
提供各种缓存功能的统一入口
|
||||
"""
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
2
api/app/cache/memory/__init__.py
vendored
2
api/app/cache/memory/__init__.py
vendored
@@ -5,8 +5,10 @@ Memory 缓存模块
|
||||
"""
|
||||
from .emotion_memory import EmotionMemoryCache
|
||||
from .implicit_memory import ImplicitMemoryCache
|
||||
from .interest_memory import InterestMemoryCache
|
||||
|
||||
__all__ = [
|
||||
"EmotionMemoryCache",
|
||||
"ImplicitMemoryCache",
|
||||
"InterestMemoryCache",
|
||||
]
|
||||
|
||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
||||
"""
|
||||
Interest Distribution Cache
|
||||
|
||||
兴趣分布缓存模块
|
||||
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from app.aioRedis import aio_redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 缓存过期时间:24小时
|
||||
INTEREST_CACHE_EXPIRE = 86400
|
||||
|
||||
|
||||
class InterestMemoryCache:
|
||||
"""兴趣分布缓存类"""
|
||||
|
||||
PREFIX = "cache:memory:interest_distribution"
|
||||
|
||||
@classmethod
|
||||
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||
"""生成 Redis key
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
完整的 Redis key
|
||||
"""
|
||||
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||
|
||||
@classmethod
|
||||
async def set_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
data: List[Dict[str, Any]],
|
||||
expire: int = INTEREST_CACHE_EXPIRE,
|
||||
) -> bool:
|
||||
"""设置用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||
expire: 过期时间(秒),默认24小时
|
||||
|
||||
Returns:
|
||||
是否设置成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
payload = {
|
||||
"data": data,
|
||||
"generated_at": datetime.now().isoformat(),
|
||||
"cached": True,
|
||||
}
|
||||
value = json.dumps(payload, ensure_ascii=False)
|
||||
await aio_redis.set(key, value, ex=expire)
|
||||
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
async def get_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""获取用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
兴趣分布列表,缓存不存在或已过期返回 None
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
value = await aio_redis.get(key)
|
||||
if value:
|
||||
payload = json.loads(value)
|
||||
logger.info(f"命中兴趣分布缓存: {key}")
|
||||
return payload.get("data")
|
||||
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
async def delete_interest_distribution(
|
||||
cls,
|
||||
end_user_id: str,
|
||||
language: str,
|
||||
) -> bool:
|
||||
"""删除用户兴趣分布缓存
|
||||
|
||||
Args:
|
||||
end_user_id: 用户ID
|
||||
language: 语言类型
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
"""
|
||||
try:
|
||||
key = cls._get_key(end_user_id, language)
|
||||
result = await aio_redis.delete(key)
|
||||
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||
return result > 0
|
||||
except Exception as e:
|
||||
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||
return False
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import platform
|
||||
from datetime import timedelta
|
||||
from celery.schedules import crontab
|
||||
from urllib.parse import quote
|
||||
|
||||
from celery import Celery
|
||||
@@ -43,8 +44,8 @@ celery_app.conf.update(
|
||||
task_ignore_result=False,
|
||||
|
||||
# 超时设置
|
||||
task_time_limit=1800, # 30分钟硬超时
|
||||
task_soft_time_limit=1500, # 25分钟软超时
|
||||
task_time_limit=3600, # 60分钟硬超时
|
||||
task_soft_time_limit=3000, # 50分钟软超时
|
||||
|
||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||
@@ -90,11 +91,10 @@ celery_app.conf.update(
|
||||
celery_app.autodiscover_tasks(['app'])
|
||||
|
||||
# Celery Beat schedule for periodic tasks
|
||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
||||
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||
# 这个30秒的设计不合理
|
||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
||||
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||
|
||||
#构建定时任务配置
|
||||
beat_schedule_config = {
|
||||
|
||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||
match retrieve_data.retrieve_type:
|
||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case chunk_schema.RetrieveType.SEMANTIC:
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
return success(data=rs, msg="retrieval successful")
|
||||
case _:
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||
# Efficient deduplication
|
||||
seen_ids = set()
|
||||
unique_rs = []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||
from app.celery_app import celery_app
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.language_utils import get_language_from_header
|
||||
@@ -661,34 +662,56 @@ async def get_knowledge_type_stats_api(
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
||||
async def get_hot_memory_tags_by_user_api(
|
||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
||||
limit: int = Query(20, description="返回标签数量限制"),
|
||||
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||
async def get_interest_distribution_by_user_api(
|
||||
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: Session=Depends(get_db),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
"""
|
||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
||||
language = get_language_from_header(language_type)
|
||||
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||
try:
|
||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
||||
# 优先读取缓存
|
||||
cached = await InterestMemoryCache.get_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit
|
||||
language=language,
|
||||
)
|
||||
return success(data=result, msg="获取热门记忆标签成功")
|
||||
if cached is not None:
|
||||
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||
|
||||
# 缓存未命中,调用模型生成
|
||||
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||
end_user_id=end_user_id,
|
||||
limit=limit,
|
||||
language=language
|
||||
)
|
||||
|
||||
# 写入缓存,24小时过期
|
||||
await InterestMemoryCache.set_interest_distribution(
|
||||
end_user_id=end_user_id,
|
||||
language=language,
|
||||
data=result,
|
||||
)
|
||||
|
||||
return success(data=result, msg="获取兴趣分布标签成功")
|
||||
except Exception as e:
|
||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
||||
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||
|
||||
|
||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||
|
||||
@@ -606,8 +606,8 @@ async def dashboard_data(
|
||||
|
||||
# 获取RAG相关数据
|
||||
try:
|
||||
# total_memory: 使用 total_chunk(总chunk数)
|
||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
||||
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||
rag_data["total_memory"] = total_chunk
|
||||
|
||||
# total_app: 统计当前空间下的所有app数量
|
||||
|
||||
@@ -249,6 +249,7 @@ async def chat(
|
||||
app_id=app.id,
|
||||
workspace_id=workspace_id,
|
||||
release_id=app.current_release.id,
|
||||
public=True
|
||||
):
|
||||
event_type = event.get("event", "message")
|
||||
event_data = event.get("data", {})
|
||||
|
||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
||||
|
||||
Stores memory content for the specified end user using the Memory API Service.
|
||||
"""
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||
|
||||
memory_api_service = MemoryAPIService(db)
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Annotated, Any, Dict, Optional
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field, TypeAdapter
|
||||
|
||||
load_dotenv()
|
||||
|
||||
@@ -200,12 +201,25 @@ class Settings:
|
||||
|
||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||
|
||||
# Memory Cache Regeneration Configuration
|
||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||
|
||||
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||
|
||||
# Memory Module Configuration (internal)
|
||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||
@@ -230,7 +244,7 @@ class Settings:
|
||||
# General Ontology Type Configuration
|
||||
# ========================================================================
|
||||
# 通用本体文件路径列表(逗号分隔)
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||
|
||||
# 是否启用通用本体类型功能
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
|
||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||
|
||||
class InterestTags(BaseModel):
|
||||
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
||||
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
||||
|
||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
"""
|
||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||
return structured_response.meaningful_tags
|
||||
|
||||
except Exception as e:
|
||||
print(f"LLM筛选过程中发生错误: {e}")
|
||||
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
# 在LLM失败时返回原始标签,确保流程继续
|
||||
return tags
|
||||
|
||||
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
||||
"""
|
||||
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
||||
|
||||
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
Args:
|
||||
tags: 原始标签列表
|
||||
end_user_id: 用户ID,用于获取LLM配置
|
||||
|
||||
Returns:
|
||||
筛选后的兴趣活动标签列表
|
||||
"""
|
||||
try:
|
||||
with get_db_context() as db:
|
||||
from app.services.memory_agent_service import (
|
||||
get_end_user_connected_config,
|
||||
)
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
|
||||
if not config_id and not workspace_id:
|
||||
raise ValueError(
|
||||
f"No memory_config_id found for end_user_id: {end_user_id}."
|
||||
)
|
||||
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
if not memory_config.llm_model_id:
|
||||
raise ValueError(
|
||||
f"No llm_model_id found in memory config {config_id}."
|
||||
)
|
||||
|
||||
factory = MemoryClientFactory(db)
|
||||
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||
|
||||
tag_list_str = ", ".join(tags)
|
||||
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
||||
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": rendered_prompt
|
||||
}
|
||||
]
|
||||
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=InterestTags
|
||||
)
|
||||
|
||||
return structured_response.interest_tags
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||
return tags
|
||||
|
||||
|
||||
async def get_raw_tags_from_db(
|
||||
connector: Neo4jConnector,
|
||||
end_user_id: str,
|
||||
@@ -183,3 +254,56 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取用户的兴趣分布标签。
|
||||
|
||||
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
||||
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
ValueError: 如果end_user_id未提供或为空
|
||||
"""
|
||||
if not end_user_id or not end_user_id.strip():
|
||||
raise ValueError(
|
||||
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||
)
|
||||
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 查询更多原始标签,给LLM提供充足上下文
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
||||
|
||||
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
||||
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
||||
|
||||
# 构建最终标签列表:
|
||||
# - 原始标签中存在的,保留原始频率
|
||||
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
||||
final_tags = []
|
||||
seen = set()
|
||||
for tag in interest_tag_names:
|
||||
if tag in seen:
|
||||
continue
|
||||
seen.add(tag)
|
||||
freq = raw_freq_map.get(tag, 1)
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
# 按频率降序排列
|
||||
final_tags.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
await connector.close()
|
||||
|
||||
@@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt(
|
||||
})
|
||||
|
||||
return rendered_prompt
|
||||
|
||||
|
||||
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
|
||||
"""
|
||||
Renders the interest filter prompt using the interest_filter.jinja2 template.
|
||||
|
||||
Args:
|
||||
tag_list: Comma-separated string of raw tags to filter
|
||||
language: Output language ("zh" for Chinese, "en" for English)
|
||||
|
||||
Returns:
|
||||
Rendered prompt content as string
|
||||
"""
|
||||
template = prompt_env.get_template("interest_filter.jinja2")
|
||||
rendered_prompt = template.render(tag_list=tag_list, language=language)
|
||||
log_prompt_rendering('interest filter', rendered_prompt)
|
||||
return rendered_prompt
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
{% if language == "zh" %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
|
||||
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
|
||||
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
|
||||
- '川味可视化', '川菜' → '烹饪'
|
||||
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
|
||||
- '吉他', '指弹', '琴谱' → '吉他'
|
||||
- '跑步', '5公里', '跑鞋' → '跑步'
|
||||
- '瑜伽垫', '瑜伽课' → '瑜伽'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
|
||||
- If multiple tags all point to '攀岩', output '攀岩' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
|
||||
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in Chinese.
|
||||
|
||||
**Example**:
|
||||
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
|
||||
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
|
||||
{% else %}
|
||||
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||
|
||||
**Step 1 - Infer the underlying interest from each tag**:
|
||||
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||
|
||||
Examples of inference:
|
||||
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
|
||||
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
|
||||
- 'morning meditation streak', 'mind-body peak' → 'meditation'
|
||||
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
|
||||
- 'open source project', 'data visualization tool', 'Python' → 'programming'
|
||||
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
|
||||
- 'running', '5km', 'running shoes' → 'running'
|
||||
|
||||
**Step 2 - Consolidate and deduplicate**:
|
||||
- Merge tags that point to the same interest into one representative label
|
||||
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
|
||||
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
|
||||
|
||||
**Step 3 - Filter out non-interest tags**:
|
||||
Remove tags that do NOT suggest any hobby or interest:
|
||||
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
|
||||
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
|
||||
|
||||
**Output format**: Return a list of concise interest activity names in English.
|
||||
|
||||
**Example**:
|
||||
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
|
||||
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
|
||||
|
||||
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
|
||||
{% endif %}
|
||||
@@ -127,7 +127,7 @@ class EventStreamHandler:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": data.get("chunk")
|
||||
"content": data.get("chunk")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
|
||||
yield {
|
||||
"event": "message",
|
||||
"data": {
|
||||
"chunk": final_chunk
|
||||
"content": final_chunk
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -272,7 +272,7 @@ class WorkflowExecutor:
|
||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||
if event_type == "node_chunk":
|
||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||
full_content += msg_event["data"]["chunk"]
|
||||
full_content += msg_event["data"]["content"]
|
||||
yield msg_event
|
||||
|
||||
elif event_type == "node_error":
|
||||
@@ -295,12 +295,12 @@ class WorkflowExecutor:
|
||||
self.graph,
|
||||
self.execution_context.checkpoint_config
|
||||
):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
# Flush any remaining chunks
|
||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||
full_content += msg_event["data"]['chunk']
|
||||
full_content += msg_event["data"]['content']
|
||||
yield msg_event
|
||||
|
||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||
|
||||
@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和
|
||||
"""
|
||||
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
from sqlalchemy import func
|
||||
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id == "Memory"
|
||||
).scalar()
|
||||
|
||||
total = result if result is not None else 0
|
||||
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||
return total
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||
"""
|
||||
根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
|
||||
|
||||
try:
|
||||
count = db.query(Knowledge).filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id != "Memory"
|
||||
).count()
|
||||
|
||||
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
|
||||
return count
|
||||
except Exception as e:
|
||||
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
|
||||
class ChunkRetrieve(BaseModel):
|
||||
query: str
|
||||
kb_ids: list[uuid.UUID]
|
||||
file_names_filter: list[str] | None = Field(None)
|
||||
similarity_threshold: float | None = Field(None)
|
||||
vector_similarity_weight: float | None = Field(None)
|
||||
top_k: int | None = Field(None)
|
||||
|
||||
@@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import (
|
||||
)
|
||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.db import get_db_context
|
||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||
@@ -890,36 +890,36 @@ class MemoryAgentService:
|
||||
return result
|
||||
|
||||
|
||||
async def get_hot_memory_tags_by_user(
|
||||
|
||||
async def get_interest_distribution_by_user(
|
||||
self,
|
||||
end_user_id: Optional[str] = None,
|
||||
limit: int = 20
|
||||
limit: int = 5,
|
||||
language: str = "zh"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定用户的热门记忆标签
|
||||
获取指定用户的兴趣分布标签。
|
||||
|
||||
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
|
||||
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||
|
||||
参数:
|
||||
- end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段
|
||||
- end_user_id: 用户ID(必填)
|
||||
- limit: 返回标签数量限制
|
||||
- language: 输出语言("zh" 中文, "en" 英文)
|
||||
|
||||
返回格式:
|
||||
[
|
||||
{"name": "标签名", "frequency": 频次},
|
||||
{"name": "兴趣活动名", "frequency": 频次},
|
||||
...
|
||||
]
|
||||
|
||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
||||
"""
|
||||
try:
|
||||
# by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度)
|
||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
||||
payload = []
|
||||
for tag, freq in tags:
|
||||
payload.append({"name": tag, "frequency": freq})
|
||||
return payload
|
||||
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
|
||||
return [{"name": tag, "frequency": freq} for tag, freq in tags]
|
||||
except Exception as e:
|
||||
logger.error(f"热门记忆标签查询失败: {e}")
|
||||
raise Exception(f"热门记忆标签查询失败: {e}")
|
||||
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||
|
||||
|
||||
async def get_user_profile(
|
||||
|
||||
@@ -140,9 +140,11 @@ class MemoryAPIService:
|
||||
|
||||
try:
|
||||
# Delegate to MemoryAgentService
|
||||
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||
result = await MemoryAgentService().write_memory(
|
||||
end_user_id=end_user_id,
|
||||
messages=message,
|
||||
messages=messages,
|
||||
config_id=config_id,
|
||||
db=self.db,
|
||||
storage_type=storage_type,
|
||||
@@ -151,9 +153,18 @@ class MemoryAPIService:
|
||||
|
||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||
|
||||
# result may be a string "success" or a dict with a "status" key
|
||||
# Preserve the full dict so callers don't silently lose extra fields
|
||||
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||
if isinstance(result, dict):
|
||||
return {
|
||||
**result,
|
||||
"status": result.get("status", "unknown"),
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
return {
|
||||
"status": "success" if result == "success" else result,
|
||||
"end_user_id": end_user_id
|
||||
"status": result if isinstance(result, str) else "success",
|
||||
"end_user_id": end_user_id,
|
||||
}
|
||||
|
||||
except ConfigurationError as e:
|
||||
|
||||
@@ -390,19 +390,59 @@ def get_rag_total_kb(
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
|
||||
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
|
||||
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
|
||||
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
||||
return total_kb
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def get_rag_user_kb_total_chunk(
|
||||
db: Session,
|
||||
current_user: User
|
||||
) -> int:
|
||||
"""
|
||||
根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。
|
||||
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
|
||||
"""
|
||||
workspace_id = current_user.current_workspace_id
|
||||
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.models.app_model import App
|
||||
from sqlalchemy import func
|
||||
|
||||
# 通过 App 关联取该 workspace 下所有 end_user_id
|
||||
end_user_ids = [
|
||||
str(eid) for (eid,) in db.query(EndUser.id)
|
||||
.join(App, EndUser.app_id == App.id)
|
||||
.filter(App.workspace_id == workspace_id)
|
||||
.all()
|
||||
]
|
||||
if not end_user_ids:
|
||||
return 0
|
||||
|
||||
file_names = [f"{uid}.txt" for uid in end_user_ids]
|
||||
result = db.query(func.sum(Document.chunk_num)).filter(
|
||||
Document.file_name.in_(file_names)
|
||||
).scalar()
|
||||
|
||||
total_chunk = int(result or 0)
|
||||
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
|
||||
return total_chunk
|
||||
except Exception as e:
|
||||
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_current_user_total_chunk(
|
||||
end_user_id: str,
|
||||
db: Session,
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
from app.core.workflow.validator import validate_workflow_config
|
||||
from app.db import get_db
|
||||
@@ -23,7 +24,7 @@ from app.repositories.workflow_repository import (
|
||||
WorkflowExecutionRepository,
|
||||
WorkflowNodeExecutionRepository
|
||||
)
|
||||
from app.schemas import DraftRunRequest
|
||||
from app.schemas import DraftRunRequest, FileInput
|
||||
from app.services.conversation_service import ConversationService
|
||||
from app.services.multi_agent_service import convert_uuids_to_str
|
||||
from app.services.multimodal_service import MultimodalService
|
||||
@@ -445,6 +446,91 @@ class WorkflowService:
|
||||
"success_rate": completed / total if total > 0 else 0
|
||||
}
|
||||
|
||||
async def _handle_file_input(self, files: list[FileInput]):
|
||||
if not files:
|
||||
return []
|
||||
|
||||
files_struct = []
|
||||
for file in files:
|
||||
files_struct.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
return files_struct
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
"""
|
||||
Map internal workflow events to public-facing event formats.
|
||||
|
||||
Purpose:
|
||||
- Hide internal execution details
|
||||
- Expose a stable and simplified public event schema
|
||||
- Filter out non-public events
|
||||
- Maintain backward compatibility when possible
|
||||
|
||||
Args:
|
||||
event (dict): Internal event object, e.g.:
|
||||
{
|
||||
"event": "workflow_start",
|
||||
"data": {...}
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict | None:
|
||||
- Returns the mapped public event
|
||||
- Returns None if the event should not be exposed
|
||||
"""
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", "")),
|
||||
"error": payload.get("error", "")
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
Unified event emission entry.
|
||||
|
||||
Args:
|
||||
public (bool):
|
||||
- True -> Emit mapped public event
|
||||
- False -> Emit raw internal event
|
||||
|
||||
internal_event (dict):
|
||||
The original internal event object
|
||||
|
||||
Returns:
|
||||
dict | None:
|
||||
- The mapped event
|
||||
- Or None if the event is filtered out
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
# ==================== 工作流执行 ====================
|
||||
|
||||
async def run(
|
||||
@@ -479,10 +565,11 @@ class WorkflowService:
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
input_data = {
|
||||
"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
@@ -506,22 +593,8 @@ class WorkflowService:
|
||||
"execution_config": config.execution_config
|
||||
}
|
||||
|
||||
# 4. 获取工作空间 ID(从 app 获取)
|
||||
|
||||
# 5. 执行工作流
|
||||
from app.core.workflow.executor import execute_workflow
|
||||
|
||||
try:
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
files = await self._handle_file_input(payload.files)
|
||||
input_data["files"] = files
|
||||
# 更新状态为运行中
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
@@ -601,42 +674,6 @@ class WorkflowService:
|
||||
message=f"工作流执行失败: {str(e)}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _map_public_event(event: dict) -> dict | None:
|
||||
event_type = event.get("event")
|
||||
payload = event.get("data")
|
||||
match event_type:
|
||||
case "workflow_start":
|
||||
return {
|
||||
"event": "start",
|
||||
"data": {
|
||||
"conversation_id": payload.get("conversation_id"),
|
||||
}
|
||||
}
|
||||
case "workflow_end":
|
||||
return {
|
||||
"event": "end",
|
||||
"data": {
|
||||
"elapsed_time": payload.get("elapsed_time"),
|
||||
"message_length": len(payload.get("output", "")),
|
||||
"error": payload.get("error", "")
|
||||
}
|
||||
}
|
||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||
return None
|
||||
case _:
|
||||
return event
|
||||
|
||||
def _emit(self, public: bool, internal_event: dict):
|
||||
"""
|
||||
decide
|
||||
"""
|
||||
if public:
|
||||
mapped = self._map_public_event(internal_event)
|
||||
else:
|
||||
mapped = internal_event
|
||||
return mapped
|
||||
|
||||
async def run_stream(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
@@ -671,10 +708,11 @@ class WorkflowService:
|
||||
message=f"工作流配置不存在: app_id={app_id}"
|
||||
)
|
||||
|
||||
input_data = {"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
input_data = {
|
||||
"message": payload.message, "variables": payload.variables,
|
||||
"conversation_id": payload.conversation_id,
|
||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||
}
|
||||
|
||||
# 转换 conversation_id 为 UUID
|
||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||
@@ -699,16 +737,7 @@ class WorkflowService:
|
||||
}
|
||||
|
||||
try:
|
||||
files = []
|
||||
if payload.files:
|
||||
for file in payload.files:
|
||||
files.append(
|
||||
{
|
||||
"type": file.type,
|
||||
"url": await self.multimodal_service.get_file_url(file),
|
||||
"__file": True
|
||||
}
|
||||
)
|
||||
files = await self._handle_file_input(payload.files)
|
||||
input_data["files"] = files
|
||||
self.update_execution_status(execution.execution_id, "running")
|
||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||
@@ -723,7 +752,6 @@ class WorkflowService:
|
||||
input_data["conv_messages"] = last_state.get("messages") or []
|
||||
break
|
||||
init_message_length = len(input_data.get("conv_messages", []))
|
||||
from app.core.workflow.executor import execute_workflow_stream
|
||||
|
||||
async for event in execute_workflow_stream(
|
||||
workflow_config=workflow_config_dict,
|
||||
@@ -789,37 +817,6 @@ class WorkflowService:
|
||||
return node.get("config", {}).get("variables", [])
|
||||
raise BusinessException("workflow config error - start node not found")
|
||||
|
||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
||||
"""清理事件数据,移除不可序列化的对象
|
||||
|
||||
Args:
|
||||
event: 原始事件数据
|
||||
|
||||
Returns:
|
||||
可序列化的事件数据
|
||||
"""
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
def clean_value(value):
|
||||
"""递归清理值"""
|
||||
if isinstance(value, BaseMessage):
|
||||
# 将 Message 对象转换为字典
|
||||
return {
|
||||
"type": value.__class__.__name__,
|
||||
"content": value.content,
|
||||
}
|
||||
elif isinstance(value, dict):
|
||||
return {k: clean_value(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [clean_value(item) for item in value]
|
||||
elif isinstance(value, (str, int, float, bool, type(None))):
|
||||
return value
|
||||
else:
|
||||
# 其他不可序列化的对象转换为字符串
|
||||
return str(value)
|
||||
|
||||
return clean_value(event)
|
||||
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
|
||||
|
||||
@@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
||||
return result
|
||||
|
||||
try:
|
||||
def sync_task():
|
||||
trio.run(
|
||||
lambda: _run(
|
||||
row=task,
|
||||
@@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
||||
with_community=with_community,
|
||||
)
|
||||
)
|
||||
try:
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(sync_task)
|
||||
future.result() # Blocks until the task completes
|
||||
except Exception as e:
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
||||
|
||||
@@ -139,7 +139,7 @@ SMTP_USER=
|
||||
SMTP_PASSWORD=
|
||||
|
||||
# 本体类型融合配置 (记得写入env_example)
|
||||
GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
|
||||
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
|
||||
|
||||
@@ -456,6 +456,7 @@ export const en = {
|
||||
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
||||
imageSquareRequired: 'Please upload a square image',
|
||||
nameInvalid: 'Name cannot start or end with a space',
|
||||
notAllSpaces: 'Cannot be all spaces',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: 'search model…',
|
||||
@@ -1782,6 +1783,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
mcp: 'MCP Services',
|
||||
inner: 'Built-in Tools',
|
||||
custom: 'Custom Tools',
|
||||
market: 'Tool Market',
|
||||
mcpSearchPlaceholder: 'Search MCP Services...',
|
||||
innerSearchPlaceholder: 'Search Tools...',
|
||||
customSearchPlaceholder: 'Search Custom Tools...',
|
||||
@@ -1955,7 +1957,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
||||
path: 'Path',
|
||||
viewDetail: 'View Details',
|
||||
textLink: 'Test Connection',
|
||||
noResult: 'Processing results will be displayed here'
|
||||
noResult: 'Processing results will be displayed here',
|
||||
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
||||
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
||||
},
|
||||
workflow: {
|
||||
coreNode: 'Core Nodes',
|
||||
|
||||
@@ -1036,6 +1036,7 @@ export const zh = {
|
||||
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
||||
imageSquareRequired: '请上传正方形比例图片',
|
||||
nameInvalid: '不能是空格开头或结尾',
|
||||
notAllSpaces: '不能是纯空格',
|
||||
},
|
||||
model: {
|
||||
searchPlaceholder: '搜索模型…',
|
||||
@@ -1779,6 +1780,7 @@ export const zh = {
|
||||
mcp: 'MCP 服务',
|
||||
inner: '内置工具',
|
||||
custom: '自定义工具',
|
||||
market: '工具市场',
|
||||
mcpSearchPlaceholder: '搜索MCP服务...',
|
||||
innerSearchPlaceholder: '搜索工具...',
|
||||
customSearchPlaceholder: '搜索自定义工具...',
|
||||
@@ -1952,7 +1954,9 @@ export const zh = {
|
||||
path: '路径',
|
||||
viewDetail: '查看详情',
|
||||
textLink: '测试连接',
|
||||
noResult: '处理结果将显示在这里'
|
||||
noResult: '处理结果将显示在这里',
|
||||
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
||||
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
||||
},
|
||||
workflow: {
|
||||
coreNode: '核心节点',
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* @Author: ZhaoYing
|
||||
* @Date: 2026-02-03 16:58:03
|
||||
* @Last Modified by: ZhaoYing
|
||||
* @Last Modified time: 2026-03-03 13:46:22
|
||||
* @Last Modified time: 2026-03-04 12:10:44
|
||||
*/
|
||||
/**
|
||||
* Conversation Page
|
||||
@@ -267,8 +267,8 @@ const Conversation: FC = () => {
|
||||
currentConversationId = newId
|
||||
break
|
||||
case 'message':
|
||||
const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; }
|
||||
updateAssistantMessage(content ?? chunk)
|
||||
const { content, conversation_id: curId } = item.data as { content: string; conversation_id: string; }
|
||||
updateAssistantMessage(content)
|
||||
|
||||
if (curId) {
|
||||
currentConversationId = curId;
|
||||
|
||||
@@ -185,6 +185,7 @@ const OntologyClassExtractModal = forwardRef<OntologyClassExtractModalRef, Ontol
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ max: 2000 },
|
||||
{ pattern: /^(?!\s*$).+$/, message: t('common.notAllSpaces') },
|
||||
]}
|
||||
>
|
||||
<Input.TextArea placeholder={t('ontology.scenarioPlaceholder')} />
|
||||
|
||||
315
web/src/views/ToolManagement/Market.tsx
Normal file
315
web/src/views/ToolManagement/Market.tsx
Normal file
@@ -0,0 +1,315 @@
|
||||
import React, { useState, useRef, type ReactNode } from 'react';
|
||||
import { Input, Button, Spin, App } from 'antd';
|
||||
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
|
||||
|
||||
interface MarketSource {
|
||||
id: string;
|
||||
name: string;
|
||||
category: string;
|
||||
icon: string;
|
||||
url: string;
|
||||
desc: string;
|
||||
apiKey: string;
|
||||
connected: boolean;
|
||||
mcpCount: number;
|
||||
}
|
||||
|
||||
interface MarketMcp {
|
||||
id: string;
|
||||
name: string;
|
||||
provider: string;
|
||||
type: string;
|
||||
desc: string;
|
||||
downloads?: string;
|
||||
stars?: string;
|
||||
icon: string;
|
||||
configTemplate: any;
|
||||
}
|
||||
|
||||
interface MarketCategory {
|
||||
id: string;
|
||||
name: string;
|
||||
icon: string;
|
||||
}
|
||||
|
||||
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [selectedSource, setSelectedSource] = useState<string | null>(null);
|
||||
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
|
||||
const [marketSources, setMarketSources] = useState<MarketSource[]>([
|
||||
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 },
|
||||
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 },
|
||||
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 },
|
||||
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 },
|
||||
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
|
||||
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
|
||||
]);
|
||||
|
||||
const [categories] = useState<MarketCategory[]>([
|
||||
{ id: 'official', name: '官方/综合', icon: '🌐' },
|
||||
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
|
||||
{ id: 'community', name: '社区/垂直', icon: '👥' }
|
||||
]);
|
||||
|
||||
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
|
||||
'github-mcp': [
|
||||
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
|
||||
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
|
||||
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成,支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
|
||||
]
|
||||
});
|
||||
|
||||
const [searchKeyword, setSearchKeyword] = useState('');
|
||||
|
||||
const handleSelectSource = (sourceId: string) => {
|
||||
setSelectedSource(sourceId);
|
||||
};
|
||||
|
||||
const handleRefresh = (sourceId: string) => {
|
||||
setLoading(true);
|
||||
setTimeout(() => {
|
||||
// 模拟刷新数据
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source) {
|
||||
message.success(`${source.name} 列表已刷新`);
|
||||
}
|
||||
setLoading(false);
|
||||
}, 600);
|
||||
};
|
||||
|
||||
const handleOpenConfig = (sourceId: string) => {
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source) {
|
||||
marketConfigModalRef.current?.handleOpen(source);
|
||||
}
|
||||
};
|
||||
|
||||
const handleConnect = (sourceId: string, apiKey: string) => {
|
||||
// 更新市场源状态
|
||||
setMarketSources(prev => prev.map(source => {
|
||||
if (source.id === sourceId) {
|
||||
return {
|
||||
...source,
|
||||
apiKey,
|
||||
connected: true
|
||||
};
|
||||
}
|
||||
return source;
|
||||
}));
|
||||
|
||||
// 模拟获取MCP列表
|
||||
setTimeout(() => {
|
||||
const source = marketSources.find(s => s.id === sourceId);
|
||||
if (source && !mcpCache[sourceId]) {
|
||||
// 生成模拟数据
|
||||
const mockData: MarketMcp[] = [
|
||||
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
|
||||
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
|
||||
];
|
||||
setMcpCache(prev => ({
|
||||
...prev,
|
||||
[sourceId]: mockData
|
||||
}));
|
||||
}
|
||||
message.success(`已连接 ${source?.name}`);
|
||||
}, 800);
|
||||
};
|
||||
|
||||
const renderSourceDetail = () => {
|
||||
if (!selectedSource) {
|
||||
return (
|
||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
|
||||
<div className="rb:text-6xl rb:mb-4">🏪</div>
|
||||
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2">选择一个 MCP 市场</h3>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md">从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const source = marketSources.find(s => s.id === selectedSource);
|
||||
if (!source) return null;
|
||||
|
||||
const mcpList = mcpCache[selectedSource] || [];
|
||||
const filteredList = mcpList.filter(mcp =>
|
||||
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
|
||||
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase())
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6">
|
||||
<div className="rb:flex rb:gap-4">
|
||||
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0">
|
||||
{source.icon}
|
||||
</div>
|
||||
<div className="rb:flex-1">
|
||||
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="rb:flex rb:gap-3">
|
||||
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
||||
配置
|
||||
</Button>
|
||||
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
|
||||
前往市场
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="rb:mt-6">
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
|
||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
|
||||
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
|
||||
</h3>
|
||||
<div className="rb:flex rb:gap-3 rb:items-center">
|
||||
{source.connected && (
|
||||
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
||||
刷新
|
||||
</Button>
|
||||
)}
|
||||
{mcpList.length > 0 && (
|
||||
<Input
|
||||
prefix={<SearchOutlined />}
|
||||
placeholder="搜索服务..."
|
||||
value={searchKeyword}
|
||||
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||
style={{ width: 200 }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{mcpList.length > 0 ? (
|
||||
<Spin spinning={loading}>
|
||||
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4">
|
||||
{filteredList.map(mcp => (
|
||||
<div
|
||||
key={mcp.id}
|
||||
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
||||
>
|
||||
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
|
||||
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg">
|
||||
{mcp.icon}
|
||||
</div>
|
||||
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${
|
||||
mcp.type === 'Hosted'
|
||||
? 'rb:bg-blue-50 rb:text-blue-700'
|
||||
: 'rb:bg-gray-100 rb:text-gray-600'
|
||||
}`}>
|
||||
{mcp.type}
|
||||
</span>
|
||||
</div>
|
||||
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3>
|
||||
{mcp.provider && (
|
||||
<div className="rb:mb-2">
|
||||
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span>
|
||||
</div>
|
||||
)}
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p>
|
||||
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
|
||||
{mcp.downloads && (
|
||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||
<GlobalOutlined /> {mcp.downloads}
|
||||
</span>
|
||||
)}
|
||||
{mcp.stars && (
|
||||
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||
⭐ {mcp.stars}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="rb:flex rb:justify-end">
|
||||
<Button type="primary" size="small">
|
||||
+ 添加
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</Spin>
|
||||
) : (
|
||||
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
|
||||
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
|
||||
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
|
||||
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
|
||||
</h4>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
|
||||
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
|
||||
</p>
|
||||
{!source.connected && (
|
||||
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
|
||||
配置连接
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]">
|
||||
{/* 左侧市场源列表 */}
|
||||
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0">
|
||||
<div className="rb:p-4 rb:border-b rb:border-gray-200">
|
||||
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP 市场</span>
|
||||
</div>
|
||||
{categories.map(cat => (
|
||||
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0">
|
||||
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase">
|
||||
<span className="rb:text-sm">{cat.icon}</span>
|
||||
<span>{cat.name}</span>
|
||||
</div>
|
||||
<div className="rb:px-2 rb:py-1">
|
||||
{marketSources
|
||||
.filter(s => s.category === cat.id)
|
||||
.map(source => (
|
||||
<div
|
||||
key={source.id}
|
||||
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${
|
||||
selectedSource === source.id
|
||||
? 'rb:bg-blue-50 rb:text-blue-600'
|
||||
: 'hover:rb:bg-gray-50'
|
||||
}`}
|
||||
onClick={() => handleSelectSource(source.id)}
|
||||
>
|
||||
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span>
|
||||
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||
{source.name}
|
||||
</span>
|
||||
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full">
|
||||
{source.mcpCount}
|
||||
</span>
|
||||
{source.connected && (
|
||||
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1">●</span>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
{/* 右侧内容区 */}
|
||||
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden">
|
||||
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
|
||||
{renderSourceDetail()}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 配置弹窗 */}
|
||||
<MarketConfigModal
|
||||
ref={marketConfigModalRef}
|
||||
onConnect={handleConnect}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default Market;
|
||||
@@ -6,6 +6,7 @@ import type { CustomToolItem, CustomToolModalRef, ToolItem } from '../types'
|
||||
import RbModal from '@/components/RbModal';
|
||||
import { parseSchema, addTool, updateTool } from '@/api/tools';
|
||||
import Table from '@/components/Table';
|
||||
import { stringRegExp } from '@/utils/validator';
|
||||
const FormItem = Form.Item;
|
||||
|
||||
interface CustomToolModalProps {
|
||||
@@ -134,7 +135,11 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('tool.name')}
|
||||
rules={[{ required: true, message: t('common.enterNamePlaceholder') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('tool.enterNamePlaceholder') },
|
||||
{ max: 50 },
|
||||
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('tool.enterNamePlaceholder')} />
|
||||
</Form.Item>
|
||||
|
||||
173
web/src/views/ToolManagement/components/MarketConfigModal.tsx
Normal file
173
web/src/views/ToolManagement/components/MarketConfigModal.tsx
Normal file
@@ -0,0 +1,173 @@
|
||||
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||
import { Form, Input, Button, App, Space } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||
import RbModal from '@/components/RbModal';
|
||||
|
||||
const FormItem = Form.Item;
|
||||
|
||||
interface MarketSource {
|
||||
id: string;
|
||||
name: string;
|
||||
icon: string;
|
||||
url: string;
|
||||
desc: string;
|
||||
apiKey: string;
|
||||
connected: boolean;
|
||||
}
|
||||
|
||||
interface MarketConfigModalProps {
|
||||
onConnect: (sourceId: string, apiKey: string) => void;
|
||||
}
|
||||
|
||||
export interface MarketConfigModalRef {
|
||||
handleOpen: (source: MarketSource) => void;
|
||||
handleClose: () => void;
|
||||
}
|
||||
|
||||
const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProps>(({
|
||||
onConnect
|
||||
}, ref) => {
|
||||
const { t } = useTranslation();
|
||||
const { message } = App.useApp();
|
||||
const [visible, setVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [currentSource, setCurrentSource] = useState<MarketSource | null>(null);
|
||||
const [showApiKey, setShowApiKey] = useState(false);
|
||||
|
||||
const handleClose = () => {
|
||||
setVisible(false);
|
||||
form.resetFields();
|
||||
setLoading(false);
|
||||
setCurrentSource(null);
|
||||
setShowApiKey(false);
|
||||
};
|
||||
|
||||
const handleOpen = (source: MarketSource) => {
|
||||
setCurrentSource(source);
|
||||
form.setFieldsValue({
|
||||
url: source.url,
|
||||
apiKey: source.apiKey,
|
||||
});
|
||||
setVisible(true);
|
||||
};
|
||||
|
||||
const handleSave = () => {
|
||||
form
|
||||
.validateFields()
|
||||
.then((values) => {
|
||||
if (!currentSource) return;
|
||||
|
||||
setLoading(true);
|
||||
|
||||
// 模拟连接延迟
|
||||
setTimeout(() => {
|
||||
onConnect(currentSource.id, values.apiKey || '');
|
||||
message.success(`正在连接 ${currentSource.name}...`);
|
||||
setLoading(false);
|
||||
handleClose();
|
||||
}, 500);
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log('表单验证失败:', err);
|
||||
});
|
||||
};
|
||||
|
||||
const handleCopyUrl = () => {
|
||||
if (currentSource?.url) {
|
||||
navigator.clipboard.writeText(currentSource.url).then(() => {
|
||||
message.success(t('common.copySuccess'));
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
handleOpen,
|
||||
handleClose
|
||||
}));
|
||||
|
||||
if (!currentSource) return null;
|
||||
|
||||
return (
|
||||
<RbModal
|
||||
title={`配置 ${currentSource.name}`}
|
||||
open={visible}
|
||||
onCancel={handleClose}
|
||||
okText="保存并连接"
|
||||
onOk={handleSave}
|
||||
confirmLoading={loading}
|
||||
width={600}
|
||||
>
|
||||
<div>
|
||||
{/* 市场源信息头部 */}
|
||||
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
|
||||
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0">
|
||||
{currentSource.icon}
|
||||
</div>
|
||||
<div className="rb:flex-1">
|
||||
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
|
||||
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
>
|
||||
{/* 市场地址 */}
|
||||
<FormItem
|
||||
name="url"
|
||||
label="市场地址"
|
||||
>
|
||||
<Space.Compact style={{ width: '100%' }}>
|
||||
<Input
|
||||
readOnly
|
||||
placeholder="市场地址"
|
||||
/>
|
||||
<Button
|
||||
icon={<CopyOutlined />}
|
||||
onClick={handleCopyUrl}
|
||||
>
|
||||
复制
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</FormItem>
|
||||
|
||||
{/* API Key */}
|
||||
<FormItem
|
||||
name="apiKey"
|
||||
label={
|
||||
<span>
|
||||
API Key <span className="rb:text-gray-400 rb:font-normal">(可选)</span>
|
||||
</span>
|
||||
}
|
||||
extra="部分市场需要 API Key 才能获取完整的服务列表"
|
||||
>
|
||||
<Space.Compact style={{ width: '100%' }}>
|
||||
<Input
|
||||
type={showApiKey ? 'text' : 'password'}
|
||||
placeholder="输入 API Key 以获取更多服务"
|
||||
autoComplete="off"
|
||||
/>
|
||||
<Button
|
||||
icon={showApiKey ? <EyeInvisibleOutlined /> : <EyeOutlined />}
|
||||
onClick={() => setShowApiKey(!showApiKey)}
|
||||
/>
|
||||
</Space.Compact>
|
||||
</FormItem>
|
||||
|
||||
{/* 连接状态 */}
|
||||
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
|
||||
<span className="rb:text-gray-600">连接状态:</span>
|
||||
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
|
||||
{currentSource.connected ? '● 已连接' : '○ 未连接'}
|
||||
</span>
|
||||
</div>
|
||||
</Form>
|
||||
</div>
|
||||
</RbModal>
|
||||
);
|
||||
});
|
||||
|
||||
export default MarketConfigModal;
|
||||
@@ -9,6 +9,7 @@ import RequestHeaderModal from './RequestHeaderModal';
|
||||
import Table from '@/components/Table';
|
||||
import { addTool, updateTool, testConnection } from '@/api/tools'
|
||||
import type { McpServiceModalRef } from '../types'
|
||||
import { stringRegExp } from '@/utils/validator';
|
||||
|
||||
const FormItem = Form.Item;
|
||||
|
||||
@@ -168,14 +169,22 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
name={['config', "server_url"]}
|
||||
label={t('tool.serviceEndpoint')}
|
||||
extra={t('tool.serviceEndpointExtra')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ max: 500 },
|
||||
{ pattern: /^https?:\/\/\S+$/, message: t('tool.serverUrlInvalid') },
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('tool.serviceEndpointPlaceholder')} />
|
||||
</FormItem>
|
||||
<Form.Item
|
||||
name="name"
|
||||
label={t('tool.name')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ max: 50 },
|
||||
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('tool.namePlaceholder')} />
|
||||
</Form.Item>
|
||||
@@ -201,6 +210,7 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
||||
<FormItem
|
||||
name="description"
|
||||
label={t('tool.description')}
|
||||
rules={[{ max: 500 }]}
|
||||
>
|
||||
<Input.TextArea rows={3} placeholder={t('common.inputPlaceholder', { title: t('tool.description') })}/>
|
||||
</FormItem>
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
|
||||
|
||||
import type { RequestHeader, RequestHeaderModalRef } from './McpServiceModal'
|
||||
import RbModal from '@/components/RbModal'
|
||||
import { stringRegExp } from '@/utils/validator';
|
||||
|
||||
const FormItem = Form.Item;
|
||||
|
||||
@@ -82,7 +83,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
|
||||
<FormItem
|
||||
name="key"
|
||||
label={t('tool.requestHeaderName')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ pattern: /^[a-zA-Z0-9][a-zA-Z0-9\-_]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$/, message: t('tool.requestHeaderKeyInvalid') },
|
||||
{ max: 100 }
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('common.enter')} />
|
||||
</FormItem>
|
||||
@@ -90,7 +95,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
|
||||
<FormItem
|
||||
name="value"
|
||||
label={t('tool.requestHeaderValue')}
|
||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
||||
rules={[
|
||||
{ required: true, message: t('common.pleaseEnter') },
|
||||
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||
{ max: 2000 }
|
||||
]}
|
||||
>
|
||||
<Input placeholder={t('common.enter',)} />
|
||||
</FormItem>
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
/*
|
||||
* @Description:
|
||||
* @Version: 0.0.1
|
||||
* @Author: yujiangping
|
||||
* @Date: 2026-01-05 17:22:23
|
||||
* @LastEditors: yujiangping
|
||||
* @LastEditTime: 2026-03-04 15:12:48
|
||||
*/
|
||||
import React, { useState } from 'react';
|
||||
import { Tabs } from 'antd';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
@@ -5,9 +13,10 @@ import { useTranslation } from 'react-i18next';
|
||||
import Mcp from './Mcp';
|
||||
import Inner from './Inner';
|
||||
import Custom from './Custom';
|
||||
import Market from './Market';
|
||||
import Tag from '@/components/Tag'
|
||||
|
||||
const tabKeys = ['mcp', 'inner', 'custom']
|
||||
const tabKeys = ['mcp', 'inner', 'custom', 'market']
|
||||
const ToolManagement: React.FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const [activeTab, setActiveTab] = useState('mcp');
|
||||
@@ -45,6 +54,7 @@ const ToolManagement: React.FC = () => {
|
||||
{activeTab === 'mcp' && <Mcp getStatusTag={getStatusTag} />}
|
||||
{activeTab === 'inner' && <Inner getStatusTag={getStatusTag} />}
|
||||
{activeTab === 'custom' && <Custom getStatusTag={getStatusTag} />}
|
||||
{/* {activeTab === 'market' && <Market getStatusTag={getStatusTag} />} */}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -6,6 +6,7 @@ import {
|
||||
getShortTerm,
|
||||
} from '@/api/memory'
|
||||
import Empty from '@/components/Empty'
|
||||
import Markdown from '@/components/Markdown'
|
||||
|
||||
interface ShortTermItem {
|
||||
retrieval: Array<{ query: string; retrieval: string[]; }>;
|
||||
@@ -85,7 +86,9 @@ const ShortTermDetail: FC = () => {
|
||||
))}
|
||||
<div>
|
||||
<div className="rb:font-medium rb:leading-5 rb:mb-1">{t('shortTermDetail.answer')}</div>
|
||||
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">{vo.answer}</div>
|
||||
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">
|
||||
<Markdown content={vo.answer} />
|
||||
</div>
|
||||
</div>
|
||||
</Space>
|
||||
</div>
|
||||
@@ -103,7 +106,9 @@ const ShortTermDetail: FC = () => {
|
||||
: data.long_term?.map((vo, voIdx) => (
|
||||
<div key={voIdx} className="rb:leading-5 rb:shadow-[inset_3px_0px_0px_0px_#155EEF] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:px-6 rb:py-3">
|
||||
<div className="rb:mb-1 rb:font-medium rb:leading-5.5">{vo.query}</div>
|
||||
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">{vo.retrieval}</div>
|
||||
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">
|
||||
<Markdown content={vo.retrieval} />
|
||||
</div>
|
||||
</div>
|
||||
))
|
||||
}
|
||||
|
||||
@@ -174,8 +174,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
*/
|
||||
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||
data.forEach(item => {
|
||||
const { chunk, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
|
||||
chunk: string;
|
||||
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
|
||||
content: string;
|
||||
conversation_id: string | null;
|
||||
cycle_id: string;
|
||||
cycle_idx: number;
|
||||
@@ -202,7 +202,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
||||
if (lastIndex >= 0) {
|
||||
newList[lastIndex] = {
|
||||
...newList[lastIndex],
|
||||
content: newList[lastIndex].content + chunk
|
||||
content: newList[lastIndex].content + content
|
||||
}
|
||||
}
|
||||
return newList
|
||||
|
||||
Reference in New Issue
Block a user