[add] app chat v1

This commit is contained in:
Mark
2025-12-24 20:35:04 +08:00
parent 63d5047d21
commit bbd73d5e95
14 changed files with 1497 additions and 264 deletions

View File

@@ -1,9 +1,12 @@
"""会话服务"""
import uuid
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Annotated
from fastapi import Depends
from sqlalchemy.orm import Session
from sqlalchemy import select, desc
from app.db import get_db
from app.models import Conversation, Message
from app.core.exceptions import ResourceNotFoundException, BusinessException
from app.core.error_codes import BizCode
@@ -14,10 +17,10 @@ logger = get_business_logger()
class ConversationService:
"""会话服务"""
def __init__(self, db: Session):
self.db = db
def create_conversation(
self,
app_id: uuid.UUID,
@@ -36,11 +39,11 @@ class ConversationService:
is_draft=is_draft,
config_snapshot=config_snapshot
)
self.db.add(conversation)
self.db.commit()
self.db.refresh(conversation)
logger.info(
"创建会话成功",
extra={
@@ -50,9 +53,9 @@ class ConversationService:
"is_draft": is_draft
}
)
return conversation
def get_conversation(
self,
conversation_id: uuid.UUID,
@@ -60,17 +63,17 @@ class ConversationService:
) -> Conversation:
"""获取会话"""
stmt = select(Conversation).where(Conversation.id == conversation_id)
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
conversation = self.db.scalars(stmt).first()
if not conversation:
raise ResourceNotFoundException("会话", str(conversation_id))
return conversation
def list_conversations(
self,
app_id: uuid.UUID,
@@ -86,25 +89,25 @@ class ConversationService:
Conversation.workspace_id == workspace_id,
Conversation.is_active == True
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
# 总数
count_stmt = stmt.with_only_columns(Conversation.id)
total = len(self.db.execute(count_stmt).all())
# 分页
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(self.db.scalars(stmt).all())
return conversations, total
def add_message(
self,
conversation_id: uuid.UUID,
@@ -119,22 +122,22 @@ class ConversationService:
content=content,
meta_data=meta_data
)
self.db.add(message)
# 更新会话的消息计数和更新时间
conversation = self.get_conversation(conversation_id)
conversation.message_count += 1
# 如果是第一条用户消息,可以用它作为标题
if conversation.message_count == 1 and role == "user":
conversation.title = content[:50] + ("..." if len(content) > 50 else "")
self.db.commit()
self.db.refresh(message)
return message
def get_messages(
self,
conversation_id: uuid.UUID,
@@ -144,30 +147,30 @@ class ConversationService:
stmt = select(Message).where(
Message.conversation_id == conversation_id
).order_by(Message.created_at)
if limit:
stmt = stmt.limit(limit)
messages = list(self.db.scalars(stmt).all())
return messages
def get_conversation_history(
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None
) -> List[dict]:
"""获取会话历史消息
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
Returns:
List[dict]: 历史消息列表,格式为 [{"role": "user", "content": "..."}, ...]
"""
messages = self.get_messages(conversation_id, limit=max_history)
# 转换为字典格式
history = [
{
@@ -176,9 +179,9 @@ class ConversationService:
}
for msg in messages
]
return history
def save_conversation_messages(
self,
conversation_id: uuid.UUID,
@@ -192,14 +195,14 @@ class ConversationService:
role="user",
content=user_message
)
# 添加助手消息
self.add_message(
conversation_id=conversation_id,
role="assistant",
content=assistant_message
)
logger.debug(
"保存会话消息成功",
extra={
@@ -208,7 +211,7 @@ class ConversationService:
"assistant_message_length": len(assistant_message)
}
)
def delete_conversation(
self,
conversation_id: uuid.UUID,
@@ -217,9 +220,9 @@ class ConversationService:
"""删除会话(软删除)"""
conversation = self.get_conversation(conversation_id, workspace_id)
conversation.is_active = False
self.db.commit()
logger.info(
"删除会话成功",
extra={
@@ -227,3 +230,53 @@ class ConversationService:
"workspace_id": str(workspace_id)
}
)
def create_or_get_conversation(
self,
app_id: uuid.UUID,
workspace_id: uuid.UUID,
is_draft: bool = False,
conversation_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
) -> Conversation:
"""创建或获取会话"""
# 如果提供了 conversation_id尝试获取现有会话
if conversation_id:
try:
conversation = self.get_conversation(
conversation_id=conversation_id,
workspace_id=workspace_id
)
# 验证会话是否属于该应用
if conversation.app_id != app_id:
raise BusinessException("会话不属于该应用", BizCode.INVALID_CONVERSATION)
return conversation
except ResourceNotFoundException:
logger.warning(
"会话不存在,将创建新会话",
extra={"conversation_id": str(conversation_id)}
)
# 创建新会话(使用发布版本的配置)
conversation = self.create_conversation(
app_id=app_id,
workspace_id=workspace_id,
user_id=user_id,
is_draft=is_draft
)
logger.info(
"为分享链接创建新会话"
)
return conversation
# ==================== 依赖注入函数 ====================
def get_conversation_service(
db: Annotated[Session, Depends(get_db)]
) -> ConversationService:
"""获取工作流服务(依赖注入)"""
return ConversationService(db)