diff --git a/api/app/core/rag_utils/chunk_insight.py b/api/app/core/rag_utils/chunk_insight.py index e904e53d..9fbdbbb2 100644 --- a/api/app/core/rag_utils/chunk_insight.py +++ b/api/app/core/rag_utils/chunk_insight.py @@ -5,8 +5,9 @@ This module provides functionality to analyze chunk content and generate insight """ import asyncio +import os from collections import Counter -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from app.core.logging_config import get_business_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -15,12 +16,31 @@ from pydantic import BaseModel, Field business_logger = get_business_logger() +DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -def _get_llm_client(): - """Get LLM client using db context.""" + +def _get_llm_client(end_user_id: Optional[str] = None): + """Get LLM client, preferring user-connected config with fallback to default.""" with get_db_context() as db: + try: + if end_user_id: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + if config_id or workspace_id: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + factory = MemoryClientFactory(db) + return factory.get_llm_client(memory_config.llm_model_id) + except Exception as e: + business_logger.warning(f"Failed to get user connected config, using default LLM: {e}") factory = MemoryClientFactory(db) - return factory.get_llm_client(None) # Uses default LLM + return factory.get_llm_client(DEFAULT_LLM_ID) class ChunkInsight(BaseModel): @@ -37,7 +57,7 @@ class DomainClassification(BaseModel): ) -async def classify_chunk_domain(chunk: str) -> str: +async def classify_chunk_domain(chunk: str, end_user_id: Optional[str] = None) -> str: """ Classify a chunk into a specific domain. @@ -48,7 +68,7 @@ async def classify_chunk_domain(chunk: str) -> str: Domain name """ try: - llm_client = _get_llm_client() + llm_client = _get_llm_client(end_user_id) prompt = f"""请将以下文本内容归类到最合适的领域中。 @@ -82,7 +102,7 @@ async def classify_chunk_domain(chunk: str) -> str: return "其他" -async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]: +async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20, end_user_id: Optional[str] = None) -> Dict[str, float]: """ Analyze the domain distribution of chunks. @@ -103,7 +123,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) - # 为每个chunk分类 domain_counts = Counter() for chunk in chunks_to_analyze: - domain = await classify_chunk_domain(chunk) + domain = await classify_chunk_domain(chunk, end_user_id) domain_counts[domain] += 1 # 计算百分比 @@ -121,7 +141,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) - return {} -async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str: +async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15, end_user_id: Optional[str] = None) -> str: """ Generate insights from the given chunks. @@ -138,7 +158,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str try: # 1. 分析领域分布 - domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks) + domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks, end_user_id=end_user_id) # 2. 统计基本信息 total_chunks = len(chunks) @@ -185,7 +205,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str ] # 调用LLM生成洞察 - llm_client = _get_llm_client() + llm_client = _get_llm_client(end_user_id) response = await llm_client.chat(messages=messages) insight = response.content.strip() diff --git a/api/app/core/rag_utils/chunk_summary.py b/api/app/core/rag_utils/chunk_summary.py index 7f69af88..53df2ab3 100644 --- a/api/app/core/rag_utils/chunk_summary.py +++ b/api/app/core/rag_utils/chunk_summary.py @@ -5,7 +5,8 @@ This module provides functionality to summarize chunk content using LLM. """ import asyncio -from typing import Any, Dict, List +import os +from typing import Any, Dict, List, Optional from app.core.logging_config import get_business_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -14,12 +15,31 @@ from pydantic import BaseModel, Field business_logger = get_business_logger() +DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -def _get_llm_client(): - """Get LLM client using db context.""" + +def _get_llm_client(end_user_id: Optional[str] = None): + """Get LLM client, preferring user-connected config with fallback to default.""" with get_db_context() as db: + try: + if end_user_id: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + if config_id or workspace_id: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + factory = MemoryClientFactory(db) + return factory.get_llm_client(memory_config.llm_model_id) + except Exception as e: + business_logger.warning(f"Failed to get user connected config, using default LLM: {e}") factory = MemoryClientFactory(db) - return factory.get_llm_client(None) # Uses default LLM + return factory.get_llm_client(DEFAULT_LLM_ID) class ChunkSummary(BaseModel): @@ -27,7 +47,7 @@ class ChunkSummary(BaseModel): summary: str = Field(..., description="简洁的chunk内容摘要") -async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str: +async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10, end_user_id: Optional[str] = None) -> str: """ Generate a summary for the given chunks. @@ -67,7 +87,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str ] # 调用LLM生成摘要 - llm_client = _get_llm_client() + llm_client = _get_llm_client(end_user_id) response = await llm_client.chat(messages=messages) summary = response.content.strip() diff --git a/api/app/core/rag_utils/chunk_tags.py b/api/app/core/rag_utils/chunk_tags.py index 2057f8ac..98ab4a33 100644 --- a/api/app/core/rag_utils/chunk_tags.py +++ b/api/app/core/rag_utils/chunk_tags.py @@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content """ import asyncio +import os from collections import Counter -from typing import List, Tuple +from typing import List, Optional, Tuple from app.core.logging_config import get_business_logger from app.core.memory.utils.llm.llm_utils import MemoryClientFactory @@ -15,12 +16,31 @@ from pydantic import BaseModel, Field business_logger = get_business_logger() +DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus") -def _get_llm_client(): - """Get LLM client using db context.""" + +def _get_llm_client(end_user_id: Optional[str] = None): + """Get LLM client, preferring user-connected config with fallback to default.""" with get_db_context() as db: + try: + if end_user_id: + from app.services.memory_agent_service import get_end_user_connected_config + from app.services.memory_config_service import MemoryConfigService + connected_config = get_end_user_connected_config(end_user_id, db) + config_id = connected_config.get("memory_config_id") + workspace_id = connected_config.get("workspace_id") + if config_id or workspace_id: + config_service = MemoryConfigService(db) + memory_config = config_service.load_memory_config( + config_id=config_id, + workspace_id=workspace_id + ) + factory = MemoryClientFactory(db) + return factory.get_llm_client(memory_config.llm_model_id) + except Exception as e: + business_logger.warning(f"Failed to get user connected config, using default LLM: {e}") factory = MemoryClientFactory(db) - return factory.get_llm_client(None) # Uses default LLM + return factory.get_llm_client(DEFAULT_LLM_ID) class ExtractedTags(BaseModel): @@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel): personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等") -async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]: +async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]: """ Extract meaningful tags from the given chunks. @@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: "标签应该是名词或名词短语,能够准确概括文本的核心内容。" ) - llm_client = _get_llm_client() + llm_client = _get_llm_client(end_user_id) # 为每个chunk单独提取标签,然后统计频率 all_tags = [] @@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1 return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks)) -async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]: +async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]: """ Extract persona (人物形象) from the given chunks. @@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch ] # 调用LLM提取人物形象 - llm_client = _get_llm_client() + llm_client = _get_llm_client(end_user_id) structured_response = await llm_client.response_structured( messages=messages, response_model=ExtractedPersona diff --git a/api/app/services/memory_dashboard_service.py b/api/app/services/memory_dashboard_service.py index 05aed57e..63a9c361 100644 --- a/api/app/services/memory_dashboard_service.py +++ b/api/app/services/memory_dashboard_service.py @@ -678,9 +678,9 @@ async def get_chunk_summary_and_tags( # 3. 并发生成摘要、提取标签和人物形象 import asyncio - summary_task = generate_chunk_summary(chunks, max_chunks=limit) - tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit) - personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit) + summary_task = generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id) + tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id) + personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id) summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task) @@ -736,7 +736,7 @@ async def get_chunk_insight( from app.core.rag_utils import generate_chunk_insight # 3. 生成洞察 - insight = await generate_chunk_insight(chunks, max_chunks=limit) + insight = await generate_chunk_insight(chunks, max_chunks=limit, end_user_id=end_user_id) result = { "insight": insight