[changes] User summaries stored in RAG, generation of memory insights
This commit is contained in:
@@ -5,8 +5,9 @@ This module provides functionality to analyze chunk content and generate insight
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ChunkInsight(BaseModel):
|
||||
@@ -37,7 +57,7 @@ class DomainClassification(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
async def classify_chunk_domain(chunk: str) -> str:
|
||||
async def classify_chunk_domain(chunk: str, end_user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Classify a chunk into a specific domain.
|
||||
|
||||
@@ -48,7 +68,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
||||
Domain name
|
||||
"""
|
||||
try:
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||
|
||||
@@ -82,7 +102,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
||||
return "其他"
|
||||
|
||||
|
||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20, end_user_id: Optional[str] = None) -> Dict[str, float]:
|
||||
"""
|
||||
Analyze the domain distribution of chunks.
|
||||
|
||||
@@ -103,7 +123,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
|
||||
# 为每个chunk分类
|
||||
domain_counts = Counter()
|
||||
for chunk in chunks_to_analyze:
|
||||
domain = await classify_chunk_domain(chunk)
|
||||
domain = await classify_chunk_domain(chunk, end_user_id)
|
||||
domain_counts[domain] += 1
|
||||
|
||||
# 计算百分比
|
||||
@@ -121,7 +141,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
|
||||
return {}
|
||||
|
||||
|
||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15, end_user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate insights from the given chunks.
|
||||
|
||||
@@ -138,7 +158,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
||||
|
||||
try:
|
||||
# 1. 分析领域分布
|
||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks, end_user_id=end_user_id)
|
||||
|
||||
# 2. 统计基本信息
|
||||
total_chunks = len(chunks)
|
||||
@@ -185,7 +205,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成洞察
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
insight = response.content.strip()
|
||||
|
||||
@@ -5,7 +5,8 @@ This module provides functionality to summarize chunk content using LLM.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -14,12 +15,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ChunkSummary(BaseModel):
|
||||
@@ -27,7 +47,7 @@ class ChunkSummary(BaseModel):
|
||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||
|
||||
|
||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10, end_user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Generate a summary for the given chunks.
|
||||
|
||||
@@ -67,7 +87,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
|
||||
]
|
||||
|
||||
# 调用LLM生成摘要
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
response = await llm_client.chat(messages=messages)
|
||||
|
||||
summary = response.content.strip()
|
||||
|
||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import Counter
|
||||
from typing import List, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
||||
|
||||
business_logger = get_business_logger()
|
||||
|
||||
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||
|
||||
def _get_llm_client():
|
||||
"""Get LLM client using db context."""
|
||||
|
||||
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||
with get_db_context() as db:
|
||||
try:
|
||||
if end_user_id:
|
||||
from app.services.memory_agent_service import get_end_user_connected_config
|
||||
from app.services.memory_config_service import MemoryConfigService
|
||||
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||
config_id = connected_config.get("memory_config_id")
|
||||
workspace_id = connected_config.get("workspace_id")
|
||||
if config_id or workspace_id:
|
||||
config_service = MemoryConfigService(db)
|
||||
memory_config = config_service.load_memory_config(
|
||||
config_id=config_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(memory_config.llm_model_id)
|
||||
except Exception as e:
|
||||
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||
factory = MemoryClientFactory(db)
|
||||
return factory.get_llm_client(None) # Uses default LLM
|
||||
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||
|
||||
|
||||
class ExtractedTags(BaseModel):
|
||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||
|
||||
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Extract meaningful tags from the given chunks.
|
||||
|
||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||
)
|
||||
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
|
||||
# 为每个chunk单独提取标签,然后统计频率
|
||||
all_tags = []
|
||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||
|
||||
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Extract persona (人物形象) from the given chunks.
|
||||
|
||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
||||
]
|
||||
|
||||
# 调用LLM提取人物形象
|
||||
llm_client = _get_llm_client()
|
||||
llm_client = _get_llm_client(end_user_id)
|
||||
structured_response = await llm_client.response_structured(
|
||||
messages=messages,
|
||||
response_model=ExtractedPersona
|
||||
|
||||
@@ -678,9 +678,9 @@ async def get_chunk_summary_and_tags(
|
||||
|
||||
# 3. 并发生成摘要、提取标签和人物形象
|
||||
import asyncio
|
||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
|
||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
|
||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
|
||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id)
|
||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id)
|
||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id)
|
||||
|
||||
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
||||
|
||||
@@ -736,7 +736,7 @@ async def get_chunk_insight(
|
||||
from app.core.rag_utils import generate_chunk_insight
|
||||
|
||||
# 3. 生成洞察
|
||||
insight = await generate_chunk_insight(chunks, max_chunks=limit)
|
||||
insight = await generate_chunk_insight(chunks, max_chunks=limit, end_user_id=end_user_id)
|
||||
|
||||
result = {
|
||||
"insight": insight
|
||||
|
||||
Reference in New Issue
Block a user