Merge pull request #677 from SuanmoSuanyangTechnology/develop

Develop
This commit is contained in:
yingzhao
2026-03-24 15:22:33 +08:00
committed by GitHub
23 changed files with 678 additions and 305 deletions

View File

@@ -77,6 +77,7 @@ celery_app.conf.update(
# Worker 设置 (per-worker settings are in docker-compose command line)
worker_prefetch_multiplier=1, # Don't hoard tasks, fairer distribution
worker_redirect_stdouts_level='INFO', # stdout/print → INFO instead of WARNING
# 结果过期时间
result_expires=3600, # 结果保存1小时

View File

@@ -8,6 +8,7 @@ from fastapi import APIRouter
from . import (
api_key_controller,
app_controller,
app_log_controller,
auth_controller,
chunk_controller,
document_controller,
@@ -70,6 +71,7 @@ manager_router.include_router(chunk_controller.router)
manager_router.include_router(test_controller.router)
manager_router.include_router(knowledgeshare_controller.router)
manager_router.include_router(app_controller.router)
manager_router.include_router(app_log_controller.router)
manager_router.include_router(upload_controller.router)
manager_router.include_router(memory_agent_controller.router)
manager_router.include_router(memory_dashboard_controller.router)

View File

@@ -57,6 +57,7 @@ def list_apps(
page: int = 1,
pagesize: int = 10,
ids: Optional[str] = None,
api_key: Optional[str] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
@@ -65,10 +66,25 @@ def list_apps(
- 默认包含本工作空间的应用和分享给本工作空间的应用
- 设置 include_shared=false 可以只查看本工作空间的应用
- 当提供 ids 参数时,按逗号分割获取指定应用,不分页
- 当提供 api_key 参数时,查找该 API Key 关联的应用
"""
from sqlalchemy import select as sa_select
from app.models.api_key_model import ApiKey
workspace_id = current_user.current_workspace_id
service = app_service.AppService(db)
# 通过 API Key 搜索:精确匹配,将 resource_id 注入 ids 走统一分页流程
if api_key:
matched_id = db.execute(
sa_select(ApiKey.resource_id).where(
ApiKey.workspace_id == workspace_id,
ApiKey.api_key == api_key,
ApiKey.resource_id.isnot(None),
)
).scalar_one_or_none()
ids = str(matched_id) if matched_id else ""
# 当 ids 存在且不为 None 时,根据 ids 获取应用
if ids is not None:
app_ids = [app_id.strip() for app_id in ids.split(',') if app_id.strip()]

View File

@@ -0,0 +1,129 @@
"""应用日志(消息记录)接口"""
import uuid
from typing import Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy import select, desc, func
from sqlalchemy.orm import Session
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.dependencies import get_current_user, cur_workspace_access_guard
from app.models.conversation_model import Conversation, Message
from app.schemas.app_log_schema import AppLogConversation, AppLogConversationDetail, AppLogMessage
from app.schemas.response_schema import PageData, PageMeta
from app.services.app_service import AppService
router = APIRouter(prefix="/apps", tags=["App Logs"])
logger = get_business_logger()
@router.get("/{app_id}/logs", summary="应用日志 - 会话列表")
@cur_workspace_access_guard()
def list_app_logs(
app_id: uuid.UUID,
page: int = Query(1, ge=1),
pagesize: int = Query(20, ge=1, le=100),
user_id: Optional[str] = None,
is_draft: Optional[bool] = None,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""查看应用下所有会话记录(分页)
- 支持按 user_id 筛选
- 支持按 is_draft 筛选(草稿会话 / 发布会话)
- 按最新更新时间倒序排列
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
service = AppService(db)
service.get_app(app_id, workspace_id)
stmt = select(Conversation).where(
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True),
)
if user_id:
stmt = stmt.where(Conversation.user_id == user_id)
if is_draft is not None:
stmt = stmt.where(Conversation.is_draft == is_draft)
total = int(db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * pagesize).limit(pagesize)
conversations = list(db.scalars(stmt).all())
items = [AppLogConversation.model_validate(c) for c in conversations]
meta = PageMeta(page=page, pagesize=pagesize, total=total, hasnext=(page * pagesize) < total)
logger.info(
"查询应用日志会话列表",
extra={"app_id": str(app_id), "total": total, "page": page}
)
return success(data=PageData(page=meta, items=items))
@router.get("/{app_id}/logs/{conversation_id}", summary="应用日志 - 会话消息详情")
@cur_workspace_access_guard()
def get_app_log_detail(
app_id: uuid.UUID,
conversation_id: uuid.UUID,
db: Session = Depends(get_db),
current_user=Depends(get_current_user),
):
"""查看某会话的完整消息记录
- 返回会话基本信息 + 所有消息(按时间正序)
- 消息 meta_data 包含模型名、token 用量等信息
"""
workspace_id = current_user.current_workspace_id
# 验证应用访问权限
service = AppService(db)
service.get_app(app_id, workspace_id)
# 查询会话(确保属于该应用和工作空间)
conversation = db.scalars(
select(Conversation).where(
Conversation.id == conversation_id,
Conversation.app_id == app_id,
Conversation.workspace_id == workspace_id,
Conversation.is_active.is_(True),
)
).first()
if not conversation:
from app.core.exceptions import ResourceNotFoundException
raise ResourceNotFoundException("会话", str(conversation_id))
# 查询消息(按时间正序)
messages = list(db.scalars(
select(Message)
.where(Message.conversation_id == conversation_id)
.order_by(Message.created_at)
).all())
detail = AppLogConversationDetail.model_validate(conversation)
detail.messages = [AppLogMessage.model_validate(m) for m in messages]
logger.info(
"查询应用日志会话详情",
extra={
"app_id": str(app_id),
"conversation_id": str(conversation_id),
"message_count": len(messages)
}
)
return success(data=detail)

View File

@@ -14,6 +14,9 @@ Routes:
import os
import uuid
from typing import Any
import httpx
import mimetypes
from urllib.parse import urlparse, unquote
from fastapi import APIRouter, Depends, File, HTTPException, Request, UploadFile, status
from fastapi.responses import FileResponse, RedirectResponse
@@ -290,6 +293,101 @@ async def upload_file_with_share_token(
)
@router.get("/files/info-by-url", response_model=ApiResponse)
async def get_file_info_by_url(
url: str,
):
"""
Get file information by network URL (no authentication required).
Fetches file metadata from a remote URL via HTTP HEAD request.
Falls back to GET request if HEAD is not supported.
Returns file type, name, and size.
Args:
url: The network URL of the file.
Returns:
ApiResponse with file information.
"""
api_logger.info(f"File info by URL request: url={url}")
try:
async with httpx.AsyncClient(timeout=10.0) as client:
# Try HEAD request first
response = await client.head(url, follow_redirects=True)
# If HEAD fails, try GET request (some servers don't support HEAD)
if response.status_code != 200:
api_logger.info(f"HEAD request failed with {response.status_code}, trying GET request")
response = await client.get(url, follow_redirects=True)
if response.status_code != 200:
api_logger.error(f"Failed to fetch file info: HTTP {response.status_code}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unable to access file: HTTP {response.status_code}"
)
# Get file size from Content-Length header or actual content
file_size = response.headers.get("Content-Length")
if file_size:
file_size = int(file_size)
elif hasattr(response, 'content'):
file_size = len(response.content)
else:
file_size = None
# Get content type from Content-Type header
content_type = response.headers.get("Content-Type", "application/octet-stream")
# Remove charset and other parameters from content type
content_type = content_type.split(';')[0].strip()
# Extract filename from Content-Disposition or URL
file_name = None
content_disposition = response.headers.get("Content-Disposition")
if content_disposition and "filename=" in content_disposition:
parts = content_disposition.split("filename=")
if len(parts) > 1:
file_name = parts[1].strip('"').strip("'")
if not file_name:
parsed_url = urlparse(url)
file_name = unquote(os.path.basename(parsed_url.path)) or "unknown"
# Extract file extension from filename
_, file_ext = os.path.splitext(file_name)
# If no extension found, infer from content type
if not file_ext:
ext = mimetypes.guess_extension(content_type)
if ext:
file_ext = ext
file_name = f"{file_name}{file_ext}"
api_logger.info(f"File info retrieved: name={file_name}, size={file_size}, type={content_type}")
return success(
data={
"url": url,
"file_name": file_name,
"file_ext": file_ext.lower() if file_ext else "",
"file_size": file_size,
"content_type": content_type,
},
msg="File information retrieved successfully"
)
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Unexpected error: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file information: {str(e)}"
)
@router.get("/files/{file_id}", response_model=Any)
async def download_file(
request: Request,
@@ -697,3 +795,44 @@ async def permanent_download_file(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve file: {str(e)}"
)
@router.get("/files/{file_id}/status", response_model=ApiResponse)
async def get_file_status(
file_id: uuid.UUID,
db: Session = Depends(get_db),
):
"""
Get file upload/processing status (no authentication required).
This endpoint is used to check if a file (e.g., TTS audio) is ready.
Returns status: pending, completed, or failed.
Args:
file_id: The UUID of the file.
db: Database session.
Returns:
ApiResponse with file status and metadata.
"""
api_logger.info(f"File status request: file_id={file_id}")
# Query file metadata from database
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
api_logger.warning(f"File not found in database: file_id={file_id}")
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="The file does not exist"
)
return success(
data={
"file_id": str(file_id),
"status": file_metadata.status,
"file_name": file_metadata.file_name,
"file_size": file_metadata.file_size,
"content_type": file_metadata.content_type,
},
msg="File status retrieved successfully"
)

View File

@@ -529,8 +529,9 @@ def log_time(step_name: str, duration: float, log_file: str = "logs/time.log") -
# Fallback to console only if file write fails
print(f"Warning: Could not write to timing log: {e}")
# Always print to console (backward compatible behavior)
print(f"{step_name}: {duration:.2f}s")
# Always log at INFO level (avoids Celery treating stdout as WARNING)
_timing_logger = logging.getLogger(__name__)
_timing_logger.info(f"{step_name}: {duration:.2f}s")
def get_agent_logger(name: str = "agent_service",

View File

@@ -19,7 +19,7 @@ from app.core.memory.utils.log.logging_utils import log_time
from app.db import get_db_context
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, schedule_clustering_after_write
from app.repositories.neo4j.graph_saver import save_dialog_and_statements_to_neo4j, _trigger_clustering_sync
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
from app.schemas.memory_config_schema import MemoryConfig
@@ -169,8 +169,8 @@ async def write(
)
if success:
logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应
schedule_clustering_after_write(
# 写入成功后,同步等待聚类完成(避免与 Memory Summary 并发冲突
await _trigger_clustering_sync(
all_entity_nodes,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,

View File

@@ -71,13 +71,11 @@ class LabelPropagationEngine:
connector: Neo4jConnector,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -239,6 +237,7 @@ class LabelPropagationEngine:
await self.repo.upsert_community(new_cid, end_user_id, member_count=1)
await self.repo.assign_entity_to_community(entity_id, new_cid, end_user_id)
logger.debug(f"[Clustering] 孤立实体 {entity_id} → 新社区 {new_cid}")
await self._generate_community_metadata([new_cid], end_user_id)
return
# 统计邻居社区分布
@@ -273,7 +272,8 @@ class LabelPropagationEngine:
await self._evaluate_merge(
list(community_ids_in_neighbors), end_user_id
)
await self._generate_community_metadata([target_cid], end_user_id)
# 新实体加入后成员变化,强制重新生成元数据
await self._generate_community_metadata([target_cid], end_user_id, force=True)
async def _evaluate_merge(
self, community_ids: List[str], end_user_id: str
@@ -453,7 +453,7 @@ class LabelPropagationEngine:
return lines
async def _generate_community_metadata(
self, community_ids: List[str], end_user_id: str
self, community_ids: List[str], end_user_id: str, force: bool = False
) -> None:
"""
为一个或多个社区生成并写入元数据。
@@ -462,69 +462,82 @@ class LabelPropagationEngine:
1. 逐个社区调 LLM 生成 name / summary串行
2. 收集所有 summary一次性批量 embed
3. 单个社区用 update_community_metadata多个用 batch_update_community_metadata
"""
if not community_ids:
return
Args:
force: 为 True 时跳过完整性检查,强制重新生成(用于增量更新成员变化后)
"""
from app.db import get_db_context
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
# --- 阶段1并发调 LLM 生成每个社区的 name / summary ---
async def _build_one(cid: str):
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
async def _build_one(cid: str) -> Optional[Dict]:
try:
if not force:
check_embedding = bool(self.embedding_model_id)
if await self.repo.is_community_complete(cid, end_user_id, check_embedding=check_embedding):
return None
members = await self.repo.get_community_members(cid, end_user_id)
if not members:
logger.warning(f"[Clustering] 社区 {cid} 无成员,跳过元数据生成")
return None
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
all_names = [m["name"] for m in members if m.get("name")]
name = "".join(core_entities[:3]) if core_entities else cid[:8]
summary = f"包含实体:{', '.join(all_names)}"
if self.llm_model_id:
try:
entity_list_str = "\n".join(self._build_entity_lines(members))
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
except Exception as e:
logger.warning(f"[Clustering] 社区 {cid} LLM 生成失败,使用兜底值: {e}")
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
except Exception as e:
logger.error(f"[Clustering] 社区 {cid} 元数据准备失败: {e}", exc_info=True)
return None
sorted_members = sorted(
members,
key=lambda m: m.get("activation_value") or 0,
reverse=True,
)
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
entity_list_str = "\n".join(self._build_entity_lines(members))
# 方案四:注入社区内实体间关系三元组
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)
with get_db_context() as db:
llm_client = MemoryClientFactory(db).get_llm_client(self.llm_model_id)
response = await llm_client.chat([{"role": "user", "content": prompt}])
text = response.content if hasattr(response, "content") else str(response)
name, summary = "", ""
for line in text.strip().splitlines():
if line.startswith("名称:"):
name = line[3:].strip()
elif line.startswith("摘要:"):
summary = line[3:].strip()
return {
"community_id": cid,
"end_user_id": end_user_id,
"name": name,
"summary": summary,
"core_entities": core_entities,
"summary_embedding": None,
}
results = await asyncio.gather(
*[_build_one(cid) for cid in community_ids],
return_exceptions=True,
@@ -537,15 +550,20 @@ class LabelPropagationEngine:
metadata_list.append(res)
if not metadata_list:
logger.warning(f"[Clustering] 无有效元数据可写入community_ids={community_ids}")
return
# --- 阶段2批量生成 summary_embedding ---
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
if self.embedding_model_id:
try:
summaries = [m["summary"] for m in metadata_list]
with get_db_context() as db:
embedder = MemoryClientFactory(db).get_embedder_client(self.embedding_model_id)
embeddings = await embedder.response(summaries)
for i, meta in enumerate(metadata_list):
meta["summary_embedding"] = embeddings[i] if i < len(embeddings) else None
except Exception as e:
logger.error(f"[Clustering] 批量生成 summary_embedding 失败: {e}", exc_info=True)
# --- 阶段3写入单个 or 批量)---
if len(metadata_list) == 1:
@@ -558,17 +576,13 @@ class LabelPropagationEngine:
core_entities=m["core_entities"],
summary_embedding=m["summary_embedding"],
)
if result:
logger.info(f"[Clustering] 社区 {m['community_id']} 元数据写入成功: name={m['name']}, summary={m['summary'][:30]}...")
else:
logger.warning(f"[Clustering] 社区 {m['community_id']} 元数据写入返回 False")
if not result:
logger.error(f"[Clustering] 社区 {m['community_id']} 元数据写入失败")
else:
ok = await self.repo.batch_update_community_metadata(metadata_list)
if ok:
logger.info(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据成功")
else:
logger.warning(f"[Clustering] 批量写入社区元数据失败")
if not ok:
logger.error(f"[Clustering] 批量写入 {len(metadata_list)} 个社区元数据失败")
@staticmethod
def _new_community_id() -> str:
return str(uuid.uuid4())
return str(uuid.uuid4())

View File

@@ -9,6 +9,7 @@
"""
import asyncio
import logging
import os
import hashlib
import json
@@ -26,6 +27,8 @@ from app.core.memory.storage_services.extraction_engine.data_preprocessing.scene
ScenePatterns
)
logger = logging.getLogger(__name__)
class DialogExtractionResponse(BaseModel):
"""对话级一次性抽取的结构化返回,用于加速剪枝。
@@ -706,7 +709,7 @@ class SemanticPruner:
# 阈值保护最高0.9
proportion = float(self.config.pruning_threshold)
if proportion > 0.9:
print(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9")
logger.warning(f"[剪枝-数据集] 阈值{proportion}超过上限0.9已自动调整为0.9")
proportion = 0.9
if proportion < 0.0:
proportion = 0.0
@@ -905,7 +908,7 @@ class SemanticPruner:
# Safety: avoid empty dataset
if not result:
print("警告: 语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
logger.warning("语义剪枝后数据集为空,已回退为未剪枝数据以避免流程中断")
return dialogs
return result
@@ -915,8 +918,7 @@ class SemanticPruner:
try:
self.run_logs.append(msg)
except Exception:
# 任何异常都不影响打印
pass
print(msg)
logger.debug(msg)

View File

@@ -5,8 +5,11 @@
"""
import asyncio
import logging
from typing import Any, Dict, List, Tuple
logger = logging.getLogger(__name__)
from app.core.memory.llm_tools.openai_embedder import OpenAIEmbedderClient
from app.core.memory.models.message_models import DialogData
from app.core.models.base import RedBearModelConfig
@@ -48,9 +51,9 @@ class EmbeddingGenerator:
return await self.embedder_client.response(texts)
# 分批并行处理
print(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
logger.info(f"文本数量 {len(texts)} 超过批次大小 {batch_size},分批并行处理")
batches = [texts[i:i+batch_size] for i in range(0, len(texts), batch_size)]
print(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
logger.info(f"分成 {len(batches)} 批,每批最多 {batch_size} 个文本")
# 并行发送所有批次
batch_results = await asyncio.gather(*[
@@ -62,7 +65,7 @@ class EmbeddingGenerator:
for batch_result in batch_results:
embeddings.extend(batch_result)
print(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
logger.info(f"分批并行处理完成,共生成 {len(embeddings)} 个嵌入向量")
return embeddings
async def generate_statement_embeddings(
@@ -77,7 +80,7 @@ class EmbeddingGenerator:
Returns:
每个对话的陈述句嵌入向量映射列表
"""
print("\n=== 生成陈述句嵌入向量 ===")
logger.debug("=== 生成陈述句嵌入向量 ===")
# 收集所有陈述句
all_statements = []
@@ -102,7 +105,7 @@ class EmbeddingGenerator:
stmt_id = chunked_dialogs[d_idx].chunks[c_idx].statements[s_idx].id
stmt_embedding_maps[d_idx][stmt_id] = embedding
print(f"{len(all_statements)} 个陈述句生成了嵌入向量")
logger.info(f"{len(all_statements)} 个陈述句生成了嵌入向量")
return stmt_embedding_maps
async def generate_chunk_embeddings(
@@ -117,7 +120,7 @@ class EmbeddingGenerator:
Returns:
每个对话的分块嵌入向量映射列表
"""
print("\n=== 生成分块嵌入向量 ===")
logger.debug("=== 生成分块嵌入向量 ===")
# 收集所有分块
all_chunks = []
@@ -138,7 +141,7 @@ class EmbeddingGenerator:
chunk_id = chunked_dialogs[d_idx].chunks[c_idx].id
chunk_embedding_maps[d_idx][chunk_id] = embedding
print(f"{len(all_chunks)} 个分块生成了嵌入向量")
logger.info(f"{len(all_chunks)} 个分块生成了嵌入向量")
return chunk_embedding_maps
async def generate_dialog_embeddings(
@@ -172,7 +175,7 @@ class EmbeddingGenerator:
Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表)
"""
print("\n=== 生成所有嵌入向量 ===")
logger.debug("=== 生成所有嵌入向量 ===")
# 并发生成陈述句和分块嵌入向量
stmt_embedding_maps, chunk_embedding_maps = await asyncio.gather(
@@ -183,9 +186,7 @@ class EmbeddingGenerator:
# 对话嵌入向量(当前跳过)
dialog_embeddings = await self.generate_dialog_embeddings(chunked_dialogs)
print(
f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量"
)
logger.info(f"生成完成:{len(chunked_dialogs)} 个对话的嵌入向量")
return stmt_embedding_maps, chunk_embedding_maps, dialog_embeddings
@@ -201,7 +202,7 @@ class EmbeddingGenerator:
Returns:
更新后的三元组映射列表(实体包含嵌入向量)
"""
print("\n=== 生成实体嵌入向量 ===")
logger.debug("=== 生成实体嵌入向量 ===")
entity_texts: List[str] = []
entity_refs: List[Any] = []
@@ -219,7 +220,7 @@ class EmbeddingGenerator:
entity_refs.append(ent)
if not entity_texts:
print("没有找到需要生成嵌入向量的实体")
logger.debug("没有找到需要生成嵌入向量的实体")
return triplet_maps
# 批量生成嵌入向量
@@ -227,13 +228,13 @@ class EmbeddingGenerator:
# 打印前几个嵌入向量的维度
for i in range(min(5, len(embeddings))):
print(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
logger.debug(f"实体 '{entity_texts[i]}' 嵌入向量维度: {len(embeddings[i])}")
# 将嵌入向量赋值给实体
for ent, emb in zip(entity_refs, embeddings):
setattr(ent, "name_embedding", emb)
print(f"{len(entity_refs)} 个实体生成了嵌入向量")
logger.info(f"{len(entity_refs)} 个实体生成了嵌入向量")
return triplet_maps
@@ -296,7 +297,7 @@ async def embedding_generation_all(
Returns:
(陈述句嵌入映射列表, 分块嵌入映射列表, 对话嵌入列表, 更新后的三元组映射列表)
"""
print("\n=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
logger.debug("=== 综合嵌入向量生成(陈述句/分块/对话 + 实体)===")
generator = EmbeddingGenerator(embedding_id)

View File

@@ -9,7 +9,7 @@ class User(Base):
__tablename__ = "users"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4, index=True)
username = Column(String, unique=True, index=True, nullable=False)
username = Column(String, index=True, nullable=False) # 社区版:用户名不唯一,仅邮箱唯一
email = Column(String, unique=True, index=True, nullable=False)
hashed_password = Column(String, nullable=False)
is_active = Column(Boolean, default=True, nullable=False)

View File

@@ -1,10 +1,13 @@
from typing import List, Optional
import logging
from app.repositories.neo4j.cypher_queries import DIALOGUE_NODE_SAVE, STATEMENT_NODE_SAVE, CHUNK_NODE_SAVE,MEMORY_SUMMARY_NODE_SAVE
from app.core.memory.models.graph_models import DialogueNode, StatementNode, ChunkNode, MemorySummaryNode
# 使用新的仓储层
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
logger = logging.getLogger(__name__)
async def delete_all_nodes(end_user_id: str, connector: Neo4jConnector):
"""Delete all nodes in the database."""
@@ -217,10 +220,10 @@ async def add_memory_summary_nodes(summaries: List[MemorySummaryNode], connector
summaries=flattened
)
created_ids = [record.get("uuid") for record in result]
print(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
logger.info(f"Successfully saved {len(created_ids)} MemorySummary nodes to Neo4j")
return created_ids
except Exception as e:
print(f"Failed to save MemorySummary nodes to Neo4j: {e}")
logger.error(f"Failed to save MemorySummary nodes to Neo4j: {e}")
return None

View File

@@ -300,7 +300,7 @@ class CommunityRepository:
)
return bool(result)
except Exception as e:
logger.error(f"update_community_metadata failed: {e}")
logger.error(f"update_community_metadata failed: {e}", exc_info=True)
return False
async def batch_update_community_metadata(

View File

@@ -1069,6 +1069,7 @@ Graph_Node_query = """
COMMUNITY_NODE_UPSERT = """
MERGE (c:Community {community_id: $community_id})
ON CREATE SET c.id = $community_id
SET c.end_user_id = $end_user_id,
c.member_count = $member_count,
c.updated_at = datetime()
@@ -1175,7 +1176,8 @@ RETURN c.community_id AS community_id, cnt AS member_count
UPDATE_COMMUNITY_METADATA = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
SET c.name = $name,
SET c.id = coalesce(c.id, $community_id),
c.name = $name,
c.summary = $summary,
c.core_entities = $core_entities,
c.summary_embedding = $summary_embedding,
@@ -1186,7 +1188,8 @@ RETURN c.community_id AS community_id
BATCH_UPDATE_COMMUNITY_METADATA = """
UNWIND $communities AS row
MATCH (c:Community {community_id: row.community_id, end_user_id: row.end_user_id})
SET c.name = row.name,
SET c.id = coalesce(c.id, row.community_id),
c.name = row.name,
c.summary = row.summary,
c.core_entities = row.core_entities,
c.summary_embedding = row.summary_embedding,
@@ -1270,6 +1273,40 @@ RETURN
startNode(r) = e AS r_from_e
"""
CHECK_COMMUNITY_IS_COMPLETE = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL
) AS is_complete
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL AND
c.summary_embedding IS NOT NULL
) AS is_complete
"""
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.name = ''
OR c.summary IS NULL OR c.summary = ''
OR c.core_entities IS NULL
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
RETURN c.community_id AS community_id
"""
# Community keyword search: matches name or summary via fulltext index
SEARCH_COMMUNITIES_BY_KEYWORD = """
@@ -1325,39 +1362,4 @@ RETURN s.statement AS statement,
c.name AS community_name
ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
"""
CHECK_COMMUNITY_IS_COMPLETE = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL
) AS is_complete
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL AND
c.summary_embedding IS NOT NULL
) AS is_complete
"""
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.name = ''
OR c.summary IS NULL OR c.summary = ''
OR c.core_entities IS NULL
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
RETURN c.community_id AS community_id
"""
"""

View File

@@ -162,7 +162,7 @@ async def save_dialog_and_statements_to_neo4j(
"""Save dialogue nodes, chunk nodes, statement nodes, entities, and all relationships to Neo4j using graph models.
只负责数据写入,不触发聚类。聚类由调用方在写入成功后通过
schedule_clustering_after_write() 显式触发。
_trigger_clustering_sync() 显式触发。
Args:
dialogue_nodes: List of DialogueNode objects to save
@@ -303,16 +303,13 @@ async def save_dialog_and_statements_to_neo4j(
return False
def schedule_clustering_after_write(
async def _trigger_clustering_sync(
entity_nodes: List,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
"""
写入 Neo4j 成功后,调度后台聚类任务
可通过环境变量 CLUSTERING_ENABLED=false 禁用(用于基准测试对比)。
使用 asyncio.create_task 异步触发,不阻塞写入响应。
同步等待聚类完成,避免与其他 LLM 任务并发冲突
"""
if not entity_nodes:
return
@@ -324,8 +321,8 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
logger.info(f"[Clustering] 准备触发聚类(同步),实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
await _trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
async def _trigger_clustering(

View File

@@ -0,0 +1,53 @@
"""应用日志消息记录Schema"""
import uuid
import datetime
from typing import Optional, Dict, Any, List
from pydantic import BaseModel, Field, ConfigDict, field_serializer
class AppLogMessage(BaseModel):
"""单条消息记录"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
conversation_id: uuid.UUID
role: str = Field(description="角色: user / assistant / system")
content: str
meta_data: Optional[Dict[str, Any]] = None
created_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("meta_data", when_used="json")
def _serialize_meta_data(self, data: Optional[Dict[str, Any]]):
return data or {}
class AppLogConversation(BaseModel):
"""会话摘要(用于列表)"""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
app_id: uuid.UUID
user_id: Optional[str] = None
title: Optional[str] = None
message_count: int = 0
is_draft: bool
created_at: datetime.datetime
updated_at: datetime.datetime
@field_serializer("created_at", when_used="json")
def _serialize_created_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
@field_serializer("updated_at", when_used="json")
def _serialize_updated_at(self, dt: datetime.datetime):
return int(dt.timestamp() * 1000) if dt else None
class AppLogConversationDetail(AppLogConversation):
"""会话详情(包含消息列表)"""
messages: List[AppLogMessage] = Field(default_factory=list)

View File

@@ -129,39 +129,12 @@ class AppChatService:
)
# 加载历史消息
messages = self.conversation_service.get_messages(
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
limit=10
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件
processed_files = None
@@ -206,7 +179,8 @@ class AppChatService:
# 构建用户消息内容(含多模态文件)
human_meta = {
"files": []
"files": [],
"history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
@@ -221,6 +195,13 @@ class AppChatService:
"url": f.url
})
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
# 保存消息
if audio_url:
assistant_meta["audio_url"] = audio_url
@@ -251,6 +232,7 @@ class AppChatService:
"suggested_questions": suggested_questions,
"citations": self.agent_service._filter_citations(features_config, result.get("citations", [])),
"audio_url": audio_url,
"audio_status": "pending"
}
async def agnet_chat_stream(
@@ -350,39 +332,12 @@ class AppChatService:
)
# 加载历史消息
messages = self.conversation_service.get_messages(
history = await self.conversation_service.get_conversation_history(
conversation_id=conversation_id,
limit=10
max_history=10,
current_provider=api_key_obj.provider,
current_is_omni=api_key_obj.is_omni
)
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
history_files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in history_files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件
processed_files = None
@@ -433,7 +388,7 @@ class AppChatService:
elapsed_time = time.time() - start_time
ModelApiKeyService.record_api_key_usage(self.db, api_key_obj.id)
# 发送结束事件(包含 suggested_questions、tts、citations
# 发送结束事件(包含 suggested_questions、tts、audio_status、citations
end_data: dict = {"elapsed_time": elapsed_time, "message_length": len(full_content), "error": None}
sq_config = features_config.get("suggested_questions_after_answer", {})
if isinstance(sq_config, dict) and sq_config.get("enabled"):
@@ -443,11 +398,23 @@ class AppChatService:
"api_base": api_key_obj.api_base}, {}
)
end_data["audio_url"] = stream_audio_url
# 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self.agent_service._filter_citations(features_config, [])
# 保存消息
human_meta = {
"files":[]
"files":[],
"history_files": {}
}
assistant_meta = {
"model": api_key_obj.model_name,
@@ -457,11 +424,16 @@ class AppChatService:
if files:
for f in files:
# url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({
"type": f.type,
"url": f.url
})
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": api_key_obj.provider,
"is_omni": api_key_obj.is_omni
}
if stream_audio_url:
assistant_meta["audio_url"] = stream_audio_url

View File

@@ -274,7 +274,8 @@ class ConversationService:
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None,
api_config: Optional[ModelInfo] = None
current_provider: Optional[str] = None,
current_is_omni: Optional[bool] = None
) -> List[dict]:
"""
Retrieve historical conversation messages formatted as dictionaries.
@@ -282,7 +283,8 @@ class ConversationService:
Args:
conversation_id (uuid.UUID): Conversation UUID.
max_history (Optional[int]): Maximum number of messages to retrieve.
api_config (Optional[ModelInfo]): Model API configuration for multimodal processing.
current_provider (Optional[str]): Current provider for file handling.
current_is_omni (Optional[bool]): Current omni flag for file handling.
Returns:
List[dict]: List of message dictionaries with keys 'role' and 'content'.
@@ -292,38 +294,30 @@ class ConversationService:
limit=max_history
)
# 转换为字典格式
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
if api_config:
# 使用 MultimodalService 处理文件
from app.services.multimodal_service import MultimodalService
multimodal_service = MultimodalService(self.db, api_config=api_config)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(processed_files)
history.append({
msg_dict = {
"role": msg.role,
"content": content
})
"content": [{"type": "text", "text": msg.content}]
}
# 处理用户消息中的多模态文件
if msg.role == "user" and msg.meta_data:
history_files = msg.meta_data.get("history_files", {})
if history_files and current_provider and current_is_omni is not None:
# 检查是否需要重新处理文件
stored_provider = history_files.get("provider")
stored_is_omni = history_files.get("is_omni")
# 如果provider或is_omni不匹配需要重新处理
if stored_provider != current_provider or stored_is_omni != current_is_omni:
continue
# provider和is_omni匹配直接使用存储的内容
msg_dict["content"].extend(history_files.get("content"))
history.append(msg_dict)
return history
@@ -539,6 +533,7 @@ class ConversationService:
provider = api_config.provider
api_key = api_config.api_key
api_base = api_config.api_base
is_omni = api_config.is_omni
model_type = config.type
llm = RedBearLLM(
@@ -546,7 +541,8 @@ class ConversationService:
model_name=model_name,
provider=provider,
api_key=api_key,
base_url=api_base
base_url=api_base,
is_omni=is_omni
),
type=ModelType(model_type)
)
@@ -554,15 +550,8 @@ class ConversationService:
conversation_messages = await self.get_conversation_history(
conversation_id=conversation_id,
max_history=20,
api_config=ModelInfo(
model_name=model_name,
provider=provider,
api_key=api_key,
api_base=api_base,
capability=api_config.capability,
is_omni=api_config.is_omni,
model_type=model_type
)
current_provider=provider,
current_is_omni=is_omni
)
if len(conversation_messages) == 0:
return ConversationOut(

View File

@@ -592,8 +592,9 @@ class AgentRunService:
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
api_config=model_info,
max_history=10
max_history=10,
current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
@@ -661,7 +662,10 @@ class AgentRunService:
})
},
files=files,
audio_url=audio_url
processed_files=processed_files,
audio_url=audio_url,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
)
response = {
@@ -678,6 +682,7 @@ class AgentRunService:
) if not sub_agent else [],
"citations": self._filter_citations(features_config, result.get("citations", [])),
"audio_url": audio_url,
"audio_status": "pending"
}
logger.info(
@@ -830,8 +835,9 @@ class AgentRunService:
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
api_config=model_info,
max_history=memory_config.get("max_history", 10)
max_history=memory_config.get("max_history", 10),
current_provider=api_key_config.get("provider"),
current_is_omni=api_key_config.get("is_omni", False)
)
# 6. 处理多模态文件
@@ -909,10 +915,13 @@ class AgentRunService:
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": total_tokens}
},
files=files,
audio_url=stream_audio_url
processed_files=processed_files,
audio_url=stream_audio_url,
provider=api_key_config.get("provider"),
is_omni=api_key_config.get("is_omni", False)
)
# 12. 发送结束事件(包含 suggested_questions 和 tts
# 12. 发送结束事件(包含 suggested_questions、audio_url 和 audio_status
end_data: Dict[str, Any] = {
"conversation_id": conversation_id,
"elapsed_time": elapsed_time,
@@ -923,6 +932,17 @@ class AgentRunService:
features_config, full_content, api_key_config, effective_params
)
end_data["audio_url"] = stream_audio_url
# 检查TTS是否已完成非阻塞不取消任务
audio_status = "pending"
if tts_task is not None and tts_task.done():
# 任务已完成,检查是否有异常
try:
tts_task.result()
audio_status = "completed"
except Exception as e:
logger.warning(f"TTS任务异常: {e}")
audio_status = "failed"
end_data["audio_status"] = audio_status if stream_audio_url else None
end_data["citations"] = self._filter_citations(features_config, [])
yield self._format_sse_event("end", end_data)
@@ -1119,14 +1139,17 @@ class AgentRunService:
async def _load_conversation_history(
self,
conversation_id: str,
api_config: ModelInfo | None = None,
max_history: int = 10
max_history: int = 10,
current_provider: Optional[str] = None,
current_is_omni: Optional[bool] = None
) -> List[Dict[str, str]]:
"""加载会话历史消息
"""加载会话历史消息,并根据当前模型配置处理多模态文件
Args:
conversation_id: 会话ID
max_history: 最大历史消息数量
current_provider: 当前模型的provider
current_is_omni: 当前模型的is_omni
Returns:
List[Dict]: 历史消息列表
@@ -1138,7 +1161,8 @@ class AgentRunService:
history = await conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id),
max_history=max_history,
api_config=api_config
current_provider=current_provider,
current_is_omni=current_is_omni
)
logger.debug(
@@ -1166,7 +1190,10 @@ class AgentRunService:
app_id: Optional[uuid.UUID] = None,
user_id: Optional[str] = None,
files: Optional[List[FileInput]] = None,
audio_url: Optional[str] = None
processed_files: Optional[List[Dict[str, Any]]] = None,
audio_url: Optional[str] = None,
provider: Optional[str] = None,
is_omni: Optional[bool] = None
) -> None:
"""保存会话消息(会话已通过 _ensure_conversation 确保存在)
@@ -1177,6 +1204,11 @@ class AgentRunService:
app_id: 应用ID未使用保留用于兼容性
user_id: 用户ID未使用保留用于兼容性
meta_data: token消耗
files: 原始文件输入
processed_files: 处理后的文件
audio_url: 音频URL
provider: 模型供应商
is_omni: 是否为全模态模型
"""
try:
from app.services.conversation_service import ConversationService
@@ -1186,15 +1218,24 @@ class AgentRunService:
# 保存消息(会话已经存在)
human_meta = {
"files": []
"files": [],
"history_files": {}
}
if files:
for f in files:
# url = await MultimodalService(self.db).get_file_url(f)
human_meta["files"].append({
"type": f.type,
"url": f.url
})
# 保存 history_files包含 provider 和 is_omni 信息
if processed_files:
human_meta["history_files"] = {
"content": processed_files,
"provider": provider,
"is_omni": is_omni
}
# 保存用户消息
conversation_service.add_message(
conversation_id=conv_uuid,
@@ -1420,8 +1461,9 @@ class AgentRunService:
workspace_id: Optional[uuid.UUID] = None,
) -> tuple[Optional[str], Optional[asyncio.Task]]:
"""文本流式输入并行合成音频。
返回 (audio_url, task)audio_url 立即可用task 完成后文件内容就绪。
返回 (audio_url, task)audio_url 立即可用pending状态task 完成后文件内容就绪。
调用方向 text_queue put 文本 chunk结束时 put None。
前端可通过 GET /storage/files/{file_id}/status 轮询检查音频是否就绪。
"""
tts_config = features_config.get("text_to_speech", {})
if not isinstance(tts_config, dict) or not tts_config.get("enabled"):
@@ -1808,6 +1850,7 @@ class AgentRunService:
),
"cost_estimate": self._estimate_cost(usage, model_info["model_config"]),
"audio_url": result.get("audio_url"),
"audio_status": result.get("audio_status"),
"citations": result.get("citations", []),
"suggested_questions": result.get("suggested_questions", []),
"error": None
@@ -1885,6 +1928,7 @@ class AgentRunService:
"results": [{
**r,
"audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
} for r in results],
@@ -2016,6 +2060,7 @@ class AgentRunService:
full_content = ""
returned_conversation_id = model_conversation_id
audio_url = None
audio_status = None
citations = []
suggested_questions = []
@@ -2074,6 +2119,7 @@ class AgentRunService:
# 从 end 事件中提取 features 输出字段
if event_type == "end" and event_data:
audio_url = event_data.get("audio_url")
audio_status = event_data.get("audio_status")
citations = event_data.get("citations", [])
suggested_questions = event_data.get("suggested_questions", [])
@@ -2103,6 +2149,7 @@ class AgentRunService:
"message": full_content,
"elapsed_time": elapsed,
"audio_url": audio_url,
"audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"error": None
@@ -2117,6 +2164,7 @@ class AgentRunService:
"elapsed_time": elapsed,
"message_length": len(full_content),
"audio_url": audio_url,
"audio_status": audio_status,
"citations": citations,
"suggested_questions": suggested_questions,
"timestamp": time.time()
@@ -2253,6 +2301,7 @@ class AgentRunService:
"message": r.get("message"),
"elapsed_time": r.get("elapsed_time", 0),
"audio_url": r.get("audio_url"),
"audio_status": r.get("audio_status"),
"citations": r.get("citations", []),
"suggested_questions": r.get("suggested_questions", []),
"error": r.get("error")

