"""会话服务""" import uuid from typing import Optional, List, Tuple from sqlalchemy.orm import Session from sqlalchemy import select, desc from app.models import Conversation, Message from app.core.exceptions import ResourceNotFoundException, BusinessException from app.core.error_codes import BizCode from app.core.logging_config import get_business_logger logger = get_business_logger() class ConversationService: """会话服务""" def __init__(self, db: Session): self.db = db def create_conversation( self, app_id: uuid.UUID, workspace_id: uuid.UUID, user_id: Optional[str] = None, title: Optional[str] = None, is_draft: bool = False, config_snapshot: Optional[dict] = None ) -> Conversation: """创建会话""" conversation = Conversation( app_id=app_id, workspace_id=workspace_id, user_id=user_id, title=title or "新会话", is_draft=is_draft, config_snapshot=config_snapshot ) self.db.add(conversation) self.db.commit() self.db.refresh(conversation) logger.info( "创建会话成功", extra={ "conversation_id": str(conversation.id), "app_id": str(app_id), "workspace_id": str(workspace_id), "is_draft": is_draft } ) return conversation def get_conversation( self, conversation_id: uuid.UUID, workspace_id: Optional[uuid.UUID] = None ) -> Conversation: """获取会话""" stmt = select(Conversation).where(Conversation.id == conversation_id) if workspace_id: stmt = stmt.where(Conversation.workspace_id == workspace_id) conversation = self.db.scalars(stmt).first() if not conversation: raise ResourceNotFoundException("会话", str(conversation_id)) return conversation def list_conversations( self, app_id: uuid.UUID, workspace_id: uuid.UUID, user_id: Optional[str] = None, is_draft: Optional[bool] = None, page: int = 1, pagesize: int = 20 ) -> Tuple[List[Conversation], int]: """列出会话""" stmt = select(Conversation).where( Conversation.app_id == app_id, Conversation.workspace_id == workspace_id, Conversation.is_active == True ) if user_id: stmt = stmt.where(Conversation.user_id == user_id) if is_draft is not None: stmt = stmt.where(Conversation.is_draft == is_draft) # 总数 count_stmt = stmt.with_only_columns(Conversation.id) total = len(self.db.execute(count_stmt).all()) # 分页 stmt = stmt.order_by(desc(Conversation.updated_at)) stmt = stmt.offset((page - 1) * pagesize).limit(pagesize) conversations = list(self.db.scalars(stmt).all()) return conversations, total def add_message( self, conversation_id: uuid.UUID, role: str, content: str, meta_data: Optional[dict] = None ) -> Message: """添加消息""" message = Message( conversation_id=conversation_id, role=role, content=content, meta_data=meta_data ) self.db.add(message) # 更新会话的消息计数和更新时间 conversation = self.get_conversation(conversation_id) conversation.message_count += 1 # 如果是第一条用户消息,可以用它作为标题 if conversation.message_count == 1 and role == "user": conversation.title = content[:50] + ("..." if len(content) > 50 else "") self.db.commit() self.db.refresh(message) return message def get_messages( self, conversation_id: uuid.UUID, limit: Optional[int] = None ) -> List[Message]: """获取会话消息""" stmt = select(Message).where( Message.conversation_id == conversation_id ).order_by(Message.created_at) if limit: stmt = stmt.limit(limit) messages = list(self.db.scalars(stmt).all()) return messages def get_conversation_history( self, conversation_id: uuid.UUID, max_history: Optional[int] = None ) -> List[dict]: """获取会话历史消息 Args: conversation_id: 会话ID max_history: 最大历史消息数量 Returns: List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...] """ messages = self.get_messages(conversation_id, limit=max_history) # 转换为字典格式 history = [ { "role": msg.role, "content": msg.content } for msg in messages ] return history def save_conversation_messages( self, conversation_id: uuid.UUID, user_message: str, assistant_message: str ): """保存会话消息(用户消息和助手回复)""" # 添加用户消息 self.add_message( conversation_id=conversation_id, role="user", content=user_message ) # 添加助手消息 self.add_message( conversation_id=conversation_id, role="assistant", content=assistant_message ) logger.debug( "保存会话消息成功", extra={ "conversation_id": str(conversation_id), "user_message_length": len(user_message), "assistant_message_length": len(assistant_message) } ) def delete_conversation( self, conversation_id: uuid.UUID, workspace_id: uuid.UUID ): """删除会话(软删除)""" conversation = self.get_conversation(conversation_id, workspace_id) conversation.is_active = False self.db.commit() logger.info( "删除会话成功", extra={ "conversation_id": str(conversation_id), "workspace_id": str(workspace_id) } )