653 lines
24 KiB
Python
653 lines
24 KiB
Python
"""会话服务"""
|
||
import os
|
||
import uuid
|
||
from datetime import datetime, timedelta
|
||
from typing import Annotated
|
||
from typing import Optional, List, Tuple
|
||
|
||
import json_repair
|
||
from fastapi import Depends
|
||
from jinja2 import Template
|
||
from sqlalchemy.orm import Session
|
||
|
||
from app.core.error_codes import BizCode
|
||
from app.core.exceptions import BusinessException
|
||
from app.core.exceptions import ResourceNotFoundException
|
||
from app.core.logging_config import get_business_logger
|
||
from app.core.models import RedBearLLM, RedBearModelConfig
|
||
from app.db import get_db
|
||
from app.models import Conversation, Message, User, ModelType
|
||
from app.models.conversation_model import ConversationDetail
|
||
from app.models.prompt_optimizer_model import RoleType
|
||
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
|
||
from app.schemas.conversation_schema import ConversationOut
|
||
from app.schemas.model_schema import ModelInfo
|
||
from app.services import workspace_service
|
||
from app.services.model_service import ModelConfigService
|
||
|
||
logger = get_business_logger()
|
||
|
||
|
||
class ConversationService:
|
||
"""
|
||
Service layer for managing conversations and messages.
|
||
Provides methods to create, retrieve, list, and manipulate conversations and messages.
|
||
Delegates database operations to repositories.
|
||
"""
|
||
|
||
def __init__(self, db: Session):
|
||
self.db = db
|
||
self.conversation_repo = ConversationRepository(db)
|
||
self.message_repo = MessageRepository(db)
|
||
|
||
def create_conversation(
|
||
self,
|
||
app_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
user_id: Optional[str] = None,
|
||
title: Optional[str] = None,
|
||
is_draft: bool = False,
|
||
config_snapshot: Optional[dict] = None
|
||
) -> Conversation:
|
||
"""
|
||
Create a new conversation in the system.
|
||
|
||
Args:
|
||
app_id (uuid.UUID): The application ID the conversation belongs to.
|
||
workspace_id (uuid.UUID): Workspace ID for context.
|
||
user_id (Optional[str]): Optional user ID for the conversation owner.
|
||
title (Optional[str]): Conversation title. Defaults to 'New Conversation' if not provided.
|
||
is_draft (bool): Whether the conversation is a draft.
|
||
config_snapshot (Optional[dict]): Optional configuration snapshot.
|
||
|
||
Returns:
|
||
Conversation: Newly created Conversation instance.
|
||
"""
|
||
try:
|
||
conversation = self.conversation_repo.create_conversation(
|
||
app_id=app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
title=title or "New Conversation",
|
||
is_draft=is_draft,
|
||
config_snapshot=config_snapshot
|
||
)
|
||
self.db.commit()
|
||
self.db.refresh(conversation)
|
||
|
||
logger.info(
|
||
"Create Conversation Success",
|
||
extra={
|
||
"conversation_id": str(conversation.id),
|
||
"app_id": str(app_id),
|
||
"workspace_id": str(workspace_id),
|
||
"is_draft": is_draft
|
||
}
|
||
)
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Create Conversation Failed - {str(e)}"
|
||
)
|
||
self.db.rollback()
|
||
raise BusinessException(f"Error create Convsersation", code=BizCode.DB_ERROR)
|
||
|
||
return conversation
|
||
|
||
def get_conversation(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
workspace_id: Optional[uuid.UUID] = None
|
||
) -> Conversation:
|
||
"""
|
||
Retrieve a conversation by its ID.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): The conversation UUID.
|
||
workspace_id (Optional[uuid.UUID]): Optional workspace UUID to restrict the query.
|
||
|
||
Raises:
|
||
ResourceNotFoundException: If the conversation does not exist.
|
||
|
||
Returns:
|
||
Conversation: The requested Conversation instance.
|
||
"""
|
||
conversation = self.conversation_repo.get_conversation_by_conversation_id(
|
||
conversation_id=conversation_id,
|
||
workspace_id=workspace_id
|
||
)
|
||
|
||
return conversation
|
||
|
||
def get_user_conversations(
|
||
self,
|
||
user_id: uuid.UUID,
|
||
page: int = 1,
|
||
page_size: int = 20
|
||
) -> tuple[list[Conversation], int]:
|
||
"""
|
||
Retrieve recent conversations for a specific user with pagination.
|
||
|
||
Args:
|
||
user_id (uuid.UUID): Unique identifier of the user.
|
||
page (int): Page number (1-based). Defaults to 1.
|
||
page_size (int): Number of items per page. Defaults to 20.
|
||
|
||
Returns:
|
||
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
|
||
"""
|
||
conversations, total = self.conversation_repo.get_conversation_by_user_id(
|
||
user_id,
|
||
page=page,
|
||
page_size=page_size
|
||
)
|
||
return conversations, total
|
||
|
||
def list_conversations(
|
||
self,
|
||
app_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
user_id: Optional[str] = None,
|
||
is_draft: Optional[bool] = None,
|
||
page: int = 1,
|
||
pagesize: int = 20
|
||
) -> Tuple[List[Conversation], int]:
|
||
"""
|
||
List conversations with optional filters and pagination.
|
||
|
||
Args:
|
||
app_id (uuid.UUID): Application ID filter.
|
||
workspace_id (uuid.UUID): Workspace ID filter.
|
||
user_id (Optional[str]): Optional user ID filter.
|
||
is_draft (Optional[bool]): Optional draft status filter.
|
||
page (int): Page number, 1-based.
|
||
pagesize (int): Number of items per page.
|
||
|
||
Returns:
|
||
Tuple[List[Conversation], int]: A list of Conversation instances and the total count.
|
||
"""
|
||
conversations, total = self.conversation_repo.list_conversations(
|
||
app_id=app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
is_draft=is_draft,
|
||
page=page,
|
||
pagesize=pagesize
|
||
)
|
||
|
||
return conversations, total
|
||
|
||
def add_message(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
role: str,
|
||
content: str,
|
||
meta_data: Optional[dict] = None,
|
||
message_id: Optional[uuid.UUID] = None,
|
||
) -> Message:
|
||
"""
|
||
Add a message to a conversation using UnitOfWork.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): Conversation UUID.
|
||
role (str): Role of the message sender ('user' or 'assistant').
|
||
content (str): Message content.
|
||
meta_data (Optional[dict]): Optional metadata.
|
||
message_id (Optional[uuid.UUID]): Optional custom message UUID.
|
||
|
||
Returns:
|
||
Message: Newly created Message instance.
|
||
"""
|
||
try:
|
||
conversation = self.conversation_repo.get_conversation_by_conversation_id(
|
||
conversation_id
|
||
)
|
||
|
||
message = Message(
|
||
id=message_id if message_id else uuid.uuid4(),
|
||
conversation_id=conversation_id,
|
||
role=role,
|
||
content=content,
|
||
meta_data=meta_data,
|
||
)
|
||
|
||
self.message_repo.add_message(message)
|
||
|
||
conversation.message_count += 1
|
||
|
||
if conversation.message_count <= 2 and role == "user":
|
||
conversation.title = (
|
||
content[:50] + ("..." if len(content) > 50 else "")
|
||
)
|
||
|
||
self.db.commit()
|
||
self.db.refresh(message)
|
||
|
||
logger.info(
|
||
"Message added successfully",
|
||
extra={
|
||
"conversation_id": str(conversation_id),
|
||
"message_id": str(message.id),
|
||
"role": role,
|
||
"content_length": len(content),
|
||
},
|
||
)
|
||
|
||
return message
|
||
except Exception as e:
|
||
logger.error(
|
||
f"Message added error, db roll back - {str(e)}",
|
||
extra={
|
||
"conversation_id": str(conversation_id),
|
||
"role": role,
|
||
"content_length": len(content),
|
||
},
|
||
)
|
||
self.db.rollback()
|
||
raise BusinessException(
|
||
f"Error adding message, conversation_id={conversation_id}",
|
||
code=BizCode.DB_ERROR
|
||
)
|
||
|
||
def get_messages(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
limit: Optional[int] = None
|
||
) -> List[Message]:
|
||
"""
|
||
Retrieve messages for a conversation.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): Conversation UUID.
|
||
limit (Optional[int]): Optional maximum number of messages.
|
||
|
||
Returns:
|
||
List[Message]: List of messages ordered by creation time.
|
||
"""
|
||
messages = self.message_repo.get_message_by_conversation_id(
|
||
conversation_id,
|
||
limit
|
||
)
|
||
|
||
return messages
|
||
|
||
async def get_conversation_history(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
max_history: Optional[int] = None,
|
||
current_provider: Optional[str] = None,
|
||
current_is_omni: Optional[bool] = None
|
||
) -> List[dict]:
|
||
"""
|
||
Retrieve historical conversation messages formatted as dictionaries.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): Conversation UUID.
|
||
max_history (Optional[int]): Maximum number of messages to retrieve.
|
||
current_provider (Optional[str]): Current provider for file handling.
|
||
current_is_omni (Optional[bool]): Current omni flag for file handling.
|
||
|
||
Returns:
|
||
List[dict]: List of message dictionaries with keys 'role' and 'content'.
|
||
"""
|
||
messages = self.message_repo.get_message_by_conversation_id(
|
||
conversation_id,
|
||
limit=max_history
|
||
)
|
||
|
||
history = []
|
||
for msg in messages:
|
||
msg_dict = {
|
||
"role": msg.role,
|
||
"content": [{"type": "text", "text": msg.content}]
|
||
}
|
||
|
||
# 处理用户消息中的多模态文件
|
||
if msg.role == "user" and msg.meta_data:
|
||
history_files = msg.meta_data.get("history_files", {})
|
||
|
||
if history_files and current_provider and current_is_omni is not None:
|
||
# 检查是否需要重新处理文件
|
||
stored_provider = history_files.get("provider")
|
||
stored_is_omni = history_files.get("is_omni")
|
||
|
||
# 如果provider或is_omni不匹配,需要重新处理
|
||
if stored_provider != current_provider or stored_is_omni != current_is_omni:
|
||
continue
|
||
|
||
# provider和is_omni匹配,直接使用存储的内容
|
||
msg_dict["content"].extend(history_files.get("content"))
|
||
|
||
history.append(msg_dict)
|
||
|
||
return history
|
||
|
||
def save_conversation_messages(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
user_message: str,
|
||
assistant_message: str,
|
||
meta_data: Optional[dict] = None
|
||
):
|
||
"""
|
||
Save a pair of user and assistant messages to the conversation.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): Conversation UUID.
|
||
user_message (str): User's message content.
|
||
assistant_message (str): Assistant's response content.
|
||
meta_data (Optional[dict]): Optional metadata for the messages.
|
||
"""
|
||
self.add_message(
|
||
conversation_id=conversation_id,
|
||
role="user",
|
||
content=user_message
|
||
)
|
||
|
||
ai_message = self.add_message(
|
||
conversation_id=conversation_id,
|
||
role="assistant",
|
||
content=assistant_message,
|
||
meta_data=meta_data
|
||
)
|
||
|
||
logger.debug(
|
||
"Saved conversation messages successfully",
|
||
extra={
|
||
"conversation_id": str(conversation_id),
|
||
"user_message_length": len(user_message),
|
||
"assistant_message_length": len(assistant_message)
|
||
}
|
||
)
|
||
return ai_message.id
|
||
|
||
def delete_conversation(
|
||
self,
|
||
conversation_id: uuid.UUID,
|
||
workspace_id: uuid.UUID
|
||
):
|
||
"""
|
||
Soft delete a conversation.
|
||
|
||
Args:
|
||
conversation_id (uuid.UUID): Conversation UUID.
|
||
workspace_id (uuid.UUID): Workspace UUID for validation.
|
||
"""
|
||
try:
|
||
self.conversation_repo.soft_delete_conversation_by_conversation_id(
|
||
conversation_id,
|
||
workspace_id
|
||
)
|
||
self.db.commit()
|
||
|
||
logger.info(
|
||
"Soft deleted conversation successfully",
|
||
extra={
|
||
"conversation_id": str(conversation_id),
|
||
"workspace_id": str(workspace_id)
|
||
}
|
||
)
|
||
except Exception as e:
|
||
self.db.rollback()
|
||
logger.error(
|
||
f"Error deleting conversation, conversation_id={conversation_id} - {str(e)}",
|
||
)
|
||
raise BusinessException("Error deleting conversation", code=BizCode.DB_ERROR)
|
||
|
||
def create_or_get_conversation(
|
||
self,
|
||
app_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
is_draft: bool = False,
|
||
conversation_id: Optional[uuid.UUID] = None,
|
||
user_id: Optional[str] = None,
|
||
) -> Conversation:
|
||
"""
|
||
Retrieve an existing conversation by ID or create a new one.
|
||
|
||
Args:
|
||
app_id (uuid.UUID): Application ID.
|
||
workspace_id (uuid.UUID): Workspace ID.
|
||
is_draft (bool): Whether the conversation should be a draft.
|
||
conversation_id (Optional[uuid.UUID]): Optional conversation ID to retrieve.
|
||
user_id (Optional[str]): Optional user ID.
|
||
|
||
Returns:
|
||
Conversation: Existing or newly created conversation.
|
||
"""
|
||
if conversation_id:
|
||
try:
|
||
conversation = self.get_conversation(
|
||
conversation_id=conversation_id,
|
||
workspace_id=workspace_id
|
||
)
|
||
|
||
# 验证会话是否属于该应用
|
||
if conversation.app_id != app_id:
|
||
raise BusinessException(
|
||
"Conversation does not belong to this app",
|
||
BizCode.INVALID_CONVERSATION
|
||
)
|
||
return conversation
|
||
except ResourceNotFoundException:
|
||
logger.warning(
|
||
"Conversation not found. A new conversation will be created.",
|
||
extra={"conversation_id": str(conversation_id)}
|
||
)
|
||
|
||
# 创建新会话(使用发布版本的配置)
|
||
conversation = self.create_conversation(
|
||
app_id=app_id,
|
||
workspace_id=workspace_id,
|
||
user_id=user_id,
|
||
is_draft=is_draft
|
||
)
|
||
|
||
logger.info(
|
||
"Created a new conversation for shared link usage",
|
||
extra={
|
||
"conversation_id": str(conversation_id),
|
||
}
|
||
)
|
||
|
||
return conversation
|
||
|
||
async def get_conversation_detail(
|
||
self,
|
||
user: User,
|
||
conversation_id: uuid.UUID,
|
||
workspace_id: uuid.UUID,
|
||
language: str = "zh"
|
||
) -> ConversationOut:
|
||
"""
|
||
Retrieve or generate the summary and theme of a conversation.
|
||
|
||
This method first attempts to fetch the conversation detail from the repository.
|
||
If no detail exists or the conversation is outdated (>1 day), it generates a new
|
||
summary using the configured LLM model, stores it, and returns it.
|
||
|
||
Args:
|
||
user (User): The user requesting the conversation summary.
|
||
conversation_id (UUID): Unique identifier of the conversation.
|
||
workspace_id (UUID): Identifier of the workspace where the conversation belongs.
|
||
language (str, optional): Language for the summary generation. Defaults to "zh".
|
||
|
||
Returns:
|
||
ConversationOut: An object containing the conversation's theme, summary,
|
||
takeaways, and information score.
|
||
|
||
Raises:
|
||
BusinessException: If the workspace model is not configured, the model does
|
||
not exist, API keys are missing, or the LLM output is invalid.
|
||
|
||
Notes:
|
||
- If conversation details exist and are recent, they are returned directly.
|
||
- LLM generation uses system and user prompt templates from the filesystem.
|
||
- JSON repair is applied to ensure model outputs can be safely parsed.
|
||
- Commits the new conversation detail only if it is generated or outdated.
|
||
"""
|
||
logger.info(f"Fetching conversation detail for conversation_id={conversation_id}, workspace_id={workspace_id}")
|
||
|
||
conversation_detail = self.conversation_repo.get_conversation_detail(
|
||
conversation_id=conversation_id,
|
||
)
|
||
conversation = self.get_conversation(
|
||
conversation_id=conversation_id,
|
||
)
|
||
if not conversation:
|
||
raise BusinessException("Conversation not found", BizCode.INVALID_CONVERSATION)
|
||
is_stable = (
|
||
conversation.updated_at
|
||
and datetime.now() - conversation.updated_at > timedelta(days=1)
|
||
)
|
||
if conversation_detail and is_stable:
|
||
logger.info(f"Conversation detail found in repository for conversation_id={conversation_id}")
|
||
return ConversationOut(
|
||
theme=conversation_detail.theme,
|
||
question=conversation_detail.question if conversation_detail.question else [],
|
||
summary=conversation_detail.summary,
|
||
takeaways=conversation_detail.takeaways,
|
||
info_score=conversation_detail.info_score,
|
||
)
|
||
logger.info("Conversation detail not found, generating new summary using LLM")
|
||
configs = workspace_service.get_workspace_models_configs(
|
||
db=self.db,
|
||
workspace_id=workspace_id,
|
||
user=user
|
||
)
|
||
model_id = configs.get('llm')
|
||
if not model_id:
|
||
logger.error(f"Workspace model configuration not found for workspace_id={workspace_id}")
|
||
raise BusinessException("Workspace model configuration not found. Please configure a model first.", code=BizCode.MODEL_NOT_FOUND)
|
||
config = ModelConfigService.get_model_by_id(db=self.db, model_id=model_id)
|
||
|
||
if not config:
|
||
logger.error("Configured model not found for model_id={model_id}")
|
||
raise BusinessException("Configured model does not exist.", BizCode.NOT_FOUND)
|
||
|
||
if not config.api_keys or len(config.api_keys) == 0:
|
||
logger.error(f"Model API keys missing for model_id={model_id}", )
|
||
raise BusinessException("Model configuration missing API keys.", BizCode.INVALID_PARAMETER)
|
||
|
||
api_config = config.api_keys[0]
|
||
model_name = api_config.model_name
|
||
provider = api_config.provider
|
||
api_key = api_config.api_key
|
||
api_base = api_config.api_base
|
||
is_omni = api_config.is_omni
|
||
capability = api_config.capability
|
||
model_type = config.type
|
||
|
||
llm = RedBearLLM(
|
||
RedBearModelConfig(
|
||
model_name=model_name,
|
||
provider=provider,
|
||
api_key=api_key,
|
||
base_url=api_base,
|
||
is_omni=is_omni,
|
||
support_thinking="thinking" in (capability or []),
|
||
),
|
||
type=ModelType(model_type)
|
||
)
|
||
|
||
conversation_messages = await self.get_conversation_history(
|
||
conversation_id=conversation_id,
|
||
max_history=20,
|
||
current_provider=provider,
|
||
current_is_omni=is_omni
|
||
)
|
||
if len(conversation_messages) == 0:
|
||
return ConversationOut(
|
||
theme="",
|
||
question=[],
|
||
summary="",
|
||
takeaways=[],
|
||
info_score=0,
|
||
)
|
||
prompt_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'prompt')
|
||
with open(os.path.join(prompt_path, 'conversation_summary_system.jinja2'), 'r', encoding='utf-8') as f:
|
||
system_prompt = f.read()
|
||
rendered_system_message = Template(system_prompt).render()
|
||
|
||
with open(os.path.join(prompt_path, 'conversation_summary_user.jinja2'), 'r', encoding='utf-8') as f:
|
||
user_prompt = f.read()
|
||
rendered_user_message = Template(user_prompt).render(
|
||
language=language,
|
||
conversation=str(conversation_messages)
|
||
)
|
||
|
||
messages = [
|
||
(RoleType.SYSTEM, rendered_system_message),
|
||
(RoleType.USER, rendered_user_message),
|
||
]
|
||
logger.info(f"Invoking LLM for conversation_id={conversation_id}")
|
||
model_resp = await llm.ainvoke(messages)
|
||
|
||
try:
|
||
if isinstance(model_resp.content, str):
|
||
result = json_repair.repair_json(model_resp.content, return_objects=True)
|
||
elif isinstance(model_resp.content, list):
|
||
result = json_repair.repair_json(model_resp.content[0].get("text"), return_objects=True)
|
||
elif isinstance(model_resp.content, dict):
|
||
result = model_resp.content
|
||
else:
|
||
raise BusinessException("Unexpect model output", code=BizCode.LLM_ERROR)
|
||
except Exception as e:
|
||
logger.exception(f"Failed to parse LLM response for conversation_id={conversation_id}")
|
||
raise BusinessException("Failed to parse LLM response", code=BizCode.LLM_ERROR) from e
|
||
|
||
summary = result.get('summary', "")
|
||
theme = result.get('theme', "")
|
||
question = result.get("question") or []
|
||
takeaways = result.get("takeaways") or []
|
||
info_score = result.get("info_score", 50)
|
||
|
||
if not is_stable:
|
||
if not conversation_detail:
|
||
logger.info(f"Creating conversation detail in DB for conversation_id={conversation_id}")
|
||
conversation_detail = ConversationDetail(
|
||
conversation_id=conversation.id,
|
||
summary=summary,
|
||
theme=theme,
|
||
question=question,
|
||
takeaways=takeaways,
|
||
info_score=info_score
|
||
)
|
||
self.conversation_repo.add_conversation_detail(conversation_detail)
|
||
else:
|
||
logger.info(f"Updating conversation detail in DB for conversation_id={conversation_id}")
|
||
conversation_detail.summary = summary
|
||
conversation_detail.theme = theme
|
||
conversation_detail.question = question
|
||
conversation_detail.takeaways = takeaways
|
||
conversation_detail.info_score = info_score
|
||
|
||
self.db.commit()
|
||
self.db.refresh(conversation_detail)
|
||
|
||
logger.info(f"Returning conversation summary for conversation_id={conversation_id}")
|
||
conversation_out = ConversationOut(
|
||
theme=theme,
|
||
question=question,
|
||
summary=summary,
|
||
takeaways=takeaways,
|
||
info_score=info_score
|
||
)
|
||
return conversation_out
|
||
|
||
|
||
# ==================== Dependency Injection ====================
|
||
|
||
def get_conversation_service(
|
||
db: Annotated[Session, Depends(get_db)]
|
||
) -> ConversationService:
|
||
"""
|
||
Dependency injection function to provide ConversationService instance.
|
||
|
||
Args:
|
||
db (Session): Database session provided by FastAPI dependency.
|
||
|
||
Returns:
|
||
ConversationService: Service instance.
|
||
"""
|
||
return ConversationService(db)
|