View File

@@ -350,9 +350,6 @@ class MemoryAgentService:
langchain_messages.append(HumanMessage(content=msg['content']))
elif msg['role'] == 'assistant':
langchain_messages.append(AIMessage(content=msg['content']))
print(100 * '-')
print(langchain_messages)
print(100 * '-')
# 初始状态 - 包含所有必要字段
initial_state = {
"messages": langchain_messages,

View File

@@ -78,18 +78,7 @@ def create_user(db: Session, user: UserCreate) -> User:
business_logger.info(f"创建用户: {user.username}, email: {user.email}")
try:
# 检查用户名是否已存在
business_logger.debug(f"检查用户名是否已存在: {user.username}")
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
if db_user_by_username:
business_logger.warning(f"用户名已存在: {user.username}")
raise BusinessException(
"用户名已存在",
code=BizCode.DUPLICATE_NAME,
context={"username": user.username, "email": user.email}
)
# 检查邮箱是否已注册
# 检查邮箱是否已注册(邮箱保持唯一)
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:
@@ -164,22 +153,7 @@ def create_superuser(db: Session, user: UserCreate, current_user: User) -> User:
)
try:
# 检查用户名是否已存在
business_logger.debug(f"检查用户名是否已存在: {user.username}")
db_user_by_username = user_repository.get_user_by_username(db, username=user.username)
if db_user_by_username:
business_logger.warning(f"用户名已存在: {user.username}")
raise BusinessException(
"用户名已存在",
code=BizCode.DUPLICATE_NAME,
context={
"username": user.username,
"email": user.email,
"created_by": str(current_user.id)
}
)
# 检查邮箱是否已注册
# 检查邮箱是否已注册(邮箱保持唯一)
business_logger.debug(f"检查邮箱是否已注册: {user.email}")
db_user_by_email = user_repository.get_user_by_email(db, email=user.email)
if db_user_by_email:

View File

@@ -2760,7 +2760,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace
patch_fail = 0
for cid in incomplete_ids:
try:
await engine._generate_community_metadata(cid, end_user_id)
await engine._generate_community_metadata([cid], end_user_id)
patch_ok += 1
except Exception as patch_err:
patch_fail += 1

View File

@@ -0,0 +1,32 @@
"""202603231611
Revision ID: 05a681a6ca93
Revises: 74b51dfece29
Create Date: 2026-03-23 16:12:44.110292
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '05a681a6ca93'
down_revision: Union[str, None] = '74b51dfece29'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_username'), table_name='users')
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_users_username'), table_name='users')
op.create_index(op.f('ix_users_username'), 'users', ['username'], unique=True)
# ### end Alembic commands ###