Compare commits
13 Commits
feat/wxy-d
...
feat/memor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80902eb79a | ||
|
|
f86c023477 | ||
|
|
1d73c9e5a8 | ||
|
|
89bdb9f4b5 | ||
|
|
c57490a063 | ||
|
|
a7d3930f4d | ||
|
|
d30b9224ab | ||
|
|
8f6aad333f | ||
|
|
72c71c1000 | ||
|
|
2c02c67e9e | ||
|
|
03d2228d87 | ||
|
|
9598bd5905 | ||
|
|
d85a1cb131 |
@@ -1,4 +1,4 @@
|
||||
import asyncio
|
||||
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -10,7 +10,7 @@ from app.dependencies import get_current_user
|
||||
from app.models.user_model import User
|
||||
from app.schemas.response_schema import ApiResponse
|
||||
|
||||
from app.services import memory_dashboard_service, memory_storage_service, workspace_service
|
||||
from app.services import memory_dashboard_service, workspace_service
|
||||
from app.services.memory_agent_service import get_end_users_connected_configs_batch
|
||||
from app.services.app_statistics_service import AppStatisticsService
|
||||
from app.core.logging_config import get_api_logger
|
||||
@@ -48,7 +48,7 @@ def get_workspace_total_end_users(
|
||||
|
||||
|
||||
@router.get("/end_users", response_model=ApiResponse)
|
||||
async def get_workspace_end_users(
|
||||
def get_workspace_end_users(
|
||||
workspace_id: Optional[uuid.UUID] = Query(None, description="工作空间ID(可选,默认当前用户工作空间)"),
|
||||
keyword: Optional[str] = Query(None, description="搜索关键词(同时模糊匹配 other_name 和 id)"),
|
||||
page: int = Query(1, ge=1, description="页码,从1开始"),
|
||||
@@ -58,6 +58,15 @@ async def get_workspace_end_users(
|
||||
):
|
||||
"""
|
||||
获取工作空间的宿主列表(分页查询,支持模糊搜索)
|
||||
|
||||
新增:记忆数量过滤:
|
||||
Neo4j 模式:
|
||||
- 使用 end_users.memory_count 过滤 memory_count > 0 的宿主
|
||||
- memory_num.total 直接取 end_user.memory_count
|
||||
|
||||
RAG 模式:
|
||||
- 使用 documents.chunk_num 聚合过滤 chunk 总数 > 0 的宿主
|
||||
- memory_num.total 取聚合后的 chunk 总数
|
||||
|
||||
返回工作空间下的宿主列表,支持分页查询和模糊搜索。
|
||||
通过 keyword 参数同时模糊匹配 other_name 和 id 字段。
|
||||
@@ -80,17 +89,29 @@ async def get_workspace_end_users(
|
||||
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
|
||||
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表, 类型: {current_workspace_type}")
|
||||
|
||||
# 获取分页的 end_users
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword
|
||||
)
|
||||
if current_workspace_type == "rag":
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated_rag(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword,
|
||||
)
|
||||
raw_items = end_users_result.get("items", [])
|
||||
end_users = [item["end_user"] for item in raw_items]
|
||||
else:
|
||||
end_users_result = memory_dashboard_service.get_workspace_end_users_paginated(
|
||||
db=db,
|
||||
workspace_id=workspace_id,
|
||||
current_user=current_user,
|
||||
page=page,
|
||||
pagesize=pagesize,
|
||||
keyword=keyword,
|
||||
)
|
||||
raw_items = end_users_result.get("items", [])
|
||||
end_users = raw_items
|
||||
|
||||
end_users = end_users_result.get("items", [])
|
||||
total = end_users_result.get("total", 0)
|
||||
|
||||
if not end_users:
|
||||
@@ -101,50 +122,19 @@ async def get_workspace_end_users(
|
||||
"page": page,
|
||||
"pagesize": pagesize,
|
||||
"total": total,
|
||||
"hasnext": (page * pagesize) < total
|
||||
}
|
||||
"hasnext": (page * pagesize) < total,
|
||||
},
|
||||
}, msg="宿主列表获取成功")
|
||||
|
||||
end_user_ids = [str(user.id) for user in end_users]
|
||||
|
||||
# 并发执行两个独立的查询任务
|
||||
async def get_memory_configs():
|
||||
"""获取记忆配置(在线程池中执行同步查询)"""
|
||||
try:
|
||||
return await asyncio.to_thread(
|
||||
get_end_users_connected_configs_batch,
|
||||
end_user_ids, db
|
||||
)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
return {}
|
||||
try:
|
||||
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db)
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取记忆配置失败: {str(e)}")
|
||||
memory_configs_map = {}
|
||||
|
||||
async def get_memory_nums():
|
||||
"""获取记忆数量"""
|
||||
if current_workspace_type == "rag":
|
||||
# RAG 模式:批量查询
|
||||
try:
|
||||
chunk_map = await asyncio.to_thread(
|
||||
memory_dashboard_service.get_users_total_chunk_batch,
|
||||
end_user_ids, db, current_user
|
||||
)
|
||||
return {uid: {"total": count} for uid, count in chunk_map.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
elif current_workspace_type == "neo4j":
|
||||
# Neo4j 模式:批量查询(简化版本,只返回total)
|
||||
try:
|
||||
batch_result = await memory_storage_service.search_all_batch(end_user_ids)
|
||||
return {uid: {"total": count} for uid, count in batch_result.items()}
|
||||
except Exception as e:
|
||||
api_logger.error(f"批量获取 Neo4j 记忆数量失败: {str(e)}")
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
return {uid: {"total": 0} for uid in end_user_ids}
|
||||
|
||||
# 触发按需初始化:为 implicit_emotions_storage 中没有记录的用户异步生成数据
|
||||
# 触发按需初始化:为 implicit_emotions_storage / interest_distribution 中没有记录的用户异步生成数据
|
||||
try:
|
||||
from app.celery_app import celery_app as _celery_app
|
||||
_celery_app.send_task(
|
||||
@@ -159,27 +149,26 @@ async def get_workspace_end_users(
|
||||
except Exception as e:
|
||||
api_logger.warning(f"触发按需初始化任务失败(不影响主流程): {e}")
|
||||
|
||||
# 并发执行配置查询和记忆数量查询
|
||||
memory_configs_map, memory_nums_map = await asyncio.gather(
|
||||
get_memory_configs(),
|
||||
get_memory_nums()
|
||||
)
|
||||
|
||||
# 构建结果列表
|
||||
items = []
|
||||
for end_user in end_users:
|
||||
for index, end_user in enumerate(end_users):
|
||||
user_id = str(end_user.id)
|
||||
config_info = memory_configs_map.get(user_id, {})
|
||||
|
||||
if current_workspace_type == "rag":
|
||||
memory_total = int(raw_items[index].get("memory_count", 0) or 0)
|
||||
else:
|
||||
memory_total = int(getattr(end_user, "memory_count", 0) or 0)
|
||||
|
||||
items.append({
|
||||
'end_user': {
|
||||
'id': user_id,
|
||||
'other_name': end_user.other_name
|
||||
"end_user": {
|
||||
"id": user_id,
|
||||
"other_name": end_user.other_name,
|
||||
},
|
||||
'memory_num': memory_nums_map.get(user_id, {"total": 0}),
|
||||
'memory_config': {
|
||||
"memory_num": {"total": memory_total},
|
||||
"memory_config": {
|
||||
"memory_config_id": config_info.get("memory_config_id"),
|
||||
"memory_config_name": config_info.get("memory_config_name")
|
||||
}
|
||||
"memory_config_name": config_info.get("memory_config_name"),
|
||||
},
|
||||
})
|
||||
|
||||
# 触发社区聚类补全任务(异步,不阻塞接口响应)
|
||||
@@ -407,6 +396,7 @@ def get_current_user_rag_total_num(
|
||||
total_chunk = memory_dashboard_service.get_current_user_total_chunk(end_user_id, db, current_user)
|
||||
return success(data=total_chunk, msg="宿主RAG知识数据获取成功")
|
||||
|
||||
|
||||
@router.get("/rag_content", response_model=ApiResponse)
|
||||
def get_rag_content(
|
||||
end_user_id: str = Query(..., description="宿主ID"),
|
||||
|
||||
@@ -20,6 +20,7 @@ from app.core.memory.storage_services.extraction_engine.knowledge_extraction.mem
|
||||
memory_summary_generation
|
||||
from app.core.memory.utils.llm.llm_utils import MemoryClientFactory
|
||||
from app.core.memory.utils.log.logging_utils import log_time
|
||||
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
|
||||
from app.db import get_db_context
|
||||
from app.repositories.neo4j.add_edges import add_memory_summary_statement_edges
|
||||
from app.repositories.neo4j.add_nodes import add_memory_summary_nodes
|
||||
@@ -313,6 +314,28 @@ async def write(
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[WRITE] 写入活动统计缓存失败(不影响主流程): {cache_err}", exc_info=True)
|
||||
|
||||
# 同步 Neo4j 记忆节点总数到 PostgreSQL end_users.memory_count
|
||||
if end_user_id:
|
||||
try:
|
||||
memory_count_connector = Neo4jConnector()
|
||||
try:
|
||||
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||
end_user_id,
|
||||
memory_count_connector,
|
||||
)
|
||||
finally:
|
||||
await memory_count_connector.close()
|
||||
|
||||
logger.info(
|
||||
f"[MemoryCount] 写入后同步 memory_count: "
|
||||
f"end_user_id={end_user_id}, count={node_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[MemoryCount] 写入后同步 memory_count 失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Close LLM/Embedder underlying httpx clients to prevent
|
||||
# 'RuntimeError: Event loop is closed' during garbage collection
|
||||
for client_obj in (llm_client, embedder_client):
|
||||
@@ -331,3 +354,4 @@ async def write(
|
||||
|
||||
logger.info("=== Pipeline Complete ===")
|
||||
logger.info(f"Total execution time: {total_time:.2f} seconds")
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from uuid import UUID
|
||||
from datetime import datetime
|
||||
|
||||
from app.core.memory.storage_services.forgetting_engine.forgetting_strategy import ForgettingStrategy
|
||||
from app.core.memory.utils.memory_count_utils import sync_end_user_memory_count_from_neo4j
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
@@ -145,7 +146,22 @@ class ForgettingScheduler:
|
||||
}
|
||||
|
||||
logger.info("没有可遗忘的节点对,遗忘周期结束")
|
||||
|
||||
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
|
||||
if end_user_id:
|
||||
try:
|
||||
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||
end_user_id,
|
||||
self.connector,
|
||||
)
|
||||
logger.info(
|
||||
f"[MemoryCount] 遗忘后同步 memory_count: "
|
||||
f"end_user_id={end_user_id}, count={node_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return report
|
||||
|
||||
# 步骤3:按激活值排序(激活值最低的优先)
|
||||
@@ -302,7 +318,22 @@ class ForgettingScheduler:
|
||||
f"({reduction_rate:.2%}), "
|
||||
f"耗时 {duration:.2f} 秒"
|
||||
)
|
||||
|
||||
# 同步 Neo4j 记忆节点总数到 PostgreSQL 的 end_users.memory_count
|
||||
if end_user_id:
|
||||
try:
|
||||
node_count = await sync_end_user_memory_count_from_neo4j(
|
||||
end_user_id,
|
||||
self.connector,
|
||||
)
|
||||
logger.info(
|
||||
f"[MemoryCount] 遗忘后同步 memory_count: "
|
||||
f"end_user_id={end_user_id}, count={node_count}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[MemoryCount] 遗忘后同步 memory_count 失败(不影响主流程): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
|
||||
36
api/app/core/memory/utils/memory_count_utils.py
Normal file
36
api/app/core/memory/utils/memory_count_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from uuid import UUID
|
||||
|
||||
from app.db import get_db_context
|
||||
from app.models.end_user_model import EndUser
|
||||
from app.repositories.memory_config_repository import MemoryConfigRepository
|
||||
from app.repositories.neo4j.neo4j_connector import Neo4jConnector
|
||||
|
||||
|
||||
async def sync_end_user_memory_count_from_neo4j(
|
||||
end_user_id: str,
|
||||
connector: Neo4jConnector,
|
||||
) -> int:
|
||||
"""
|
||||
Sync one end user's Neo4j memory node count to PostgreSQL.
|
||||
|
||||
The caller owns the Neo4j connector lifecycle.
|
||||
"""
|
||||
if not end_user_id:
|
||||
return 0
|
||||
|
||||
result = await connector.execute_query(
|
||||
MemoryConfigRepository.SEARCH_FOR_ALL_BATCH,
|
||||
end_user_ids=[end_user_id],
|
||||
)
|
||||
node_count = int(result[0]["total"]) if result else 0
|
||||
|
||||
with get_db_context() as db:
|
||||
db.query(EndUser).filter(
|
||||
EndUser.id == UUID(end_user_id)
|
||||
).update(
|
||||
{"memory_count": node_count},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.commit()
|
||||
|
||||
return node_count
|
||||
@@ -2,7 +2,6 @@
|
||||
# Author: Eternity
|
||||
# @Email: 1533512157@qq.com
|
||||
# @Time : 2026/2/10 13:33
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
@@ -142,10 +141,9 @@ class GraphBuilder:
|
||||
|
||||
for node_info in source_nodes:
|
||||
if self.get_node_type(node_info["id"]) in BRANCH_NODES:
|
||||
if node_info.get("branch") is not None:
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
branch_nodes.append(
|
||||
(node_info["id"], node_info["branch"])
|
||||
)
|
||||
else:
|
||||
if self.get_node_type(node_info["id"]) in (NodeType.END, NodeType.OUTPUT):
|
||||
output_nodes.append(node_info["id"])
|
||||
@@ -316,12 +314,9 @@ class GraphBuilder:
|
||||
for idx in range(len(related_edge)):
|
||||
# Generate a condition expression for each edge
|
||||
# Used later to determine which branch to take based on the node's output
|
||||
# For LLM nodes, use branch_signal field for routing (output is dynamic text)
|
||||
# For other branch nodes (e.g. HTTP), use output field
|
||||
route_field = "branch_signal" if node_type == NodeType.LLM else "output"
|
||||
related_edge[idx]['condition'] = (
|
||||
f"node[{json.dumps(node_id)}][{json.dumps(route_field)}] == {json.dumps(related_edge[idx]['label'])}"
|
||||
)
|
||||
# Assumes node output `node.<node_id>.output` matches the edge's label
|
||||
# For example, if node.123.output == 'CASE1', take the branch labeled 'CASE1'
|
||||
related_edge[idx]['condition'] = f"node['{node_id}']['output'] == '{related_edge[idx]['label']}'"
|
||||
|
||||
if node_instance:
|
||||
# Wrap node's run method to avoid closure issues
|
||||
|
||||
@@ -18,17 +18,10 @@ class AssignerNode(BaseNode):
|
||||
super().__init__(node_config, workflow_config, down_stream_nodes)
|
||||
self.variable_updater = True
|
||||
self.typed_config: AssignerNodeConfig | None = None
|
||||
self._input_data: dict[str, Any] | None = None
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {}
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取节点输入,如果有缓存的执行前数据则使用缓存"""
|
||||
if self._input_data is not None:
|
||||
return self._input_data
|
||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
"""
|
||||
Execute the assignment operation defined by this node.
|
||||
@@ -41,9 +34,6 @@ class AssignerNode(BaseNode):
|
||||
Returns:
|
||||
None or the result of the assignment operation.
|
||||
"""
|
||||
# 在执行前提取并缓存输入数据(捕获执行前的变量值)
|
||||
self._input_data = {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
# Initialize a variable pool for accessing conversation, node, and system variables
|
||||
self.typed_config = AssignerNodeConfig(**self.config)
|
||||
logger.info(f"节点 {self.node_id} 开始执行")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
@@ -23,9 +22,6 @@ from app.services.multimodal_service import MultimodalService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 匹配模板变量 {{xxx}} 的正则
|
||||
_TEMPLATE_PATTERN = re.compile(r"\{\{.*?\}\}")
|
||||
|
||||
|
||||
class NodeExecutionError(Exception):
|
||||
"""节点执行失败异常。
|
||||
@@ -507,29 +503,10 @@ class BaseNode(ABC):
|
||||
variable_pool: The variable pool used for reading and writing variables.
|
||||
|
||||
Returns:
|
||||
A dictionary containing the node's input data with all template
|
||||
variables resolved to their actual runtime values.
|
||||
A dictionary containing the node's input data.
|
||||
"""
|
||||
return {"config": self._resolve_config(self.config, variable_pool)}
|
||||
|
||||
@staticmethod
|
||||
def _resolve_config(config: Any, variable_pool: VariablePool) -> Any:
|
||||
"""递归解析 config 中的模板变量,将 {{xxx}} 替换为实际值。
|
||||
|
||||
Args:
|
||||
config: 节点的原始配置(可能包含模板变量)。
|
||||
variable_pool: 变量池,用于解析模板变量。
|
||||
|
||||
Returns:
|
||||
解析后的配置,所有字符串中的 {{变量}} 已被替换为真实值。
|
||||
"""
|
||||
if isinstance(config, str) and _TEMPLATE_PATTERN.search(config):
|
||||
return BaseNode._render_template(config, variable_pool, strict=False)
|
||||
elif isinstance(config, dict):
|
||||
return {k: BaseNode._resolve_config(v, variable_pool) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [BaseNode._resolve_config(item, variable_pool) for item in config]
|
||||
return config
|
||||
# Default implementation returns the node configuration
|
||||
return {"config": self.config}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> Any:
|
||||
"""Extracts the actual output from the business result.
|
||||
|
||||
@@ -121,10 +121,7 @@ class DocExtractorNode(BaseNode):
|
||||
return business_result
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
file_selector = self.config.get("file_selector", "")
|
||||
# 将变量选择器(如 sys.files)解析为实际值
|
||||
resolved = self.get_variable(file_selector, variable_pool, strict=False, default=file_selector)
|
||||
return {"file_selector": resolved}
|
||||
return {"file_selector": self.config.get("file_selector")}
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
|
||||
config = DocExtractorNodeConfig(**self.config)
|
||||
|
||||
@@ -31,7 +31,7 @@ class NodeType(StrEnum):
|
||||
NOTES = "notes"
|
||||
|
||||
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER, NodeType.LLM})
|
||||
BRANCH_NODES = frozenset({NodeType.IF_ELSE, NodeType.HTTP_REQUEST, NodeType.QUESTION_CLASSIFIER})
|
||||
|
||||
|
||||
class ComparisonOperator(StrEnum):
|
||||
|
||||
@@ -6,7 +6,6 @@ import uuid
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.core.workflow.nodes.base_config import BaseNodeConfig, VariableDefinition
|
||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
|
||||
|
||||
@@ -50,20 +49,6 @@ class MemoryWindowSetting(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class LLMErrorHandleConfig(BaseModel):
|
||||
"""LLM 异常处理配置"""
|
||||
|
||||
method: HttpErrorHandle = Field(
|
||||
default=HttpErrorHandle.NONE,
|
||||
description="异常处理策略:'none' 抛出异常, 'default' 返回默认值, 'branch' 走异常分支",
|
||||
)
|
||||
|
||||
output: str = Field(
|
||||
default="",
|
||||
description="LLM 异常时返回的默认输出文本(method=default 时生效)",
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeConfig(BaseNodeConfig):
|
||||
"""LLM 节点配置
|
||||
|
||||
@@ -167,11 +152,6 @@ class LLMNodeConfig(BaseNodeConfig):
|
||||
description="输出变量定义(自动生成,通常不需要修改)"
|
||||
)
|
||||
|
||||
error_handle: LLMErrorHandleConfig = Field(
|
||||
default_factory=LLMErrorHandleConfig,
|
||||
description="LLM 异常处理配置",
|
||||
)
|
||||
|
||||
@field_validator("messages", "prompt")
|
||||
@classmethod
|
||||
def validate_input_mode(cls, v):
|
||||
|
||||
@@ -15,7 +15,6 @@ from app.core.models import RedBearLLM, RedBearModelConfig
|
||||
from app.core.workflow.engine.state_manager import WorkflowState
|
||||
from app.core.workflow.engine.variable_pool import VariablePool
|
||||
from app.core.workflow.nodes.base_node import BaseNode
|
||||
from app.core.workflow.nodes.enums import HttpErrorHandle
|
||||
from app.core.workflow.nodes.llm.config import LLMNodeConfig
|
||||
from app.core.workflow.variable.base_variable import VariableType
|
||||
from app.db import get_db_context
|
||||
@@ -77,7 +76,7 @@ class LLMNode(BaseNode):
|
||||
self.messages = []
|
||||
|
||||
def _output_types(self) -> dict[str, VariableType]:
|
||||
return {"output": VariableType.STRING, "branch_signal": VariableType.STRING}
|
||||
return {"output": VariableType.STRING}
|
||||
|
||||
def _render_context(self, message: str, variable_pool: VariablePool):
|
||||
context = f"<context>{self._render_template(self.typed_config.context, variable_pool)}</context>"
|
||||
@@ -240,7 +239,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return llm
|
||||
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> AIMessage:
|
||||
"""非流式执行 LLM 调用
|
||||
|
||||
Args:
|
||||
@@ -248,36 +247,28 @@ class LLMNode(BaseNode):
|
||||
variable_pool: 变量池
|
||||
|
||||
Returns:
|
||||
dict: {"llm_result": AIMessage, "branch_signal": "SUCCESS"} on success,
|
||||
{"llm_result": None, "branch_signal": "ERROR"} on branch error
|
||||
LLM 响应消息
|
||||
"""
|
||||
try:
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
# self.typed_config = LLMNodeConfig(**self.config)
|
||||
llm = await self._prepare_llm(state, variable_pool, False)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(非流式)")
|
||||
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
# 调用 LLM(支持字符串或消息列表)
|
||||
response = await llm.ainvoke(self.messages)
|
||||
# 提取内容
|
||||
if hasattr(response, 'content'):
|
||||
content = self.process_model_output(response.content)
|
||||
else:
|
||||
content = str(response)
|
||||
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(content)}")
|
||||
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return {
|
||||
"llm_result": AIMessage(content=content, response_metadata={
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
}),
|
||||
"branch_signal": "SUCCESS",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"节点 {self.node_id} LLM 调用失败: {e}")
|
||||
return self._handle_llm_error(e)
|
||||
# 返回 AIMessage(包含响应元数据)
|
||||
return AIMessage(content=content, response_metadata={
|
||||
**response.response_metadata,
|
||||
"token_usage": getattr(response, 'usage_metadata', None) or response.response_metadata.get('token_usage')
|
||||
})
|
||||
|
||||
def _extract_input(self, state: WorkflowState, variable_pool: VariablePool) -> dict[str, Any]:
|
||||
"""提取输入数据(用于记录)"""
|
||||
@@ -295,36 +286,16 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_output(self, business_result: Any) -> dict:
|
||||
"""从业务结果中提取输出变量
|
||||
|
||||
支持新旧两种格式:
|
||||
- 新格式:{"llm_result": AIMessage, "branch_signal": "SUCCESS"}
|
||||
- 旧格式:AIMessage(向后兼容)
|
||||
"""
|
||||
if isinstance(business_result, dict) and "branch_signal" in business_result:
|
||||
llm_result = business_result.get("llm_result")
|
||||
if isinstance(llm_result, AIMessage):
|
||||
return {
|
||||
"output": llm_result.content,
|
||||
"branch_signal": business_result["branch_signal"],
|
||||
}
|
||||
return {
|
||||
"output": str(llm_result) if llm_result else "",
|
||||
"branch_signal": business_result["branch_signal"],
|
||||
}
|
||||
# 旧格式向后兼容
|
||||
def _extract_output(self, business_result: Any) -> str:
|
||||
"""从 AIMessage 中提取文本内容"""
|
||||
if isinstance(business_result, AIMessage):
|
||||
return {"output": business_result.content, "branch_signal": "SUCCESS"}
|
||||
return {"output": str(business_result), "branch_signal": "SUCCESS"}
|
||||
return business_result.content
|
||||
return str(business_result)
|
||||
|
||||
def _extract_token_usage(self, business_result: Any) -> dict[str, int] | None:
|
||||
"""从业务结果中提取 token 使用情况"""
|
||||
llm_result = business_result
|
||||
if isinstance(business_result, dict):
|
||||
llm_result = business_result.get("llm_result", business_result)
|
||||
if isinstance(llm_result, AIMessage) and hasattr(llm_result, 'response_metadata'):
|
||||
usage = llm_result.response_metadata.get('token_usage')
|
||||
"""从 AIMessage 中提取 token 使用情况"""
|
||||
if isinstance(business_result, AIMessage) and hasattr(business_result, 'response_metadata'):
|
||||
usage = business_result.response_metadata.get('token_usage')
|
||||
if usage:
|
||||
return {
|
||||
"prompt_tokens": usage.get('input_tokens', 0),
|
||||
@@ -333,44 +304,6 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
return None
|
||||
|
||||
def _handle_llm_error(self, error: Exception) -> dict:
|
||||
"""处理 LLM 调用异常,根据 error_handle 配置决定行为
|
||||
|
||||
Args:
|
||||
error: LLM 调用中捕获的异常
|
||||
|
||||
Returns:
|
||||
dict: {"llm_result": None, "branch_signal": "ERROR"} for branch mode,
|
||||
or default output for default mode
|
||||
|
||||
Raises:
|
||||
原异常(当 error_handle.method 为 NONE 时)
|
||||
"""
|
||||
if self.typed_config is None:
|
||||
raise error
|
||||
|
||||
match self.typed_config.error_handle.method:
|
||||
case HttpErrorHandle.NONE:
|
||||
raise error
|
||||
case HttpErrorHandle.DEFAULT:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id}: LLM 调用失败,返回默认输出"
|
||||
)
|
||||
default_output = self.typed_config.error_handle.output or ""
|
||||
return {
|
||||
"llm_result": AIMessage(content=default_output, response_metadata={}),
|
||||
"branch_signal": "SUCCESS",
|
||||
}
|
||||
case HttpErrorHandle.BRANCH:
|
||||
logger.warning(
|
||||
f"节点 {self.node_id}: LLM 调用失败,切换到异常处理分支"
|
||||
)
|
||||
return {
|
||||
"llm_result": None,
|
||||
"branch_signal": "ERROR",
|
||||
}
|
||||
raise error
|
||||
|
||||
async def execute_stream(self, state: WorkflowState, variable_pool: VariablePool):
|
||||
"""流式执行 LLM 调用
|
||||
|
||||
@@ -383,58 +316,54 @@ class LLMNode(BaseNode):
|
||||
"""
|
||||
self.typed_config = LLMNodeConfig(**self.config)
|
||||
|
||||
try:
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
llm = await self._prepare_llm(state, variable_pool, True)
|
||||
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
logger.info(f"节点 {self.node_id} 开始执行 LLM 调用(流式)")
|
||||
# logger.debug(f"LLM 配置: streaming={getattr(llm._model, 'streaming', 'unknown')}")
|
||||
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
# 累积完整响应
|
||||
full_response = ""
|
||||
chunk_count = 0
|
||||
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
last_usage_metadata = {}
|
||||
async for chunk in llm.astream(self.messages):
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
last_usage_metadata = chunk.usage_metadata
|
||||
# 调用 LLM(流式,支持字符串或消息列表)
|
||||
last_meta_data = {}
|
||||
last_usage_metadata = {}
|
||||
async for chunk in llm.astream(self.messages):
|
||||
if hasattr(chunk, 'content'):
|
||||
content = self.process_model_output(chunk.content)
|
||||
else:
|
||||
content = str(chunk)
|
||||
if hasattr(chunk, 'response_metadata') and chunk.response_metadata:
|
||||
last_meta_data = chunk.response_metadata
|
||||
if hasattr(chunk, 'usage_metadata') and chunk.usage_metadata:
|
||||
last_usage_metadata = chunk.usage_metadata
|
||||
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
# 只有当内容不为空时才处理
|
||||
if content:
|
||||
full_response += content
|
||||
chunk_count += 1
|
||||
|
||||
# 流式返回每个文本片段
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata={
|
||||
**last_meta_data,
|
||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||
# 流式返回每个文本片段
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": content
|
||||
}
|
||||
)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": {"llm_result": final_message, "branch_signal": "SUCCESS"}}
|
||||
except Exception as e:
|
||||
logger.error(f"节点 {self.node_id} LLM 流式调用失败: {e}")
|
||||
error_result = self._handle_llm_error(e)
|
||||
yield {"__final__": True, "result": error_result}
|
||||
yield {
|
||||
"__final__": False,
|
||||
"chunk": "",
|
||||
"done": True
|
||||
}
|
||||
logger.info(f"节点 {self.node_id} LLM 调用完成,输出长度: {len(full_response)}, 总 chunks: {chunk_count}")
|
||||
|
||||
# 构建完整的 AIMessage(包含元数据)
|
||||
final_message = AIMessage(
|
||||
content=full_response,
|
||||
response_metadata={
|
||||
**last_meta_data,
|
||||
"token_usage": last_usage_metadata or last_meta_data.get('token_usage')
|
||||
}
|
||||
)
|
||||
|
||||
# yield 完成标记
|
||||
yield {"__final__": True, "result": final_message}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
@@ -38,6 +38,15 @@ class EndUser(Base):
|
||||
comment="关联的记忆配置ID"
|
||||
)
|
||||
|
||||
memory_count = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=0,
|
||||
server_default="0",
|
||||
index=True,
|
||||
comment="记忆节点总数",
|
||||
)
|
||||
|
||||
# 用户摘要四个维度 - User Summary Four Dimensions
|
||||
user_summary = Column(Text, nullable=True, comment="缓存的用户摘要(基本介绍)")
|
||||
personality_traits = Column(Text, nullable=True, comment="性格特点")
|
||||
|
||||
@@ -19,4 +19,6 @@ class EndUser(BaseModel):
|
||||
|
||||
# 用户摘要和洞察更新时间
|
||||
user_summary_updated_at: Optional[datetime.datetime] = Field(description="用户摘要最后更新时间", default=None)
|
||||
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
|
||||
memory_insight_updated_at: Optional[datetime.datetime] = Field(description="洞察报告最后更新时间", default=None)
|
||||
#用户记忆节点总数(Neo4j模式)
|
||||
memory_count: int = Field(description="记忆节点总数", default=0)
|
||||
@@ -102,11 +102,6 @@ class AppDslService:
|
||||
{**r, "_ref": self._agent_ref(r.get("target_agent_id"))} for r in (cfg["routing_rules"] or [])
|
||||
]
|
||||
return enriched
|
||||
if app_type == AppType.WORKFLOW:
|
||||
enriched = {**cfg}
|
||||
if "nodes" in cfg:
|
||||
enriched["nodes"] = self._enrich_workflow_nodes(cfg["nodes"])
|
||||
return enriched
|
||||
return cfg
|
||||
|
||||
def _export_draft(self, app: App, meta: dict, app_meta: dict) -> tuple[str, str]:
|
||||
@@ -115,7 +110,7 @@ class AppDslService:
|
||||
config_data = {
|
||||
"variables": config.variables if config else [],
|
||||
"edges": config.edges if config else [],
|
||||
"nodes": self._enrich_workflow_nodes(config.nodes) if config else [],
|
||||
"nodes": config.nodes if config else [],
|
||||
"features": config.features if config else {},
|
||||
"execution_config": config.execution_config if config else {},
|
||||
"triggers": config.triggers if config else [],
|
||||
@@ -195,23 +190,6 @@ class AppDslService:
|
||||
def _enrich_tools(self, tools: list) -> list:
|
||||
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
|
||||
|
||||
def _enrich_workflow_nodes(self, nodes: list) -> list:
|
||||
"""enrich 工作流节点中的模型引用,添加 name、provider、type 信息"""
|
||||
from app.core.workflow.nodes.enums import NodeType
|
||||
enriched_nodes = []
|
||||
for node in (nodes or []):
|
||||
node_type = node.get("type")
|
||||
config = dict(node.get("config") or {})
|
||||
|
||||
if node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||
model_id = config.get("model_id")
|
||||
if model_id:
|
||||
config["model_ref"] = self._model_ref(model_id)
|
||||
del config["model_id"]
|
||||
|
||||
enriched_nodes.append({**node, "config": config})
|
||||
return enriched_nodes
|
||||
|
||||
def _skill_ref(self, skill_id) -> Optional[dict]:
|
||||
if not skill_id:
|
||||
return None
|
||||
@@ -642,16 +620,16 @@ class AppDslService:
|
||||
warnings.append(f"[{node_label}] 知识库 '{kb_id}' 未匹配,已移除,请导入后手动配置")
|
||||
config["knowledge_bases"] = resolved_kbs
|
||||
elif node_type in (NodeType.LLM.value, NodeType.QUESTION_CLASSIFIER.value, NodeType.PARAMETER_EXTRACTOR.value):
|
||||
model_ref = config.get("model_ref") or config.get("model_id")
|
||||
model_ref = config.get("model_id")
|
||||
if model_ref:
|
||||
ref_dict = None
|
||||
if isinstance(model_ref, dict):
|
||||
ref_dict = {
|
||||
"id": model_ref.get("id"),
|
||||
"name": model_ref.get("name"),
|
||||
"provider": model_ref.get("provider"),
|
||||
"type": model_ref.get("type")
|
||||
}
|
||||
ref_id = model_ref.get("id")
|
||||
ref_name = model_ref.get("name")
|
||||
if ref_id:
|
||||
ref_dict = {"id": ref_id}
|
||||
elif ref_name is not None:
|
||||
ref_dict = {"name": ref_name, "provider": model_ref.get("provider"), "type": model_ref.get("type")}
|
||||
elif isinstance(model_ref, str):
|
||||
try:
|
||||
uuid.UUID(model_ref)
|
||||
@@ -662,18 +640,12 @@ class AppDslService:
|
||||
resolved_model_id = self._resolve_model(ref_dict, tenant_id, warnings)
|
||||
if resolved_model_id:
|
||||
config["model_id"] = resolved_model_id
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
else:
|
||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||
config["model_id"] = None
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
else:
|
||||
warnings.append(f"[{node_label}] 模型未匹配,已置空,请导入后手动配置")
|
||||
config["model_id"] = None
|
||||
if "model_ref" in config:
|
||||
del config["model_ref"]
|
||||
resolved_nodes.append({**node, "config": config})
|
||||
return resolved_nodes
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, nullslast, or_, and_, cast, String
|
||||
from sqlalchemy import desc, nullslast, or_, cast, String, func
|
||||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import HTTPException
|
||||
@@ -102,6 +102,7 @@ def get_workspace_end_users_paginated(
|
||||
"""获取工作空间的宿主列表(分页版本,支持模糊搜索)
|
||||
|
||||
返回结果按 created_at 从新到旧排序(NULL 值排在最后)
|
||||
固定过滤 memory_count > 0 的宿主,保证分页基于“有记忆宿主”集合计算。
|
||||
支持通过 keyword 参数同时模糊搜索 other_name 和 id 字段
|
||||
|
||||
Args:
|
||||
@@ -120,7 +121,8 @@ def get_workspace_end_users_paginated(
|
||||
try:
|
||||
# 构建基础查询
|
||||
base_query = db.query(EndUserModel).filter(
|
||||
EndUserModel.workspace_id == workspace_id
|
||||
EndUserModel.workspace_id == workspace_id,
|
||||
EndUserModel.memory_count > 0 , # 只查询有记忆的宿主
|
||||
)
|
||||
|
||||
# 构建搜索条件(过滤空字符串和None)
|
||||
@@ -128,20 +130,13 @@ def get_workspace_end_users_paginated(
|
||||
|
||||
if keyword:
|
||||
keyword_pattern = f"%{keyword}%"
|
||||
# other_name 匹配始终生效;id 匹配仅对 other_name 为空的记录生效
|
||||
base_query = base_query.filter(
|
||||
or_(
|
||||
EndUserModel.other_name.ilike(keyword_pattern),
|
||||
and_(
|
||||
or_(
|
||||
EndUserModel.other_name.is_(None),
|
||||
EndUserModel.other_name == "",
|
||||
),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
)
|
||||
)
|
||||
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name;other_name 为空时匹配 id)")
|
||||
business_logger.info(f"应用模糊搜索: keyword={keyword}(匹配 other_name 或 id)")
|
||||
|
||||
# 获取总记录数
|
||||
total = base_query.count()
|
||||
@@ -169,6 +164,98 @@ def get_workspace_end_users_paginated(
|
||||
business_logger.error(f"获取工作空间宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}")
|
||||
raise
|
||||
|
||||
def get_workspace_end_users_paginated_rag(
|
||||
db: Session,
|
||||
workspace_id: uuid.UUID,
|
||||
current_user: User,
|
||||
page: int,
|
||||
pagesize: int,
|
||||
keyword: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""RAG 模式宿主列表分页。
|
||||
|
||||
RAG 记忆数量以 documents.chunk_num 为准:
|
||||
- file_name = end_user_id + ".txt"
|
||||
- 只统计当前 workspace 下 permission_id="Memory" 的用户记忆知识库
|
||||
- 在 SQL 层过滤 chunk 总数为 0 的宿主,保证分页准确
|
||||
"""
|
||||
business_logger.info(
|
||||
f"获取 RAG 宿主列表(分页): workspace_id={workspace_id}, "
|
||||
f"keyword={keyword}, page={page}, pagesize={pagesize}, 操作者: {current_user.username}"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.document_model import Document
|
||||
from app.models.knowledge_model import Knowledge
|
||||
|
||||
chunk_subquery = (
|
||||
db.query(
|
||||
Document.file_name.label("file_name"),
|
||||
func.coalesce(func.sum(Document.chunk_num), 0).label("memory_count"),
|
||||
)
|
||||
.join(Knowledge, Document.kb_id == Knowledge.id)
|
||||
.filter(
|
||||
Knowledge.workspace_id == workspace_id,
|
||||
Knowledge.status == 1,
|
||||
Knowledge.permission_id == "Memory",
|
||||
Document.status == 1,
|
||||
)
|
||||
.group_by(Document.file_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
base_query = (
|
||||
db.query(
|
||||
EndUserModel,
|
||||
chunk_subquery.c.memory_count.label("memory_count"),
|
||||
)
|
||||
.join(
|
||||
chunk_subquery,
|
||||
chunk_subquery.c.file_name == func.concat(cast(EndUserModel.id, String), ".txt"),
|
||||
)
|
||||
.filter(
|
||||
EndUserModel.workspace_id == workspace_id,
|
||||
chunk_subquery.c.memory_count > 0,
|
||||
)
|
||||
)
|
||||
|
||||
keyword = keyword.strip() if keyword else None
|
||||
if keyword:
|
||||
keyword_pattern = f"%{keyword}%"
|
||||
base_query = base_query.filter(
|
||||
or_(
|
||||
EndUserModel.other_name.ilike(keyword_pattern),
|
||||
cast(EndUserModel.id, String).ilike(keyword_pattern),
|
||||
)
|
||||
)
|
||||
|
||||
total = base_query.count()
|
||||
if total == 0:
|
||||
business_logger.info("RAG 模式下没有符合条件的宿主")
|
||||
return {"items": [], "total": 0}
|
||||
|
||||
rows = base_query.order_by(
|
||||
nullslast(desc(EndUserModel.created_at)),
|
||||
desc(EndUserModel.id),
|
||||
).offset((page - 1) * pagesize).limit(pagesize).all()
|
||||
|
||||
items = []
|
||||
for end_user_orm, memory_count in rows:
|
||||
items.append({
|
||||
"end_user": EndUserSchema.model_validate(end_user_orm),
|
||||
"memory_count": int(memory_count or 0),
|
||||
})
|
||||
|
||||
business_logger.info(f"成功获取 RAG 宿主记录 {len(items)} 条,总计 {total} 条")
|
||||
return {"items": items, "total": total}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
business_logger.error(
|
||||
f"获取 RAG 宿主列表(分页)失败: workspace_id={workspace_id} - {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
def get_workspace_memory_increment(
|
||||
db: Session,
|
||||
|
||||
47
api/migrations/versions/1f85dce125e5_202604271530.py
Normal file
47
api/migrations/versions/1f85dce125e5_202604271530.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""202604271530
|
||||
|
||||
Revision ID: 1f85dce125e5
|
||||
Revises: 4e89970f9e7c
|
||||
Create Date: 2026-04-27 15:30:35.614679
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1f85dce125e5'
|
||||
down_revision: Union[str, None] = '4e89970f9e7c'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('files', sa.Column('file_key', sa.String(length=512), nullable=True, comment='storage file key for FileStorageService'))
|
||||
op.create_index(op.f('ix_files_file_key'), 'files', ['file_key'], unique=False)
|
||||
op.alter_column('model_configs', 'capability',
|
||||
existing_type=postgresql.ARRAY(sa.VARCHAR()),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video', 'thinking'])",
|
||||
existing_comment="模型能力列表(如['vision', 'audio', 'video'])",
|
||||
existing_nullable=False)
|
||||
# ### end Alembic commands ###
|
||||
op.execute("""
|
||||
UPDATE files
|
||||
SET file_key = 'kb/' || kb_id::text || '/' || parent_id::text || '/' || id::text || file_ext
|
||||
WHERE file_ext != 'folder' AND file_key IS NULL
|
||||
""")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column('model_configs', 'capability',
|
||||
existing_type=postgresql.ARRAY(sa.VARCHAR()),
|
||||
comment="模型能力列表(如['vision', 'audio', 'video'])",
|
||||
existing_comment="模型能力列表(如['vision', 'audio', 'video', 'thinking'])",
|
||||
existing_nullable=False)
|
||||
op.drop_index(op.f('ix_files_file_key'), table_name='files')
|
||||
op.drop_column('files', 'file_key')
|
||||
# ### end Alembic commands ###
|
||||
34
api/migrations/versions/e2d60c6d1a1a_202604281230.py
Normal file
34
api/migrations/versions/e2d60c6d1a1a_202604281230.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""202604281230
|
||||
|
||||
Revision ID: e2d60c6d1a1a
|
||||
Revises: 1f85dce125e5
|
||||
Create Date: 2026-04-28 12:32:01.643954
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e2d60c6d1a1a'
|
||||
down_revision: Union[str, None] = '1f85dce125e5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column('tenants', 'api_ops_rate_limit')
|
||||
op.drop_column('tenants', 'plan')
|
||||
op.drop_column('tenants', 'plan_expired_at')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('tenants', sa.Column('plan_expired_at', postgresql.TIMESTAMP(), autoincrement=False, nullable=True))
|
||||
op.add_column('tenants', sa.Column('plan', sa.VARCHAR(length=50), autoincrement=False, nullable=True))
|
||||
op.add_column('tenants', sa.Column('api_ops_rate_limit', sa.VARCHAR(length=100), autoincrement=False, nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
BIN
web/src/assets/images/index/index_bg.png
Normal file
BIN
web/src/assets/images/index/index_bg.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 108 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 336 KiB |
BIN
web/src/assets/images/login/bg.mp4
Normal file
BIN
web/src/assets/images/login/bg.mp4
Normal file
Binary file not shown.
Binary file not shown.
|
Before Width: | Height: | Size: 387 B |
13
web/src/assets/images/login/check.svg
Normal file
13
web/src/assets/images/login/check.svg
Normal file
@@ -0,0 +1,13 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<svg width="16px" height="16px" viewBox="0 0 16 16" version="1.1" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink">
|
||||
<title>勾选</title>
|
||||
<g id="空间外层页面优化" stroke="none" stroke-width="1" fill="none" fill-rule="evenodd">
|
||||
<g id="登录页面" transform="translate(-64, -611)" fill="#FFFFFF" fill-rule="nonzero">
|
||||
<g id="编组-8" transform="translate(64, 608)">
|
||||
<g id="勾选" transform="translate(0, 3)">
|
||||
<path d="M12,0 C14.209139,0 16,1.790861 16,4 L16,12 C16,14.209139 14.209139,16 12,16 L4,16 C1.790861,16 0,14.209139 0,12 L0,4 C0,1.790861 1.790861,4.4408921e-16 4,0 L12,0 Z M11.9182266,4.80024782 C11.7273831,4.80024782 11.5444062,4.87629473 11.4097812,5.0115625 L6.552,9.86932813 L4.4284375,7.74489063 C4.29381317,7.60962766 4.11083967,7.53358379 3.92,7.53358379 C3.72916033,7.53358379 3.54618683,7.60962766 3.4115625,7.74489063 C3.27602096,7.87955071 3.19979999,8.06271883 3.19979999,8.25378125 C3.19979999,8.44484367 3.27602096,8.62801179 3.4115625,8.76267188 L6.0453125,11.3946719 C6.17993745,11.5299396 6.3629143,11.6059866 6.55375781,11.6059866 C6.74460132,11.6059866 6.92757818,11.5299396 7.06220312,11.3946719 L12.4311094,6.02667188 C12.5659036,5.89187668 12.6412595,5.70881589 12.6404302,5.51818919 C12.639587,5.3275625 12.5626279,5.14516989 12.4266562,5.0115625 C12.2920469,4.87629473 12.1090701,4.80024782 11.9182266,4.80024782 Z" id="形状结合"></path>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
BIN
web/src/assets/images/login/title_en.png
Normal file
BIN
web/src/assets/images/login/title_en.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 5.3 KiB |
BIN
web/src/assets/images/login/title_zh.png
Normal file
BIN
web/src/assets/images/login/title_zh.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 3.8 KiB |
@@ -467,4 +467,29 @@ input:-webkit-autofill:active {
|
||||
animation-name: onAutoFillStart;
|
||||
animation-duration: 1ms;
|
||||
}
|
||||
@keyframes onAutoFillStart { from {} to {} }
|
||||
@keyframes onAutoFillStart { from {} to {} }
|
||||
/* Login input placeholder */
|
||||
.login-input input::placeholder {
|
||||
color: #A8A9AA !important;
|
||||
}
|
||||
|
||||
.login-input {
|
||||
border-color: #A8A9AA;
|
||||
}
|
||||
|
||||
/* Login input hover/focus border */
|
||||
.login-input:hover,
|
||||
.login-input:focus-within {
|
||||
border-color: #FFFFFF !important;
|
||||
box-shadow: none !important;
|
||||
}
|
||||
|
||||
/* Override browser autofill styles */
|
||||
.login-input input:-webkit-autofill,
|
||||
.login-input input:-webkit-autofill:hover,
|
||||
.login-input input:-webkit-autofill:focus,
|
||||
.login-input input:-webkit-autofill:active {
|
||||
-webkit-box-shadow: 0 0 0px 1000px #0A0A0A inset !important;
|
||||
-webkit-text-fill-color: #FFFFFF !important;
|
||||
transition: background-color 5000s ease-in-out 0s !important;
|
||||
}
|
||||
@@ -102,7 +102,7 @@ const Index = () => {
|
||||
<Flex gap={12} wrap="nowrap" className="rb:w-full! rb:h-full! rb:overflow-y-auto">
|
||||
<div className="rb:flex-1 rb:min-w-0">
|
||||
<Flex vertical>
|
||||
<div className='rb:w-full rb:h-26 rb:p-4 rb:bg-cover rb:bg-[url("@/assets/images/index/index_bg@2x.png")] rb:rounded-xl rb:overflow-hidden'>
|
||||
<div className='rb:w-full rb:h-26 rb:p-4 rb:bg-cover rb:bg-[url("@/assets/images/index/index_bg.png")] rb:rounded-xl rb:overflow-hidden'>
|
||||
<div className="rb:font-[MiSans-Bold] rb:font-bold rb:text-white rb:text-[18px] rb:leading-7">
|
||||
{t('index.spaceTitle')}
|
||||
</div>
|
||||
|
||||
@@ -14,27 +14,33 @@ import React, { useState, useEffect } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { Button, Input, Form, App } from 'antd';
|
||||
import type { FormProps } from 'antd';
|
||||
import clsx from 'clsx';
|
||||
|
||||
import { useUser, type LoginInfo } from '@/store/user';
|
||||
import { login } from '@/api/user'
|
||||
import loginBg from '@/assets/images/login/loginBg.png'
|
||||
import check from '@/assets/images/login/check.png'
|
||||
import loginBg from '@/assets/images/login/bg.mp4'
|
||||
import check from '@/assets/images/login/check.svg'
|
||||
import email from '@/assets/images/login/email.svg'
|
||||
import lock from '@/assets/images/login/lock.svg'
|
||||
import type { LoginForm } from './types';
|
||||
import { useI18n } from '@/store/locale'
|
||||
|
||||
/**
|
||||
* Input field styling
|
||||
*/
|
||||
const inputClassName = "rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!"
|
||||
const inputClassName = "login-input rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]! rb:bg-transparent! rb:text-[#FFFFFF]! [&_input]:rb:text-[#FFFFFF]! [&_input]:rb:caret-[#FFFFFF]!"
|
||||
|
||||
/**
|
||||
* Login page component
|
||||
*/const LoginPage: React.FC = () => {
|
||||
const { t } = useTranslation();
|
||||
const { clearUserInfo, updateLoginInfo, getUserInfo } = useUser();
|
||||
const { language } = useI18n()
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [form] = Form.useForm<LoginForm>();
|
||||
const emailVal = Form.useWatch('email', form);
|
||||
const passwordVal = Form.useWatch('password', form);
|
||||
const canLogin = !!(emailVal && passwordVal);
|
||||
const { message } = App.useApp();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -43,6 +49,7 @@ const inputClassName = "rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!"
|
||||
|
||||
/** Handle login form submission */
|
||||
const handleLogin: FormProps<LoginForm>['onFinish'] = async (values) => {
|
||||
if (!canLogin) return;
|
||||
if (!values.email) {
|
||||
message.warning(t('login.emailPlaceholder'));
|
||||
return;
|
||||
@@ -64,42 +71,45 @@ const inputClassName = "rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!"
|
||||
|
||||
|
||||
return (
|
||||
<div className="rb:min-h-screen rb:flex rb:h-screen">
|
||||
<div className="rb:min-h-screen rb:flex rb:h-screen rb:bg-[#0A0A0A] rb:text-[#FFFFFF]">
|
||||
<div className="rb:relative rb:w-1/2 rb:h-screen rb:overflow-hidden">
|
||||
<img src={loginBg} alt="loginBg" className="rb:w-full rb:h-full rb:object-cover rb:absolute rb:top-1/2 rb:-translate-y-1/2 rb:left-0" />
|
||||
<div className="rb:absolute rb:top-14 rb:left-16">
|
||||
<div className="rb:text-[28px] rb:leading-8.25 rb:font-bold rb:font-[AlimamaShuHeiTi,AlimamaShuHeiTi] rb:mb-4">{t('login.title')}</div>
|
||||
<div className="rb:text-[18px] rb:leading-6.25 rb:font-regular">{t('login.subTitle')}</div>
|
||||
<video src={loginBg} loop autoPlay playsInline muted className="rb:w-full rb:h-full rb:object-cover"></video>
|
||||
<div className="rb:absolute rb:top-10 rb:left-12">
|
||||
<div className={clsx("rb:h-8.25 rb:bg-cover", {
|
||||
"rb:w-89 rb:bg-[url('@/assets/images/login/title_en.png')]": language !== 'zh',
|
||||
"rb:w-42 rb:bg-[url('@/assets/images/login/title_zh.png')]": language === 'zh'
|
||||
})}></div>
|
||||
<div className="rb:text-[18px] rb:text-[rgba(255,255,255,0.7)] rb:leading-6.25 rb:font-regular rb:mt-3">{t('login.subTitle')}</div>
|
||||
</div>
|
||||
|
||||
<div className="rb:absolute rb:bottom-20.25 rb:left-16 rb:grid rb:grid-cols-2 rb:gap-x-30 rb:gap-y-10.75">
|
||||
{['intelligentMemory', 'instantRecall', 'knowledgeAssociation'].map(key => (
|
||||
<div key={key} className="rb:flex">
|
||||
<div className="rb:absolute rb:bottom-14 rb:left-12 rb:right-12 rb:grid rb:grid-cols-2 rb:gap-x-30 rb:gap-y-10.75">
|
||||
{['intelligentMemory', 'instantRecall', 'knowledgeAssociation'].map((key, index) => (
|
||||
<div key={key} className={`rb:flex${index === 0 ? ' rb:col-span-2' : ''}`}>
|
||||
<img src={check} className="rb:w-4 rb:h-4 rb:mr-2 rb:mt-0.75" />
|
||||
<div className="rb:text-[16px] rb:leading-5.5">
|
||||
<div className="rb:font-medium">{t(`login.${key}`)}</div>
|
||||
<div className="rb:text-[#5B6167] rb:text-[14px] rb:leading-5 rb:font-regular! rb:mt-2">{t(`login.${key}Desc`)}</div>
|
||||
<div className="rb:text-[14px] rb:text-[rgba(255,255,255,0.7)] rb:leading-5 rb:font-regular! rb:mt-2">{t(`login.${key}Desc`)}</div>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="rb:bg-[#FFFFFF] rb:flex rb:items-center rb:justify-center rb:flex-[1_1_auto]">
|
||||
<div className="rb:w-100 rb:mx-auto">
|
||||
<div className="rb:text-center rb:text-[28px] rb:font-semibold rb:leading-8 rb:mb-12">{t('login.welcome')}</div>
|
||||
<div className="rb:flex rb:items-center rb:justify-center rb:flex-[1_1_auto]">
|
||||
<div className="rb:w-110 rb:mx-auto">
|
||||
<div className="rb:text-center rb:text-[24px] rb:font-[MiSans-Bold] rb:font-bold rb:leading-8 rb:mb-12">{t('login.welcome')}</div>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleLogin}
|
||||
>
|
||||
<Form.Item name="email" className="rb:mb-5!">
|
||||
<Form.Item name="email" className="rb:mb-6!">
|
||||
<Input
|
||||
prefix={<img src={email} className="rb:w-5 rb:h-5 rb:mr-2" />}
|
||||
placeholder={t('login.emailPlaceholder')}
|
||||
className={inputClassName}
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item name="password">
|
||||
<Form.Item name="password" className="rb:mb-0!">
|
||||
<Input.Password
|
||||
prefix={<img src={lock} className="rb:w-5 rb:h-5 rb:mr-2" />}
|
||||
placeholder={t('login.passwordPlaceholder')}
|
||||
@@ -111,7 +121,11 @@ const inputClassName = "rb:rounded-[8px]! rb:p-[12px]! rb:h-[44px]!"
|
||||
block
|
||||
loading={loading}
|
||||
htmlType="submit"
|
||||
className="rb:h-10! rb:rounded-lg! rb:mt-4"
|
||||
disabled={!canLogin}
|
||||
className={clsx("rb:h-11.5! rb:rounded-lg! rb:mt-12", {
|
||||
'rb:hover:bg-[#2d6ef1]! rb:bg-[#155EEF]! rb:border-[#155EEF]!': canLogin,
|
||||
'rb:bg-[#171719]! rb:border-[#171719]!': !canLogin
|
||||
})}
|
||||
>
|
||||
{t('login.loginIn')}
|
||||
</Button>
|
||||
|
||||
@@ -361,7 +361,7 @@ const Market: React.FC<{ getStatusTag?: (status: string) => ReactNode }> = () =>
|
||||
)}
|
||||
</Flex>
|
||||
<div>
|
||||
<div className="rb:font-[MiSans Bold] rb:font-bold rb:text-[16px] rb:leading-5.5">{source.name}</div>
|
||||
<div className="rb:font-[MiSans-Bold] rb:font-bold rb:text-[16px] rb:leading-5.5">{source.name}</div>
|
||||
<div className="rb:text-[#5B6167] rb:text-[12px] rb:leading-4.5">{t('tool.availableMcp')} ({mcpTotal})</div>
|
||||
</div>
|
||||
</Flex>
|
||||
|
||||
Reference in New Issue
Block a user