Merge branch 'develop' into feature/model_zy
This commit is contained in:
3
api/app/cache/__init__.py
vendored
3
api/app/cache/__init__.py
vendored
@@ -3,9 +3,10 @@ Cache 缓存模块
|
|||||||
|
|
||||||
提供各种缓存功能的统一入口
|
提供各种缓存功能的统一入口
|
||||||
"""
|
"""
|
||||||
from .memory import EmotionMemoryCache, ImplicitMemoryCache
|
from .memory import EmotionMemoryCache, ImplicitMemoryCache, InterestMemoryCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmotionMemoryCache",
|
"EmotionMemoryCache",
|
||||||
"ImplicitMemoryCache",
|
"ImplicitMemoryCache",
|
||||||
|
"InterestMemoryCache",
|
||||||
]
|
]
|
||||||
|
|||||||
2
api/app/cache/memory/__init__.py
vendored
2
api/app/cache/memory/__init__.py
vendored
@@ -5,8 +5,10 @@ Memory 缓存模块
|
|||||||
"""
|
"""
|
||||||
from .emotion_memory import EmotionMemoryCache
|
from .emotion_memory import EmotionMemoryCache
|
||||||
from .implicit_memory import ImplicitMemoryCache
|
from .implicit_memory import ImplicitMemoryCache
|
||||||
|
from .interest_memory import InterestMemoryCache
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EmotionMemoryCache",
|
"EmotionMemoryCache",
|
||||||
"ImplicitMemoryCache",
|
"ImplicitMemoryCache",
|
||||||
|
"InterestMemoryCache",
|
||||||
]
|
]
|
||||||
|
|||||||
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
122
api/app/cache/memory/interest_memory.py
vendored
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""
|
||||||
|
Interest Distribution Cache
|
||||||
|
|
||||||
|
兴趣分布缓存模块
|
||||||
|
用于缓存用户的兴趣分布标签数据,避免重复调用模型生成
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from app.aioRedis import aio_redis
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 缓存过期时间:24小时
|
||||||
|
INTEREST_CACHE_EXPIRE = 86400
|
||||||
|
|
||||||
|
|
||||||
|
class InterestMemoryCache:
|
||||||
|
"""兴趣分布缓存类"""
|
||||||
|
|
||||||
|
PREFIX = "cache:memory:interest_distribution"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_key(cls, end_user_id: str, language: str) -> str:
|
||||||
|
"""生成 Redis key
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户ID
|
||||||
|
language: 语言类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整的 Redis key
|
||||||
|
"""
|
||||||
|
return f"{cls.PREFIX}:by_user:{end_user_id}:{language}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def set_interest_distribution(
|
||||||
|
cls,
|
||||||
|
end_user_id: str,
|
||||||
|
language: str,
|
||||||
|
data: List[Dict[str, Any]],
|
||||||
|
expire: int = INTEREST_CACHE_EXPIRE,
|
||||||
|
) -> bool:
|
||||||
|
"""设置用户兴趣分布缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户ID
|
||||||
|
language: 语言类型
|
||||||
|
data: 兴趣分布列表,格式 [{"name": "...", "frequency": ...}, ...]
|
||||||
|
expire: 过期时间(秒),默认24小时
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否设置成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(end_user_id, language)
|
||||||
|
payload = {
|
||||||
|
"data": data,
|
||||||
|
"generated_at": datetime.now().isoformat(),
|
||||||
|
"cached": True,
|
||||||
|
}
|
||||||
|
value = json.dumps(payload, ensure_ascii=False)
|
||||||
|
await aio_redis.set(key, value, ex=expire)
|
||||||
|
logger.info(f"设置兴趣分布缓存成功: {key}, 过期时间: {expire}秒")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置兴趣分布缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def get_interest_distribution(
|
||||||
|
cls,
|
||||||
|
end_user_id: str,
|
||||||
|
language: str,
|
||||||
|
) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""获取用户兴趣分布缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户ID
|
||||||
|
language: 语言类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
兴趣分布列表,缓存不存在或已过期返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(end_user_id, language)
|
||||||
|
value = await aio_redis.get(key)
|
||||||
|
if value:
|
||||||
|
payload = json.loads(value)
|
||||||
|
logger.info(f"命中兴趣分布缓存: {key}")
|
||||||
|
return payload.get("data")
|
||||||
|
logger.info(f"兴趣分布缓存不存在或已过期: {key}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取兴趣分布缓存失败: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def delete_interest_distribution(
|
||||||
|
cls,
|
||||||
|
end_user_id: str,
|
||||||
|
language: str,
|
||||||
|
) -> bool:
|
||||||
|
"""删除用户兴趣分布缓存
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 用户ID
|
||||||
|
language: 语言类型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否删除成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
key = cls._get_key(end_user_id, language)
|
||||||
|
result = await aio_redis.delete(key)
|
||||||
|
logger.info(f"删除兴趣分布缓存: {key}, 结果: {result}")
|
||||||
|
return result > 0
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除兴趣分布缓存失败: {e}", exc_info=True)
|
||||||
|
return False
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
import os
|
import os
|
||||||
import platform
|
import platform
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from celery.schedules import crontab
|
||||||
from urllib.parse import quote
|
from urllib.parse import quote
|
||||||
|
|
||||||
from celery import Celery
|
from celery import Celery
|
||||||
@@ -43,8 +44,8 @@ celery_app.conf.update(
|
|||||||
task_ignore_result=False,
|
task_ignore_result=False,
|
||||||
|
|
||||||
# 超时设置
|
# 超时设置
|
||||||
task_time_limit=1800, # 30分钟硬超时
|
task_time_limit=3600, # 60分钟硬超时
|
||||||
task_soft_time_limit=1500, # 25分钟软超时
|
task_soft_time_limit=3000, # 50分钟软超时
|
||||||
|
|
||||||
# Worker 设置 (per-worker settings are in docker-compose command line)
|
# Worker 设置 (per-worker settings are in docker-compose command line)
|
||||||
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
|
||||||
@@ -90,11 +91,10 @@ celery_app.conf.update(
|
|||||||
celery_app.autodiscover_tasks(['app'])
|
celery_app.autodiscover_tasks(['app'])
|
||||||
|
|
||||||
# Celery Beat schedule for periodic tasks
|
# Celery Beat schedule for periodic tasks
|
||||||
memory_increment_schedule = timedelta(hours=settings.MEMORY_INCREMENT_INTERVAL_HOURS)
|
memory_increment_schedule = crontab(hour=settings.MEMORY_INCREMENT_HOUR, minute=settings.MEMORY_INCREMENT_MINUTE)
|
||||||
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
memory_cache_regeneration_schedule = timedelta(hours=settings.MEMORY_CACHE_REGENERATION_HOURS)
|
||||||
# 这个30秒的设计不合理
|
workspace_reflection_schedule = timedelta(seconds=settings.WORKSPACE_REFLECTION_INTERVAL_SECONDS)
|
||||||
workspace_reflection_schedule = timedelta(seconds=30) # 每30秒运行一次settings.REFLECTION_INTERVAL_TIME
|
forgetting_cycle_schedule = timedelta(hours=settings.FORGETTING_CYCLE_INTERVAL_HOURS)
|
||||||
forgetting_cycle_schedule = timedelta(hours=24) # 每24小时运行一次遗忘周期
|
|
||||||
|
|
||||||
#构建定时任务配置
|
#构建定时任务配置
|
||||||
beat_schedule_config = {
|
beat_schedule_config = {
|
||||||
|
|||||||
@@ -441,14 +441,14 @@ async def retrieve_chunks(
|
|||||||
# 1 participle search, 2 semantic search, 3 hybrid search
|
# 1 participle search, 2 semantic search, 3 hybrid search
|
||||||
match retrieve_data.retrieve_type:
|
match retrieve_data.retrieve_type:
|
||||||
case chunk_schema.RetrieveType.PARTICIPLE:
|
case chunk_schema.RetrieveType.PARTICIPLE:
|
||||||
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
rs = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=rs, msg="retrieval successful")
|
||||||
case chunk_schema.RetrieveType.SEMANTIC:
|
case chunk_schema.RetrieveType.SEMANTIC:
|
||||||
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
rs = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
return success(data=rs, msg="retrieval successful")
|
return success(data=rs, msg="retrieval successful")
|
||||||
case _:
|
case _:
|
||||||
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight)
|
rs1 = vector_service.search_by_vector(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.vector_similarity_weight, file_names_filter=retrieve_data.file_names_filter)
|
||||||
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold)
|
rs2 = vector_service.search_by_full_text(query=retrieve_data.query, top_k=retrieve_data.top_k, indices=indices, score_threshold=retrieve_data.similarity_threshold, file_names_filter=retrieve_data.file_names_filter)
|
||||||
# Efficient deduplication
|
# Efficient deduplication
|
||||||
seen_ids = set()
|
seen_ids = set()
|
||||||
unique_rs = []
|
unique_rs = []
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from app.cache.memory.interest_memory import InterestMemoryCache
|
||||||
from app.celery_app import celery_app
|
from app.celery_app import celery_app
|
||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.language_utils import get_language_from_header
|
from app.core.language_utils import get_language_from_header
|
||||||
@@ -661,34 +662,56 @@ async def get_knowledge_type_stats_api(
|
|||||||
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取知识库类型统计失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/hot_memory_tags/by_user", response_model=ApiResponse)
|
@router.get("/analytics/interest_distribution/by_user", response_model=ApiResponse)
|
||||||
async def get_hot_memory_tags_by_user_api(
|
async def get_interest_distribution_by_user_api(
|
||||||
end_user_id: Optional[str] = Query(None, description="用户ID(可选)"),
|
end_user_id: str = Query(..., description="用户ID(必填)"),
|
||||||
limit: int = Query(20, description="返回标签数量限制"),
|
limit: int = Query(5, le=5, description="返回兴趣标签数量限制,最多5个"),
|
||||||
|
language_type: str = Header(default=None, alias="X-Language-Type"),
|
||||||
current_user: User = Depends(get_current_user),
|
current_user: User = Depends(get_current_user),
|
||||||
db: Session=Depends(get_db),
|
db: Session = Depends(get_db),
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
获取指定用户的热门记忆标签
|
获取指定用户的兴趣分布标签
|
||||||
|
|
||||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习、创作等),
|
||||||
|
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
[
|
[
|
||||||
{"name": "标签名", "frequency": 频次},
|
{"name": "兴趣活动名", "frequency": 频次},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
"""
|
"""
|
||||||
api_logger.info(f"Hot memory tags by user requested: end_user_id={end_user_id}")
|
language = get_language_from_header(language_type)
|
||||||
|
api_logger.info(f"Interest distribution by user requested: end_user_id={end_user_id}, language={language}")
|
||||||
try:
|
try:
|
||||||
result = await memory_agent_service.get_hot_memory_tags_by_user(
|
# 优先读取缓存
|
||||||
|
cached = await InterestMemoryCache.get_interest_distribution(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
limit=limit
|
language=language,
|
||||||
)
|
)
|
||||||
return success(data=result, msg="获取热门记忆标签成功")
|
if cached is not None:
|
||||||
|
api_logger.info(f"Interest distribution cache hit: end_user_id={end_user_id}")
|
||||||
|
return success(data=cached, msg="获取兴趣分布标签成功")
|
||||||
|
|
||||||
|
# 缓存未命中,调用模型生成
|
||||||
|
result = await memory_agent_service.get_interest_distribution_by_user(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
limit=limit,
|
||||||
|
language=language
|
||||||
|
)
|
||||||
|
|
||||||
|
# 写入缓存,24小时过期
|
||||||
|
await InterestMemoryCache.set_interest_distribution(
|
||||||
|
end_user_id=end_user_id,
|
||||||
|
language=language,
|
||||||
|
data=result,
|
||||||
|
)
|
||||||
|
|
||||||
|
return success(data=result, msg="获取兴趣分布标签成功")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
api_logger.error(f"Hot memory tags by user failed: {str(e)}")
|
api_logger.error(f"Interest distribution by user failed: {str(e)}")
|
||||||
return fail(BizCode.INTERNAL_ERROR, "获取热门记忆标签失败", str(e))
|
return fail(BizCode.INTERNAL_ERROR, "获取兴趣分布标签失败", str(e))
|
||||||
|
|
||||||
|
|
||||||
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
@router.get("/analytics/user_profile", response_model=ApiResponse)
|
||||||
|
|||||||
@@ -606,8 +606,8 @@ async def dashboard_data(
|
|||||||
|
|
||||||
# 获取RAG相关数据
|
# 获取RAG相关数据
|
||||||
try:
|
try:
|
||||||
# total_memory: 使用 total_chunk(总chunk数)
|
# total_memory: 只统计用户知识库(permission_id='Memory')的chunk数
|
||||||
total_chunk = memory_dashboard_service.get_rag_total_chunk(db, current_user)
|
total_chunk = memory_dashboard_service.get_rag_user_kb_total_chunk(db, current_user)
|
||||||
rag_data["total_memory"] = total_chunk
|
rag_data["total_memory"] = total_chunk
|
||||||
|
|
||||||
# total_app: 统计当前空间下的所有app数量
|
# total_app: 统计当前空间下的所有app数量
|
||||||
|
|||||||
@@ -249,6 +249,7 @@ async def chat(
|
|||||||
app_id=app.id,
|
app_id=app.id,
|
||||||
workspace_id=workspace_id,
|
workspace_id=workspace_id,
|
||||||
release_id=app.current_release.id,
|
release_id=app.current_release.id,
|
||||||
|
public=True
|
||||||
):
|
):
|
||||||
event_type = event.get("event", "message")
|
event_type = event.get("event", "message")
|
||||||
event_data = event.get("data", {})
|
event_data = event.get("data", {})
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ async def write_memory_api_service(
|
|||||||
|
|
||||||
Stores memory content for the specified end user using the Memory API Service.
|
Stores memory content for the specified end user using the Memory API Service.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, tenant_id: {api_key_auth.tenant_id}")
|
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
|
||||||
|
|
||||||
memory_api_service = MemoryAPIService(db)
|
memory_api_service = MemoryAPIService(db)
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Annotated, Any, Dict, Optional
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import Field, TypeAdapter
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
@@ -200,12 +201,25 @@ class Settings:
|
|||||||
|
|
||||||
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
REFLECTION_INTERVAL_SECONDS: float = float(os.getenv("REFLECTION_INTERVAL_SECONDS", "300"))
|
||||||
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
HEALTH_CHECK_SECONDS: float = float(os.getenv("HEALTH_CHECK_SECONDS", "600"))
|
||||||
MEMORY_INCREMENT_INTERVAL_HOURS: float = float(os.getenv("MEMORY_INCREMENT_INTERVAL_HOURS", "24"))
|
|
||||||
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
REFLECTION_INTERVAL_TIME: Optional[str] = int(os.getenv("REFLECTION_INTERVAL_TIME", 30))
|
||||||
|
|
||||||
# Memory Cache Regeneration Configuration
|
# Memory Cache Regeneration Configuration
|
||||||
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
MEMORY_CACHE_REGENERATION_HOURS: int = int(os.getenv("MEMORY_CACHE_REGENERATION_HOURS", "24"))
|
||||||
|
|
||||||
|
# Celery Beat Schedule Configuration (定时任务执行频率)
|
||||||
|
MEMORY_INCREMENT_HOUR: int = TypeAdapter(
|
||||||
|
Annotated[int, Field(ge=0, le=23, description="cron hour [0, 23]")]
|
||||||
|
).validate_python(int(os.getenv("MEMORY_INCREMENT_HOUR", "2")))
|
||||||
|
MEMORY_INCREMENT_MINUTE: int = TypeAdapter(
|
||||||
|
Annotated[int, Field(ge=0, le=59, description="cron minute [0, 59]")]
|
||||||
|
).validate_python(int(os.getenv("MEMORY_INCREMENT_MINUTE", "0")))
|
||||||
|
WORKSPACE_REFLECTION_INTERVAL_SECONDS: int = TypeAdapter(
|
||||||
|
Annotated[int, Field(ge=1, description="reflection interval in seconds, must be >= 1")]
|
||||||
|
).validate_python(int(os.getenv("WORKSPACE_REFLECTION_INTERVAL_SECONDS", "30")))
|
||||||
|
FORGETTING_CYCLE_INTERVAL_HOURS: int = TypeAdapter(
|
||||||
|
Annotated[int, Field(ge=1, description="forgetting cycle interval in hours, must be >= 1")]
|
||||||
|
).validate_python(int(os.getenv("FORGETTING_CYCLE_INTERVAL_HOURS", "24")))
|
||||||
|
|
||||||
# Memory Module Configuration (internal)
|
# Memory Module Configuration (internal)
|
||||||
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
MEMORY_OUTPUT_DIR: str = os.getenv("MEMORY_OUTPUT_DIR", "logs/memory-output")
|
||||||
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
MEMORY_CONFIG_DIR: str = os.getenv("MEMORY_CONFIG_DIR", "app/core/memory")
|
||||||
@@ -230,7 +244,7 @@ class Settings:
|
|||||||
# General Ontology Type Configuration
|
# General Ontology Type Configuration
|
||||||
# ========================================================================
|
# ========================================================================
|
||||||
# 通用本体文件路径列表(逗号分隔)
|
# 通用本体文件路径列表(逗号分隔)
|
||||||
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "app/core/memory/ontology_services/General_purpose_entity.ttl")
|
GENERAL_ONTOLOGY_FILES: str = os.getenv("GENERAL_ONTOLOGY_FILES", "api/app/core/memory/ontology_services/General_purpose_entity.ttl")
|
||||||
|
|
||||||
# 是否启用通用本体类型功能
|
# 是否启用通用本体类型功能
|
||||||
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
ENABLE_GENERAL_ONTOLOGY_TYPES: bool = os.getenv("ENABLE_GENERAL_ONTOLOGY_TYPES", "true").lower() == "true"
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||||
@@ -16,6 +19,10 @@ class FilteredTags(BaseModel):
|
|||||||
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
"""用于接收LLM筛选后的核心标签列表的模型。"""
|
||||||
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
meaningful_tags: List[str] = Field(..., description="从原始列表中筛选出的具有核心代表意义的名词列表。")
|
||||||
|
|
||||||
|
class InterestTags(BaseModel):
|
||||||
|
"""用于接收LLM筛选后的兴趣活动标签列表的模型。"""
|
||||||
|
interest_tags: List[str] = Field(..., description="从原始列表中筛选出的代表用户兴趣活动的标签列表。")
|
||||||
|
|
||||||
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
||||||
"""
|
"""
|
||||||
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
使用LLM筛选标签列表,仅保留具有代表性的核心名词。
|
||||||
@@ -85,10 +92,74 @@ async def filter_tags_with_llm(tags: List[str], end_user_id: str) -> List[str]:
|
|||||||
return structured_response.meaningful_tags
|
return structured_response.meaningful_tags
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"LLM筛选过程中发生错误: {e}")
|
logger.error(f"LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||||
# 在LLM失败时返回原始标签,确保流程继续
|
# 在LLM失败时返回原始标签,确保流程继续
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
|
async def filter_interests_with_llm(tags: List[str], end_user_id: str, language: str = "zh") -> List[str]:
|
||||||
|
"""
|
||||||
|
使用LLM从标签列表中筛选出代表用户兴趣活动的标签。
|
||||||
|
|
||||||
|
与 filter_tags_with_llm 不同,此函数专注于识别"活动/行为"类兴趣,
|
||||||
|
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tags: 原始标签列表
|
||||||
|
end_user_id: 用户ID,用于获取LLM配置
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
筛选后的兴趣活动标签列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with get_db_context() as db:
|
||||||
|
from app.services.memory_agent_service import (
|
||||||
|
get_end_user_connected_config,
|
||||||
|
)
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
|
||||||
|
if not config_id and not workspace_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"No memory_config_id found for end_user_id: {end_user_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not memory_config.llm_model_id:
|
||||||
|
raise ValueError(
|
||||||
|
f"No llm_model_id found in memory config {config_id}."
|
||||||
|
)
|
||||||
|
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
llm_client = factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
|
||||||
|
tag_list_str = ", ".join(tags)
|
||||||
|
from app.core.memory.utils.prompt.prompt_utils import render_interest_filter_prompt
|
||||||
|
rendered_prompt = render_interest_filter_prompt(tag_list_str, language=language)
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": rendered_prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
structured_response = await llm_client.response_structured(
|
||||||
|
messages=messages,
|
||||||
|
response_model=InterestTags
|
||||||
|
)
|
||||||
|
|
||||||
|
return structured_response.interest_tags
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"兴趣标签LLM筛选过程中发生错误: {e}", exc_info=True)
|
||||||
|
return tags
|
||||||
|
|
||||||
|
|
||||||
async def get_raw_tags_from_db(
|
async def get_raw_tags_from_db(
|
||||||
connector: Neo4jConnector,
|
connector: Neo4jConnector,
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
@@ -183,3 +254,56 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
|||||||
finally:
|
finally:
|
||||||
# 确保关闭连接
|
# 确保关闭连接
|
||||||
await connector.close()
|
await connector.close()
|
||||||
|
|
||||||
|
async def get_interest_distribution(end_user_id: str, limit: int = 10, by_user: bool = False, language: str = "zh") -> List[Tuple[str, int]]:
|
||||||
|
"""
|
||||||
|
获取用户的兴趣分布标签。
|
||||||
|
|
||||||
|
与 get_hot_memory_tags 不同,此函数使用专门针对"活动/行为"的LLM prompt,
|
||||||
|
过滤掉纯物品、工具、地点等,只保留能代表用户兴趣爱好的活动类标签。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||||
|
limit: 最终返回的标签数量限制(默认10)
|
||||||
|
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: 如果end_user_id未提供或为空
|
||||||
|
"""
|
||||||
|
if not end_user_id or not end_user_id.strip():
|
||||||
|
raise ValueError(
|
||||||
|
"end_user_id is required. Please provide a valid end_user_id or user_id."
|
||||||
|
)
|
||||||
|
|
||||||
|
connector = Neo4jConnector()
|
||||||
|
try:
|
||||||
|
# 查询更多原始标签,给LLM提供充足上下文
|
||||||
|
query_limit = 40
|
||||||
|
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||||
|
if not raw_tags_with_freq:
|
||||||
|
return []
|
||||||
|
|
||||||
|
raw_tag_names = [tag for tag, freq in raw_tags_with_freq]
|
||||||
|
raw_freq_map = {tag: freq for tag, freq in raw_tags_with_freq}
|
||||||
|
|
||||||
|
# 使用兴趣活动专用prompt进行筛选(支持语义推断出新标签)
|
||||||
|
interest_tag_names = await filter_interests_with_llm(raw_tag_names, end_user_id, language=language)
|
||||||
|
|
||||||
|
# 构建最终标签列表:
|
||||||
|
# - 原始标签中存在的,保留原始频率
|
||||||
|
# - LLM推断出的新标签(不在原始列表中),赋予默认频率1
|
||||||
|
final_tags = []
|
||||||
|
seen = set()
|
||||||
|
for tag in interest_tag_names:
|
||||||
|
if tag in seen:
|
||||||
|
continue
|
||||||
|
seen.add(tag)
|
||||||
|
freq = raw_freq_map.get(tag, 1)
|
||||||
|
final_tags.append((tag, freq))
|
||||||
|
|
||||||
|
# 按频率降序排列
|
||||||
|
final_tags.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
return final_tags[:limit]
|
||||||
|
finally:
|
||||||
|
await connector.close()
|
||||||
|
|||||||
@@ -548,3 +548,20 @@ async def render_ontology_extraction_prompt(
|
|||||||
})
|
})
|
||||||
|
|
||||||
return rendered_prompt
|
return rendered_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def render_interest_filter_prompt(tag_list: str, language: str = "zh") -> str:
|
||||||
|
"""
|
||||||
|
Renders the interest filter prompt using the interest_filter.jinja2 template.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tag_list: Comma-separated string of raw tags to filter
|
||||||
|
language: Output language ("zh" for Chinese, "en" for English)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rendered prompt content as string
|
||||||
|
"""
|
||||||
|
template = prompt_env.get_template("interest_filter.jinja2")
|
||||||
|
rendered_prompt = template.render(tag_list=tag_list, language=language)
|
||||||
|
log_prompt_rendering('interest filter', rendered_prompt)
|
||||||
|
return rendered_prompt
|
||||||
|
|||||||
@@ -0,0 +1,67 @@
|
|||||||
|
{% if language == "zh" %}
|
||||||
|
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||||
|
|
||||||
|
**Step 1 - Infer the underlying interest from each tag**:
|
||||||
|
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||||
|
|
||||||
|
Examples of inference:
|
||||||
|
- '攀岩', '室内攀岩馆', '攀岩者数据仪表盘', '路线解锁地图', '指力', '路线等级', '当日攀岩流畅度' → '攀岩'
|
||||||
|
- '风光摄影元数据增强器', 'EXIF数据', '.CR2文件', '.NEF文件', '日出拍摄点', '曝光补偿', '光圈', '太阳高度角', '云量预测图层' → '摄影'
|
||||||
|
- '晨间冥想坚持天数', '身心协同峰值' → '冥想'
|
||||||
|
- '川味可视化', '川菜' → '烹饪'
|
||||||
|
- '开源项目命名建议', 'climbviz', '可视化', '力量增长雷达图' → '编程' 或 '数据可视化'
|
||||||
|
- '吉他', '指弹', '琴谱' → '吉他'
|
||||||
|
- '跑步', '5公里', '跑鞋' → '跑步'
|
||||||
|
- '瑜伽垫', '瑜伽课' → '瑜伽'
|
||||||
|
|
||||||
|
**Step 2 - Consolidate and deduplicate**:
|
||||||
|
- Merge tags that point to the same interest into one representative label
|
||||||
|
- Use concise, standard hobby names (e.g., '攀岩', '摄影', '编程', '烹饪', '冥想', '吉他', '跑步')
|
||||||
|
- If multiple tags all point to '攀岩', output '攀岩' only once
|
||||||
|
|
||||||
|
**Step 3 - Filter out non-interest tags**:
|
||||||
|
Remove tags that do NOT suggest any hobby or interest:
|
||||||
|
- Generic system/assistant terms (e.g., '助手', '用户', 'AI')
|
||||||
|
- Pure abstract metrics with no clear hobby link (e.g., '完成时间', '日期', '自我评分')
|
||||||
|
- Location names with no clear hobby link (e.g., '青城山后山' alone — but if combined with photography context, infer '摄影')
|
||||||
|
|
||||||
|
**Output format**: Return a list of concise interest activity names in Chinese.
|
||||||
|
|
||||||
|
**Example**:
|
||||||
|
Input: ['攀岩', '攀岩者数据仪表盘', '路线解锁地图', '指力', '风光摄影元数据增强器', 'EXIF数据', '晨间冥想坚持天数', '川味可视化', '可视化', '助手', '完成时间']
|
||||||
|
Output: ['攀岩', '摄影', '冥想', '烹饪', '编程']
|
||||||
|
|
||||||
|
Now process the following tag list and return the inferred interest activities in Chinese: {{ tag_list }}
|
||||||
|
{% else %}
|
||||||
|
You are a user interest analysis expert. Your task is to infer and extract the user's core hobby/interest activities from a tag list. The tags may be specific project names, tool names, or compound nouns — your job is to identify the underlying interest they represent.
|
||||||
|
|
||||||
|
**Step 1 - Infer the underlying interest from each tag**:
|
||||||
|
Look at each tag and ask: "What hobby or interest does this tag suggest the user has?"
|
||||||
|
|
||||||
|
Examples of inference:
|
||||||
|
- 'rock climbing', 'indoor climbing gym', 'climber dashboard', 'route map', 'finger strength' → 'rock climbing'
|
||||||
|
- 'landscape photography metadata enhancer', 'EXIF data', 'sunrise shooting spot', 'exposure compensation' → 'photography'
|
||||||
|
- 'morning meditation streak', 'mind-body peak' → 'meditation'
|
||||||
|
- 'Sichuan cuisine visualization', 'Sichuan food' → 'cooking'
|
||||||
|
- 'open source project', 'data visualization tool', 'Python' → 'programming'
|
||||||
|
- 'guitar', 'fingerpicking', 'sheet music' → 'guitar'
|
||||||
|
- 'running', '5km', 'running shoes' → 'running'
|
||||||
|
|
||||||
|
**Step 2 - Consolidate and deduplicate**:
|
||||||
|
- Merge tags that point to the same interest into one representative label
|
||||||
|
- Use concise, standard hobby names (e.g., 'rock climbing', 'photography', 'programming', 'cooking', 'meditation')
|
||||||
|
- If multiple tags all point to 'rock climbing', output 'rock climbing' only once
|
||||||
|
|
||||||
|
**Step 3 - Filter out non-interest tags**:
|
||||||
|
Remove tags that do NOT suggest any hobby or interest:
|
||||||
|
- Generic system/assistant terms (e.g., 'assistant', 'user', 'AI')
|
||||||
|
- Pure abstract metrics with no clear hobby link (e.g., 'completion time', 'date', 'self-rating')
|
||||||
|
|
||||||
|
**Output format**: Return a list of concise interest activity names in English.
|
||||||
|
|
||||||
|
**Example**:
|
||||||
|
Input: ['rock climbing', 'climber dashboard', 'route map', 'finger strength', 'landscape photography metadata enhancer', 'EXIF data', 'morning meditation streak', 'Sichuan cuisine visualization', 'visualization', 'assistant', 'completion time']
|
||||||
|
Output: ['rock climbing', 'photography', 'meditation', 'cooking', 'programming']
|
||||||
|
|
||||||
|
Now process the following tag list and return the inferred interest activities in English: {{ tag_list }}
|
||||||
|
{% endif %}
|
||||||
@@ -127,7 +127,7 @@ class EventStreamHandler:
|
|||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
"chunk": data.get("chunk")
|
"content": data.get("chunk")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -274,7 +274,7 @@ class StreamOutputCoordinator:
|
|||||||
yield {
|
yield {
|
||||||
"event": "message",
|
"event": "message",
|
||||||
"data": {
|
"data": {
|
||||||
"chunk": final_chunk
|
"content": final_chunk
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -272,7 +272,7 @@ class WorkflowExecutor:
|
|||||||
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
event_type = data.get("type", "node_chunk") # "message" or "node_chunk"
|
||||||
if event_type == "node_chunk":
|
if event_type == "node_chunk":
|
||||||
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
async for msg_event in self.event_handler.handle_node_chunk_event(data):
|
||||||
full_content += msg_event["data"]["chunk"]
|
full_content += msg_event["data"]["content"]
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
elif event_type == "node_error":
|
elif event_type == "node_error":
|
||||||
@@ -295,12 +295,12 @@ class WorkflowExecutor:
|
|||||||
self.graph,
|
self.graph,
|
||||||
self.execution_context.checkpoint_config
|
self.execution_context.checkpoint_config
|
||||||
):
|
):
|
||||||
full_content += msg_event["data"]['chunk']
|
full_content += msg_event["data"]['content']
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
# Flush any remaining chunks
|
# Flush any remaining chunks
|
||||||
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
async for msg_event in self.stream_coordinator.flush_remaining_chunk(self.variable_pool):
|
||||||
full_content += msg_event["data"]['chunk']
|
full_content += msg_event["data"]['content']
|
||||||
yield msg_event
|
yield msg_event
|
||||||
|
|
||||||
result = graph.get_state(self.execution_context.checkpoint_config).values
|
result = graph.get_state(self.execution_context.checkpoint_config).values
|
||||||
|
|||||||
@@ -211,3 +211,46 @@ def get_total_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
db_logger.error(f"Failed to query total knowledge base count: workspace_id={workspace_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_kb_chunk_num_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||||
|
"""
|
||||||
|
根据workspace_id查询knowledges表中permission_id='Memory'(用户知识库)的chunk_num总和
|
||||||
|
"""
|
||||||
|
db_logger.debug(f"Query user KB chunk_num by workspace_id: workspace_id={workspace_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from sqlalchemy import func
|
||||||
|
result = db.query(func.sum(Knowledge.chunk_num)).filter(
|
||||||
|
Knowledge.workspace_id == workspace_id,
|
||||||
|
Knowledge.status == 1,
|
||||||
|
Knowledge.permission_id == "Memory"
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
total = result if result is not None else 0
|
||||||
|
db_logger.info(f"User KB chunk_num query successful: workspace_id={workspace_id}, total={total}")
|
||||||
|
return total
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query user KB chunk_num: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_non_user_kb_count_by_workspace(db: Session, workspace_id: uuid.UUID) -> int:
|
||||||
|
"""
|
||||||
|
根据workspace_id查询knowledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||||
|
"""
|
||||||
|
db_logger.debug(f"Query non-user KB count by workspace_id: workspace_id={workspace_id}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
count = db.query(Knowledge).filter(
|
||||||
|
Knowledge.workspace_id == workspace_id,
|
||||||
|
Knowledge.status == 1,
|
||||||
|
Knowledge.permission_id != "Memory"
|
||||||
|
).count()
|
||||||
|
|
||||||
|
db_logger.info(f"Non-user KB count query successful: workspace_id={workspace_id}, count={count}")
|
||||||
|
return count
|
||||||
|
except Exception as e:
|
||||||
|
db_logger.error(f"Failed to query non-user KB count: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class ChunkUpdate(BaseModel):
|
|||||||
class ChunkRetrieve(BaseModel):
|
class ChunkRetrieve(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
kb_ids: list[uuid.UUID]
|
kb_ids: list[uuid.UUID]
|
||||||
|
file_names_filter: list[str] | None = Field(None)
|
||||||
similarity_threshold: float | None = Field(None)
|
similarity_threshold: float | None = Field(None)
|
||||||
vector_similarity_weight: float | None = Field(None)
|
vector_similarity_weight: float | None = Field(None)
|
||||||
top_k: int | None = Field(None)
|
top_k: int | None = Field(None)
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from app.core.memory.agent.utils.messages_tools import (
|
|||||||
)
|
)
|
||||||
from app.core.memory.agent.utils.type_classifier import status_typle
|
from app.core.memory.agent.utils.type_classifier import status_typle
|
||||||
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
from app.core.memory.agent.utils.write_tools import write # 新增:直接导入 write 函数
|
||||||
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags
|
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags, get_interest_distribution
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
from app.db import get_db_context
|
from app.db import get_db_context
|
||||||
from app.models.knowledge_model import Knowledge, KnowledgeType
|
from app.models.knowledge_model import Knowledge, KnowledgeType
|
||||||
@@ -890,36 +890,36 @@ class MemoryAgentService:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
async def get_hot_memory_tags_by_user(
|
|
||||||
|
async def get_interest_distribution_by_user(
|
||||||
self,
|
self,
|
||||||
end_user_id: Optional[str] = None,
|
end_user_id: Optional[str] = None,
|
||||||
limit: int = 20
|
limit: int = 5,
|
||||||
|
language: str = "zh"
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定用户的热门记忆标签
|
获取指定用户的兴趣分布标签。
|
||||||
|
|
||||||
|
与热门标签不同,此接口专注于识别用户的兴趣活动(运动、爱好、学习等),
|
||||||
|
过滤掉纯物品、工具、地点等不代表用户主动参与活动的名词。
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- end_user_id: 用户ID(可选),对应Neo4j中的end_user_id字段
|
- end_user_id: 用户ID(必填)
|
||||||
- limit: 返回标签数量限制
|
- limit: 返回标签数量限制
|
||||||
|
- language: 输出语言("zh" 中文, "en" 英文)
|
||||||
|
|
||||||
返回格式:
|
返回格式:
|
||||||
[
|
[
|
||||||
{"name": "标签名", "frequency": 频次},
|
{"name": "兴趣活动名", "frequency": 频次},
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
|
|
||||||
注意:标签语言由写入时的 X-Language-Type 决定,查询时不进行翻译
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# by_user=False 表示按 end_user_id 查询(在Neo4j中,end_user_id就是用户维度)
|
tags = await get_interest_distribution(end_user_id, limit=limit, by_user=False, language=language)
|
||||||
tags = await get_hot_memory_tags(end_user_id, limit=limit, by_user=False)
|
return [{"name": tag, "frequency": freq} for tag, freq in tags]
|
||||||
payload = []
|
|
||||||
for tag, freq in tags:
|
|
||||||
payload.append({"name": tag, "frequency": freq})
|
|
||||||
return payload
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"热门记忆标签查询失败: {e}")
|
logger.error(f"兴趣分布标签查询失败: {e}")
|
||||||
raise Exception(f"热门记忆标签查询失败: {e}")
|
raise Exception(f"兴趣分布标签查询失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
async def get_user_profile(
|
async def get_user_profile(
|
||||||
|
|||||||
@@ -140,9 +140,11 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Delegate to MemoryAgentService
|
# Delegate to MemoryAgentService
|
||||||
|
# Convert string message to list[dict] format expected by MemoryAgentService
|
||||||
|
messages = message if isinstance(message, list) else [{"role": "user", "content": message}]
|
||||||
result = await MemoryAgentService().write_memory(
|
result = await MemoryAgentService().write_memory(
|
||||||
end_user_id=end_user_id,
|
end_user_id=end_user_id,
|
||||||
messages=message,
|
messages=messages,
|
||||||
config_id=config_id,
|
config_id=config_id,
|
||||||
db=self.db,
|
db=self.db,
|
||||||
storage_type=storage_type,
|
storage_type=storage_type,
|
||||||
@@ -151,9 +153,18 @@ class MemoryAPIService:
|
|||||||
|
|
||||||
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
logger.info(f"Memory write successful for end_user: {end_user_id}")
|
||||||
|
|
||||||
|
# result may be a string "success" or a dict with a "status" key
|
||||||
|
# Preserve the full dict so callers don't silently lose extra fields
|
||||||
|
# (e.g. error codes, metadata) returned by MemoryAgentService.
|
||||||
|
if isinstance(result, dict):
|
||||||
|
return {
|
||||||
|
**result,
|
||||||
|
"status": result.get("status", "unknown"),
|
||||||
|
"end_user_id": end_user_id,
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
"status": "success" if result == "success" else result,
|
"status": result if isinstance(result, str) else "success",
|
||||||
"end_user_id": end_user_id
|
"end_user_id": end_user_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
except ConfigurationError as e:
|
except ConfigurationError as e:
|
||||||
|
|||||||
@@ -390,19 +390,59 @@ def get_rag_total_kb(
|
|||||||
current_user: User
|
current_user: User
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
根据当前用户所在的workspace_id查询konwledges表所有不同id的数量
|
根据当前用户所在的workspace_id查询konwledges表中排除用户知识库(permission_id!='Memory')的数量
|
||||||
"""
|
"""
|
||||||
workspace_id = current_user.current_workspace_id
|
workspace_id = current_user.current_workspace_id
|
||||||
business_logger.info(f"获取RAG总知识库数: workspace_id={workspace_id}, 操作者: {current_user.username}")
|
business_logger.info(f"获取RAG总知识库数(排除用户知识库): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
total_kb = knowledge_repository.get_total_kb_count_by_workspace(db, workspace_id)
|
total_kb = knowledge_repository.get_non_user_kb_count_by_workspace(db, workspace_id)
|
||||||
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
business_logger.info(f"成功获取RAG总知识库数: {total_kb}")
|
||||||
return total_kb
|
return total_kb
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
business_logger.error(f"获取RAG总知识库数失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_user_kb_total_chunk(
|
||||||
|
db: Session,
|
||||||
|
current_user: User
|
||||||
|
) -> int:
|
||||||
|
"""
|
||||||
|
根据当前用户所在的workspace_id,从documents表统计所有用户知识库的chunk总数。
|
||||||
|
与 /end_users 接口保持同源:查询 file_name 匹配 end_user_id.txt 的文档 chunk_num 之和。
|
||||||
|
"""
|
||||||
|
workspace_id = current_user.current_workspace_id
|
||||||
|
business_logger.info(f"获取用户知识库总chunk数(documents表): workspace_id={workspace_id}, 操作者: {current_user.username}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.document_model import Document
|
||||||
|
from app.models.end_user_model import EndUser
|
||||||
|
from app.models.app_model import App
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
# 通过 App 关联取该 workspace 下所有 end_user_id
|
||||||
|
end_user_ids = [
|
||||||
|
str(eid) for (eid,) in db.query(EndUser.id)
|
||||||
|
.join(App, EndUser.app_id == App.id)
|
||||||
|
.filter(App.workspace_id == workspace_id)
|
||||||
|
.all()
|
||||||
|
]
|
||||||
|
if not end_user_ids:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
file_names = [f"{uid}.txt" for uid in end_user_ids]
|
||||||
|
result = db.query(func.sum(Document.chunk_num)).filter(
|
||||||
|
Document.file_name.in_(file_names)
|
||||||
|
).scalar()
|
||||||
|
|
||||||
|
total_chunk = int(result or 0)
|
||||||
|
business_logger.info(f"成功获取用户知识库总chunk数: {total_chunk}")
|
||||||
|
return total_chunk
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.error(f"获取用户知识库总chunk数失败: workspace_id={workspace_id} - {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
def get_current_user_total_chunk(
|
def get_current_user_total_chunk(
|
||||||
end_user_id: str,
|
end_user_id: str,
|
||||||
db: Session,
|
db: Session,
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from sqlalchemy.orm import Session
|
|||||||
from app.core.error_codes import BizCode
|
from app.core.error_codes import BizCode
|
||||||
from app.core.exceptions import BusinessException
|
from app.core.exceptions import BusinessException
|
||||||
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
from app.core.workflow.adapters.registry import PlatformAdapterRegistry
|
||||||
|
from app.core.workflow.executor import execute_workflow, execute_workflow_stream
|
||||||
from app.core.workflow.nodes.enums import NodeType
|
from app.core.workflow.nodes.enums import NodeType
|
||||||
from app.core.workflow.validator import validate_workflow_config
|
from app.core.workflow.validator import validate_workflow_config
|
||||||
from app.db import get_db
|
from app.db import get_db
|
||||||
@@ -23,7 +24,7 @@ from app.repositories.workflow_repository import (
|
|||||||
WorkflowExecutionRepository,
|
WorkflowExecutionRepository,
|
||||||
WorkflowNodeExecutionRepository
|
WorkflowNodeExecutionRepository
|
||||||
)
|
)
|
||||||
from app.schemas import DraftRunRequest
|
from app.schemas import DraftRunRequest, FileInput
|
||||||
from app.services.conversation_service import ConversationService
|
from app.services.conversation_service import ConversationService
|
||||||
from app.services.multi_agent_service import convert_uuids_to_str
|
from app.services.multi_agent_service import convert_uuids_to_str
|
||||||
from app.services.multimodal_service import MultimodalService
|
from app.services.multimodal_service import MultimodalService
|
||||||
@@ -445,6 +446,91 @@ class WorkflowService:
|
|||||||
"success_rate": completed / total if total > 0 else 0
|
"success_rate": completed / total if total > 0 else 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def _handle_file_input(self, files: list[FileInput]):
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
files_struct = []
|
||||||
|
for file in files:
|
||||||
|
files_struct.append(
|
||||||
|
{
|
||||||
|
"type": file.type,
|
||||||
|
"url": await self.multimodal_service.get_file_url(file),
|
||||||
|
"__file": True
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return files_struct
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _map_public_event(event: dict) -> dict | None:
|
||||||
|
"""
|
||||||
|
Map internal workflow events to public-facing event formats.
|
||||||
|
|
||||||
|
Purpose:
|
||||||
|
- Hide internal execution details
|
||||||
|
- Expose a stable and simplified public event schema
|
||||||
|
- Filter out non-public events
|
||||||
|
- Maintain backward compatibility when possible
|
||||||
|
|
||||||
|
Args:
|
||||||
|
event (dict): Internal event object, e.g.:
|
||||||
|
{
|
||||||
|
"event": "workflow_start",
|
||||||
|
"data": {...}
|
||||||
|
}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict | None:
|
||||||
|
- Returns the mapped public event
|
||||||
|
- Returns None if the event should not be exposed
|
||||||
|
"""
|
||||||
|
event_type = event.get("event")
|
||||||
|
payload = event.get("data")
|
||||||
|
match event_type:
|
||||||
|
case "workflow_start":
|
||||||
|
return {
|
||||||
|
"event": "start",
|
||||||
|
"data": {
|
||||||
|
"conversation_id": payload.get("conversation_id"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "workflow_end":
|
||||||
|
return {
|
||||||
|
"event": "end",
|
||||||
|
"data": {
|
||||||
|
"elapsed_time": payload.get("elapsed_time"),
|
||||||
|
"message_length": len(payload.get("output", "")),
|
||||||
|
"error": payload.get("error", "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
||||||
|
return None
|
||||||
|
case _:
|
||||||
|
return event
|
||||||
|
|
||||||
|
def _emit(self, public: bool, internal_event: dict):
|
||||||
|
"""
|
||||||
|
Unified event emission entry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
public (bool):
|
||||||
|
- True -> Emit mapped public event
|
||||||
|
- False -> Emit raw internal event
|
||||||
|
|
||||||
|
internal_event (dict):
|
||||||
|
The original internal event object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict | None:
|
||||||
|
- The mapped event
|
||||||
|
- Or None if the event is filtered out
|
||||||
|
"""
|
||||||
|
if public:
|
||||||
|
mapped = self._map_public_event(internal_event)
|
||||||
|
else:
|
||||||
|
mapped = internal_event
|
||||||
|
return mapped
|
||||||
|
|
||||||
# ==================== 工作流执行 ====================
|
# ==================== 工作流执行 ====================
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
@@ -479,10 +565,11 @@ class WorkflowService:
|
|||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data = {"message": payload.message, "variables": payload.variables,
|
input_data = {
|
||||||
"conversation_id": payload.conversation_id,
|
"message": payload.message, "variables": payload.variables,
|
||||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
"conversation_id": payload.conversation_id,
|
||||||
}
|
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||||
|
}
|
||||||
|
|
||||||
# 转换 conversation_id 为 UUID
|
# 转换 conversation_id 为 UUID
|
||||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||||
@@ -506,22 +593,8 @@ class WorkflowService:
|
|||||||
"execution_config": config.execution_config
|
"execution_config": config.execution_config
|
||||||
}
|
}
|
||||||
|
|
||||||
# 4. 获取工作空间 ID(从 app 获取)
|
|
||||||
|
|
||||||
# 5. 执行工作流
|
|
||||||
from app.core.workflow.executor import execute_workflow
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
files = []
|
files = await self._handle_file_input(payload.files)
|
||||||
if payload.files:
|
|
||||||
for file in payload.files:
|
|
||||||
files.append(
|
|
||||||
{
|
|
||||||
"type": file.type,
|
|
||||||
"url": await self.multimodal_service.get_file_url(file),
|
|
||||||
"__file": True
|
|
||||||
}
|
|
||||||
)
|
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
# 更新状态为运行中
|
# 更新状态为运行中
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
@@ -601,42 +674,6 @@ class WorkflowService:
|
|||||||
message=f"工作流执行失败: {str(e)}"
|
message=f"工作流执行失败: {str(e)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _map_public_event(event: dict) -> dict | None:
|
|
||||||
event_type = event.get("event")
|
|
||||||
payload = event.get("data")
|
|
||||||
match event_type:
|
|
||||||
case "workflow_start":
|
|
||||||
return {
|
|
||||||
"event": "start",
|
|
||||||
"data": {
|
|
||||||
"conversation_id": payload.get("conversation_id"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "workflow_end":
|
|
||||||
return {
|
|
||||||
"event": "end",
|
|
||||||
"data": {
|
|
||||||
"elapsed_time": payload.get("elapsed_time"),
|
|
||||||
"message_length": len(payload.get("output", "")),
|
|
||||||
"error": payload.get("error", "")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "node_start" | "node_end" | "node_error" | "cycle_item":
|
|
||||||
return None
|
|
||||||
case _:
|
|
||||||
return event
|
|
||||||
|
|
||||||
def _emit(self, public: bool, internal_event: dict):
|
|
||||||
"""
|
|
||||||
decide
|
|
||||||
"""
|
|
||||||
if public:
|
|
||||||
mapped = self._map_public_event(internal_event)
|
|
||||||
else:
|
|
||||||
mapped = internal_event
|
|
||||||
return mapped
|
|
||||||
|
|
||||||
async def run_stream(
|
async def run_stream(
|
||||||
self,
|
self,
|
||||||
app_id: uuid.UUID,
|
app_id: uuid.UUID,
|
||||||
@@ -671,10 +708,11 @@ class WorkflowService:
|
|||||||
message=f"工作流配置不存在: app_id={app_id}"
|
message=f"工作流配置不存在: app_id={app_id}"
|
||||||
)
|
)
|
||||||
|
|
||||||
input_data = {"message": payload.message, "variables": payload.variables,
|
input_data = {
|
||||||
"conversation_id": payload.conversation_id,
|
"message": payload.message, "variables": payload.variables,
|
||||||
"files": [file.model_dump(mode='json') for file in payload.files]
|
"conversation_id": payload.conversation_id,
|
||||||
}
|
"files": [file.model_dump(mode='json') for file in payload.files]
|
||||||
|
}
|
||||||
|
|
||||||
# 转换 conversation_id 为 UUID
|
# 转换 conversation_id 为 UUID
|
||||||
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
conversation_id_uuid = uuid.UUID(payload.conversation_id) if payload.conversation_id else None
|
||||||
@@ -699,16 +737,7 @@ class WorkflowService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
files = []
|
files = await self._handle_file_input(payload.files)
|
||||||
if payload.files:
|
|
||||||
for file in payload.files:
|
|
||||||
files.append(
|
|
||||||
{
|
|
||||||
"type": file.type,
|
|
||||||
"url": await self.multimodal_service.get_file_url(file),
|
|
||||||
"__file": True
|
|
||||||
}
|
|
||||||
)
|
|
||||||
input_data["files"] = files
|
input_data["files"] = files
|
||||||
self.update_execution_status(execution.execution_id, "running")
|
self.update_execution_status(execution.execution_id, "running")
|
||||||
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid)
|
||||||
@@ -723,7 +752,6 @@ class WorkflowService:
|
|||||||
input_data["conv_messages"] = last_state.get("messages") or []
|
input_data["conv_messages"] = last_state.get("messages") or []
|
||||||
break
|
break
|
||||||
init_message_length = len(input_data.get("conv_messages", []))
|
init_message_length = len(input_data.get("conv_messages", []))
|
||||||
from app.core.workflow.executor import execute_workflow_stream
|
|
||||||
|
|
||||||
async for event in execute_workflow_stream(
|
async for event in execute_workflow_stream(
|
||||||
workflow_config=workflow_config_dict,
|
workflow_config=workflow_config_dict,
|
||||||
@@ -789,37 +817,6 @@ class WorkflowService:
|
|||||||
return node.get("config", {}).get("variables", [])
|
return node.get("config", {}).get("variables", [])
|
||||||
raise BusinessException("workflow config error - start node not found")
|
raise BusinessException("workflow config error - start node not found")
|
||||||
|
|
||||||
def _clean_event_for_json(self, event: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""清理事件数据,移除不可序列化的对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event: 原始事件数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
可序列化的事件数据
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
def clean_value(value):
|
|
||||||
"""递归清理值"""
|
|
||||||
if isinstance(value, BaseMessage):
|
|
||||||
# 将 Message 对象转换为字典
|
|
||||||
return {
|
|
||||||
"type": value.__class__.__name__,
|
|
||||||
"content": value.content,
|
|
||||||
}
|
|
||||||
elif isinstance(value, dict):
|
|
||||||
return {k: clean_value(v) for k, v in value.items()}
|
|
||||||
elif isinstance(value, list):
|
|
||||||
return [clean_value(item) for item in value]
|
|
||||||
elif isinstance(value, (str, int, float, bool, type(None))):
|
|
||||||
return value
|
|
||||||
else:
|
|
||||||
# 其他不可序列化的对象转换为字符串
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
return clean_value(event)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== 依赖注入函数 ====================
|
# ==================== 依赖注入函数 ====================
|
||||||
|
|
||||||
|
|||||||
@@ -257,7 +257,7 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task result for task {task}:\n{result}\n"
|
||||||
return result
|
return result
|
||||||
|
|
||||||
try:
|
def sync_task():
|
||||||
trio.run(
|
trio.run(
|
||||||
lambda: _run(
|
lambda: _run(
|
||||||
row=task,
|
row=task,
|
||||||
@@ -272,6 +272,10 @@ def parse_document(file_path: str, document_id: uuid.UUID):
|
|||||||
with_community=with_community,
|
with_community=with_community,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
try:
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
future = executor.submit(sync_task)
|
||||||
|
future.result() # Blocks until the task completes
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} GraphRAG task failed for task {task}:\n{str(e)}\n"
|
||||||
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
progress_msg += f"{datetime.now().strftime('%H:%M:%S')} Knowledge Graph done ({time.time() - start_time}s)"
|
||||||
|
|||||||
@@ -139,7 +139,7 @@ SMTP_USER=
|
|||||||
SMTP_PASSWORD=
|
SMTP_PASSWORD=
|
||||||
|
|
||||||
# 本体类型融合配置 (记得写入env_example)
|
# 本体类型融合配置 (记得写入env_example)
|
||||||
GENERAL_ONTOLOGY_FILES=app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
GENERAL_ONTOLOGY_FILES=api/app/core/memory/ontology_services/General_purpose_entity.ttl # 指定要加载的本体文件路径,多个文件用逗号分隔
|
||||||
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
ENABLE_GENERAL_ONTOLOGY_TYPES=true # 总开关,控制是否启用通用本体类型融合功能(false = 不使用任何本体类型指导)
|
||||||
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
|
MAX_ONTOLOGY_TYPES_IN_PROMPT=100 # 限制传给 LLM 的类型数量,防止 Prompt 过长
|
||||||
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
|
CORE_GENERAL_TYPES=Person,Organization,Place,Event,Work,Concept # 定义核心类型列表,这些类型会优先包含在合并结果中
|
||||||
|
|||||||
@@ -456,6 +456,7 @@ export const en = {
|
|||||||
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
logoTip: `Supported image formats: JPG, PNG \n Suggested size: square ratio \n Maximum size: ≤ 2MB`,
|
||||||
imageSquareRequired: 'Please upload a square image',
|
imageSquareRequired: 'Please upload a square image',
|
||||||
nameInvalid: 'Name cannot start or end with a space',
|
nameInvalid: 'Name cannot start or end with a space',
|
||||||
|
notAllSpaces: 'Cannot be all spaces',
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
searchPlaceholder: 'search model…',
|
searchPlaceholder: 'search model…',
|
||||||
@@ -1782,6 +1783,7 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
mcp: 'MCP Services',
|
mcp: 'MCP Services',
|
||||||
inner: 'Built-in Tools',
|
inner: 'Built-in Tools',
|
||||||
custom: 'Custom Tools',
|
custom: 'Custom Tools',
|
||||||
|
market: 'Tool Market',
|
||||||
mcpSearchPlaceholder: 'Search MCP Services...',
|
mcpSearchPlaceholder: 'Search MCP Services...',
|
||||||
innerSearchPlaceholder: 'Search Tools...',
|
innerSearchPlaceholder: 'Search Tools...',
|
||||||
customSearchPlaceholder: 'Search Custom Tools...',
|
customSearchPlaceholder: 'Search Custom Tools...',
|
||||||
@@ -1955,7 +1957,9 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re
|
|||||||
path: 'Path',
|
path: 'Path',
|
||||||
viewDetail: 'View Details',
|
viewDetail: 'View Details',
|
||||||
textLink: 'Test Connection',
|
textLink: 'Test Connection',
|
||||||
noResult: 'Processing results will be displayed here'
|
noResult: 'Processing results will be displayed here',
|
||||||
|
serverUrlInvalid: 'Must start with http:// or https://, and cannot have leading or trailing spaces',
|
||||||
|
requestHeaderKeyInvalid: 'Only English letters, numbers, hyphens (-), and underscores (_) are allowed, and cannot start or end with a hyphen or underscore',
|
||||||
},
|
},
|
||||||
workflow: {
|
workflow: {
|
||||||
coreNode: 'Core Nodes',
|
coreNode: 'Core Nodes',
|
||||||
|
|||||||
@@ -1036,6 +1036,7 @@ export const zh = {
|
|||||||
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
logoTip: `支持图片格式(JPG、PNG)\n 尺寸:正方形比例 \n 文件大小限制:≤ 2MB`,
|
||||||
imageSquareRequired: '请上传正方形比例图片',
|
imageSquareRequired: '请上传正方形比例图片',
|
||||||
nameInvalid: '不能是空格开头或结尾',
|
nameInvalid: '不能是空格开头或结尾',
|
||||||
|
notAllSpaces: '不能是纯空格',
|
||||||
},
|
},
|
||||||
model: {
|
model: {
|
||||||
searchPlaceholder: '搜索模型…',
|
searchPlaceholder: '搜索模型…',
|
||||||
@@ -1779,6 +1780,7 @@ export const zh = {
|
|||||||
mcp: 'MCP 服务',
|
mcp: 'MCP 服务',
|
||||||
inner: '内置工具',
|
inner: '内置工具',
|
||||||
custom: '自定义工具',
|
custom: '自定义工具',
|
||||||
|
market: '工具市场',
|
||||||
mcpSearchPlaceholder: '搜索MCP服务...',
|
mcpSearchPlaceholder: '搜索MCP服务...',
|
||||||
innerSearchPlaceholder: '搜索工具...',
|
innerSearchPlaceholder: '搜索工具...',
|
||||||
customSearchPlaceholder: '搜索自定义工具...',
|
customSearchPlaceholder: '搜索自定义工具...',
|
||||||
@@ -1952,7 +1954,9 @@ export const zh = {
|
|||||||
path: '路径',
|
path: '路径',
|
||||||
viewDetail: '查看详情',
|
viewDetail: '查看详情',
|
||||||
textLink: '测试连接',
|
textLink: '测试连接',
|
||||||
noResult: '处理结果将显示在这里'
|
noResult: '处理结果将显示在这里',
|
||||||
|
serverUrlInvalid: '必须以 http:// 或 https:// 开头,且不能有前后空格',
|
||||||
|
requestHeaderKeyInvalid: '只支持英文、数字、连字符(-)、下划线(_),不能以连字符或下划线开头结尾',
|
||||||
},
|
},
|
||||||
workflow: {
|
workflow: {
|
||||||
coreNode: '核心节点',
|
coreNode: '核心节点',
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
* @Author: ZhaoYing
|
* @Author: ZhaoYing
|
||||||
* @Date: 2026-02-03 16:58:03
|
* @Date: 2026-02-03 16:58:03
|
||||||
* @Last Modified by: ZhaoYing
|
* @Last Modified by: ZhaoYing
|
||||||
* @Last Modified time: 2026-03-03 13:46:22
|
* @Last Modified time: 2026-03-04 12:10:44
|
||||||
*/
|
*/
|
||||||
/**
|
/**
|
||||||
* Conversation Page
|
* Conversation Page
|
||||||
@@ -267,8 +267,8 @@ const Conversation: FC = () => {
|
|||||||
currentConversationId = newId
|
currentConversationId = newId
|
||||||
break
|
break
|
||||||
case 'message':
|
case 'message':
|
||||||
const { content, chunk, conversation_id: curId } = item.data as { content: string; chunk: string; conversation_id: string; }
|
const { content, conversation_id: curId } = item.data as { content: string; conversation_id: string; }
|
||||||
updateAssistantMessage(content ?? chunk)
|
updateAssistantMessage(content)
|
||||||
|
|
||||||
if (curId) {
|
if (curId) {
|
||||||
currentConversationId = curId;
|
currentConversationId = curId;
|
||||||
|
|||||||
@@ -185,6 +185,7 @@ const OntologyClassExtractModal = forwardRef<OntologyClassExtractModalRef, Ontol
|
|||||||
rules={[
|
rules={[
|
||||||
{ required: true, message: t('common.pleaseEnter') },
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
{ max: 2000 },
|
{ max: 2000 },
|
||||||
|
{ pattern: /^(?!\s*$).+$/, message: t('common.notAllSpaces') },
|
||||||
]}
|
]}
|
||||||
>
|
>
|
||||||
<Input.TextArea placeholder={t('ontology.scenarioPlaceholder')} />
|
<Input.TextArea placeholder={t('ontology.scenarioPlaceholder')} />
|
||||||
|
|||||||
315
web/src/views/ToolManagement/Market.tsx
Normal file
315
web/src/views/ToolManagement/Market.tsx
Normal file
@@ -0,0 +1,315 @@
|
|||||||
|
import React, { useState, useRef, type ReactNode } from 'react';
|
||||||
|
import { Input, Button, Spin, App } from 'antd';
|
||||||
|
import { SearchOutlined, SettingOutlined, GlobalOutlined, SyncOutlined } from '@ant-design/icons';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import MarketConfigModal, { type MarketConfigModalRef } from './components/MarketConfigModal';
|
||||||
|
|
||||||
|
interface MarketSource {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
category: string;
|
||||||
|
icon: string;
|
||||||
|
url: string;
|
||||||
|
desc: string;
|
||||||
|
apiKey: string;
|
||||||
|
connected: boolean;
|
||||||
|
mcpCount: number;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MarketMcp {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
provider: string;
|
||||||
|
type: string;
|
||||||
|
desc: string;
|
||||||
|
downloads?: string;
|
||||||
|
stars?: string;
|
||||||
|
icon: string;
|
||||||
|
configTemplate: any;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MarketCategory {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
icon: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { message } = App.useApp();
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [selectedSource, setSelectedSource] = useState<string | null>(null);
|
||||||
|
const marketConfigModalRef = useRef<MarketConfigModalRef>(null);
|
||||||
|
const [marketSources, setMarketSources] = useState<MarketSource[]>([
|
||||||
|
{ id: 'smithery', name: 'Smithery', category: 'official', icon: '🔧', url: 'https://mcp.smithery.ai', desc: '官方 MCP 服务市场,提供丰富的 MCP 服务', apiKey: '', connected: false, mcpCount: 2847 },
|
||||||
|
{ id: 'mcpmarket', name: 'MCP Market', category: 'official', icon: '🏪', url: 'https://mcpmarket.com', desc: '综合性 MCP 市场平台', apiKey: '', connected: false, mcpCount: 1523 },
|
||||||
|
{ id: 'glama', name: 'Glama.ai MCP', category: 'official', icon: '✨', url: 'https://glama.ai/mcp', desc: 'Glama AI 提供的 MCP 服务集合', apiKey: '', connected: false, mcpCount: 892 },
|
||||||
|
{ id: 'github-mcp', name: 'modelcontextprotocol/servers', category: 'official', icon: '🐙', url: 'https://github.com/modelcontextprotocol/servers', desc: 'GitHub 官方 MCP 服务器仓库', apiKey: '', connected: true, mcpCount: 156 },
|
||||||
|
{ id: 'aliyun-bailian', name: '阿里云百炼 MCP', category: 'china-cloud', icon: '☁️', url: 'https://bailian.console.aliyun.com/mcp', desc: '阿里云百炼平台 MCP 市场', apiKey: '', connected: false, mcpCount: 423 },
|
||||||
|
{ id: 'modelscope', name: '魔搭社区 MCP', category: 'china-cloud', icon: '🎭', url: 'https://modelscope.cn/mcp', desc: '阿里达摩院魔搭社区 MCP 市场', apiKey: '', connected: false, mcpCount: 312 },
|
||||||
|
]);
|
||||||
|
|
||||||
|
const [categories] = useState<MarketCategory[]>([
|
||||||
|
{ id: 'official', name: '官方/综合', icon: '🌐' },
|
||||||
|
{ id: 'china-cloud', name: '国内云', icon: '☁️' },
|
||||||
|
{ id: 'community', name: '社区/垂直', icon: '👥' }
|
||||||
|
]);
|
||||||
|
|
||||||
|
const [mcpCache, setMcpCache] = useState<Record<string, MarketMcp[]>>({
|
||||||
|
'github-mcp': [
|
||||||
|
{ id: 'gh-1', name: 'Fetch', provider: 'modelcontextprotocol', type: 'Hosted', desc: '使用浏览器模拟大型语言模型检索和处理网页内容', downloads: '203.7m', stars: '308.2k', icon: '🌐', configTemplate: {} },
|
||||||
|
{ id: 'gh-2', name: 'Filesystem', provider: 'modelcontextprotocol', type: 'Local', desc: '安全的文件系统操作,支持读写文件和目录管理', downloads: '156.2m', stars: '245.1k', icon: '📁', configTemplate: {} },
|
||||||
|
{ id: 'gh-3', name: 'GitHub', provider: 'modelcontextprotocol', type: 'Hosted', desc: 'GitHub API 集成,支持仓库、Issue、PR 等操作', downloads: '89.4m', stars: '178.3k', icon: '🐙', configTemplate: {} },
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
const [searchKeyword, setSearchKeyword] = useState('');
|
||||||
|
|
||||||
|
const handleSelectSource = (sourceId: string) => {
|
||||||
|
setSelectedSource(sourceId);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleRefresh = (sourceId: string) => {
|
||||||
|
setLoading(true);
|
||||||
|
setTimeout(() => {
|
||||||
|
// 模拟刷新数据
|
||||||
|
const source = marketSources.find(s => s.id === sourceId);
|
||||||
|
if (source) {
|
||||||
|
message.success(`${source.name} 列表已刷新`);
|
||||||
|
}
|
||||||
|
setLoading(false);
|
||||||
|
}, 600);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpenConfig = (sourceId: string) => {
|
||||||
|
const source = marketSources.find(s => s.id === sourceId);
|
||||||
|
if (source) {
|
||||||
|
marketConfigModalRef.current?.handleOpen(source);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleConnect = (sourceId: string, apiKey: string) => {
|
||||||
|
// 更新市场源状态
|
||||||
|
setMarketSources(prev => prev.map(source => {
|
||||||
|
if (source.id === sourceId) {
|
||||||
|
return {
|
||||||
|
...source,
|
||||||
|
apiKey,
|
||||||
|
connected: true
|
||||||
|
};
|
||||||
|
}
|
||||||
|
return source;
|
||||||
|
}));
|
||||||
|
|
||||||
|
// 模拟获取MCP列表
|
||||||
|
setTimeout(() => {
|
||||||
|
const source = marketSources.find(s => s.id === sourceId);
|
||||||
|
if (source && !mcpCache[sourceId]) {
|
||||||
|
// 生成模拟数据
|
||||||
|
const mockData: MarketMcp[] = [
|
||||||
|
{ id: `${sourceId}-1`, name: `${source.name} 服务 1`, provider: source.name, type: 'Hosted', desc: `来自 ${source.name} 的 MCP 服务`, downloads: '10.2m', stars: '23.4k', icon: '🔧', configTemplate: {} },
|
||||||
|
{ id: `${sourceId}-2`, name: `${source.name} 服务 2`, provider: source.name, type: 'Local', desc: `来自 ${source.name} 的本地 MCP 服务`, downloads: '8.5m', stars: '18.7k', icon: '⚙️', configTemplate: {} }
|
||||||
|
];
|
||||||
|
setMcpCache(prev => ({
|
||||||
|
...prev,
|
||||||
|
[sourceId]: mockData
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
message.success(`已连接 ${source?.name}`);
|
||||||
|
}, 800);
|
||||||
|
};
|
||||||
|
|
||||||
|
const renderSourceDetail = () => {
|
||||||
|
if (!selectedSource) {
|
||||||
|
return (
|
||||||
|
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:h-full rb:text-center">
|
||||||
|
<div className="rb:text-6xl rb:mb-4">🏪</div>
|
||||||
|
<h3 className="rb:text-lg rb:font-semibold rb:text-gray-900 rb:mb-2">选择一个 MCP 市场</h3>
|
||||||
|
<p className="rb:text-sm rb:text-gray-600 rb:max-w-md">从左侧选择一个市场源,配置连接后即可浏览该市场的 MCP 服务</p>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
const source = marketSources.find(s => s.id === selectedSource);
|
||||||
|
if (!source) return null;
|
||||||
|
|
||||||
|
const mcpList = mcpCache[selectedSource] || [];
|
||||||
|
const filteredList = mcpList.filter(mcp =>
|
||||||
|
mcp.name.toLowerCase().includes(searchKeyword.toLowerCase()) ||
|
||||||
|
mcp.desc.toLowerCase().includes(searchKeyword.toLowerCase())
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<div className="rb:flex rb:justify-between rb:items-start rb:pb-6 rb:border-b rb:border-gray-200 rb:mb-6">
|
||||||
|
<div className="rb:flex rb:gap-4">
|
||||||
|
<div className="rb:text-5xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-xl rb:flex-shrink-0">
|
||||||
|
{source.icon}
|
||||||
|
</div>
|
||||||
|
<div className="rb:flex-1">
|
||||||
|
<h2 className="rb:text-xl rb:font-semibold rb:text-gray-900 rb:mb-2">{source.name}</h2>
|
||||||
|
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{source.desc}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<div className="rb:flex rb:gap-3">
|
||||||
|
<Button icon={<SettingOutlined />} onClick={() => handleOpenConfig(selectedSource)}>
|
||||||
|
配置
|
||||||
|
</Button>
|
||||||
|
<Button type="primary" icon={<GlobalOutlined />} onClick={() => window.open(source.url, '_blank')}>
|
||||||
|
前往市场
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="rb:mt-6">
|
||||||
|
<div className="rb:flex rb:justify-between rb:items-center rb:mb-5">
|
||||||
|
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:m-0">
|
||||||
|
可用 MCP 服务 <span className="rb:text-gray-600 rb:font-normal">({mcpList.length})</span>
|
||||||
|
</h3>
|
||||||
|
<div className="rb:flex rb:gap-3 rb:items-center">
|
||||||
|
{source.connected && (
|
||||||
|
<Button size="small" icon={<SyncOutlined />} onClick={() => handleRefresh(selectedSource)}>
|
||||||
|
刷新
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
{mcpList.length > 0 && (
|
||||||
|
<Input
|
||||||
|
prefix={<SearchOutlined />}
|
||||||
|
placeholder="搜索服务..."
|
||||||
|
value={searchKeyword}
|
||||||
|
onChange={(e) => setSearchKeyword(e.target.value)}
|
||||||
|
style={{ width: 200 }}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{mcpList.length > 0 ? (
|
||||||
|
<Spin spinning={loading}>
|
||||||
|
<div className="rb:grid rb:grid-cols-1 md:rb:grid-cols-2 lg:rb:grid-cols-3 rb:gap-4">
|
||||||
|
{filteredList.map(mcp => (
|
||||||
|
<div
|
||||||
|
key={mcp.id}
|
||||||
|
className="rb:bg-white rb:border rb:border-gray-200 rb:rounded-lg rb:p-4 rb:transition-all rb:duration-200 hover:rb:shadow-lg hover:rb:border-gray-300"
|
||||||
|
>
|
||||||
|
<div className="rb:flex rb:justify-between rb:items-center rb:mb-3">
|
||||||
|
<div className="rb:text-3xl rb:w-12 rb:h-12 rb:flex rb:items-center rb:justify-center rb:bg-gray-50 rb:rounded-lg">
|
||||||
|
{mcp.icon}
|
||||||
|
</div>
|
||||||
|
<span className={`rb:px-2 rb:py-1 rb:rounded rb:text-xs rb:font-medium ${
|
||||||
|
mcp.type === 'Hosted'
|
||||||
|
? 'rb:bg-blue-50 rb:text-blue-700'
|
||||||
|
: 'rb:bg-gray-100 rb:text-gray-600'
|
||||||
|
}`}>
|
||||||
|
{mcp.type}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
<h3 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-1">{mcp.name}</h3>
|
||||||
|
{mcp.provider && (
|
||||||
|
<div className="rb:mb-2">
|
||||||
|
<span className="rb:text-xs rb:text-gray-500">@ {mcp.provider}</span>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed rb:mb-3 rb:min-h-[42px]">{mcp.desc}</p>
|
||||||
|
<div className="rb:flex rb:gap-4 rb:mb-3 rb:pt-3 rb:border-t rb:border-gray-100">
|
||||||
|
{mcp.downloads && (
|
||||||
|
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||||
|
<GlobalOutlined /> {mcp.downloads}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
{mcp.stars && (
|
||||||
|
<span className="rb:flex rb:items-center rb:gap-1 rb:text-xs rb:text-gray-500">
|
||||||
|
⭐ {mcp.stars}
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
<div className="rb:flex rb:justify-end">
|
||||||
|
<Button type="primary" size="small">
|
||||||
|
+ 添加
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</Spin>
|
||||||
|
) : (
|
||||||
|
<div className="rb:flex rb:flex-col rb:items-center rb:justify-center rb:py-16 rb:text-center">
|
||||||
|
<div className="rb:text-6xl rb:mb-4">{source.connected ? '📭' : '🔌'}</div>
|
||||||
|
<h4 className="rb:text-base rb:font-semibold rb:text-gray-900 rb:mb-2">
|
||||||
|
{source.connected ? '暂无可用的 MCP 服务' : '尚未连接此市场'}
|
||||||
|
</h4>
|
||||||
|
<p className="rb:text-sm rb:text-gray-600 rb:mb-4">
|
||||||
|
{source.connected ? '该市场暂时没有可用的服务' : '点击右上角"配置"按钮设置连接信息'}
|
||||||
|
</p>
|
||||||
|
{!source.connected && (
|
||||||
|
<Button type="primary" onClick={() => handleOpenConfig(selectedSource)}>
|
||||||
|
配置连接
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="rb:flex rb:gap-4 rb:h-[calc(100vh-178px)]">
|
||||||
|
{/* 左侧市场源列表 */}
|
||||||
|
<div className="rb:w-70 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-y-auto rb:flex-shrink-0">
|
||||||
|
<div className="rb:p-4 rb:border-b rb:border-gray-200">
|
||||||
|
<span className="rb:text-base rb:font-semibold rb:text-gray-900">MCP 市场</span>
|
||||||
|
</div>
|
||||||
|
{categories.map(cat => (
|
||||||
|
<div key={cat.id} className="rb:py-3 rb:border-b rb:border-gray-100 last:rb:border-b-0">
|
||||||
|
<div className="rb:flex rb:items-center rb:gap-2 rb:px-4 rb:py-2 rb:text-xs rb:font-medium rb:text-gray-500 rb:uppercase">
|
||||||
|
<span className="rb:text-sm">{cat.icon}</span>
|
||||||
|
<span>{cat.name}</span>
|
||||||
|
</div>
|
||||||
|
<div className="rb:px-2 rb:py-1">
|
||||||
|
{marketSources
|
||||||
|
.filter(s => s.category === cat.id)
|
||||||
|
.map(source => (
|
||||||
|
<div
|
||||||
|
key={source.id}
|
||||||
|
className={`rb:flex rb:items-center rb:gap-2 rb:px-3 rb:py-2.5 rb:rounded-md rb:cursor-pointer rb:transition-all rb:relative ${
|
||||||
|
selectedSource === source.id
|
||||||
|
? 'rb:bg-blue-50 rb:text-blue-600'
|
||||||
|
: 'hover:rb:bg-gray-50'
|
||||||
|
}`}
|
||||||
|
onClick={() => handleSelectSource(source.id)}
|
||||||
|
>
|
||||||
|
<span className="rb:text-lg rb:flex-shrink-0">{source.icon}</span>
|
||||||
|
<span className="rb:flex-1 rb:text-sm rb:font-medium rb:overflow-hidden rb:text-ellipsis rb:whitespace-nowrap">
|
||||||
|
{source.name}
|
||||||
|
</span>
|
||||||
|
<span className="rb:text-xs rb:text-gray-500 rb:px-1.5 rb:py-0.5 rb:bg-gray-100 rb:rounded-full">
|
||||||
|
{source.mcpCount}
|
||||||
|
</span>
|
||||||
|
{source.connected && (
|
||||||
|
<span className="rb:text-green-500 rb:text-[8px] rb:ml-1">●</span>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 右侧内容区 */}
|
||||||
|
<div className="rb:flex-1 rb:bg-white rb:rounded-lg rb:border rb:border-gray-200 rb:overflow-hidden">
|
||||||
|
<div className="rb:h-full rb:overflow-y-auto rb:p-6">
|
||||||
|
{renderSourceDetail()}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{/* 配置弹窗 */}
|
||||||
|
<MarketConfigModal
|
||||||
|
ref={marketConfigModalRef}
|
||||||
|
onConnect={handleConnect}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default Market;
|
||||||
@@ -6,6 +6,7 @@ import type { CustomToolItem, CustomToolModalRef, ToolItem } from '../types'
|
|||||||
import RbModal from '@/components/RbModal';
|
import RbModal from '@/components/RbModal';
|
||||||
import { parseSchema, addTool, updateTool } from '@/api/tools';
|
import { parseSchema, addTool, updateTool } from '@/api/tools';
|
||||||
import Table from '@/components/Table';
|
import Table from '@/components/Table';
|
||||||
|
import { stringRegExp } from '@/utils/validator';
|
||||||
const FormItem = Form.Item;
|
const FormItem = Form.Item;
|
||||||
|
|
||||||
interface CustomToolModalProps {
|
interface CustomToolModalProps {
|
||||||
@@ -134,7 +135,11 @@ const CustomToolModal = forwardRef<CustomToolModalRef, CustomToolModalProps>(({
|
|||||||
<Form.Item
|
<Form.Item
|
||||||
name="name"
|
name="name"
|
||||||
label={t('tool.name')}
|
label={t('tool.name')}
|
||||||
rules={[{ required: true, message: t('common.enterNamePlaceholder') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('tool.enterNamePlaceholder') },
|
||||||
|
{ max: 50 },
|
||||||
|
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('tool.enterNamePlaceholder')} />
|
<Input placeholder={t('tool.enterNamePlaceholder')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|||||||
173
web/src/views/ToolManagement/components/MarketConfigModal.tsx
Normal file
173
web/src/views/ToolManagement/components/MarketConfigModal.tsx
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
import { forwardRef, useImperativeHandle, useState } from 'react';
|
||||||
|
import { Form, Input, Button, App, Space } from 'antd';
|
||||||
|
import { useTranslation } from 'react-i18next';
|
||||||
|
import { CopyOutlined, EyeInvisibleOutlined, EyeOutlined } from '@ant-design/icons';
|
||||||
|
import RbModal from '@/components/RbModal';
|
||||||
|
|
||||||
|
const FormItem = Form.Item;
|
||||||
|
|
||||||
|
interface MarketSource {
|
||||||
|
id: string;
|
||||||
|
name: string;
|
||||||
|
icon: string;
|
||||||
|
url: string;
|
||||||
|
desc: string;
|
||||||
|
apiKey: string;
|
||||||
|
connected: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
interface MarketConfigModalProps {
|
||||||
|
onConnect: (sourceId: string, apiKey: string) => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MarketConfigModalRef {
|
||||||
|
handleOpen: (source: MarketSource) => void;
|
||||||
|
handleClose: () => void;
|
||||||
|
}
|
||||||
|
|
||||||
|
const MarketConfigModal = forwardRef<MarketConfigModalRef, MarketConfigModalProps>(({
|
||||||
|
onConnect
|
||||||
|
}, ref) => {
|
||||||
|
const { t } = useTranslation();
|
||||||
|
const { message } = App.useApp();
|
||||||
|
const [visible, setVisible] = useState(false);
|
||||||
|
const [form] = Form.useForm();
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
const [currentSource, setCurrentSource] = useState<MarketSource | null>(null);
|
||||||
|
const [showApiKey, setShowApiKey] = useState(false);
|
||||||
|
|
||||||
|
const handleClose = () => {
|
||||||
|
setVisible(false);
|
||||||
|
form.resetFields();
|
||||||
|
setLoading(false);
|
||||||
|
setCurrentSource(null);
|
||||||
|
setShowApiKey(false);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleOpen = (source: MarketSource) => {
|
||||||
|
setCurrentSource(source);
|
||||||
|
form.setFieldsValue({
|
||||||
|
url: source.url,
|
||||||
|
apiKey: source.apiKey,
|
||||||
|
});
|
||||||
|
setVisible(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleSave = () => {
|
||||||
|
form
|
||||||
|
.validateFields()
|
||||||
|
.then((values) => {
|
||||||
|
if (!currentSource) return;
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
// 模拟连接延迟
|
||||||
|
setTimeout(() => {
|
||||||
|
onConnect(currentSource.id, values.apiKey || '');
|
||||||
|
message.success(`正在连接 ${currentSource.name}...`);
|
||||||
|
setLoading(false);
|
||||||
|
handleClose();
|
||||||
|
}, 500);
|
||||||
|
})
|
||||||
|
.catch((err) => {
|
||||||
|
console.log('表单验证失败:', err);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
const handleCopyUrl = () => {
|
||||||
|
if (currentSource?.url) {
|
||||||
|
navigator.clipboard.writeText(currentSource.url).then(() => {
|
||||||
|
message.success(t('common.copySuccess'));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
useImperativeHandle(ref, () => ({
|
||||||
|
handleOpen,
|
||||||
|
handleClose
|
||||||
|
}));
|
||||||
|
|
||||||
|
if (!currentSource) return null;
|
||||||
|
|
||||||
|
return (
|
||||||
|
<RbModal
|
||||||
|
title={`配置 ${currentSource.name}`}
|
||||||
|
open={visible}
|
||||||
|
onCancel={handleClose}
|
||||||
|
okText="保存并连接"
|
||||||
|
onOk={handleSave}
|
||||||
|
confirmLoading={loading}
|
||||||
|
width={600}
|
||||||
|
>
|
||||||
|
<div>
|
||||||
|
{/* 市场源信息头部 */}
|
||||||
|
<div className="rb:flex rb:gap-4 rb:mb-6 rb:p-4 rb:bg-gray-50 rb:rounded-lg">
|
||||||
|
<div className="rb:text-4xl rb:w-16 rb:h-16 rb:flex rb:items-center rb:justify-center rb:bg-white rb:rounded-lg rb:flex-shrink-0">
|
||||||
|
{currentSource.icon}
|
||||||
|
</div>
|
||||||
|
<div className="rb:flex-1">
|
||||||
|
<h3 className="rb:text-base rb:font-semibold rb:mb-1 rb:text-gray-900">{currentSource.name}</h3>
|
||||||
|
<p className="rb:text-sm rb:text-gray-600 rb:leading-relaxed">{currentSource.desc}</p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<Form
|
||||||
|
form={form}
|
||||||
|
layout="vertical"
|
||||||
|
>
|
||||||
|
{/* 市场地址 */}
|
||||||
|
<FormItem
|
||||||
|
name="url"
|
||||||
|
label="市场地址"
|
||||||
|
>
|
||||||
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
|
<Input
|
||||||
|
readOnly
|
||||||
|
placeholder="市场地址"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
icon={<CopyOutlined />}
|
||||||
|
onClick={handleCopyUrl}
|
||||||
|
>
|
||||||
|
复制
|
||||||
|
</Button>
|
||||||
|
</Space.Compact>
|
||||||
|
</FormItem>
|
||||||
|
|
||||||
|
{/* API Key */}
|
||||||
|
<FormItem
|
||||||
|
name="apiKey"
|
||||||
|
label={
|
||||||
|
<span>
|
||||||
|
API Key <span className="rb:text-gray-400 rb:font-normal">(可选)</span>
|
||||||
|
</span>
|
||||||
|
}
|
||||||
|
extra="部分市场需要 API Key 才能获取完整的服务列表"
|
||||||
|
>
|
||||||
|
<Space.Compact style={{ width: '100%' }}>
|
||||||
|
<Input
|
||||||
|
type={showApiKey ? 'text' : 'password'}
|
||||||
|
placeholder="输入 API Key 以获取更多服务"
|
||||||
|
autoComplete="off"
|
||||||
|
/>
|
||||||
|
<Button
|
||||||
|
icon={showApiKey ? <EyeInvisibleOutlined /> : <EyeOutlined />}
|
||||||
|
onClick={() => setShowApiKey(!showApiKey)}
|
||||||
|
/>
|
||||||
|
</Space.Compact>
|
||||||
|
</FormItem>
|
||||||
|
|
||||||
|
{/* 连接状态 */}
|
||||||
|
<div className="rb:flex rb:items-center rb:gap-2 rb:p-3 rb:bg-gray-50 rb:rounded rb:text-sm">
|
||||||
|
<span className="rb:text-gray-600">连接状态:</span>
|
||||||
|
<span className={`rb:font-medium ${currentSource.connected ? 'rb:text-green-600' : 'rb:text-gray-400'}`}>
|
||||||
|
{currentSource.connected ? '● 已连接' : '○ 未连接'}
|
||||||
|
</span>
|
||||||
|
</div>
|
||||||
|
</Form>
|
||||||
|
</div>
|
||||||
|
</RbModal>
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
export default MarketConfigModal;
|
||||||
@@ -9,6 +9,7 @@ import RequestHeaderModal from './RequestHeaderModal';
|
|||||||
import Table from '@/components/Table';
|
import Table from '@/components/Table';
|
||||||
import { addTool, updateTool, testConnection } from '@/api/tools'
|
import { addTool, updateTool, testConnection } from '@/api/tools'
|
||||||
import type { McpServiceModalRef } from '../types'
|
import type { McpServiceModalRef } from '../types'
|
||||||
|
import { stringRegExp } from '@/utils/validator';
|
||||||
|
|
||||||
const FormItem = Form.Item;
|
const FormItem = Form.Item;
|
||||||
|
|
||||||
@@ -168,14 +169,22 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
|||||||
name={['config', "server_url"]}
|
name={['config', "server_url"]}
|
||||||
label={t('tool.serviceEndpoint')}
|
label={t('tool.serviceEndpoint')}
|
||||||
extra={t('tool.serviceEndpointExtra')}
|
extra={t('tool.serviceEndpointExtra')}
|
||||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
|
{ max: 500 },
|
||||||
|
{ pattern: /^https?:\/\/\S+$/, message: t('tool.serverUrlInvalid') },
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('tool.serviceEndpointPlaceholder')} />
|
<Input placeholder={t('tool.serviceEndpointPlaceholder')} />
|
||||||
</FormItem>
|
</FormItem>
|
||||||
<Form.Item
|
<Form.Item
|
||||||
name="name"
|
name="name"
|
||||||
label={t('tool.name')}
|
label={t('tool.name')}
|
||||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
|
{ max: 50 },
|
||||||
|
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('tool.namePlaceholder')} />
|
<Input placeholder={t('tool.namePlaceholder')} />
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
@@ -201,6 +210,7 @@ const McpServiceModal = forwardRef<McpServiceModalRef, McpServiceModalProps>(({
|
|||||||
<FormItem
|
<FormItem
|
||||||
name="description"
|
name="description"
|
||||||
label={t('tool.description')}
|
label={t('tool.description')}
|
||||||
|
rules={[{ max: 500 }]}
|
||||||
>
|
>
|
||||||
<Input.TextArea rows={3} placeholder={t('common.inputPlaceholder', { title: t('tool.description') })}/>
|
<Input.TextArea rows={3} placeholder={t('common.inputPlaceholder', { title: t('tool.description') })}/>
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import { useTranslation } from 'react-i18next';
|
|||||||
|
|
||||||
import type { RequestHeader, RequestHeaderModalRef } from './McpServiceModal'
|
import type { RequestHeader, RequestHeaderModalRef } from './McpServiceModal'
|
||||||
import RbModal from '@/components/RbModal'
|
import RbModal from '@/components/RbModal'
|
||||||
|
import { stringRegExp } from '@/utils/validator';
|
||||||
|
|
||||||
const FormItem = Form.Item;
|
const FormItem = Form.Item;
|
||||||
|
|
||||||
@@ -82,7 +83,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
|
|||||||
<FormItem
|
<FormItem
|
||||||
name="key"
|
name="key"
|
||||||
label={t('tool.requestHeaderName')}
|
label={t('tool.requestHeaderName')}
|
||||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
|
{ pattern: /^[a-zA-Z0-9][a-zA-Z0-9\-_]*[a-zA-Z0-9]$|^[a-zA-Z0-9]$/, message: t('tool.requestHeaderKeyInvalid') },
|
||||||
|
{ max: 100 }
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('common.enter')} />
|
<Input placeholder={t('common.enter')} />
|
||||||
</FormItem>
|
</FormItem>
|
||||||
@@ -90,7 +95,11 @@ const RequestHeaderModal = forwardRef<RequestHeaderModalRef, RequestHeaderModalP
|
|||||||
<FormItem
|
<FormItem
|
||||||
name="value"
|
name="value"
|
||||||
label={t('tool.requestHeaderValue')}
|
label={t('tool.requestHeaderValue')}
|
||||||
rules={[{ required: true, message: t('common.pleaseEnter') }]}
|
rules={[
|
||||||
|
{ required: true, message: t('common.pleaseEnter') },
|
||||||
|
{ pattern: stringRegExp, message: t('common.nameInvalid') },
|
||||||
|
{ max: 2000 }
|
||||||
|
]}
|
||||||
>
|
>
|
||||||
<Input placeholder={t('common.enter',)} />
|
<Input placeholder={t('common.enter',)} />
|
||||||
</FormItem>
|
</FormItem>
|
||||||
|
|||||||
@@ -1,3 +1,11 @@
|
|||||||
|
/*
|
||||||
|
* @Description:
|
||||||
|
* @Version: 0.0.1
|
||||||
|
* @Author: yujiangping
|
||||||
|
* @Date: 2026-01-05 17:22:23
|
||||||
|
* @LastEditors: yujiangping
|
||||||
|
* @LastEditTime: 2026-03-04 15:12:48
|
||||||
|
*/
|
||||||
import React, { useState } from 'react';
|
import React, { useState } from 'react';
|
||||||
import { Tabs } from 'antd';
|
import { Tabs } from 'antd';
|
||||||
import { useTranslation } from 'react-i18next';
|
import { useTranslation } from 'react-i18next';
|
||||||
@@ -5,9 +13,10 @@ import { useTranslation } from 'react-i18next';
|
|||||||
import Mcp from './Mcp';
|
import Mcp from './Mcp';
|
||||||
import Inner from './Inner';
|
import Inner from './Inner';
|
||||||
import Custom from './Custom';
|
import Custom from './Custom';
|
||||||
|
import Market from './Market';
|
||||||
import Tag from '@/components/Tag'
|
import Tag from '@/components/Tag'
|
||||||
|
|
||||||
const tabKeys = ['mcp', 'inner', 'custom']
|
const tabKeys = ['mcp', 'inner', 'custom', 'market']
|
||||||
const ToolManagement: React.FC = () => {
|
const ToolManagement: React.FC = () => {
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const [activeTab, setActiveTab] = useState('mcp');
|
const [activeTab, setActiveTab] = useState('mcp');
|
||||||
@@ -45,6 +54,7 @@ const ToolManagement: React.FC = () => {
|
|||||||
{activeTab === 'mcp' && <Mcp getStatusTag={getStatusTag} />}
|
{activeTab === 'mcp' && <Mcp getStatusTag={getStatusTag} />}
|
||||||
{activeTab === 'inner' && <Inner getStatusTag={getStatusTag} />}
|
{activeTab === 'inner' && <Inner getStatusTag={getStatusTag} />}
|
||||||
{activeTab === 'custom' && <Custom getStatusTag={getStatusTag} />}
|
{activeTab === 'custom' && <Custom getStatusTag={getStatusTag} />}
|
||||||
|
{/* {activeTab === 'market' && <Market getStatusTag={getStatusTag} />} */}
|
||||||
</div>
|
</div>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import {
|
|||||||
getShortTerm,
|
getShortTerm,
|
||||||
} from '@/api/memory'
|
} from '@/api/memory'
|
||||||
import Empty from '@/components/Empty'
|
import Empty from '@/components/Empty'
|
||||||
|
import Markdown from '@/components/Markdown'
|
||||||
|
|
||||||
interface ShortTermItem {
|
interface ShortTermItem {
|
||||||
retrieval: Array<{ query: string; retrieval: string[]; }>;
|
retrieval: Array<{ query: string; retrieval: string[]; }>;
|
||||||
@@ -85,7 +86,9 @@ const ShortTermDetail: FC = () => {
|
|||||||
))}
|
))}
|
||||||
<div>
|
<div>
|
||||||
<div className="rb:font-medium rb:leading-5 rb:mb-1">{t('shortTermDetail.answer')}</div>
|
<div className="rb:font-medium rb:leading-5 rb:mb-1">{t('shortTermDetail.answer')}</div>
|
||||||
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">{vo.answer}</div>
|
<div className="rb:bg-[#FFFFFF] rb:border rb:border-[#DFE4ED] rb:rounded-md rb:px-3 rb:py-2.5 rb:leading-5">
|
||||||
|
<Markdown content={vo.answer} />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</Space>
|
</Space>
|
||||||
</div>
|
</div>
|
||||||
@@ -103,7 +106,9 @@ const ShortTermDetail: FC = () => {
|
|||||||
: data.long_term?.map((vo, voIdx) => (
|
: data.long_term?.map((vo, voIdx) => (
|
||||||
<div key={voIdx} className="rb:leading-5 rb:shadow-[inset_3px_0px_0px_0px_#155EEF] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:px-6 rb:py-3">
|
<div key={voIdx} className="rb:leading-5 rb:shadow-[inset_3px_0px_0px_0px_#155EEF] rb:bg-[#FBFDFF] rb:border rb:border-[#DFE4ED] rb:rounded-lg rb:px-6 rb:py-3">
|
||||||
<div className="rb:mb-1 rb:font-medium rb:leading-5.5">{vo.query}</div>
|
<div className="rb:mb-1 rb:font-medium rb:leading-5.5">{vo.query}</div>
|
||||||
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">{vo.retrieval}</div>
|
<div className="rb:mt-1 rb:leading-5 rb:text-[#5B6167] rb:text-[12px]">
|
||||||
|
<Markdown content={vo.retrieval} />
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -174,8 +174,8 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
|||||||
*/
|
*/
|
||||||
const handleStreamMessage = (data: SSEMessage[]) => {
|
const handleStreamMessage = (data: SSEMessage[]) => {
|
||||||
data.forEach(item => {
|
data.forEach(item => {
|
||||||
const { chunk, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
|
const { content, conversation_id, node_id, cycle_id, cycle_idx, input, output, error, elapsed_time, status } = item.data as {
|
||||||
chunk: string;
|
content: string;
|
||||||
conversation_id: string | null;
|
conversation_id: string | null;
|
||||||
cycle_id: string;
|
cycle_id: string;
|
||||||
cycle_idx: number;
|
cycle_idx: number;
|
||||||
@@ -202,7 +202,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef }>(({ appId
|
|||||||
if (lastIndex >= 0) {
|
if (lastIndex >= 0) {
|
||||||
newList[lastIndex] = {
|
newList[lastIndex] = {
|
||||||
...newList[lastIndex],
|
...newList[lastIndex],
|
||||||
content: newList[lastIndex].content + chunk
|
content: newList[lastIndex].content + content
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return newList
|
return newList
|
||||||
|
|||||||
Reference in New Issue
Block a user