Files
MemoryBear/api/app/services/conversation_service.py
2025-12-15 14:09:43 +08:00

230 lines
6.5 KiB
Python

"""会话服务"""
import uuid
from typing import Optional, List, Tuple
from sqlalchemy.orm import Session
from sqlalchemy import select, desc
from app.models import Conversation, Message
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
from app.core.logging_config import get_business_logger
logger = get_business_logger()
class ConversationService:
"""会话服务"""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
title: Optional[str] = None,
is_draft: bool = False,
config_snapshot: Optional[dict] = None
) -> Conversation:
"""创建会话"""
conversation = Conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
title=title or "新会话",
is_draft=is_draft,
config_snapshot=config_snapshot
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info(
"创建会话成功",
extra={
"conversation_id": str(conversation.id),
"app_id": str(app_id),
"workspace_id": str(workspace_id),
"is_draft": is_draft
}
)
return conversation
def get_conversation(
self,
conversation_id: uuid.UUID,
workspace_id: Optional[uuid.UUID] = None
) -> Conversation:
"""获取会话"""
stmt = select(Conversation).where(Conversation.id == conversation_id)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
conversation = self.db.scalars(stmt).first()
if not conversation:
raise ResourceNotFoundException("会话", str(conversation_id))
return conversation
def list_conversations(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
user_id: Optional[str] = None,
is_draft: Optional[bool] = None,
page: int = 1,
pagesize: int = 20
) -> Tuple[List[Conversation], int]:
"""列出会话"""
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active == True
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# 总数
count_stmt = stmt.with_only_columns(Conversation.id)
total = len(self.db.execute(count_stmt).all())
# 分页
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
return conversations, total
def add_message(
self,
conversation_id: uuid.UUID,
role: str,
content: str,
meta_data: Optional[dict] = None
) -> Message:
"""添加消息"""
message = Message(
conversation_id=conversation_id,
role=role,
content=content,
meta_data=meta_data
)
self.db.add(message)
# 更新会话的消息计数和更新时间
conversation = self.get_conversation(conversation_id)
conversation.message_count += 1
# 如果是第一条用户消息,可以用它作为标题
if conversation.message_count == 1 and role == "user":
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
self.db.commit()
self.db.refresh(message)
return message
def get_messages(
self,
conversation_id: uuid.UUID,
limit: Optional[int] = None
) -> List[Message]:
"""获取会话消息"""
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
if limit:
stmt = stmt.limit(limit)
messages = list(self.db.scalars(stmt).all())
return messages
def get_conversation_history(
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None
) -> List[dict]:
"""获取会话历史消息
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
Returns:
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
"""
messages = self.get_messages(conversation_id, limit=max_history)
# 转换为字典格式
history = [
{
"role": msg.role,
"content": msg.content
}
for msg in messages
]
return history
def save_conversation_messages(
self,
conversation_id: uuid.UUID,
user_message: str,
assistant_message: str
):
"""保存会话消息(用户消息和助手回复)"""
# 添加用户消息
self.add_message(
conversation_id=conversation_id,
role="user",
content=user_message
)
# 添加助手消息
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
)
logger.debug(
"保存会话消息成功",
extra={
"conversation_id": str(conversation_id),
"user_message_length": len(user_message),
"assistant_message_length": len(assistant_message)
}
)
def delete_conversation(
self,
conversation_id: uuid.UUID,
workspace_id: uuid.UUID
):
"""删除会话(软删除)"""
conversation = self.get_conversation(conversation_id, workspace_id)
conversation.is_active = False
self.db.commit()
logger.info(
"删除会话成功",
extra={
"conversation_id": str(conversation_id),
"workspace_id": str(workspace_id)
}
)