[changes] User summaries stored in RAG, generation of memory insights

This commit is contained in:
lanceyq
2026-03-09 18:50:32 +08:00
parent 107c676185
commit e4aaa18f61
4 changed files with 89 additions and 29 deletions

View File

@@ -5,8 +5,9 @@ This module provides functionality to analyze chunk content and generate insight
"""
import asyncio
import os
from collections import Counter
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ChunkInsight(BaseModel):
@@ -37,7 +57,7 @@ class DomainClassification(BaseModel):
)
async def classify_chunk_domain(chunk: str) -> str:
async def classify_chunk_domain(chunk: str, end_user_id: Optional[str] = None) -> str:
"""
Classify a chunk into a specific domain.
@@ -48,7 +68,7 @@ async def classify_chunk_domain(chunk: str) -> str:
Domain name
"""
try:
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
prompt = f"""请将以下文本内容归类到最合适的领域中。
@@ -82,7 +102,7 @@ async def classify_chunk_domain(chunk: str) -> str:
return "其他"
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20, end_user_id: Optional[str] = None) -> Dict[str, float]:
"""
Analyze the domain distribution of chunks.
@@ -103,7 +123,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
# 为每个chunk分类
domain_counts = Counter()
for chunk in chunks_to_analyze:
domain = await classify_chunk_domain(chunk)
domain = await classify_chunk_domain(chunk, end_user_id)
domain_counts[domain] += 1
# 计算百分比
@@ -121,7 +141,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
return {}
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15, end_user_id: Optional[str] = None) -> str:
"""
Generate insights from the given chunks.
@@ -138,7 +158,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
try:
# 1. 分析领域分布
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks, end_user_id=end_user_id)
# 2. 统计基本信息
total_chunks = len(chunks)
@@ -185,7 +205,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
]
# 调用LLM生成洞察
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
response = await llm_client.chat(messages=messages)
insight = response.content.strip()

View File

@@ -5,7 +5,8 @@ This module provides functionality to summarize chunk content using LLM.
"""
import asyncio
from typing import Any, Dict, List
import os
from typing import Any, Dict, List, Optional
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -14,12 +15,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ChunkSummary(BaseModel):
@@ -27,7 +47,7 @@ class ChunkSummary(BaseModel):
summary: str = Field(..., description="简洁的chunk内容摘要")
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10, end_user_id: Optional[str] = None) -> str:
"""
Generate a summary for the given chunks.
@@ -67,7 +87,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
]
# 调用LLM生成摘要
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
response = await llm_client.chat(messages=messages)
summary = response.content.strip()

View File

@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
"""
import asyncio
import os
from collections import Counter
from typing import List, Tuple
from typing import List, Optional, Tuple
from app.core.logging_config import get_business_logger
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
business_logger = get_business_logger()
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
def _get_llm_client():
"""Get LLM client using db context."""
def _get_llm_client(end_user_id: Optional[str] = None):
"""Get LLM client, preferring user-connected config with fallback to default."""
with get_db_context() as db:
try:
if end_user_id:
from app.services.memory_agent_service import get_end_user_connected_config
from app.services.memory_config_service import MemoryConfigService
connected_config = get_end_user_connected_config(end_user_id, db)
config_id = connected_config.get("memory_config_id")
workspace_id = connected_config.get("workspace_id")
if config_id or workspace_id:
config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config(
config_id=config_id,
workspace_id=workspace_id
)
factory = MemoryClientFactory(db)
return factory.get_llm_client(memory_config.llm_model_id)
except Exception as e:
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
factory = MemoryClientFactory(db)
return factory.get_llm_client(None) # Uses default LLM
return factory.get_llm_client(DEFAULT_LLM_ID)
class ExtractedTags(BaseModel):
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师''旅行爱好者'")
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
"""
Extract meaningful tags from the given chunks.
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
)
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
# 为每个chunk单独提取标签然后统计频率
all_tags = []
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
"""
Extract persona (人物形象) from the given chunks.
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
]
# 调用LLM提取人物形象
llm_client = _get_llm_client()
llm_client = _get_llm_client(end_user_id)
structured_response = await llm_client.response_structured(
messages=messages,
response_model=ExtractedPersona

View File

@@ -678,9 +678,9 @@ async def get_chunk_summary_and_tags(
# 3. 并发生成摘要、提取标签和人物形象
import asyncio
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
summary_task = generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id)
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id)
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id)
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
@@ -736,7 +736,7 @@ async def get_chunk_insight(
from app.core.rag_utils import generate_chunk_insight
# 3. 生成洞察
insight = await generate_chunk_insight(chunks, max_chunks=limit)
insight = await generate_chunk_insight(chunks, max_chunks=limit, end_user_id=end_user_id)
result = {
"insight": insight