[changes] User summaries stored in RAG, generation of memory insights
This commit is contained in:
@@ -5,8 +5,9 @@ This module provides functionality to analyze chunk content and generate insight
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
class ChunkInsight(BaseModel):
|
class ChunkInsight(BaseModel):
|
||||||
@@ -37,7 +57,7 @@ class DomainClassification(BaseModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def classify_chunk_domain(chunk: str) -> str:
|
async def classify_chunk_domain(chunk: str, end_user_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Classify a chunk into a specific domain.
|
Classify a chunk into a specific domain.
|
||||||
|
|
||||||
@@ -48,7 +68,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
|||||||
Domain name
|
Domain name
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
|
|
||||||
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
prompt = f"""请将以下文本内容归类到最合适的领域中。
|
||||||
|
|
||||||
@@ -82,7 +102,7 @@ async def classify_chunk_domain(chunk: str) -> str:
|
|||||||
return "其他"
|
return "其他"
|
||||||
|
|
||||||
|
|
||||||
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -> Dict[str, float]:
|
async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20, end_user_id: Optional[str] = None) -> Dict[str, float]:
|
||||||
"""
|
"""
|
||||||
Analyze the domain distribution of chunks.
|
Analyze the domain distribution of chunks.
|
||||||
|
|
||||||
@@ -103,7 +123,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
|
|||||||
# 为每个chunk分类
|
# 为每个chunk分类
|
||||||
domain_counts = Counter()
|
domain_counts = Counter()
|
||||||
for chunk in chunks_to_analyze:
|
for chunk in chunks_to_analyze:
|
||||||
domain = await classify_chunk_domain(chunk)
|
domain = await classify_chunk_domain(chunk, end_user_id)
|
||||||
domain_counts[domain] += 1
|
domain_counts[domain] += 1
|
||||||
|
|
||||||
# 计算百分比
|
# 计算百分比
|
||||||
@@ -121,7 +141,7 @@ async def analyze_domain_distribution(chunks: List[str], max_chunks: int = 20) -
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str:
|
async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15, end_user_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Generate insights from the given chunks.
|
Generate insights from the given chunks.
|
||||||
|
|
||||||
@@ -138,7 +158,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 1. 分析领域分布
|
# 1. 分析领域分布
|
||||||
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks)
|
domain_dist = await analyze_domain_distribution(chunks, max_chunks=max_chunks, end_user_id=end_user_id)
|
||||||
|
|
||||||
# 2. 统计基本信息
|
# 2. 统计基本信息
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
@@ -185,7 +205,7 @@ async def generate_chunk_insight(chunks: List[str], max_chunks: int = 15) -> str
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 调用LLM生成洞察
|
# 调用LLM生成洞察
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
response = await llm_client.chat(messages=messages)
|
response = await llm_client.chat(messages=messages)
|
||||||
|
|
||||||
insight = response.content.strip()
|
insight = response.content.strip()
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ This module provides functionality to summarize chunk content using LLM.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Any, Dict, List
|
import os
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -14,12 +15,31 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
class ChunkSummary(BaseModel):
|
class ChunkSummary(BaseModel):
|
||||||
@@ -27,7 +47,7 @@ class ChunkSummary(BaseModel):
|
|||||||
summary: str = Field(..., description="简洁的chunk内容摘要")
|
summary: str = Field(..., description="简洁的chunk内容摘要")
|
||||||
|
|
||||||
|
|
||||||
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str:
|
async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10, end_user_id: Optional[str] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a summary for the given chunks.
|
Generate a summary for the given chunks.
|
||||||
|
|
||||||
@@ -67,7 +87,7 @@ async def generate_chunk_summary(chunks: List[str], max_chunks: int = 10) -> str
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 调用LLM生成摘要
|
# 调用LLM生成摘要
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
response = await llm_client.chat(messages=messages)
|
response = await llm_client.chat(messages=messages)
|
||||||
|
|
||||||
summary = response.content.strip()
|
summary = response.content.strip()
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ This module provides functionality to extract meaningful tags from chunk content
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import List, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from app.core.logging_config import get_business_logger
|
from app.core.logging_config import get_business_logger
|
||||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||||
@@ -15,12 +16,31 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
business_logger = get_business_logger()
|
business_logger = get_business_logger()
|
||||||
|
|
||||||
|
DEFAULT_LLM_ID = os.getenv("SELECTED_LLM_ID", "openai/qwen-plus")
|
||||||
|
|
||||||
def _get_llm_client():
|
|
||||||
"""Get LLM client using db context."""
|
def _get_llm_client(end_user_id: Optional[str] = None):
|
||||||
|
"""Get LLM client, preferring user-connected config with fallback to default."""
|
||||||
with get_db_context() as db:
|
with get_db_context() as db:
|
||||||
|
try:
|
||||||
|
if end_user_id:
|
||||||
|
from app.services.memory_agent_service import get_end_user_connected_config
|
||||||
|
from app.services.memory_config_service import MemoryConfigService
|
||||||
|
connected_config = get_end_user_connected_config(end_user_id, db)
|
||||||
|
config_id = connected_config.get("memory_config_id")
|
||||||
|
workspace_id = connected_config.get("workspace_id")
|
||||||
|
if config_id or workspace_id:
|
||||||
|
config_service = MemoryConfigService(db)
|
||||||
|
memory_config = config_service.load_memory_config(
|
||||||
|
config_id=config_id,
|
||||||
|
workspace_id=workspace_id
|
||||||
|
)
|
||||||
|
factory = MemoryClientFactory(db)
|
||||||
|
return factory.get_llm_client(memory_config.llm_model_id)
|
||||||
|
except Exception as e:
|
||||||
|
business_logger.warning(f"Failed to get user connected config, using default LLM: {e}")
|
||||||
factory = MemoryClientFactory(db)
|
factory = MemoryClientFactory(db)
|
||||||
return factory.get_llm_client(None) # Uses default LLM
|
return factory.get_llm_client(DEFAULT_LLM_ID)
|
||||||
|
|
||||||
|
|
||||||
class ExtractedTags(BaseModel):
|
class ExtractedTags(BaseModel):
|
||||||
@@ -33,7 +53,7 @@ class ExtractedPersona(BaseModel):
|
|||||||
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
personas: List[str] = Field(..., description="从文本中提取的人物形象列表,如'产品设计师'、'旅行爱好者'等")
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10) -> List[Tuple[str, int]]:
|
async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks: int = 10, end_user_id: Optional[str] = None) -> List[Tuple[str, int]]:
|
||||||
"""
|
"""
|
||||||
Extract meaningful tags from the given chunks.
|
Extract meaningful tags from the given chunks.
|
||||||
|
|
||||||
@@ -64,7 +84,7 @@ async def extract_chunk_tags(chunks: List[str], max_tags: int = 10, max_chunks:
|
|||||||
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
"标签应该是名词或名词短语,能够准确概括文本的核心内容。"
|
||||||
)
|
)
|
||||||
|
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
|
|
||||||
# 为每个chunk单独提取标签,然后统计频率
|
# 为每个chunk单独提取标签,然后统计频率
|
||||||
all_tags = []
|
all_tags = []
|
||||||
@@ -116,7 +136,7 @@ async def extract_chunk_tags_with_frequency(chunks: List[str], max_tags: int = 1
|
|||||||
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
return await extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=len(chunks))
|
||||||
|
|
||||||
|
|
||||||
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20) -> List[str]:
|
async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_chunks: int = 20, end_user_id: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Extract persona (人物形象) from the given chunks.
|
Extract persona (人物形象) from the given chunks.
|
||||||
|
|
||||||
@@ -159,7 +179,7 @@ async def extract_chunk_persona(chunks: List[str], max_personas: int = 5, max_ch
|
|||||||
]
|
]
|
||||||
|
|
||||||
# 调用LLM提取人物形象
|
# 调用LLM提取人物形象
|
||||||
llm_client = _get_llm_client()
|
llm_client = _get_llm_client(end_user_id)
|
||||||
structured_response = await llm_client.response_structured(
|
structured_response = await llm_client.response_structured(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
response_model=ExtractedPersona
|
response_model=ExtractedPersona
|
||||||
|
|||||||
@@ -678,9 +678,9 @@ async def get_chunk_summary_and_tags(
|
|||||||
|
|
||||||
# 3. 并发生成摘要、提取标签和人物形象
|
# 3. 并发生成摘要、提取标签和人物形象
|
||||||
import asyncio
|
import asyncio
|
||||||
summary_task = generate_chunk_summary(chunks, max_chunks=limit)
|
summary_task = generate_chunk_summary(chunks, max_chunks=limit, end_user_id=end_user_id)
|
||||||
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit)
|
tags_task = extract_chunk_tags(chunks, max_tags=max_tags, max_chunks=limit, end_user_id=end_user_id)
|
||||||
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit)
|
personas_task = extract_chunk_persona(chunks, max_personas=5, max_chunks=limit, end_user_id=end_user_id)
|
||||||
|
|
||||||
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
summary, tags_with_freq, personas = await asyncio.gather(summary_task, tags_task, personas_task)
|
||||||
|
|
||||||
@@ -736,7 +736,7 @@ async def get_chunk_insight(
|
|||||||
from app.core.rag_utils import generate_chunk_insight
|
from app.core.rag_utils import generate_chunk_insight
|
||||||
|
|
||||||
# 3. 生成洞察
|
# 3. 生成洞察
|
||||||
insight = await generate_chunk_insight(chunks, max_chunks=limit)
|
insight = await generate_chunk_insight(chunks, max_chunks=limit, end_user_id=end_user_id)
|
||||||
|
|
||||||
result = {
|
result = {
|
||||||
"insight": insight
|
"insight": insight
|
||||||
|
|||||||
Reference in New Issue
Block a user