Feature/memory work (#61)
* refactor(conversation): separate service and repository layers for conversation module - Split ConversationService and repository/UnitOfWork layers - Service layer now only handles business logic and orchestration - Repository layer handles all direct database operations - UnitOfWork encapsulates transactional operations for messages - Ensured all public methods have clear English docstrings with arguments, return values, and exceptions * feat(memory): implement work memory endpoints and services - Added API routes for conversation count, conversation list, messages, and detail. - Integrated ConversationService for database queries and LLM-based summary generation. * feat(memory): implement work memory endpoints and services - Added API routes for conversation count, conversation list, messages, and detail. - Integrated ConversationService for database queries and LLM-based summary generation. * feat(workflow): fix issues causing workflow failures if-else None value error knowledge empty list rerank end node output none node value assigner input none value * feat(memory): convert memory file creation time to timestamp and include title and first-line fields in file type * fix(memory): fix serialization output and default value issues * fix(workflow): fix issue with hybrid search logic in knowledge retrieval node
This commit is contained in:
@@ -1,177 +1,290 @@
|
||||
"""会话服务"""
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
import json_repair
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import select, desc
|
||||
from jinja2 import Template
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.error_codes import BizCode
|
||||
from app.core.exceptions import BusinessException
|
||||
from app.core.exceptions import ResourceNotFoundException
|
||||
from app.core.logging_config import get_business_logger
|
||||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.db import get_db
|
||||
from app.models import Conversation, Message
|
||||
from app.models import Conversation, Message, User, ModelType
|
||||
from app.models.conversation_model import ConversationDetail
|
||||
from app.models.prompt_optimizer_model import RoleType
|
||||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||||
from app.schemas.conversation_schema import ConversationOut
|
||||
from app.services import workspace_service
|
||||
from app.services.model_service import ModelConfigService
|
||||
|
||||
logger = get_business_logger()
|
||||
|
||||
|
||||
class ConversationService:
|
||||
"""会话服务"""
|
||||
"""
|
||||
Service layer for managing conversations and messages.
|
||||
Provides methods to create, retrieve, list, and manipulate conversations and messages.
|
||||
Delegates database operations to repositories.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.conversation_repo = ConversationRepository(db)
|
||||
self.message_repo = MessageRepository(db)
|
||||
|
||||
def create_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
is_draft: bool = False,
|
||||
config_snapshot: Optional[dict] = None
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
is_draft: bool = False,
|
||||
config_snapshot: Optional[dict] = None
|
||||
) -> Conversation:
|
||||
"""创建会话"""
|
||||
conversation = Conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
title=title or "新会话",
|
||||
is_draft=is_draft,
|
||||
config_snapshot=config_snapshot
|
||||
)
|
||||
"""
|
||||
Create a new conversation in the system.
|
||||
|
||||
self.db.add(conversation)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
Args:
|
||||
app_id (uuid.UUID): The application ID the conversation belongs to.
|
||||
workspace_id (uuid.UUID): Workspace ID for context.
|
||||
user_id (Optional[str]): Optional user ID for the conversation owner.
|
||||
title (Optional[str]): Conversation title. Defaults to 'New Conversation' if not provided.
|
||||
is_draft (bool): Whether the conversation is a draft.
|
||||
config_snapshot (Optional[dict]): Optional configuration snapshot.
|
||||
|
||||
logger.info(
|
||||
"创建会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
Returns:
|
||||
Conversation: Newly created Conversation instance.
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversation_repo.create_conversation(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
title=title or "New Conversation",
|
||||
is_draft=is_draft,
|
||||
config_snapshot=config_snapshot
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation)
|
||||
|
||||
logger.info(
|
||||
"Create Conversation Success",
|
||||
extra={
|
||||
"conversation_id": str(conversation.id),
|
||||
"app_id": str(app_id),
|
||||
"workspace_id": str(workspace_id),
|
||||
"is_draft": is_draft
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Create Conversation Failed - {str(e)}"
|
||||
)
|
||||
self.db.rollback()
|
||||
raise BusinessException(f"Error create Convsersation", code=BizCode.DB_ERROR)
|
||||
|
||||
return conversation
|
||||
|
||||
def get_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: Optional[uuid.UUID] = None
|
||||
) -> Conversation:
|
||||
"""获取会话"""
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
"""
|
||||
Retrieve a conversation by its ID.
|
||||
|
||||
if workspace_id:
|
||||
stmt = stmt.where(Conversation.workspace_id == workspace_id)
|
||||
Args:
|
||||
conversation_id (uuid.UUID): The conversation UUID.
|
||||
workspace_id (Optional[uuid.UUID]): Optional workspace UUID to restrict the query.
|
||||
|
||||
conversation = self.db.scalars(stmt).first()
|
||||
Raises:
|
||||
ResourceNotFoundException: If the conversation does not exist.
|
||||
|
||||
if not conversation:
|
||||
raise ResourceNotFoundException("会话", str(conversation_id))
|
||||
Returns:
|
||||
Conversation: The requested Conversation instance.
|
||||
"""
|
||||
conversation = self.conversation_repo.get_conversation_by_conversation_id(
|
||||
conversation_id=conversation_id,
|
||||
workspace_id=workspace_id
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> Tuple[List[Conversation], int]:
|
||||
"""列出会话"""
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == app_id,
|
||||
Conversation.workspace_id == workspace_id,
|
||||
Conversation.is_active == True
|
||||
def get_user_conversations(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
) -> list[Conversation]:
|
||||
"""
|
||||
Retrieve recent conversations for a specific user within a workspace.
|
||||
|
||||
This method delegates persistence logic to the repository layer and
|
||||
applies service-level defaults (e.g. recent conversation limit).
|
||||
|
||||
Args:
|
||||
user_id (uuid.UUID): Unique identifier of the user.
|
||||
workspace_id (uuid.UUID): Workspace scope for the query.
|
||||
|
||||
Returns:
|
||||
list[Conversation]: A list of recent conversation entities.
|
||||
"""
|
||||
conversations = self.conversation_repo.get_conversation_by_user_id(
|
||||
user_id,
|
||||
workspace_id,
|
||||
limit=10
|
||||
)
|
||||
return conversations
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(Conversation.user_id == user_id)
|
||||
def list_conversations(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
user_id: Optional[str] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
page: int = 1,
|
||||
pagesize: int = 20
|
||||
) -> Tuple[List[Conversation], int]:
|
||||
"""
|
||||
List conversations with optional filters and pagination.
|
||||
|
||||
if is_draft is not None:
|
||||
stmt = stmt.where(Conversation.is_draft == is_draft)
|
||||
Args:
|
||||
app_id (uuid.UUID): Application ID filter.
|
||||
workspace_id (uuid.UUID): Workspace ID filter.
|
||||
user_id (Optional[str]): Optional user ID filter.
|
||||
is_draft (Optional[bool]): Optional draft status filter.
|
||||
page (int): Page number, 1-based.
|
||||
pagesize (int): Number of items per page.
|
||||
|
||||
# 总数
|
||||
count_stmt = stmt.with_only_columns(Conversation.id)
|
||||
total = len(self.db.execute(count_stmt).all())
|
||||
|
||||
# 分页
|
||||
stmt = stmt.order_by(desc(Conversation.updated_at))
|
||||
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
|
||||
|
||||
conversations = list(self.db.scalars(stmt).all())
|
||||
Returns:
|
||||
Tuple[List[Conversation], int]: A list of Conversation instances and the total count.
|
||||
"""
|
||||
conversations, total = self.conversation_repo.list_conversations(
|
||||
app_id=app_id,
|
||||
workspace_id=workspace_id,
|
||||
user_id=user_id,
|
||||
is_draft=is_draft,
|
||||
page=page,
|
||||
pagesize=pagesize
|
||||
)
|
||||
|
||||
return conversations, total
|
||||
|
||||
def add_message(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
meta_data: Optional[dict] = None
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
meta_data: Optional[dict] = None
|
||||
) -> Message:
|
||||
"""添加消息"""
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
meta_data=meta_data
|
||||
)
|
||||
"""
|
||||
Add a message to a conversation using UnitOfWork.
|
||||
|
||||
self.db.add(message)
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
role (str): Role of the message sender ('user' or 'assistant').
|
||||
content (str): Message content.
|
||||
meta_data (Optional[dict]): Optional metadata.
|
||||
|
||||
# 更新会话的消息计数和更新时间
|
||||
conversation = self.get_conversation(conversation_id)
|
||||
conversation.message_count += 1
|
||||
Returns:
|
||||
Message: Newly created Message instance.
|
||||
"""
|
||||
try:
|
||||
conversation = self.conversation_repo.get_conversation_by_conversation_id(
|
||||
conversation_id
|
||||
)
|
||||
|
||||
# 如果是第一条用户消息,可以用它作为标题
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
|
||||
message = Message(
|
||||
conversation_id=conversation_id,
|
||||
role=role,
|
||||
content=content,
|
||||
meta_data=meta_data,
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
self.message_repo.add_message(message)
|
||||
|
||||
return message
|
||||
conversation.message_count += 1
|
||||
|
||||
if conversation.message_count == 1 and role == "user":
|
||||
conversation.title = (
|
||||
content[:50] + ("..." if len(content) > 50 else "")
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(message)
|
||||
|
||||
logger.info(
|
||||
"Message added successfully",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"message_id": str(message.id),
|
||||
"role": role,
|
||||
"content_length": len(content),
|
||||
},
|
||||
)
|
||||
|
||||
return message
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Message added error, db roll back - {str(e)}",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"role": role,
|
||||
"content_length": len(content),
|
||||
},
|
||||
)
|
||||
self.db.rollback()
|
||||
raise BusinessException(
|
||||
f"Error adding message, conversation_id={conversation_id}",
|
||||
code=BizCode.DB_ERROR
|
||||
)
|
||||
|
||||
def get_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
limit: Optional[int] = None
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
limit: Optional[int] = None
|
||||
) -> List[Message]:
|
||||
"""获取会话消息"""
|
||||
stmt = select(Message).where(
|
||||
Message.conversation_id == conversation_id
|
||||
).order_by(Message.created_at)
|
||||
"""
|
||||
Retrieve messages for a conversation.
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
limit (Optional[int]): Optional maximum number of messages.
|
||||
|
||||
messages = list(self.db.scalars(stmt).all())
|
||||
Returns:
|
||||
List[Message]: List of messages ordered by creation time.
|
||||
"""
|
||||
messages = self.message_repo.get_message_by_conversation_id(
|
||||
conversation_id,
|
||||
limit
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def get_conversation_history(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
max_history: Optional[int] = None
|
||||
) -> List[dict]:
|
||||
"""获取会话历史消息
|
||||
"""
|
||||
Retrieve historical conversation messages formatted as dictionaries.
|
||||
|
||||
Args:
|
||||
conversation_id: 会话ID
|
||||
max_history: 最大历史消息数量
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||||
|
||||
Returns:
|
||||
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
|
||||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||||
"""
|
||||
messages = self.get_messages(conversation_id, limit=max_history)
|
||||
messages = self.message_repo.get_message_by_conversation_id(
|
||||
conversation_id,
|
||||
limit=max_history
|
||||
)
|
||||
|
||||
# 转换为字典格式
|
||||
history = [
|
||||
@@ -185,20 +298,25 @@ class ConversationService:
|
||||
return history
|
||||
|
||||
def save_conversation_messages(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
user_message: str,
|
||||
assistant_message: str
|
||||
):
|
||||
"""保存会话消息(用户消息和助手回复)"""
|
||||
# 添加用户消息
|
||||
"""
|
||||
Save a pair of user and assistant messages to the conversation.
|
||||
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
user_message (str): User's message content.
|
||||
assistant_message (str): Assistant's response content.
|
||||
"""
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="user",
|
||||
content=user_message
|
||||
)
|
||||
|
||||
# 添加助手消息
|
||||
self.add_message(
|
||||
conversation_id=conversation_id,
|
||||
role="assistant",
|
||||
@@ -206,7 +324,7 @@ class ConversationService:
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"保存会话消息成功",
|
||||
"Saved conversation messages successfully",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"user_message_length": len(user_message),
|
||||
@@ -215,35 +333,59 @@ class ConversationService:
|
||||
)
|
||||
|
||||
def delete_conversation(
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
self,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID
|
||||
):
|
||||
"""删除会话(软删除)"""
|
||||
conversation = self.get_conversation(conversation_id, workspace_id)
|
||||
conversation.is_active = False
|
||||
"""
|
||||
Soft delete a conversation.
|
||||
|
||||
self.db.commit()
|
||||
Args:
|
||||
conversation_id (uuid.UUID): Conversation UUID.
|
||||
workspace_id (uuid.UUID): Workspace UUID for validation.
|
||||
"""
|
||||
try:
|
||||
self.conversation_repo.soft_delete_conversation_by_conversation_id(
|
||||
conversation_id,
|
||||
workspace_id
|
||||
)
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
"删除会话成功",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
logger.info(
|
||||
"Soft deleted conversation successfully",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
"workspace_id": str(workspace_id)
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(
|
||||
f"Error deleting conversation, conversation_id={conversation_id} - {str(e)}",
|
||||
)
|
||||
raise BusinessException("Error deleting conversation", code=BizCode.DB_ERROR)
|
||||
|
||||
def create_or_get_conversation(
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
is_draft: bool = False,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
self,
|
||||
app_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
is_draft: bool = False,
|
||||
conversation_id: Optional[uuid.UUID] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Conversation:
|
||||
"""创建或获取会话"""
|
||||
"""
|
||||
Retrieve an existing conversation by ID or create a new one.
|
||||
|
||||
# 如果提供了 conversation_id,尝试获取现有会话
|
||||
Args:
|
||||
app_id (uuid.UUID): Application ID.
|
||||
workspace_id (uuid.UUID): Workspace ID.
|
||||
is_draft (bool): Whether the conversation should be a draft.
|
||||
conversation_id (Optional[uuid.UUID]): Optional conversation ID to retrieve.
|
||||
user_id (Optional[str]): Optional user ID.
|
||||
|
||||
Returns:
|
||||
Conversation: Existing or newly created conversation.
|
||||
"""
|
||||
if conversation_id:
|
||||
try:
|
||||
conversation = self.get_conversation(
|
||||
@@ -253,11 +395,14 @@ class ConversationService:
|
||||
|
||||
# 验证会话是否属于该应用
|
||||
if conversation.app_id != app_id:
|
||||
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
|
||||
raise BusinessException(
|
||||
"Conversation does not belong to this app",
|
||||
BizCode.INVALID_CONVERSATION
|
||||
)
|
||||
return conversation
|
||||
except ResourceNotFoundException:
|
||||
logger.warning(
|
||||
"会话不存在,将创建新会话",
|
||||
"Conversation not found. A new conversation will be created.",
|
||||
extra={"conversation_id": str(conversation_id)}
|
||||
)
|
||||
|
||||
@@ -270,15 +415,179 @@ class ConversationService:
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"为分享链接创建新会话"
|
||||
"Created a new conversation for shared link usage",
|
||||
extra={
|
||||
"conversation_id": str(conversation_id),
|
||||
}
|
||||
)
|
||||
|
||||
return conversation
|
||||
|
||||
# ==================== 依赖注入函数 ====================
|
||||
async def get_conversation_detail(
|
||||
self,
|
||||
user: User,
|
||||
conversation_id: uuid.UUID,
|
||||
workspace_id: uuid.UUID,
|
||||
language: str = "zh"
|
||||
) -> ConversationOut:
|
||||
"""
|
||||
Retrieve or generate the summary and theme of a conversation.
|
||||
|
||||
This method first attempts to fetch the conversation detail from the repository.
|
||||
If no detail exists or the conversation is outdated (>1 day), it generates a new
|
||||
summary using the configured LLM model, stores it, and returns it.
|
||||
|
||||
Args:
|
||||
user (User): The user requesting the conversation summary.
|
||||
conversation_id (UUID): Unique identifier of the conversation.
|
||||
workspace_id (UUID): Identifier of the workspace where the conversation belongs.
|
||||
language (str, optional): Language for the summary generation. Defaults to "zh".
|
||||
|
||||
Returns:
|
||||
ConversationOut: An object containing the conversation's theme, summary,
|
||||
takeaways, and information score.
|
||||
|
||||
Raises:
|
||||
BusinessException: If the workspace model is not configured, the model does
|
||||
not exist, API keys are missing, or the LLM output is invalid.
|
||||
|
||||
Notes:
|
||||
- If conversation details exist and are recent, they are returned directly.
|
||||
- LLM generation uses system and user prompt templates from the filesystem.
|
||||
- JSON repair is applied to ensure model outputs can be safely parsed.
|
||||
- Commits the new conversation detail only if it is generated or outdated.
|
||||
"""
|
||||
logger.info(f"Fetching conversation detail for conversation_id={conversation_id}, workspace_id={workspace_id}")
|
||||
|
||||
conversation_detail = self.conversation_repo.get_conversation_detail(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
conversation = self.get_conversation(
|
||||
conversation_id=conversation_id,
|
||||
)
|
||||
if conversation_detail:
|
||||
logger.info(f"Conversation detail found in repository for conversation_id={conversation_id}")
|
||||
return ConversationOut(
|
||||
theme=conversation_detail.theme,
|
||||
theme_intro=conversation_detail.theme_intro,
|
||||
summary=conversation_detail.summary,
|
||||
takeaways=conversation_detail.takeaways,
|
||||
info_score=conversation_detail.info_score,
|
||||
)
|
||||
logger.info("Conversation detail not found, generating new summary using LLM")
|
||||
configs = workspace_service.get_workspace_models_configs(
|
||||
db=self.db,
|
||||
workspace_id=workspace_id,
|
||||
user=user
|
||||
)
|
||||
model_id = configs.get('llm')
|
||||
if not model_id:
|
||||
logger.error(f"Workspace model configuration not found for workspace_id={workspace_id}")
|
||||
raise BusinessException("Workspace model configuration not found. Please configure a model first.", code=BizCode.MODEL_NOT_FOUND)
|
||||
config = ModelConfigService.get_model_by_id(db=self.db, model_id=model_id)
|
||||
|
||||
if not config:
|
||||
logger.error("Configured model not found for model_id={model_id}")
|
||||
raise BusinessException("Configured model does not exist.", BizCode.NOT_FOUND)
|
||||
|
||||
if not config.api_keys or len(config.api_keys) == 0:
|
||||
logger.error(f"Model API keys missing for model_id={model_id}", )
|
||||
raise BusinessException("Model configuration missing API keys.", BizCode.INVALID_PARAMETER)
|
||||
|
||||
api_config = config.api_keys[0]
|
||||
model_name = api_config.model_name
|
||||
provider = api_config.provider
|
||||
api_key = api_config.api_key
|
||||
api_base = api_config.api_base
|
||||
model_type = config.type
|
||||
|
||||
llm = RedBearLLM(
|
||||
RedBearModelConfig(
|
||||
model_name=model_name,
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=api_base
|
||||
),
|
||||
type=ModelType(model_type)
|
||||
)
|
||||
|
||||
conversation_messages = self.get_conversation_history(
|
||||
conversation_id=conversation_id,
|
||||
max_history=30
|
||||
)
|
||||
|
||||
with open('app/services/prompt/conversation_summary_system.jinja2', 'r', encoding='utf-8') as f:
|
||||
system_prompt = f.read()
|
||||
rendered_system_message = Template(system_prompt).render()
|
||||
|
||||
with open('app/services/prompt/conversation_summary_user.jinja2', 'r', encoding='utf-8') as f:
|
||||
user_prompt = f.read()
|
||||
rendered_user_message = Template(user_prompt).render(
|
||||
language=language,
|
||||
conversation=str(conversation_messages)
|
||||
)
|
||||
messages = [
|
||||
(RoleType.SYSTEM, rendered_system_message),
|
||||
(RoleType.USER, rendered_user_message),
|
||||
]
|
||||
logger.info(f"Invoking LLM for conversation_id={conversation_id}")
|
||||
model_resp = await llm.ainvoke(messages)
|
||||
try:
|
||||
if isinstance(model_resp.content, str):
|
||||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||||
elif isinstance(model_resp.content, list):
|
||||
result = json_repair.repair_json(model_resp.content[0].get("text"), return_objects=True)
|
||||
elif isinstance(model_resp.content, dict):
|
||||
result = model_resp.content
|
||||
else:
|
||||
raise BusinessException("Unexpect model output", code=BizCode.LLM_ERROR)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to parse LLM response for conversation_id={conversation_id}")
|
||||
raise BusinessException("Failed to parse LLM response", code=BizCode.LLM_ERROR) from e
|
||||
|
||||
summary = result.get('summary', "")
|
||||
theme = result.get('theme', "")
|
||||
theme_intro = result.get("theme_intro", "")
|
||||
takeaways = result.get("takeaways") or []
|
||||
info_score = result.get("info_score", 50)
|
||||
|
||||
if datetime.now() - conversation.updated_at > timedelta(days=1):
|
||||
logger.info(f"Updating conversation detail in DB for conversation_id={conversation_id}")
|
||||
conversation_detail = ConversationDetail(
|
||||
conversation_id=conversation.id,
|
||||
summary=summary,
|
||||
theme=theme,
|
||||
theme_intro=theme_intro,
|
||||
takeaways=takeaways,
|
||||
info_score=info_score
|
||||
)
|
||||
self.conversation_repo.add_conversation_detail(conversation_detail)
|
||||
|
||||
self.db.commit()
|
||||
self.db.refresh(conversation_detail)
|
||||
logger.info(f"Returning conversation summary for conversation_id={conversation_id}")
|
||||
conversation_out = ConversationOut(
|
||||
theme=theme,
|
||||
theme_intro=theme_intro,
|
||||
summary=summary,
|
||||
takeaways=takeaways,
|
||||
info_score=info_score
|
||||
)
|
||||
return conversation_out
|
||||
|
||||
|
||||
# ==================== Dependency Injection ====================
|
||||
|
||||
def get_conversation_service(
|
||||
db: Annotated[Session, Depends(get_db)]
|
||||
) -> ConversationService:
|
||||
"""获取工作流服务(依赖注入)"""
|
||||
"""
|
||||
Dependency injection function to provide ConversationService instance.
|
||||
|
||||
Args:
|
||||
db (Session): Database session provided by FastAPI dependency.
|
||||
|
||||
Returns:
|
||||
ConversationService: Service instance.
|
||||
"""
|
||||
return ConversationService(db)
|
||||
|
||||
Reference in New Issue
Block a user