From a5bce221bd0ea74aa6c1200f9280f319c7d2c2c7 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Thu, 26 Mar 2026 20:12:11 +0800 Subject: [PATCH 001/117] refactor(memory-api): migrate end user creation to authenticated API endpoint - Remove unauthenticated end_user_controller and its router registration - Move end user creation logic to authenticated memory_api_controller endpoint - Add create_end_user method to MemoryAPIService with workspace authorization - Fix retrieve_nodes import in read_graph to use correct function reference - Consolidate end user management under authenticated memory API with API key scoping --- api/app/controllers/__init__.py | 2 - api/app/controllers/end_user_controller.py | 48 ------------------- .../service/memory_api_controller.py | 30 ++++++++++++ .../agent/langgraph_graph/read_graph.py | 6 +-- api/app/services/memory_api_service.py | 47 ++++++++++++++++++ 5 files changed, 80 insertions(+), 53 deletions(-) delete mode 100644 api/app/controllers/end_user_controller.py diff --git a/api/app/controllers/__init__.py b/api/app/controllers/__init__.py index 869eb039..50e9e0b0 100644 --- a/api/app/controllers/__init__.py +++ b/api/app/controllers/__init__.py @@ -14,7 +14,6 @@ from . import ( document_controller, emotion_config_controller, emotion_controller, - end_user_controller, file_controller, file_storage_controller, home_page_controller, @@ -99,6 +98,5 @@ manager_router.include_router(file_storage_controller.router) manager_router.include_router(ontology_controller.router) manager_router.include_router(skill_controller.router) manager_router.include_router(i18n_controller.router) -manager_router.include_router(end_user_controller.router) __all__ = ["manager_router"] diff --git a/api/app/controllers/end_user_controller.py b/api/app/controllers/end_user_controller.py deleted file mode 100644 index b9d54fea..00000000 --- a/api/app/controllers/end_user_controller.py +++ /dev/null @@ -1,48 +0,0 @@ -"""End User 管理接口 - 无需认证""" - -from app.core.logging_config import get_business_logger -from app.core.response_utils import success -from app.db import get_db -from app.repositories.end_user_repository import EndUserRepository -from app.schemas.memory_api_schema import ( - CreateEndUserRequest, - CreateEndUserResponse, -) -from fastapi import APIRouter, Depends -from sqlalchemy.orm import Session - -router = APIRouter(prefix="/end_users", tags=["End Users"]) -logger = get_business_logger() - - -@router.post("") -async def create_end_user( - data: CreateEndUserRequest, - db: Session = Depends(get_db), -): - """ - Create an end user. - - Creates a new end user for the given workspace. - If an end user with the same other_id already exists in the workspace, - returns the existing one. - """ - logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}") - - end_user_repo = EndUserRepository(db) - end_user = end_user_repo.get_or_create_end_user( - app_id=None, - workspace_id=data.workspace_id, - other_id=data.other_id, - ) - - logger.info(f"End user ready: {end_user.id}") - - result = { - "id": str(end_user.id), - "other_id": end_user.other_id or "", - "other_name": end_user.other_name or "", - "workspace_id": str(end_user.workspace_id), - } - - return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/controllers/service/memory_api_controller.py b/api/app/controllers/service/memory_api_controller.py index 08a94a89..dc5e0408 100644 --- a/api/app/controllers/service/memory_api_controller.py +++ b/api/app/controllers/service/memory_api_controller.py @@ -6,6 +6,8 @@ from app.core.response_utils import success from app.db import get_db from app.schemas.api_key_schema import ApiKeyAuth from app.schemas.memory_api_schema import ( + CreateEndUserRequest, + CreateEndUserResponse, ListConfigsResponse, MemoryReadRequest, MemoryReadResponse, @@ -113,3 +115,31 @@ async def list_memory_configs( logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}") return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully") + + +@router.post("/end_users") +@require_api_key(scopes=["memory"]) +async def create_end_user( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), +): + """ + Create an end user. + + Creates a new end user for the authorized workspace. + If an end user with the same other_id already exists, returns the existing one. + """ + body = await request.json() + payload = CreateEndUserRequest(**body) + logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {api_key_auth.workspace_id}") + + memory_api_service = MemoryAPIService(db) + + result = memory_api_service.create_end_user( + workspace_id=api_key_auth.workspace_id, + other_id=payload.other_id, + ) + + logger.info(f"End user ready: {result['id']}") + return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/core/memory/agent/langgraph_graph/read_graph.py b/api/app/core/memory/agent/langgraph_graph/read_graph.py index bddae618..e698e6ad 100644 --- a/api/app/core/memory/agent/langgraph_graph/read_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/read_graph.py @@ -15,7 +15,7 @@ from app.core.memory.agent.langgraph_graph.nodes.problem_nodes import ( Problem_Extension, ) from app.core.memory.agent.langgraph_graph.nodes.retrieve_nodes import ( - retrieve, + retrieve_nodes, ) from app.core.memory.agent.langgraph_graph.nodes.summary_nodes import ( Input_Summary, @@ -53,8 +53,8 @@ async def make_read_graph(): workflow.add_node("Split_The_Problem", Split_The_Problem) workflow.add_node("Problem_Extension", Problem_Extension) workflow.add_node("Input_Summary", Input_Summary) - # workflow.add_node("Retrieve", retrieve_nodes) - workflow.add_node("Retrieve", retrieve) + workflow.add_node("Retrieve", retrieve_nodes) + # workflow.add_node("Retrieve", retrieve) workflow.add_node("Verify", Verify) workflow.add_node("Retrieve_Summary", Retrieve_Summary) workflow.add_node("Summary", Summary) diff --git a/api/app/services/memory_api_service.py b/api/app/services/memory_api_service.py index 9282fc28..f62f526c 100644 --- a/api/app/services/memory_api_service.py +++ b/api/app/services/memory_api_service.py @@ -280,6 +280,53 @@ class MemoryAPIService: code=BizCode.MEMORY_READ_FAILED ) + def create_end_user( + self, + workspace_id: uuid.UUID, + other_id: str, + ) -> Dict[str, Any]: + """Create or retrieve an end user for the workspace. + + Uses get_or_create semantics: if an end user with the same other_id + already exists in the workspace, returns the existing one. + + Args: + workspace_id: Workspace ID from API key authorization + other_id: External user identifier + + Returns: + Dict with id, other_id, other_name, and workspace_id + + Raises: + BusinessException: If creation fails + """ + logger.info(f"Creating end user - other_id: {other_id}, workspace_id: {workspace_id}") + + try: + from app.repositories.end_user_repository import EndUserRepository + + end_user_repo = EndUserRepository(self.db) + end_user = end_user_repo.get_or_create_end_user( + app_id=None, + workspace_id=workspace_id, + other_id=other_id, + ) + + logger.info(f"End user ready: {end_user.id}") + return { + "id": str(end_user.id), + "other_id": end_user.other_id or "", + "other_name": end_user.other_name or "", + "workspace_id": str(end_user.workspace_id), + } + + except Exception as e: + logger.error(f"Failed to create end user for workspace {workspace_id}: {e}") + raise BusinessException( + message=f"Failed to create end user: {str(e)}", + code=BizCode.INTERNAL_ERROR + ) + def list_memory_configs( self, workspace_id: uuid.UUID, From 4534b65d6a67392c1a29ce0aac0452ec89e9fd8d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 11:56:22 +0800 Subject: [PATCH 002/117] refactor(workflow): optimize workflow history queries and migrate ORM to SQLAlchemy 2.0 - Migrate historical workflow queries from legacy ORM Query API to SQLAlchemy 2.0 select() + execute() - Limit query fields and use pagination to reduce returned data, improving performance - Preserve original ordering and filtering logic --- api/app/repositories/workflow_repository.py | 35 ++++++++------ api/app/services/workflow_service.py | 51 +++++++++++---------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/api/app/repositories/workflow_repository.py b/api/app/repositories/workflow_repository.py index 4e24faa0..a783fe3f 100644 --- a/api/app/repositories/workflow_repository.py +++ b/api/app/repositories/workflow_repository.py @@ -3,9 +3,9 @@ """ import uuid -from typing import Any, Annotated +from typing import Any, Annotated, Literal from sqlalchemy.orm import Session -from sqlalchemy import desc +from sqlalchemy import desc, select from fastapi import Depends from app.models.workflow_model import ( @@ -128,29 +128,36 @@ class WorkflowExecutionRepository: Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.app_id == app_id ).order_by( desc(WorkflowExecution.started_at) - ).limit(limit).offset(offset).all() + ).limit(limit).offset(offset) + return list(self.db.execute(stmt).scalars()) def get_by_conversation_id( self, - conversation_id: uuid.UUID + conversation_id: uuid.UUID, + status: Literal["running", "completed", "failed"] = None, + limit_count: int = 50 ) -> list[WorkflowExecution]: """根据会话 ID 获取执行记录列表 Args: + limit_count: conversation_id: 会话 ID + status: 状态(可选) Returns: 执行记录列表 """ - return self.db.query(WorkflowExecution).filter( + stmt = select(WorkflowExecution).filter( WorkflowExecution.conversation_id == conversation_id - ).order_by( - desc(WorkflowExecution.started_at) - ).all() + ) + if status: + stmt = stmt.filter(WorkflowExecution.status == status) + stmt = stmt.order_by(desc(WorkflowExecution.started_at)).limit(limit_count) + return list(self.db.execute(stmt).scalars()) def count_by_app_id(self, app_id: uuid.UUID) -> int: """统计应用的执行次数 @@ -199,11 +206,12 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表(按执行顺序排序) """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id ).order_by( WorkflowNodeExecution.execution_order - ).all() + ) + return list(self.db.execute(stmt).scalars()) def get_by_node_id( self, @@ -219,12 +227,13 @@ class WorkflowNodeExecutionRepository: Returns: 节点执行记录列表 """ - return self.db.query(WorkflowNodeExecution).filter( + stmt = select(WorkflowNodeExecution).filter( WorkflowNodeExecution.execution_id == execution_id, WorkflowNodeExecution.node_id == node_id ).order_by( WorkflowNodeExecution.retry_count - ).all() + ) + return list(self.db.execute(stmt).scalars()) # ==================== 依赖注入函数 ==================== diff --git a/api/app/services/workflow_service.py b/api/app/services/workflow_service.py index c7d7f2b1..13267078 100644 --- a/api/app/services/workflow_service.py +++ b/api/app/services/workflow_service.py @@ -561,6 +561,24 @@ class WorkflowService: storage_type = 'neo4j' return storage_type, user_rag_memory_id + def _get_history_info(self, conversation_id: uuid.UUID) -> tuple[dict, list] | None: + executions = self.execution_repo.get_by_conversation_id( + conversation_id=conversation_id, + status="completed", + limit_count=1 + ) + + if executions: + last_state = executions[0].output_data + if isinstance(last_state, dict): + variables = last_state.get("variables", {}) + conv_vars = variables.get("conv", {}) + # input_data["conv"] = conv_vars + # input_data["conv_messages"] = last_state.get("messages") or [] + conv_messages = last_state.get("messages") or [] + return conv_vars, conv_messages + return None + # ==================== 工作流执行 ==================== async def run( @@ -634,18 +652,11 @@ class WorkflowService: # 更新状态为运行中 self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break - + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) result = await execute_workflow( @@ -807,17 +818,11 @@ class WorkflowService: storage_type, user_rag_memory_id = self._get_memory_store_info(workspace_id) input_data["files"] = files self.update_execution_status(execution.execution_id, "running") - executions = self.execution_repo.get_by_conversation_id(conversation_id=conversation_id_uuid) - - for exec_res in executions: - if exec_res.status == "completed": - last_state = exec_res.output_data - if isinstance(last_state, dict): - variables = last_state.get("variables", {}) - conv_vars = variables.get("conv", {}) - input_data["conv"] = conv_vars - input_data["conv_messages"] = last_state.get("messages") or [] - break + history = self._get_history_info(conversation_id_uuid) + if history: + conv_vars, conv_messages = history + input_data["conv"] = conv_vars + input_data["conv_messages"] = conv_messages init_message_length = len(input_data.get("conv_messages", [])) message_id = uuid.uuid4() async for event in execute_workflow_stream( From 7fd00009a21ed3334f6c72b9068fd3a018ffd637 Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 12:00:30 +0800 Subject: [PATCH 003/117] perf(workflow): introduce LazyDict to reduce variable serialization, optimize regex to reduce compilation - Use LazyDict for deferred serialization, improving performance - Reuse regex patterns to avoid repeated compilation --- api/app/core/workflow/engine/state_manager.py | 8 +- api/app/core/workflow/engine/variable_pool.py | 60 ++++++- api/app/core/workflow/nodes/base_node.py | 12 +- .../core/workflow/nodes/cycle_graph/loop.py | 15 +- .../workflow/utils/expression_evaluator.py | 78 ++++----- .../core/workflow/utils/template_renderer.py | 164 +++++++++--------- 6 files changed, 188 insertions(+), 149 deletions(-) diff --git a/api/app/core/workflow/engine/state_manager.py b/api/app/core/workflow/engine/state_manager.py index 2da0d3a8..eed44278 100644 --- a/api/app/core/workflow/engine/state_manager.py +++ b/api/app/core/workflow/engine/state_manager.py @@ -9,10 +9,10 @@ from app.core.workflow.nodes.enums import NodeType def merge_activate_state(x, y): - return { - k: x.get(k, False) or y.get(k, False) - for k in set(x) | set(y) - } + merged = dict(x) + for k, v in y.items(): + merged[k] = merged.get(k, False) or v + return merged def merge_looping_state(x, y): diff --git a/api/app/core/workflow/engine/variable_pool.py b/api/app/core/workflow/engine/variable_pool.py index 60f1257e..7faca82d 100644 --- a/api/app/core/workflow/engine/variable_pool.py +++ b/api/app/core/workflow/engine/variable_pool.py @@ -17,6 +17,51 @@ from app.core.workflow.variable.variable_objects import T, create_variable_insta logger = logging.getLogger(__name__) +VARIABLE_PATTERN = re.compile(r"\{\{\s*(.*?)\s*}}") + + +class LazyVariableDict: + def __init__(self, source, literal): + self._source: dict[str, VariableStruct[Any]] = source + self._literal: bool = literal + self._cache = {} + + def keys(self): + return self._source.keys() + + def _resolve(self, key): + if key in self._cache: + return self._cache[key] + var_struct = self._source.get(key) + if var_struct is None: + raise KeyError(key) + value = var_struct.instance.to_literal() if self._literal else var_struct.instance.get_value() + self._cache[key] = value + return value + + def get(self, key, default=None): + try: + return self._resolve(key) + except KeyError: + return default + + def __getitem__(self, key): + return self._resolve(key) + + def __getattr__(self, key): + if key.startswith('_'): + raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{key}'") + return self._resolve(key) + + def __contains__(self, key): + return key in self._source + + def __iter__(self): + return iter(self._source) + + def __len__(self): + return len(self._source) + class VariableSelector: """变量选择器 @@ -117,8 +162,7 @@ class VariablePool: @staticmethod def transform_selector(selector): - pattern = r"\{\{\s*(.*?)\s*\}\}" - variable_literal = re.sub(pattern, r"\1", selector).strip() + variable_literal = VARIABLE_PATTERN.sub(r"\1", selector).strip() selector = VariableSelector.from_string(variable_literal).path if len(selector) != 2: raise ValueError(f"Selector not valid - {selector}") @@ -303,6 +347,16 @@ class VariablePool: """ return self._get_variable_struct(selector) is not None + def lazy_namespace(self, namespace: str, literal: bool = False) -> LazyVariableDict: + return LazyVariableDict(self.variables.get(namespace, {}), literal) + + def lazy_all_node_outputs(self, literal: bool = False) -> dict[str, LazyVariableDict]: + return { + ns: LazyVariableDict(vars_dict, literal) + for ns, vars_dict in self.variables.items() + if ns not in ("sys", "conv") + } + def get_all_system_vars(self, literal=False) -> dict[str, Any]: """获取所有系统变量 @@ -479,5 +533,3 @@ class VariablePoolInitializer: var_type=var_type, mut=False ) - - diff --git a/api/app/core/workflow/nodes/base_node.py b/api/app/core/workflow/nodes/base_node.py index 8567ebbe..bedf6165 100644 --- a/api/app/core/workflow/nodes/base_node.py +++ b/api/app/core/workflow/nodes/base_node.py @@ -552,9 +552,9 @@ class BaseNode(ABC): return render_template( template=template, - conv_vars=variable_pool.get_all_conversation_vars(literal=True), - node_outputs=variable_pool.get_all_node_outputs(literal=True), - system_vars=variable_pool.get_all_system_vars(literal=True), + conv_vars=variable_pool.lazy_namespace("conv", literal=True), + node_outputs=variable_pool.lazy_all_node_outputs(literal=True), + system_vars=variable_pool.lazy_namespace("sys", literal=True), strict=strict ) @@ -579,9 +579,9 @@ class BaseNode(ABC): return evaluate_condition( expression=expression, - conv_var=variable_pool.get_all_conversation_vars(), - node_outputs=variable_pool.get_all_node_outputs(), - system_vars=variable_pool.get_all_system_vars() + conv_var=variable_pool.lazy_namespace("conv"), + node_outputs=variable_pool.lazy_all_node_outputs(), + system_vars=variable_pool.lazy_namespace("sys") ) @staticmethod diff --git a/api/app/core/workflow/nodes/cycle_graph/loop.py b/api/app/core/workflow/nodes/cycle_graph/loop.py index 84901bad..e555a228 100644 --- a/api/app/core/workflow/nodes/cycle_graph/loop.py +++ b/api/app/core/workflow/nodes/cycle_graph/loop.py @@ -11,7 +11,6 @@ from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.cycle_graph import LoopNodeConfig from app.core.workflow.nodes.enums import ValueInputType, ComparisonOperator, LogicOperator, NodeType from app.core.workflow.nodes.operators import TypeTransformer, ConditionExpressionResolver, CompareOperatorInstance -from app.core.workflow.utils.expression_evaluator import evaluate_expression logger = logging.getLogger(__name__) @@ -85,12 +84,7 @@ class LoopRuntime: for variable in self.typed_config.cycle_vars: if variable.input_type == ValueInputType.VARIABLE: - value = evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + value = self.variable_pool.get_value(variable.value) else: value = TypeTransformer.transform(variable.value, variable.type) await self.child_variable_pool.new(self.node_id, variable.name, value, variable.type, mut=True) @@ -98,12 +92,7 @@ class LoopRuntime: **self.state ) loopstate["node_outputs"][self.node_id] = { - variable.name: evaluate_expression( - expression=variable.value, - conv_var=self.variable_pool.get_all_conversation_vars(), - node_outputs=self.variable_pool.get_all_node_outputs(), - system_vars=self.variable_pool.get_all_system_vars(), - ) + variable.name: self.variable_pool.get_value(variable.value) if variable.input_type == ValueInputType.VARIABLE else TypeTransformer.transform(variable.value, variable.type) for variable in self.typed_config.cycle_vars diff --git a/api/app/core/workflow/utils/expression_evaluator.py b/api/app/core/workflow/utils/expression_evaluator.py index 4bc5fc4c..05a3294b 100644 --- a/api/app/core/workflow/utils/expression_evaluator.py +++ b/api/app/core/workflow/utils/expression_evaluator.py @@ -4,32 +4,33 @@ from typing import Any from simpleeval import simple_eval, NameNotDefined, InvalidExpression +from app.core.workflow.engine.variable_pool import LazyVariableDict, VARIABLE_PATTERN + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class ExpressionEvaluator: """Safe expression evaluator for workflow variables and node outputs.""" - + # Reserved namespaces RESERVED_NAMESPACES = {"var", "node", "sys", "nodes"} @classmethod def normalize_template(cls, template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @classmethod def evaluate( - cls, - expression: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + cls, + expression: str, + conv_vars: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> Any: """ Safely evaluate an expression using workflow variables. @@ -49,48 +50,47 @@ class ExpressionEvaluator: # Remove Jinja2-style brackets if present expression = expression.strip() expression = cls.normalize_template(expression) - pattern = r"\{\{\s*(.*?)\s*\}\}" - expression = re.sub(pattern, r"\1", expression).strip() + expression = VARIABLE_PATTERN.sub(r"\1", expression).strip() # Build context for evaluation context = { - "conv": conv_vars, # conversation variables - "node": node_outputs, # node outputs - "sys": system_vars or {}, # system variables + "conv": conv_vars, # conversation variables + "node": node_outputs, # node outputs + "sys": system_vars or {}, # system variables } - context.update(conv_vars) - context["nodes"] = node_outputs + # context.update(conv_vars) + # context["nodes"] = node_outputs context.update(node_outputs) - + try: # simpleeval supports safe operations: # arithmetic, comparisons, logical ops, attribute/dict/list access result = simple_eval(expression, names=context) return result - + except NameNotDefined as e: logger.error(f"Undefined variable in expression: {expression}, error: {e}") raise ValueError(f"Undefined variable: {e}") - + except InvalidExpression as e: logger.error(f"Invalid expression syntax: {expression}, error: {e}") raise ValueError(f"Invalid expression syntax: {e}") - + except SyntaxError as e: logger.error(f"Syntax error in expression: {expression}, error: {e}") raise ValueError(f"Syntax error: {e}") - + except Exception as e: logger.error(f"Expression evaluation failed: {expression}, error: {e}") raise ValueError(f"Expression evaluation failed: {e}") - + @staticmethod def evaluate_bool( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + expression: str, + conv_var: dict[str, Any], + node_outputs: dict[str, Any], + system_vars: dict[str, Any] | None = None ) -> bool: """ Evaluate a boolean expression (for conditions). @@ -108,7 +108,7 @@ class ExpressionEvaluator: expression, conv_var, node_outputs, system_vars ) return bool(result) - + @staticmethod def validate_variable_names(variables: list[dict]) -> list[str]: """ @@ -121,7 +121,7 @@ class ExpressionEvaluator: list[str]: List of error messages. Empty if all names are valid. """ errors = [] - + for var in variables: var_name = var.get("name", "") @@ -134,16 +134,16 @@ class ExpressionEvaluator: errors.append( f"Variable name '{var_name}' is not a valid Python identifier" ) - + return errors # 便捷函数 def evaluate_expression( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict ) -> Any: """Evaluate an expression (convenience function).""" return ExpressionEvaluator.evaluate( @@ -152,11 +152,11 @@ def evaluate_expression( def evaluate_condition( - expression: str, - conv_var: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None -) -> bool: + expression: str, + conv_var: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, dict[str, Any] | LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict +) -> Any: """Evaluate a boolean condition expression (convenience function).""" return ExpressionEvaluator.evaluate_bool( expression, conv_var, node_outputs, system_vars diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index 6a73efc4..bb1e18bf 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -1,7 +1,8 @@ """ -模板渲染器 +Template Renderer -使用 Jinja2 提供安全的模板渲染功能,支持变量引用和表达式。 +Provides safe template rendering using Jinja2, supporting variable references +and expressions. """ import logging @@ -10,11 +11,15 @@ from typing import Any from jinja2 import TemplateSyntaxError, UndefinedError, Environment, StrictUndefined, Undefined +from app.core.workflow.engine.variable_pool import LazyVariableDict + logger = logging.getLogger(__name__) +_NORMALIZE_PATTERN = re.compile(r"\{\{\s*(\d+)\.(\w+)\s*}}") + class SafeUndefined(Undefined): - """访问未定义属性不会报错,返回空字符串""" + """Return empty string instead of raising error when accessing undefined variables""" __slots__ = () def _fail_with_undefined_error(self, *args, **kwargs): @@ -26,26 +31,22 @@ class SafeUndefined(Undefined): class TemplateRenderer: - """模板渲染器""" - def __init__(self, strict: bool = True): - """初始化渲染器 - + """Initialize renderer + Args: - strict: 是否使用严格模式(未定义变量会抛出异常) + strict: Whether to enable strict mode (raise error on undefined variables) """ self.strict = strict self.env = Environment( undefined=StrictUndefined if strict else SafeUndefined, - autoescape=False # 不自动转义,因为我们处理的是文本而非 HTML + autoescape=False # Disable auto-escaping since we handle plain text instead of HTML ) @staticmethod def normalize_template(template: str) -> str: - pattern = re.compile( - r"\{\{\s*(\d+)\.(\w+)\s*}}" - ) - return pattern.sub( + """Normalize template syntax (convert numeric node reference to dict access)""" + return _NORMALIZE_PATTERN.sub( r'{{ node["\1"].\2 }}', template ) @@ -53,24 +54,24 @@ class TemplateRenderer: def render( self, template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any] | None = None + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], + system_vars: dict[str, Any] | LazyVariableDict | None = None ) -> str: - """渲染模板 - + """Render template + Args: - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出结果 - system_vars: 系统变量 - + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Raises: - ValueError: 模板语法错误或变量未定义 - + ValueError: If template syntax is invalid or variables are undefined + Examples: >>> renderer = TemplateRenderer() >>> renderer.render( @@ -80,122 +81,119 @@ class TemplateRenderer: ... {} ... ) 'Hello World!' - + >>> renderer.render( - ... "分析结果: {{node.analyze.output}}", + ... "Analysis result: {{node.analyze.output}}", ... {}, - ... {"analyze": {"output": "正面情绪"}}, + ... {"analyze": {"output": "positive sentiment"}}, ... {} ... ) - '分析结果: 正面情绪' + 'Analysis result: positive sentiment' """ - # 构建命名空间上下文 + # Build namespace context context = { - "conv": conv_vars, # 会话变量:{{conv.user_name}} - "node": node_outputs, # 节点输出:{{node.node_1.output}} - "sys": system_vars, # 系统变量:{{sys.execution_id}} + "conv": conv_vars, # Conversation variables: {{conv.user_name}} + "node": node_outputs, # Node outputs: {{node.node_1.output}} + "sys": system_vars, # System variables: {{sys.execution_id}} } - # 支持直接通过节点ID访问节点输出:{{llm_qa.output}} - # 将所有节点输出添加到顶层上下文 + # Allow direct access to node outputs by node ID: {{llm_qa.output}} if node_outputs: context.update(node_outputs) - # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} - if conv_vars: - context.update(conv_vars) - - context["nodes"] = node_outputs or {} # 旧语法兼容 + # # 支持直接访问会话变量(不需要 conv. 前缀):{{user_name}} + # if conv_vars: + # context.update(conv_vars) + # + # context["nodes"] = node_outputs or {} # 旧语法兼容 template = self.normalize_template(template) try: tmpl = self.env.from_string(template) return tmpl.render(**context) except TemplateSyntaxError as e: - logger.error(f"模板语法错误: {template}, 错误: {e}") - raise ValueError(f"模板语法错误: {e}") - + logger.error(f"Template syntax error: {template}, error: {e}") + raise ValueError(f"Template syntax error: {e}") except UndefinedError as e: - logger.error(f"模板中引用了未定义的变量: {template}, 错误: {e}") - raise ValueError(f"未定义的变量: {e}") - + logger.error(f"Undefined variable in template: {template}, error: {e}") + raise ValueError(f"Undefined variable: {e}") except Exception as e: - logger.error(f"模板渲染异常: {template}, 错误: {e}") - raise ValueError(f"模板渲染失败: {e}") + logger.error(f"Template rendering error: {template}, error: {e}") + raise ValueError(f"Template rendering failed: {e}") def validate(self, template: str) -> list[str]: - """验证模板语法 - + """Validate template syntax + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表,如果为空则验证通过 - + List of errors (empty if valid) + Examples: >>> renderer = TemplateRenderer() >>> renderer.validate("Hello {{var.name}}!") [] - - >>> renderer.validate("Hello {{var.name") # 缺少结束标记 - ['模板语法错误: ...'] + + >>> renderer.validate("Hello {{var.name") # Missing closing tag + ['Template syntax error: ...'] """ errors = [] try: self.env.from_string(template) except TemplateSyntaxError as e: - errors.append(f"模板语法错误: {e}") + errors.append(f"Template syntax error: {e}") except Exception as e: - errors.append(f"模板验证失败: {e}") + errors.append(f"Template validation failed: {e}") return errors -# 全局渲染器实例(严格模式) +# Global renderer instances (strict / lenient) _strict_renderer = TemplateRenderer(strict=True) _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, - conv_vars: dict[str, Any], - node_outputs: dict[str, Any], - system_vars: dict[str, Any], + conv_vars: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | LazyVariableDict, + system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: - """渲染模板(便捷函数) - + """Render template (convenience function) + Args: - strict: 严格模式 - template: 模板字符串 - conv_vars: 会话变量 - node_outputs: 节点输出 - system_vars: 系统变量 - + strict: Whether to use strict mode + template: Template string + conv_vars: Conversation variables + node_outputs: Node outputs + system_vars: System variables + Returns: - 渲染后的字符串 - + Rendered string + Examples: >>> render_template( - ... "请分析: {{var.text}}", - ... {"text": "这是一段文本"}, + ... "Analyze: {{var.text}}", + ... {"text": "This is a text"}, ... {}, ... {} ... ) - '请分析: 这是一段文本' + 'Analyze: This is a text' """ renderer = _strict_renderer if strict else _lenient_renderer return renderer.render(template, conv_vars, node_outputs, system_vars) def validate_template(template: str) -> list[str]: - """验证模板语法(便捷函数) - + """Validate template syntax (convenience function) + Args: - template: 模板字符串 - + template: Template string + Returns: - 错误列表 + List of errors """ return _strict_renderer.validate(template) From bca43fcc75e41efa68f4fd997ee8acb0587c334d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Fri, 27 Mar 2026 12:02:36 +0800 Subject: [PATCH 004/117] perf(workflow): expose extract_document_text as instance method, optimize knowledge base parallel search - Change extract_document_text from private to instance method in multimodal service for external access - Optimize knowledge base search logic to improve parallel retrieval performance --- .../workflow/nodes/document_extractor/node.py | 2 +- api/app/core/workflow/nodes/knowledge/node.py | 175 ++++++++++++------ .../core/workflow/utils/template_renderer.py | 2 +- api/app/services/multimodal_service.py | 6 +- 4 files changed, 120 insertions(+), 65 deletions(-) diff --git a/api/app/core/workflow/nodes/document_extractor/node.py b/api/app/core/workflow/nodes/document_extractor/node.py index 40641f3c..bd828760 100644 --- a/api/app/core/workflow/nodes/document_extractor/node.py +++ b/api/app/core/workflow/nodes/document_extractor/node.py @@ -89,7 +89,7 @@ class DocExtractorNode(BaseNode): # Reuse cached bytes if already fetched if f.get_content(): file_input.set_content(f.get_content()) - text = await svc._extract_document_text(file_input) + text = await svc.extract_document_text(file_input) chunks.append(text) except Exception as e: logger.error( diff --git a/api/app/core/workflow/nodes/knowledge/node.py b/api/app/core/workflow/nodes/knowledge/node.py index 92699cb4..d0b6d098 100644 --- a/api/app/core/workflow/nodes/knowledge/node.py +++ b/api/app/core/workflow/nodes/knowledge/node.py @@ -1,19 +1,23 @@ +import asyncio import logging import uuid from typing import Any +from langchain_core.documents import Document + from app.core.error_codes import BizCode from app.core.exceptions import BusinessException from app.core.models import RedBearRerank, RedBearModelConfig -from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector +from app.core.rag.models.chunk import DocumentChunk +from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory from app.core.workflow.engine.state_manager import WorkflowState from app.core.workflow.engine.variable_pool import VariablePool from app.core.workflow.nodes.base_node import BaseNode from app.core.workflow.nodes.knowledge import KnowledgeRetrievalNodeConfig from app.core.workflow.variable.base_variable import VariableType from app.db import get_db_read -from app.models import knowledge_model, knowledgeshare_model, ModelType -from app.repositories import knowledge_repository, knowledgeshare_repository +from app.models import knowledge_model, ModelType +from app.repositories import knowledge_repository from app.schemas.chunk_schema import RetrieveType from app.services.model_service import ModelConfigService @@ -24,7 +28,6 @@ class KnowledgeRetrievalNode(BaseNode): def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any], down_stream_nodes: list[str]): super().__init__(node_config, workflow_config, down_stream_nodes) self.typed_config: KnowledgeRetrievalNodeConfig | None = None - self.vector_service: ElasticSearchVector | None = None def _output_types(self) -> dict[str, VariableType]: return { @@ -85,46 +88,54 @@ class KnowledgeRetrievalNode(BaseNode): unique.append(doc) return unique - def _get_existing_kb_ids(self, db, kb_ids): + def rerank(self, query: str, docs: list[DocumentChunk], top_k: int) -> list[DocumentChunk]: """ - Resolve all accessible and valid knowledge base IDs for retrieval. - - This includes: - - Private knowledge bases owned by the user - - Shared knowledge bases - - Source knowledge bases mapped via knowledge sharing relationships - + Reorder the list of document blocks and return the top_k results most relevant to the query Args: - db: Database session. - kb_ids (list[UUID]): Knowledge base IDs from node configuration. + query: query string + docs: List of document chunk to be rearranged + top_k: The number of top-level documents returned Returns: - list[UUID]: Final list of valid knowledge base IDs. + Rearranged document chunk list (sorted in descending order of relevance) + + Raises: + ValueError: If the input document list is empty or top_k is invalid """ - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Private) - - existing_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - filters = self._build_kb_filter(kb_ids, knowledge_model.PermissionType.Share) - - share_ids = knowledge_repository.get_chunked_knowledgeids( - db=db, - filters=filters - ) - - if share_ids: - filters = [ - knowledgeshare_model.KnowledgeShare.target_kb_id.in_(kb_ids) + reranker = self.get_reranker_model() + # parameter validation + if not docs: + raise ValueError("retrieval chunks be empty") + if top_k <= 0: + raise ValueError("top_k must be a positive integer") + try: + # Convert to LangChain Document object + documents = [ + Document( + page_content=doc.page_content, # Ensure that DocumentChunk possesses this attribute + metadata=doc.metadata or {} # Deal with possible None metadata + ) + for doc in docs ] - items = knowledgeshare_repository.get_source_kb_ids_by_target_kb_id( - db=db, - filters=filters + + # Perform reordering (compress_documents will automatically handle relevance scores and indexing) + reranked_docs = list(reranker.compress_documents(documents, query)) + + # Sort in descending order based on relevance score + reranked_docs.sort( + key=lambda x: x.metadata.get("relevance_score", 0), + reverse=True ) - existing_ids.extend(items) - return existing_ids + # Convert back to a list of DocumentChunk, and save the relevance_score to metadata["score"] + result = [] + for item in reranked_docs[:top_k]: + for doc in docs: + if doc.page_content == item.page_content: + doc.metadata["score"] = item.metadata["relevance_score"] + result.append(doc) + return result + except Exception as e: + raise RuntimeError(f"Failed to rerank documents: {str(e)}") from e def get_reranker_model(self) -> RedBearRerank: """ @@ -164,41 +175,77 @@ class KnowledgeRetrievalNode(BaseNode): ) return reranker - def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config): + async def knowledge_retrieval(self, db, query, db_knowledge, kb_config): + rs = [] if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER: children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id) + tasks = [] for child in children: if not (child and child.chunk_num > 0 and child.status == 1): continue - kb_config.kb_id = child.id - self.knowledge_retrieval(db, query, rs, child, kb_config) - return - self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) + child_kb_config = kb_config.model_copy() + child_kb_config.kb_id = child.id + tasks.append(self.knowledge_retrieval(db, query, child, child_kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) + return rs + vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge) indices = f"Vector_index_{kb_config.kb_id}_Node".lower() match kb_config.retrieve_type: case RetrieveType.PARTICIPLE: - rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + ) case RetrieveType.SEMANTIC: - rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight)) + rs.extend( + await asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + ) case RetrieveType.HYBRID: - rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.vector_similarity_weight) - rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k, - indices=indices, - score_threshold=kb_config.similarity_threshold) + rs1_task = asyncio.to_thread( + vector_service.search_by_vector, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.vector_similarity_weight + } + ) + rs2_task = asyncio.to_thread( + vector_service.search_by_full_text, **{ + "query": query, + "top_k": kb_config.top_k, + "indices": indices, + "score_threshold": kb_config.similarity_threshold + } + ) + rs1, rs2 = await asyncio.gather(rs1_task, rs2_task) # Deduplicate hybrid retrieval results unique_rs = self._deduplicate_docs(rs1, rs2) if not unique_rs: - return + return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k)) + rs.extend( + await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": unique_rs, "top_k": kb_config.top_k} + ) + ) else: rs.extend(sorted( unique_rs, @@ -207,6 +254,7 @@ class KnowledgeRetrievalNode(BaseNode): )[:kb_config.top_k]) case _: raise RuntimeError("Unknown retrieval type") + return rs async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any: """ @@ -238,17 +286,24 @@ class KnowledgeRetrievalNode(BaseNode): knowledge_bases = self.typed_config.knowledge_bases rs = [] + tasks = [] for kb_config in knowledge_bases: db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id) if not db_knowledge: raise RuntimeError("The knowledge base does not exist or access is denied.") - self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config) + tasks.append(self.knowledge_retrieval(db, query, db_knowledge, kb_config)) + if tasks: + result = await asyncio.gather(*tasks) + for _ in result: + rs.extend(_) if not rs: return [] if self.typed_config.reranker_id: - self.vector_service.reranker = self.get_reranker_model() - final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k) + final_rs = await asyncio.to_thread( + self.rerank, + **{"query": query, "docs": rs, "top_k": self.typed_config.reranker_top_k} + ) else: final_rs = sorted( rs, diff --git a/api/app/core/workflow/utils/template_renderer.py b/api/app/core/workflow/utils/template_renderer.py index bb1e18bf..2c2d0f67 100644 --- a/api/app/core/workflow/utils/template_renderer.py +++ b/api/app/core/workflow/utils/template_renderer.py @@ -158,7 +158,7 @@ _lenient_renderer = TemplateRenderer(strict=False) def render_template( template: str, conv_vars: dict[str, Any] | LazyVariableDict, - node_outputs: dict[str, Any] | LazyVariableDict, + node_outputs: dict[str, Any] | dict[str, LazyVariableDict], system_vars: dict[str, Any] | LazyVariableDict, strict: bool = True ) -> str: diff --git a/api/app/services/multimodal_service.py b/api/app/services/multimodal_service.py index 4cf3d89d..120cccb7 100644 --- a/api/app/services/multimodal_service.py +++ b/api/app/services/multimodal_service.py @@ -438,13 +438,13 @@ class MultimodalService: if file.transfer_method == TransferMethod.REMOTE_URL: return True, { "type": "text", - "text": f"\n{await self._extract_document_text(file)}\n" + "text": f"\n{await self.extract_document_text(file)}\n" } else: # 本地文件,提取文本内容 server_url = settings.FILE_LOCAL_SERVER_URL file.url = f"{server_url}/storage/permanent/{file.upload_file_id}" - text = await self._extract_document_text(file) + text = await self.extract_document_text(file) file_metadata = self.db.query(FileMetadata).filter( FileMetadata.id == file.upload_file_id ).first() @@ -542,7 +542,7 @@ class MultimodalService: server_url = settings.FILE_LOCAL_SERVER_URL return f"{server_url}/storage/permanent/{file_id}" - async def _extract_document_text(self, file: FileInput) -> str: + async def extract_document_text(self, file: FileInput) -> str: """ 提取文档文本内容 From f9f302dd2a599cb3ca38514712321b545bdb2afc Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 11:41:18 +0800 Subject: [PATCH 005/117] feat(web): api key & space config ui upgrade --- web/src/assets/images/common/eye.svg | 16 +-- web/src/assets/images/menuNew/arrow_t_r.svg | 16 +++ web/src/assets/images/menuNew/logout_red.svg | 17 +++ web/src/assets/images/menuNew/settings.svg | 19 +++ web/src/assets/images/menuNew/userInfo.svg | 13 +++ web/src/components/Header/index.module.css | 22 ++++ web/src/components/Header/index.tsx | 64 +++++++--- web/src/i18n/en.ts | 4 +- web/src/i18n/zh.ts | 4 +- web/src/styles/index.css | 5 + web/src/views/ApiKeyManagement/index.tsx | 116 ++++++++++++------- web/src/views/MemberManagement/index.tsx | 7 +- web/src/views/Prompt/index.tsx | 4 +- web/src/views/SpaceConfig/index.tsx | 61 ++++------ web/src/views/UserManagement/index.tsx | 2 +- 15 files changed, 256 insertions(+), 114 deletions(-) create mode 100644 web/src/assets/images/menuNew/arrow_t_r.svg create mode 100644 web/src/assets/images/menuNew/logout_red.svg create mode 100644 web/src/assets/images/menuNew/settings.svg create mode 100644 web/src/assets/images/menuNew/userInfo.svg diff --git a/web/src/assets/images/common/eye.svg b/web/src/assets/images/common/eye.svg index df2af1cf..c7b531b8 100644 --- a/web/src/assets/images/common/eye.svg +++ b/web/src/assets/images/common/eye.svg @@ -1,13 +1,13 @@ - - 编辑 + + link-outlined - - - - - - + + + + + + diff --git a/web/src/assets/images/menuNew/arrow_t_r.svg b/web/src/assets/images/menuNew/arrow_t_r.svg new file mode 100644 index 00000000..884e46c1 --- /dev/null +++ b/web/src/assets/images/menuNew/arrow_t_r.svg @@ -0,0 +1,16 @@ + + + 编组 51 + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/logout_red.svg b/web/src/assets/images/menuNew/logout_red.svg new file mode 100644 index 00000000..7057b974 --- /dev/null +++ b/web/src/assets/images/menuNew/logout_red.svg @@ -0,0 +1,17 @@ + + + 退出 + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/settings.svg b/web/src/assets/images/menuNew/settings.svg new file mode 100644 index 00000000..9a64bb29 --- /dev/null +++ b/web/src/assets/images/menuNew/settings.svg @@ -0,0 +1,19 @@ + + + 设置-界面设置 + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/web/src/assets/images/menuNew/userInfo.svg b/web/src/assets/images/menuNew/userInfo.svg new file mode 100644 index 00000000..0e67a919 --- /dev/null +++ b/web/src/assets/images/menuNew/userInfo.svg @@ -0,0 +1,13 @@ + + + 账户 + + + + + + + + + + \ No newline at end of file diff --git a/web/src/components/Header/index.module.css b/web/src/components/Header/index.module.css index ca7c79cf..d39c91ec 100644 --- a/web/src/components/Header/index.module.css +++ b/web/src/components/Header/index.module.css @@ -24,4 +24,26 @@ .header :global(.ant-breadcrumb .ant-breadcrumb-item:last-child a) { color: #212332; font-weight: 600; +} +.userDropdown:global(.ant-dropdown .ant-dropdown-menu), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu) { + padding: 12px 8px; +} +.userDropdown:global(.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-item), +.userDropdown:global(.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-submenu-title), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-submenu-title) { + padding-left: 8px; + padding-right: 4px; +} +.userDropdown:global(.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item.ant-dropdown-menu-item-danger:not(.ant-dropdown-menu-item-disabled):hover), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-item.ant-dropdown-menu-item-danger:not(.ant-dropdown-menu-item-disabled):hover) { + background-color: #F6F6F6; + color: #FF5D34; +} +.userDropdown:global(.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item-divider), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-item-divider), +.userDropdown:global(.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-submenu-title-divider), +.userDropdown:global(.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-submenu-title-divider) { + margin: 10px 8px; } \ No newline at end of file diff --git a/web/src/components/Header/index.tsx b/web/src/components/Header/index.tsx index 7a5d17b9..ac59fcc2 100644 --- a/web/src/components/Header/index.tsx +++ b/web/src/components/Header/index.tsx @@ -13,12 +13,12 @@ * @component */ -import { type FC, useRef } from 'react'; -import { Layout, Dropdown, Breadcrumb } from 'antd'; +import { type FC, useRef, useState } from 'react'; +import { Layout, Dropdown, Breadcrumb, Flex } from 'antd'; import type { MenuProps, BreadcrumbProps } from 'antd'; -import { UserOutlined, LogoutOutlined, SettingOutlined } from '@ant-design/icons'; import { useTranslation } from 'react-i18next'; import { useLocation } from 'react-router-dom'; +import clsx from 'clsx'; import { useUser } from '@/store/user'; import { useMenu } from '@/store/menu'; @@ -76,27 +76,39 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { const userMenuItems: MenuProps['items'] = [ { key: '1', + icon: + {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(0, 2) : user.username[0]} + , label: (<> -
{user.username}
-
{user.email}
+
{user.username}
+
{user.email}
), }, { key: '2', type: 'divider', + className: 'rb:bg-[#EBEBEB]!' }, { key: '3', - icon: , - label: t('header.userInfo'), + icon:
, + label: + {t('header.userInfo')} +
+
, + className: 'rb:text-[#212332]!', onClick: () => { userInfoModalRef.current?.handleOpen() }, }, { key: '4', - icon: , - label: t('header.settings'), + icon:
, + label: + {t('header.settings')} +
+
, + className: 'rb:text-[#212332]!', onClick: () => { settingModalRef.current?.handleOpen() }, @@ -104,12 +116,14 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { { key: '5', type: 'divider', + className: 'rb:bg-[#EBEBEB]!' }, { key: '6', - icon: , + icon:
, label: t('header.logout'), danger: true, + className: 'rb:hover:rb:bg-transparent rb:hover:text-[#FF5D34]!', onClick: handleLogout, }, ]; @@ -147,18 +161,34 @@ const AppHeader: FC<{source?: 'space' | 'manage';}> = ({source = 'manage'}) => { }); } + const [open, setOpen] = useState(false); + const handleOpenChange = (open: boolean) => { + setOpen(open); + } return (
{/* Breadcrumb navigation */} {/* User info dropdown menu */} - -
{user.username}
-
+ {user.username && ( + + + + {/[\u4e00-\u9fa5]/.test(user.username) ? user.username.slice(0, 2) : user.username[0]} + + {user.username} +
+
+
+ )} diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 9b957a84..2975796a 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -980,6 +980,7 @@ export const en = { scene_id: 'Ontology Scenario', }, member: { + memberList: 'Member List', username: 'Username', account: 'Account', role: 'Role', @@ -1908,7 +1909,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re permissionInfo: 'Permission Information', is_expired: 'Status', active: 'Active', - inactive: 'Expired' + inactive: 'Expired', + noScopes: 'There is no permission information here…', }, tool: { mcp: 'MCP Services', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 862ed5d4..3edd84e3 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1375,6 +1375,7 @@ export const zh = { scene_id: '本体场景', }, member: { + memberList: '成员列表', username: '用户名', account: '账号', role: '角色', @@ -1905,7 +1906,8 @@ export const zh = { permissionInfo: '授权信息', is_expired: '状态', active: '活跃', - inactive: '过期' + inactive: '过期', + noScopes: '暂无权限信息…', }, tool: { mcp: 'MCP 服务', diff --git a/web/src/styles/index.css b/web/src/styles/index.css index 71a6cce4..684a4ba6 100644 --- a/web/src/styles/index.css +++ b/web/src/styles/index.css @@ -396,6 +396,11 @@ body { color: #FFFFFF; background-color: #171719; } +.ant-dropdown .ant-dropdown-menu .ant-dropdown-menu-item.ant-dropdown-menu-item-danger:not(.ant-dropdown-menu-item-disabled):hover, +.ant-dropdown-menu-submenu .ant-dropdown-menu .ant-dropdown-menu-item.ant-dropdown-menu-item-danger:not(.ant-dropdown-menu-item-disabled):hover { + background-color: #F6F6F6; + color: #FF5D34; +} .spin.ant-spin-nested-loading .ant-spin-container::after { background: transparent; diff --git a/web/src/views/ApiKeyManagement/index.tsx b/web/src/views/ApiKeyManagement/index.tsx index 7bb417c5..071b9ef5 100644 --- a/web/src/views/ApiKeyManagement/index.tsx +++ b/web/src/views/ApiKeyManagement/index.tsx @@ -6,20 +6,22 @@ */ import React, { useRef } from 'react'; import { useTranslation } from 'react-i18next'; -import { Button, App, Space } from 'antd'; +import { Button, App, Dropdown, Flex } from 'antd'; import clsx from 'clsx'; import { DeleteOutlined, EditOutlined, EyeOutlined } from '@ant-design/icons'; import copy from 'copy-to-clipboard' +import type { MenuInfo } from 'rc-menu/lib/interface'; import type { ApiKey, ApiKeyModalRef } from './types'; import ApiKeyModal from './components/ApiKeyModal'; import ApiKeyDetailModal from './components/ApiKeyDetailModal'; -import RbCard from '@/components/RbCard/Card' +import RbCard from '@/components/RbCard' import { getApiKeyListUrl, deleteApiKey } from '@/api/apiKey'; import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList' import { formatDateTime } from '@/utils/format'; import Tag from '@/components/Tag' import { maskApiKeys } from '@/utils/apiKeyReplacer'; +import RbDescriptions from '@/components/RbDescriptions'; /** * API Key Management page component @@ -87,59 +89,85 @@ const ApiKeyManagement: React.FC = () => { } return ( <> -
+ -
+ ref={scrollListRef} url={getApiKeyListUrl} query={{ is_active: true, type: 'service' }} - column={2} + column={3} renderItem={(apiKeyItem) => { return ( - - {['id', 'is_expired', 'created_at'].map((key, index) => ( -
- {t(`apiKey.${key}`)} - - { key === 'created_at' - ? formatDateTime(apiKeyItem[key], 'YYYY-MM-DD HH:mm:ss') - : key === 'is_expired' - ? {apiKeyItem[key] ? t('apiKey.inactive') : t('apiKey.active')} - : String(apiKeyItem[key as keyof ApiKey]) - } - -
- ))} + + + {apiKeyItem.name} + + {apiKeyItem.scopes?.includes('memory') && {t('apiKey.memoryEngine')}} + {apiKeyItem.scopes?.includes('rag') && {t('apiKey.knowledgeBase')}} + {!apiKeyItem.scopes?.includes('memory') && !apiKeyItem.scopes?.includes('rag') &&
{t('apiKey.noScopes')}
} +
+
+ , + label: t('common.edit'), + onClick: () => handleEdit(apiKeyItem), + }, + { + key: 'view', + icon:
, + label: t('common.view'), + onClick: () => handleView(apiKeyItem), + }, + { + key: 'delete', + danger: true, + icon:
, + label: t('common.delete'), + onClick: () => handleDelete(apiKeyItem), + }, + ] + }} + placement="bottomRight" + > +
+ + + } + isNeedTooltip={false} + headerClassName="rb:min-h-[78px]!" + > + ({ + key, + label: t(`apiKey.${key}`), + children: + {key === 'created_at' + ? formatDateTime(apiKeyItem[key], 'YYYY-MM-DD HH:mm:ss') + : key === 'is_expired' + ? {apiKeyItem[key] ? t('apiKey.inactive') : t('apiKey.active')} + : String(apiKeyItem[key as keyof ApiKey]) + } + + }))} + /> -
- {maskApiKeys(apiKeyItem.api_key)} - - -
- - - {apiKeyItem.scopes?.includes('memory') && {t('apiKey.memoryEngine')}} - {apiKeyItem.scopes?.includes('rag') && {t('apiKey.knowledgeBase')}} - - -
- - - -
+ + {maskApiKeys(apiKeyItem.api_key)} + +
handleCopy(apiKeyItem.api_key)} className="rb:cursor-pointer rb:rounded-md rb:size-6 rb:bg-[url('@/assets/images/common/copy_dark.svg')] rb:bg-size-[16px_16px] rb:bg-center rb:bg-no-repeat" style={{ backgroundColor: 'rgba(0,0,0,0.08)' }}>
+
); }} diff --git a/web/src/views/MemberManagement/index.tsx b/web/src/views/MemberManagement/index.tsx index f846a310..b9b392ff 100644 --- a/web/src/views/MemberManagement/index.tsx +++ b/web/src/views/MemberManagement/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:42:12 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-26 15:48:43 + * @Last Modified time: 2026-03-30 11:38:42 */ /** * Member Management Page @@ -106,9 +106,10 @@ const MemberManagement: React.FC = () => { return (
- + +
{t('member.memberList')}
diff --git a/web/src/views/Prompt/index.tsx b/web/src/views/Prompt/index.tsx index 469b1e39..095a8daa 100644 --- a/web/src/views/Prompt/index.tsx +++ b/web/src/views/Prompt/index.tsx @@ -188,7 +188,7 @@ const Prompt: FC = () => { @@ -249,7 +249,7 @@ const Prompt: FC = () => {
{ const { t } = useTranslation(); @@ -63,7 +61,9 @@ const SpaceConfig: FC = () => { } return ( -
+
+
{t('menu.spaceConfig')}
+
{t('space.configAlert')}
{pageLoading ? :
{ layout="vertical" > - - - - {t('space.configAlert')} - - - - +
}
diff --git a/web/src/views/UserManagement/index.tsx b/web/src/views/UserManagement/index.tsx index 2ffdc91b..ed09994f 100644 --- a/web/src/views/UserManagement/index.tsx +++ b/web/src/views/UserManagement/index.tsx @@ -144,7 +144,7 @@ const UserManagement: React.FC = () => { return (
-
{t('user.userList')}
+
{t('user.userList')}
From 7acb7045f081046de2f2d381c26eb833062ff39c Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 11:47:58 +0800 Subject: [PATCH 006/117] feat(agent, memory): add agent-perceived memory writing --- .../controllers/public_share_controller.py | 80 -------- api/app/core/agent/langchain_agent.py | 90 ++------- .../langgraph_graph/routing/write_router.py | 83 +++------ .../agent/langgraph_graph/write_graph.py | 116 ++++-------- api/app/core/memory/agent/utils/redis_tool.py | 173 +++++++++--------- api/app/core/memory/llm_tools/llm_client.py | 2 +- api/app/schemas/memory_agent_schema.py | 10 +- api/app/services/app_chat_service.py | 65 ++++--- api/app/services/draft_run_service.py | 13 +- api/app/services/memory_perceptual_service.py | 21 --- api/app/services/model_service.py | 138 +++++++------- api/app/services/shared_chat_service.py | 43 ++--- 12 files changed, 304 insertions(+), 530 deletions(-) diff --git a/api/app/controllers/public_share_controller.py b/api/app/controllers/public_share_controller.py index f5284b46..26902b07 100644 --- a/api/app/controllers/public_share_controller.py +++ b/api/app/controllers/public_share_controller.py @@ -410,30 +410,6 @@ async def chat( agent_config = agent_config_4_app_release(release) if payload.stream: - # async def event_generator(): - # async for event in service.chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) async def event_generator(): async for event in app_chat_service.agnet_chat_stream( message=payload.message, @@ -459,20 +435,6 @@ async def chat( "X-Accel-Buffering": "no" } ) - # 非流式返回 - # result = await service.chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - # return success(data=conversation_schema.ChatResponse(**result)) result = await app_chat_service.agnet_chat( message=payload.message, conversation_id=conversation.id, # 使用已创建的会话 ID @@ -531,48 +493,6 @@ async def chat( ) return success(data=conversation_schema.ChatResponse(**result).model_dump(mode="json")) - # 多 Agent 流式返回 - # if payload.stream: - # async def event_generator(): - # async for event in service.multi_agent_chat_stream( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ): - # yield event - - # return StreamingResponse( - # event_generator(), - # media_type="text/event-stream", - # headers={ - # "Cache-Control": "no-cache", - # "Connection": "keep-alive", - # "X-Accel-Buffering": "no" - # } - # ) - - # # 多 Agent 非流式返回 - # result = await service.multi_agent_chat( - # share_token=share_token, - # message=payload.message, - # conversation_id=conversation.id, # 使用已创建的会话 ID - # user_id=str(new_end_user.id), # 转换为字符串 - # variables=payload.variables, - # password=password, - # web_search=payload.web_search, - # memory=payload.memory, - # storage_type=storage_type, - # user_rag_memory_id=user_rag_memory_id - # ) - - # return success(data=conversation_schema.ChatResponse(**result)) elif app_type == AppType.WORKFLOW: config = workflow_config_4_app_release(release) if not config.id: diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..38821313 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -11,18 +11,14 @@ LangChain Agent 封装 import time from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence -from app.core.memory.agent.langgraph_graph.write_graph import write_long_term -from app.db import get_db -from app.core.logging_config import get_business_logger -from app.core.models import RedBearLLM, RedBearModelConfig -from app.models.models_model import ModelType, ModelProvider -from app.services.memory_agent_service import ( - get_end_user_connected_config, -) from langchain.agents import create_agent from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.tools import BaseTool +from app.core.logging_config import get_business_logger +from app.core.models import RedBearLLM, RedBearModelConfig +from app.models.models_model import ModelType + logger = get_business_logger() @@ -226,10 +222,9 @@ class LangChainAgent: Returns: List[BaseMessage]: 消息列表 """ - messages = [] + messages:list = [SystemMessage(content=self.system_prompt)] # 添加系统提示词 - messages.append(SystemMessage(content=self.system_prompt)) # 添加历史消息 if history: @@ -293,12 +288,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, # 添加这个参数 - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> Dict[str, Any]: """执行对话 @@ -306,32 +296,12 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 [{"role": "user/assistant", "content": "..."}] context: 上下文信息(如知识库检索结果) + files: 多模态文件 Returns: Dict: 包含 content 和元数据的字典 """ - message_chat = message start_time = time.time() - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - from app.services.memory_agent_service import ( - get_end_user_connected_config, - ) - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - actual_end_user_id = end_user_id if end_user_id is not None else "unknown" - logger.info(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') - print(f'写入类型{storage_type, str(end_user_id), message, str(user_rag_memory_id)}') try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -419,9 +389,6 @@ class LangChainAgent: logger.info(f"最终提取的内容长度: {len(content)}") elapsed_time = time.time() - start_time - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, content, user_rag_memory_id, - actual_config_id) response = { "content": content, "model": self.model_name, @@ -452,12 +419,7 @@ class LangChainAgent: message: str, history: Optional[List[Dict[str, str]]] = None, context: Optional[str] = None, - end_user_id: Optional[str] = None, - config_id: Optional[str] = None, - storage_type: Optional[str] = None, - user_rag_memory_id: Optional[str] = None, - memory_flag: Optional[bool] = True, - files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 + files: Optional[List[Dict[str, Any]]] = None ) -> AsyncGenerator[str, None]: """执行流式对话 @@ -465,6 +427,7 @@ class LangChainAgent: message: 用户消息 history: 历史消息列表 context: 上下文信息 + files: 多模态文件 Yields: str: 消息内容块 @@ -475,23 +438,6 @@ class LangChainAgent: logger.info(f" Has tools: {bool(self.tools)}") logger.info(f" Tool count: {len(self.tools) if self.tools else 0}") logger.info("=" * 80) - message_chat = message - actual_config_id = config_id - # If config_id is None, try to get from end_user's connected config - if actual_config_id is None and end_user_id: - try: - db = next(get_db()) - try: - connected_config = get_end_user_connected_config(end_user_id, db) - actual_config_id = connected_config.get("memory_config_id") - except Exception as e: - logger.warning(f"Failed to get connected config for end_user {end_user_id}: {e}") - finally: - db.close() - except Exception as e: - logger.warning(f"Failed to get db session: {e}") - - # 注意:不在这里写入用户消息,等 AI 回复后一起写入 try: # 准备消息列表(支持多模态) messages = self._prepare_messages(message, history, context, files) @@ -501,17 +447,18 @@ class LangChainAgent: ) chunk_count = 0 - yielded_content = False # 统一使用 agent 的 astream_events 实现流式输出 logger.debug("使用 Agent astream_events 实现流式输出") full_content = '' try: + last_event = {} async for event in self.agent.astream_events( {"messages": messages}, version="v2", config={"recursion_limit": self.max_iterations} ): + last_event = event chunk_count += 1 kind = event.get("event") @@ -525,7 +472,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -536,18 +482,15 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif kind == "on_llm_stream": # 另一种 LLM 流式事件 @@ -558,7 +501,6 @@ class LangChainAgent: if isinstance(chunk_content, str) and chunk_content: full_content += chunk_content yield chunk_content - yielded_content = True elif isinstance(chunk_content, list): # 多模态响应:提取文本部分 for item in chunk_content: @@ -569,22 +511,18 @@ class LangChainAgent: if text: full_content += text yield text - yielded_content = True # OpenAI 格式: {"type": "text", "text": "..."} elif item.get("type") == "text": text = item.get("text", "") if text: full_content += text yield text - yielded_content = True elif isinstance(item, str): full_content += item yield item - yielded_content = True elif isinstance(chunk, str): full_content += chunk yield chunk - yielded_content = True # 记录工具调用(可选) elif kind == "on_tool_start": @@ -594,7 +532,7 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 - output_messages = event.get("data", {}).get("output", {}).get("messages", []) + output_messages = last_event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None @@ -604,9 +542,7 @@ class LangChainAgent: ) if response_meta else 0 yield total_tokens break - if memory_flag: - await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, - actual_config_id) + except Exception as e: logger.error(f"Agent astream_events 失败: {str(e)}", exc_info=True) raise diff --git a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py index 2074b6ca..74fb6bae 100644 --- a/api/app/core/memory/agent/langgraph_graph/routing/write_router.py +++ b/api/app/core/memory/agent/langgraph_graph/routing/write_router.py @@ -12,7 +12,6 @@ from app.core.memory.utils.llm.llm_utils import MemoryClientFactory from app.db import get_db_context from app.repositories.memory_short_repository import LongTermMemoryRepository from app.schemas.memory_agent_schema import AgentMemory_Long_Term -from app.services.memory_konwledges_server import write_rag from app.services.task_service import get_task_memory_write_result from app.tasks import write_message_task from app.utils.config_utils import resolve_config_id @@ -21,25 +20,6 @@ logger = get_agent_logger(__name__) template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt') -async def write_rag_agent(end_user_id, user_message, ai_message, user_rag_memory_id): - """ - Write messages to RAG storage system - - Combines user and AI messages into a single string format and stores them - in the RAG (Retrieval-Augmented Generation) knowledge base for future retrieval. - - Args: - end_user_id: User identifier for the conversation - user_message: User's input message content - ai_message: AI's response message content - user_rag_memory_id: RAG memory identifier for storage location - """ - # RAG mode: combine messages into string format (maintain original logic) - combined_message = f"user: {user_message}\nassistant: {ai_message}" - await write_rag(end_user_id, combined_message, user_rag_memory_id) - logger.info(f'RAG_Agent:{end_user_id};{user_rag_memory_id}') - - async def write( storage_type, end_user_id, @@ -118,7 +98,7 @@ async def write( logger.info(f'[WRITE] Task result - user={actual_end_user_id}, status={write_status}') -async def term_memory_save(long_term_messages, actual_config_id, end_user_id, type, scope): +async def term_memory_save(end_user_id, strategy_type, scope): """ Save long-term memory data to database @@ -127,10 +107,8 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty to long-term memory storage. Args: - long_term_messages: Long-term message data to be saved - actual_config_id: Configuration identifier for memory settings end_user_id: User identifier for memory association - type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) + strategy_type: Memory storage strategy type (STRATEGY_CHUNK or STRATEGY_AGGREGATE) scope: Scope/window size for memory processing """ with get_db_context() as db_session: @@ -138,7 +116,10 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty from app.core.memory.agent.utils.redis_tool import write_store result = write_store.get_session_by_userid(end_user_id) - if type == AgentMemory_Long_Term.STRATEGY_CHUNK or AgentMemory_Long_Term.STRATEGY_AGGREGATE: + if not result: + logger.warning(f"No write data found for user {end_user_id}") + return + if strategy_type in [AgentMemory_Long_Term.STRATEGY_CHUNK, AgentMemory_Long_Term.STRATEGY_AGGREGATE]: data = await format_parsing(result, "dict") chunk_data = data[:scope] if len(chunk_data) == scope: @@ -151,9 +132,6 @@ async def term_memory_save(long_term_messages, actual_config_id, end_user_id, ty logger.info(f'写入短长期:') -"""Window-based dialogue processing""" - - async def window_dialogue(end_user_id, langchain_messages, memory_config, scope): """ Process dialogue based on window size and write to Neo4j @@ -167,40 +145,33 @@ async def window_dialogue(end_user_id, langchain_messages, memory_config, scope) langchain_messages: Original message data list scope: Window size determining when to trigger long-term storage """ - scope = scope - is_end_user_id = count_store.get_sessions_count(end_user_id) - if is_end_user_id is not False: - is_end_user_id = count_store.get_sessions_count(end_user_id)[0] - redis_messages = count_store.get_sessions_count(end_user_id)[1] - if is_end_user_id and int(is_end_user_id) != int(scope): - is_end_user_id += 1 - langchain_messages += redis_messages - count_store.update_sessions_count(end_user_id, is_end_user_id, langchain_messages) - elif int(is_end_user_id) == int(scope): + is_end_user_has_history = count_store.get_sessions_count(end_user_id) + if is_end_user_has_history: + end_user_visit_count, redis_messages = is_end_user_has_history + else: + count_store.save_sessions_count(end_user_id, 1, langchain_messages) + return + end_user_visit_count += 1 + if end_user_visit_count < scope: + redis_messages.extend(langchain_messages) + count_store.update_sessions_count(end_user_id, end_user_visit_count, redis_messages) + else: logger.info('写入长期记忆NEO4J') - formatted_messages = redis_messages + redis_messages.extend(langchain_messages) # Get config_id (if memory_config is an object, extract config_id; otherwise use directly) if hasattr(memory_config, 'config_id'): config_id = memory_config.config_id else: config_id = memory_config - await write( - AgentMemory_Long_Term.STORAGE_NEO4J, - end_user_id, - "", - "", - None, - end_user_id, - config_id, - formatted_messages + write_message_task.delay( + end_user_id, # end_user_id: User ID + redis_messages, # message: JSON string format message list + config_id, # config_id: Configuration ID string + AgentMemory_Long_Term.STORAGE_NEO4J, # storage_type: "neo4j" + "" # user_rag_memory_id: RAG memory ID (not used in Neo4j mode) ) - count_store.update_sessions_count(end_user_id, 1, langchain_messages) - else: - count_store.save_sessions_count(end_user_id, 1, langchain_messages) - - -"""Time-based memory processing""" + count_store.update_sessions_count(end_user_id, 0, []) async def memory_long_term_storage(end_user_id, memory_config, time): @@ -291,9 +262,7 @@ async def aggregate_judgment(end_user_id: str, ori_messages: list, memory_config return result_dict except Exception as e: - print(f"[aggregate_judgment] 发生错误: {e}") - import traceback - traceback.print_exc() + logger.error(f"[aggregate_judgment] 发生错误: {e}", exc_info=True) return { "is_same_event": False, diff --git a/api/app/core/memory/agent/langgraph_graph/write_graph.py b/api/app/core/memory/agent/langgraph_graph/write_graph.py index bf3c6597..32fc7d8a 100644 --- a/api/app/core/memory/agent/langgraph_graph/write_graph.py +++ b/api/app/core/memory/agent/langgraph_graph/write_graph.py @@ -1,49 +1,25 @@ -import asyncio -import json -import sys import warnings -from contextlib import asynccontextmanager -from langgraph.constants import END, START -from langgraph.graph import StateGraph -from app.db import get_db, get_db_context from app.core.logging_config import get_agent_logger -from app.core.memory.agent.utils.llm_tools import WriteState -from app.core.memory.agent.langgraph_graph.nodes.write_nodes import write_node +from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ + aggregate_judgment +from app.core.memory.agent.utils.redis_tool import write_store +from app.db import get_db_context from app.schemas.memory_agent_schema import AgentMemory_Long_Term from app.services.memory_config_service import MemoryConfigService +from app.services.memory_konwledges_server import write_rag warnings.filterwarnings("ignore", category=RuntimeWarning) logger = get_agent_logger(__name__) -if sys.platform.startswith("win"): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) - -@asynccontextmanager -async def make_write_graph(): - """ - Create a write graph workflow for memory operations. - - Args: - user_id: User identifier - tools: MCP tools loaded from session - apply_id: Application identifier - end_user_id: Group identifier - memory_config: MemoryConfig object containing all configuration - """ - workflow = StateGraph(WriteState) - workflow.add_node("save_neo4j", write_node) - workflow.add_edge(START, "save_neo4j") - workflow.add_edge("save_neo4j", END) - - graph = workflow.compile() - - yield graph - - -async def long_term_storage(long_term_type: str = "chunk", langchain_messages: list = [], memory_config: str = '', - end_user_id: str = '', scope: int = 6): +async def long_term_storage( + long_term_type: str, + langchain_messages: list, + memory_config_id: str, + end_user_id: str, + scope: int = 6 +): """ Handle long-term memory storage with different strategies @@ -53,33 +29,39 @@ async def long_term_storage(long_term_type: str = "chunk", langchain_messages: l Args: long_term_type: Storage strategy type ('chunk', 'time', 'aggregate') langchain_messages: List of messages to store - memory_config: Memory configuration identifier + memory_config_id: Memory configuration identifier end_user_id: User group identifier scope: Scope parameter for chunk-based storage (default: 6) """ - from app.core.memory.agent.langgraph_graph.routing.write_router import memory_long_term_storage, window_dialogue, \ - aggregate_judgment - from app.core.memory.agent.utils.redis_tool import write_store + if langchain_messages is None: + langchain_messages = [] + write_store.save_session_write(end_user_id, langchain_messages) # 获取数据库会话 with get_db_context() as db_session: config_service = MemoryConfigService(db_session) memory_config = config_service.load_memory_config( - config_id=memory_config, # 改为整数 + config_id=memory_config_id, # 改为整数 service_name="MemoryAgentService" ) if long_term_type == AgentMemory_Long_Term.STRATEGY_CHUNK: - '''Strategy 1: Dialogue window with 6 rounds of conversation''' + # Dialogue window with 6 rounds of conversation await window_dialogue(end_user_id, langchain_messages, memory_config, scope) if long_term_type == AgentMemory_Long_Term.STRATEGY_TIME: - """Time-based strategy""" + # Time-based strategy await memory_long_term_storage(end_user_id, memory_config, AgentMemory_Long_Term.TIME_SCOPE) if long_term_type == AgentMemory_Long_Term.STRATEGY_AGGREGATE: - """Strategy 3: Aggregate judgment""" + # Aggregate judgment await aggregate_judgment(end_user_id, langchain_messages, memory_config) -async def write_long_term(storage_type, end_user_id, message_chat, aimessages, user_rag_memory_id, actual_config_id): +async def write_long_term( + storage_type: str, + end_user_id: str, + messages: list[dict], + user_rag_memory_id: str, + actual_config_id: str +): """ Write long-term memory with different storage types @@ -89,44 +71,24 @@ async def write_long_term(storage_type, end_user_id, message_chat, aimessages, u Args: storage_type: Type of storage (RAG or traditional) end_user_id: User group identifier - message_chat: User message content - aimessages: AI response messages + messages: message list user_rag_memory_id: RAG memory identifier actual_config_id: Actual configuration ID """ - from app.core.memory.agent.langgraph_graph.routing.write_router import write_rag_agent from app.core.memory.agent.langgraph_graph.routing.write_router import term_memory_save - from app.core.memory.agent.langgraph_graph.tools.write_tool import agent_chat_messages if storage_type == AgentMemory_Long_Term.STORAGE_RAG: - await write_rag_agent(end_user_id, message_chat, aimessages, user_rag_memory_id) + message_content = [] + for message in messages: + message_content.append(f'{message.get("role")}:{message.get("content")}') + messages_string = "\n".join(message_content) + await write_rag(end_user_id, messages_string, user_rag_memory_id) else: # AI reply writing (user messages and AI replies paired, written as complete dialogue at once) CHUNK = AgentMemory_Long_Term.STRATEGY_CHUNK SCOPE = AgentMemory_Long_Term.DEFAULT_SCOPE - long_term_messages = await agent_chat_messages(message_chat, aimessages) - await long_term_storage(long_term_type=CHUNK, langchain_messages=long_term_messages, - memory_config=actual_config_id, end_user_id=end_user_id, scope=SCOPE) - await term_memory_save(long_term_messages, actual_config_id, end_user_id, CHUNK, scope=SCOPE) - -# async def main(): -# """主函数 - 运行工作流""" -# langchain_messages = [ -# { -# "role": "user", -# "content": "今天周五去爬山" -# }, -# { -# "role": "assistant", -# "content": "好耶" -# } -# -# ] -# end_user_id = '837fee1b-04a2-48ee-94d7-211488908940' # 组ID -# memory_config="08ed205c-0f05-49c3-8e0c-a580d28f5fd4" -# await long_term_storage(long_term_type="chunk",langchain_messages=langchain_messages,memory_config=memory_config,end_user_id=end_user_id,scope=2) -# -# -# -# if __name__ == "__main__": -# import asyncio -# asyncio.run(main()) + await long_term_storage(long_term_type=CHUNK, + langchain_messages=messages, + memory_config_id=actual_config_id, + end_user_id=end_user_id, + scope=SCOPE) + await term_memory_save(end_user_id, CHUNK, scope=SCOPE) diff --git a/api/app/core/memory/agent/utils/redis_tool.py b/api/app/core/memory/agent/utils/redis_tool.py index c5729628..82b22c9e 100644 --- a/api/app/core/memory/agent/utils/redis_tool.py +++ b/api/app/core/memory/agent/utils/redis_tool.py @@ -3,8 +3,9 @@ import uuid from app.core.config import settings from typing import List, Dict, Any, Optional, Union +from app.core.logging_config import get_logger from app.core.memory.agent.utils.redis_base import ( - serialize_messages, + serialize_messages, deserialize_messages, fix_encoding, format_session_data, @@ -14,12 +15,12 @@ from app.core.memory.agent.utils.redis_base import ( get_current_timestamp ) - +logger = get_logger(__name__) class RedisWriteStore: """Redis Write 类型存储类,用于管理 save_session_write 相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -66,10 +67,10 @@ class RedisWriteStore: }) result = pipe.execute() - print(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session_write] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session_write] 保存会话失败: {e}") + logger.error(f"[save_session_write] 保存会话失败: {e}") raise e def get_session_by_userid(self, userid: str) -> Union[List[Dict[str, str]], bool]: @@ -99,7 +100,7 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid: # 从 key 中提取 session_id: session:write:{session_id} @@ -108,16 +109,16 @@ class RedisWriteStore: "sessionid": session_id, "messages": fix_encoding(data.get('messages', '')) }) - + if not results: return False - - print(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_session_by_userid] userid={userid}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_session_by_userid] 查询失败: {e}") + logger.error(f"[get_session_by_userid] 查询失败: {e}") return False - + def get_all_sessions_by_end_user_id(self, end_user_id: str) -> Union[List[Dict[str, Any]], bool]: """ 通过 end_user_id 获取所有 write 类型的会话数据 @@ -144,7 +145,7 @@ class RedisWriteStore: # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") + logger.debug(f"[get_all_sessions_by_end_user_id] 没有找到任何 write 类型的会话") return False # 批量获取数据 @@ -158,12 +159,12 @@ class RedisWriteStore: for key, data in zip(keys, all_data): if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == end_user_id: # 从 key 中提取 session_id: session:write:{session_id} session_id = key.split(':')[-1] - + # 构建完整的会话信息 session_info = { "session_id": session_id, @@ -173,23 +174,21 @@ class RedisWriteStore: "starttime": data.get('starttime', '') } results.append(session_info) - + if not results: - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 没有找到数据") return False - + # 按时间排序(最新的在前) results.sort(key=lambda x: x.get('starttime', ''), reverse=True) - - print(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") + + logger.debug(f"[get_all_sessions_by_end_user_id] end_user_id={end_user_id}, 找到 {len(results)} 条数据") return results except Exception as e: - print(f"[get_all_sessions_by_end_user_id] 查询失败: {e}") - import traceback - traceback.print_exc() + logger.error(f"[get_all_sessions_by_end_user_id] 查询失败: {e}", exc_info=True) return False - def find_user_recent_sessions(self, userid: str, + def find_user_recent_sessions(self, userid: str, minutes: int = 5) -> List[Dict[str, str]]: """ 根据 userid 从 save_session_write 写入的数据中查询最近 N 分钟内的会话数据 @@ -203,11 +202,11 @@ class RedisWriteStore: """ import time start_time = time.time() - + # 只查询 write 类型的 key keys = self.r.keys('session:write:*') if not keys: - print(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_recent_sessions] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -221,7 +220,7 @@ class RedisWriteStore: for data in all_data: if not data: continue - + # 从 write 类型读取,匹配 sessionid 字段 if data.get('sessionid') == userid and data.get('starttime'): # write 类型没有 aimessages,所以 Answer 为空 @@ -230,15 +229,14 @@ class RedisWriteStore: "Answer": "", "starttime": data.get('starttime', '') }) - + # 根据时间范围过滤 filtered_items = filter_by_time_range(matched_items, minutes) # 排序并移除时间字段 - result_items = sort_and_limit_results(filtered_items, limit=None) - print(result_items) + result_items = sort_and_limit_results(filtered_items) elapsed_time = time.time() - start_time - print(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " + logger.debug(f"[find_user_recent_sessions] userid={userid}, minutes={minutes}, " f"查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items @@ -258,7 +256,7 @@ class RedisWriteStore: class RedisCountStore: """Redis Count 类型存储类,用于管理访问次数统计相关的数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -278,7 +276,7 @@ class RedisCountStore: decode_responses=True, encoding='utf-8' ) - self.uudi = session_id + self.uuid = session_id def save_sessions_count(self, end_user_id: str, count: int, messages: Any) -> str: """ @@ -295,26 +293,26 @@ class RedisCountStore: session_id = str(uuid.uuid4()) key = generate_session_key(session_id, key_type="count") index_key = f'session:count:index:{end_user_id}' # 索引键 - + pipe = self.r.pipeline() pipe.hset(key, mapping={ - "id": self.uudi, + "id": self.uuid, "end_user_id": end_user_id, "count": int(count), "messages": serialize_messages(messages), "starttime": get_current_timestamp() }) pipe.expire(key, 30 * 24 * 60 * 60) # 30天过期 - + # 创建索引:end_user_id -> session_id 映射 pipe.set(index_key, session_id, ex=30 * 24 * 60 * 60) - + result = pipe.execute() - - print(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") + + logger.debug(f"[save_sessions_count] 保存结果: {result}, session_id: {session_id}") return session_id - def get_sessions_count(self, end_user_id: str) -> Union[List[Any], bool]: + def get_sessions_count(self, end_user_id: str) -> tuple[int, list[dict]] | bool: """ 通过 end_user_id 查询访问次数统计 @@ -327,7 +325,7 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) @@ -335,35 +333,40 @@ class RedisCountStore: self.r.delete(index_key) return False except Exception as type_error: - print(f"[get_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[get_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: return False - + # 直接获取数据 key = generate_session_key(session_id, key_type="count") data = self.r.hgetall(key) - + if not data: # 索引存在但数据不存在,清理索引 self.r.delete(index_key) return False - + count = data.get('count') messages_str = data.get('messages') - + if count is not None: - messages = deserialize_messages(messages_str) - return [int(count), messages] - + messages: list[dict] = deserialize_messages(messages_str) + return int(count), messages + return False except Exception as e: - print(f"[get_sessions_count] 查询失败: {e}") + logger.error(f"[get_sessions_count] 查询失败: {e}") return False - def update_sessions_count(self, end_user_id: str, new_count: int, - messages: Any) -> bool: + + def update_sessions_count( + self, + end_user_id: str, + new_count: int, + messages: Any + ) -> bool: """ 通过 end_user_id 修改访问次数统计(优化版:使用索引) @@ -378,39 +381,39 @@ class RedisCountStore: try: # 使用索引键快速查找 index_key = f'session:count:index:{end_user_id}' - + # 检查索引键类型,避免 WRONGTYPE 错误 try: key_type = self.r.type(index_key) if key_type != 'string' and key_type != 'none': # 索引键类型错误,删除并返回 False - print(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") + logger.warning(f"[update_sessions_count] 索引键类型错误: {key_type},删除索引") self.r.delete(index_key) - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False except Exception as type_error: - print(f"[update_sessions_count] 检查键类型失败: {type_error}") - + logger.error(f"[update_sessions_count] 检查键类型失败: {type_error}") + session_id = self.r.get(index_key) - + if not session_id: - print(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") + logger.debug(f"[update_sessions_count] 未找到记录: end_user_id={end_user_id}") return False - + # 直接更新数据 key = generate_session_key(session_id, key_type="count") messages_str = serialize_messages(messages) - + pipe = self.r.pipeline() - pipe.hset(key, 'count', int(new_count)) + pipe.hset(key, 'count', str(new_count)) pipe.hset(key, 'messages', messages_str) result = pipe.execute() - - print(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") + + logger.debug(f"[update_sessions_count] 更新成功: end_user_id={end_user_id}, new_count={new_count}, key={key}") return True - + except Exception as e: - print(f"[update_sessions_count] 更新失败: {e}") + logger.debug(f"[update_sessions_count] 更新失败: {e}") return False def delete_all_count_sessions(self) -> int: @@ -428,7 +431,7 @@ class RedisCountStore: class RedisSessionStore: """Redis 会话存储类,用于管理会话数据""" - + def __init__(self, host='localhost', port=6379, db=0, password=None, session_id=''): """ 初始化 Redis 连接 @@ -451,9 +454,9 @@ class RedisSessionStore: self.uudi = session_id # ==================== 写入操作 ==================== - - def save_session(self, userid: str, messages: str, aimessages: str, - apply_id: str, end_user_id: str) -> str: + + def save_session(self, userid: str, messages: str, aimessages: str, + apply_id: str, end_user_id: str) -> str: """ 写入一条会话数据,返回 session_id @@ -483,14 +486,14 @@ class RedisSessionStore: }) result = pipe.execute() - print(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") + logger.debug(f"[save_session] 保存结果: {result[0]}, session_id: {session_id}") return session_id except Exception as e: - print(f"[save_session] 保存会话失败: {e}") + logger.error(f"[save_session] 保存会话失败: {e}") raise e # ==================== 读取操作 ==================== - + def get_session(self, session_id: str) -> Optional[Dict[str, Any]]: """ 读取一条会话数据 @@ -520,8 +523,8 @@ class RedisSessionStore: sessions[sid] = self.get_session(sid) return sessions - def find_user_apply_group(self, sessionid: str, apply_id: str, - end_user_id: str) -> List[Dict[str, str]]: + def find_user_apply_group(self, sessionid: str, apply_id: str, + end_user_id: str) -> List[Dict[str, str]]: """ 根据 sessionid、apply_id 和 end_user_id 查询会话数据,返回最新的6条 @@ -535,10 +538,10 @@ class RedisSessionStore: """ import time start_time = time.time() - + keys = self.r.keys('session:*') if not keys: - print(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") + logger.debug(f"[find_user_apply_group] 查询耗时: {time.time() - start_time:.3f}秒, 结果数: 0") return [] # 批量获取数据 @@ -556,21 +559,21 @@ class RedisSessionStore: continue if (data.get('apply_id') == apply_id and - data.get('end_user_id') == end_user_id): + data.get('end_user_id') == end_user_id): # 支持模糊匹配或完全匹配 sessionid if sessionid in data.get('sessionid', '') or data.get('sessionid') == sessionid: matched_items.append(format_session_data(data, include_time=True)) - + # 排序、限制数量并移除时间字段 result_items = sort_and_limit_results(matched_items, limit=6) elapsed_time = time.time() - start_time - print(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") + logger.debug(f"[find_user_apply_group] 查询耗时: {elapsed_time:.3f}秒, 结果数: {len(result_items)}") return result_items # ==================== 更新操作 ==================== - + def update_session(self, session_id: str, field: str, value: Any) -> bool: """ 更新单个字段 @@ -591,7 +594,7 @@ class RedisSessionStore: return bool(results[0]) # ==================== 删除操作 ==================== - + def delete_session(self, session_id: str) -> int: """ 删除单条会话 @@ -632,7 +635,7 @@ class RedisSessionStore: keys = self.r.keys('session:*') if not keys: - print("[delete_duplicate_sessions] 没有会话数据") + logger.debug("[delete_duplicate_sessions] 没有会话数据") return 0 # 批量获取所有数据 @@ -678,7 +681,7 @@ class RedisSessionStore: deleted_count += len(batch) elapsed_time = time.time() - start_time - print(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") + logger.debug(f"[delete_duplicate_sessions] 删除重复会话数量: {deleted_count}, 耗时: {elapsed_time:.3f}秒") return deleted_count diff --git a/api/app/core/memory/llm_tools/llm_client.py b/api/app/core/memory/llm_tools/llm_client.py index e26aba3e..49cd9434 100644 --- a/api/app/core/memory/llm_tools/llm_client.py +++ b/api/app/core/memory/llm_tools/llm_client.py @@ -56,7 +56,7 @@ class LLMClient(ABC): self.max_retries = self.config.max_retries self.timeout = self.config.timeout - logger.info( + logger.debug( f"初始化 LLM 客户端: provider={self.provider}, " f"model={self.model_name}, max_retries={self.max_retries}" ) diff --git a/api/app/schemas/memory_agent_schema.py b/api/app/schemas/memory_agent_schema.py index b4efe61d..97aa5bb5 100644 --- a/api/app/schemas/memory_agent_schema.py +++ b/api/app/schemas/memory_agent_schema.py @@ -17,6 +17,7 @@ class Write_UserInput(BaseModel): end_user_id: str config_id: Optional[str] = None + class AgentMemory_Long_Term(ABC): """长期记忆配置常量""" STORAGE_NEO4J = "neo4j" @@ -25,8 +26,9 @@ class AgentMemory_Long_Term(ABC): STRATEGY_CHUNK = "chunk" STRATEGY_TIME = "time" DEFAULT_SCOPE = 6 - TIME_SCOPE=5 -class AgentMemoryDataset(ABC): - PRONOUN=['我','本人','在下','自己','咱','鄙人','吴','余'] - NAME='用户' + TIME_SCOPE = 5 + +class AgentMemoryDataset(ABC): + PRONOUN = ['我', '本人', '在下', '自己', '咱', '鄙人', '吴', '余'] + NAME = '用户' diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..17c2f98c 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -10,6 +10,7 @@ from sqlalchemy.orm import Session from app.core.agent.langchain_agent import LangChainAgent from app.core.logging_config import get_business_logger +from app.core.memory.agent.langgraph_graph.write_graph import write_long_term from app.db import get_db from app.models import MultiAgentConfig, AgentConfig, ModelType from app.models import WorkflowConfig @@ -20,11 +21,11 @@ from app.schemas.model_schema import ModelInfo from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole from app.services.conversation_service import ConversationService from app.services.draft_run_service import AgentRunService +from app.services.memory_agent_service import get_end_user_connected_config from app.services.model_service import ModelApiKeyService from app.services.multi_agent_orchestrator import MultiAgentOrchestrator from app.services.multimodal_service import MultimodalService from app.services.workflow_service import WorkflowService -from app.schemas import FileType logger = get_business_logger() @@ -43,18 +44,17 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, - user_id: Optional[str] = None, + files: list[FileInput], + user_id: str, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> Dict[str, Any]: """聊天(非流式)""" start_time = time.time() - config_id = None # 应用 features 配置 features_config: dict = config.features or {} @@ -93,7 +93,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, + user_id) tools.extend(kb_tools) memory_flag = False if memory: @@ -168,11 +169,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -229,6 +225,21 @@ class AppChatService: # 保存消息 if audio_url: assistant_meta["audio_url"] = audio_url + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": result["content"]} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -264,20 +275,19 @@ class AppChatService: message: str, conversation_id: uuid.UUID, config: AgentConfig, + files: list[FileInput], user_id: Optional[str] = None, variables: Optional[Dict[str, Any]] = None, web_search: bool = False, memory: bool = True, storage_type: Optional[str] = None, user_rag_memory_id: Optional[str] = None, - workspace_id: Optional[str] = None, - files: Optional[List[FileInput]] = None + workspace_id: Optional[str] = None ) -> AsyncGenerator[str, None]: """聊天(流式)""" try: start_time = time.time() - config_id = None message_id = uuid.uuid4() # 应用 features 配置 @@ -319,7 +329,8 @@ class AppChatService: tools.extend(skill_tools) if skill_prompts: system_prompt = f"{system_prompt}\n\n{skill_prompts}" - kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config(config.knowledge_retrieval, user_id) + kb_tools, citations_collector = self.agent_service.load_knowledge_retrieval_config( + config.knowledge_retrieval, user_id) tools.extend(kb_tools) # 添加长期记忆工具 memory_flag = False @@ -411,11 +422,6 @@ class AppChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): @@ -459,7 +465,7 @@ class AppChatService: # 保存消息 human_meta = { - "files":[], + "files": [], "history_files": {} } assistant_meta = { @@ -484,6 +490,22 @@ class AppChatService: if stream_audio_url: assistant_meta["audio_url"] = stream_audio_url + + if memory_flag: + connected_config = get_end_user_connected_config(user_id, self.db) + memory_config_id: str = connected_config.get("memory_config_id") + messages = [ + {"role": "user", "content": message, "files": [file.model_dump() for file in files]}, + {"role": "assistant", "content": full_content} + ] + if memory_config_id: + await write_long_term( + storage_type, + user_id, + messages, + user_rag_memory_id, + memory_config_id + ) self.conversation_service.add_message( conversation_id=conversation_id, role="user", @@ -618,7 +640,6 @@ class AppChatService: # 2. 创建编排器 orchestrator = MultiAgentOrchestrator(self.db, config) - # 3. 流式执行任务 async for event in orchestrator.execute_stream( message=message, diff --git a/api/app/services/draft_run_service.py b/api/app/services/draft_run_service.py index e188872f..aef54847 100644 --- a/api/app/services/draft_run_service.py +++ b/api/app/services/draft_run_service.py @@ -24,7 +24,7 @@ from app.core.exceptions import BusinessException from app.core.logging_config import get_business_logger from app.core.rag.nlp.search import knowledge_retrieval from app.db import get_db_context -from app.models import AgentConfig, ModelConfig, ModelType +from app.models import AgentConfig, ModelConfig from app.repositories.tool_repository import ToolRepository from app.schemas.app_schema import FileInput, Citation from app.schemas.model_schema import ModelInfo @@ -37,7 +37,6 @@ from app.services.model_parameter_merger import ModelParameterMerger from app.services.model_service import ModelApiKeyService from app.services.multimodal_service import MultimodalService from app.services.tool_service import ToolService -from app.schemas import FileType logger = get_business_logger() @@ -657,11 +656,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files # 传递处理后的文件 ) @@ -911,11 +905,6 @@ class AgentRunService: message=message, history=history, context=context, - end_user_id=user_id, - config_id=config_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - memory_flag=memory_flag, files=processed_files ): if isinstance(chunk, int): diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 3ee238e2..5c838fc0 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -243,27 +243,6 @@ class MemoryPerceptualService: memory_config: MemoryConfig, file: FileInput ): - memories = self.repository.get_by_url(file.url) - if memories: - business_logger.info(f"Perceptual memory already exists: {file.url}") - if end_user_id not in [memory.end_user_id for memory in memories]: - business_logger.info(f"Copy perceptual memory end_user_id: {end_user_id}") - memory_cache = memories[0] - memory = self.repository.create_perceptual_memory( - end_user_id=uuid.UUID(end_user_id), - perceptual_type=PerceptualType(memory_cache.perceptual_type), - file_path=memory_cache.file_path, - file_name=memory_cache.file_name, - file_ext=memory_cache.file_ext, - summary=memory_cache.summary, - meta_data=memory_cache.meta_data - ) - self.db.commit() - return memory - else: - for memory in memories: - if memory.end_user_id == uuid.UUID(end_user_id): - return memory llm, model_config = self._get_mutlimodal_client(file.type, memory_config) multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, diff --git a/api/app/services/model_service.py b/api/app/services/model_service.py index b98674ba..c9266667 100644 --- a/api/app/services/model_service.py +++ b/api/app/services/model_service.py @@ -69,7 +69,8 @@ class ModelConfigService: return items @staticmethod - def get_model_by_name(db: Session, name: str, provider: str | None = None, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def get_model_by_name(db: Session, name: str, provider: str | None = None, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """根据名称获取模型配置""" model = ModelConfigRepository.get_by_name(db, name, provider=provider, tenant_id=tenant_id) if not model: @@ -77,21 +78,22 @@ class ModelConfigService: return model @staticmethod - def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ModelConfig]: + def search_models_by_name(db: Session, name: str, tenant_id: uuid.UUID | None = None, limit: int = 10) -> List[ + ModelConfig]: """按名称模糊匹配获取模型配置列表""" return ModelConfigRepository.search_by_name(db, name, tenant_id=tenant_id, limit=limit) @staticmethod async def validate_model_config( - db: Session, - *, - model_name: str, - provider: str, - api_key: str, - api_base: Optional[str] = None, - model_type: str = "llm", - test_message: str = "Hello", - is_omni: bool = False + db: Session, + *, + model_name: str, + provider: str, + api_key: str, + api_base: Optional[str] = None, + model_type: str = "llm", + test_message: str = "Hello", + is_omni: bool = False ) -> Dict[str, Any]: """验证模型配置是否有效 @@ -158,13 +160,13 @@ class ModelConfigService: # 统一使用 RedBearEmbeddings(自动支持火山引擎多模态) embedding = RedBearEmbeddings(model_config) test_texts = [test_message, "测试文本"] - + # 火山引擎使用 embed_batch,其他使用 embed_documents if provider.lower() == "volcano": vectors = await asyncio.to_thread(embedding.embed_batch, test_texts) else: vectors = await asyncio.to_thread(embedding.embed_documents, test_texts) - + elapsed_time = time.time() - start_time return { @@ -200,11 +202,11 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "image": # 图片生成模型验证 from app.core.models.generation import RedBearImageGenerator - + generator = RedBearImageGenerator(model_config) result = await generator.agenerate( prompt="a cute panda", @@ -212,7 +214,7 @@ class ModelConfigService: ) elapsed_time = time.time() - start_time logger.info(f"成功生成图片,结果: {result}") - + return { "valid": True, "message": "图片生成模型配置验证成功", @@ -224,21 +226,21 @@ class ModelConfigService: }, "error": None } - + elif model_type_lower == "video": # 视频生成模型验证 from app.core.models.generation import RedBearVideoGenerator - + generator = RedBearVideoGenerator(model_config) result = await generator.agenerate( prompt="a cute panda playing in bamboo forest", duration=5 ) elapsed_time = time.time() - start_time - + # 视频生成是异步任务,返回任务ID task_id = result.get("task_id") if isinstance(result, dict) else None - + return { "valid": True, "message": "视频生成模型配置验证成功", @@ -265,7 +267,6 @@ class ModelConfigService: # 提取详细的错误信息 error_message = str(e) error_type = type(e).__name__ - print("=========error_message:",error_message.lower()) # 特殊处理常见的错误类型 if "unsupported countries" in error_message.lower() or "unsupported region" in error_message.lower(): # 区域/国家限制(适用于所有提供商) @@ -354,14 +355,16 @@ class ModelConfigService: return model @staticmethod - def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, tenant_id: uuid.UUID | None = None) -> ModelConfig: + def update_model(db: Session, model_id: uuid.UUID, model_data: ModelConfigUpdate, + tenant_id: uuid.UUID | None = None) -> ModelConfig: """更新模型配置""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) model = ModelConfigRepository.update(db, model_id, model_data, tenant_id=tenant_id) @@ -370,25 +373,27 @@ class ModelConfigService: return model @staticmethod - async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def create_composite_model(db: Session, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """创建组合模型""" - if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=ModelProvider.COMPOSITE, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + # 检查 API Key 关联的模型配置类型 for model_config in api_key.model_configs: # chat 和 llm 类型可以兼容 compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = model_data.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", @@ -399,7 +404,7 @@ class ModelConfigService: # f"API Key {api_key_id} 关联的模型是组合模型,不能用于创建新的组合模型", # BizCode.INVALID_PARAMETER # ) - + # 创建组合模型 model_config_data = { "tenant_id": tenant_id, @@ -418,49 +423,51 @@ class ModelConfigService: model = ModelConfigRepository.create(db, model_config_data) db.flush() - + # 关联 API Keys for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: model.api_keys.append(api_key) - + db.commit() db.refresh(model) return model @staticmethod - async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, tenant_id: uuid.UUID) -> ModelConfig: + async def update_composite_model(db: Session, model_id: uuid.UUID, model_data: model_schema.CompositeModelCreate, + tenant_id: uuid.UUID) -> ModelConfig: """更新组合模型""" existing_model = ModelConfigRepository.get_by_id(db, model_id, tenant_id=tenant_id) if not existing_model: raise BusinessException("模型配置不存在", BizCode.MODEL_NOT_FOUND) if model_data.name and model_data.name != existing_model.name: - if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, tenant_id=tenant_id): + if ModelConfigRepository.get_by_name(db, model_data.name, provider=existing_model.provider, + tenant_id=tenant_id): raise BusinessException("模型名称已存在", BizCode.DUPLICATE_NAME) - + if not existing_model.is_composite: raise BusinessException("该模型不是组合模型", BizCode.INVALID_PARAMETER) - + # 验证所有 API Key 存在且类型匹配 for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if not api_key: raise BusinessException(f"API Key {api_key_id} 不存在", BizCode.NOT_FOUND) - + for model_config in api_key.model_configs: compatible_types = {ModelType.LLM, ModelType.CHAT} config_type = model_config.type request_type = existing_model.type - - if not (config_type == request_type or + + if not (config_type == request_type or (config_type in compatible_types and request_type in compatible_types)): raise BusinessException( f"API Key {api_key_id} 关联的模型类型 ({model_config.type}) 与组合模型类型 ({model_data.type}) 不匹配", BizCode.INVALID_PARAMETER ) - + # 更新基本信息 existing_model.name = model_data.name # existing_model.type = model_data.type @@ -471,14 +478,14 @@ class ModelConfigService: existing_model.is_public = model_data.is_public if "load_balance_strategy" in model_data.model_fields_set: existing_model.load_balance_strategy = model_data.load_balance_strategy - + # 更新 API Keys 关联 existing_model.api_keys.clear() for api_key_id in model_data.api_key_ids: api_key = ModelApiKeyRepository.get_by_id(db, api_key_id) if api_key: existing_model.api_keys.append(api_key) - + db.commit() db.refresh(existing_model) return existing_model @@ -532,7 +539,7 @@ class ModelApiKeyService: """根据provider为多个ModelConfig创建API Key""" created_keys = [] failed_models = [] # 记录验证失败的模型 - + for model_config_id in data.model_config_ids: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: @@ -540,10 +547,10 @@ class ModelApiKeyService: data.is_omni = model_config.is_omni data.capability = model_config.capability - + # 从ModelBase获取model_name model_name = model_config.model_base.name if model_config.model_base else model_config.name - + # 检查是否存在API Key(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -553,7 +560,7 @@ class ModelApiKeyService: ModelApiKey.model_name == model_name, ModelConfig.tenant_id == model_config.tenant_id ).first() - + if existing_key: # 如果已存在,重新激活并更新 if existing_key.is_active: @@ -566,14 +573,14 @@ class ModelApiKeyService: existing_key.model_name = model_name existing_key.capability = data.capability existing_key.is_omni = data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + created_keys.append(existing_key) continue - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -589,7 +596,7 @@ class ModelApiKeyService: # 记录验证失败的模型,但不抛出异常 failed_models.append(model_name) continue - + # 创建API Key api_key_data = ModelApiKeyCreate( model_config_ids=[model_config_id], @@ -606,12 +613,12 @@ class ModelApiKeyService: ) api_key_obj = ModelApiKeyRepository.create(db, api_key_data) created_keys.append(api_key_obj) - + if created_keys: db.commit() for key in created_keys: db.refresh(key) - + return created_keys, failed_models @staticmethod @@ -626,7 +633,7 @@ class ModelApiKeyService: api_key_data.is_omni = model_config.is_omni if api_key_data.capability is None: api_key_data.capability = model_config.capability - + # 检查API Key是否已存在(包括软删除),需要考虑tenant_id existing_key = db.query(ModelApiKey).join( ModelApiKey.model_configs @@ -650,15 +657,15 @@ class ModelApiKeyService: existing_key.model_name = api_key_data.model_name existing_key.capability = api_key_data.capability existing_key.is_omni = api_key_data.is_omni - + # 检查是否已关联该模型配置 if model_config not in existing_key.model_configs: existing_key.model_configs.append(model_config) - + db.commit() db.refresh(existing_key) return existing_key - + # 验证配置 validation_result = await ModelConfigService.validate_model_config( db=db, @@ -691,7 +698,7 @@ class ModelApiKeyService: # 获取关联的模型配置以获取模型类型 if existing_api_key.model_configs: model_config = existing_api_key.model_configs[0] - + validation_result = await ModelConfigService.validate_model_config( db=db, model_name=api_key_data.model_name or existing_api_key.model_name, @@ -729,15 +736,15 @@ class ModelApiKeyService: model_config = ModelConfigRepository.get_by_id(db, model_config_id) if not model_config: return None - + api_keys = [key for key in model_config.api_keys if key.is_active] if not api_keys: return None - + # 如果是轮询策略,按使用次数最少,次数相同则选最早使用的 if model_config.load_balance_strategy == LoadBalanceStrategy.ROUND_ROBIN: return min(api_keys, key=lambda x: (int(x.usage_count or "0"), x.last_used_at or datetime.min)) - + # 否则返回第一个 return api_keys[0] @@ -760,20 +767,19 @@ class ModelApiKeyService: raise BusinessException("没有可用的 API Key", BizCode.AGENT_CONFIG_MISSING) - class ModelBaseService: """基础模型服务""" @staticmethod def get_model_base_list(db: Session, query: model_schema.ModelBaseQuery, tenant_id: uuid.UUID = None) -> List: models = ModelBaseRepository.get_list(db, query) - + provider_groups = {} for m in models: model_dict = model_schema.ModelBase.model_validate(m).model_dump() if tenant_id: model_dict['is_added'] = ModelBaseRepository.check_added_by_tenant(db, m.id, tenant_id) - + provider = m.provider if provider not in provider_groups: provider_groups[provider] = { @@ -781,7 +787,7 @@ class ModelBaseService: "models": [] } provider_groups[provider]["models"].append(model_dict) - + return list(provider_groups.values()) @staticmethod @@ -823,10 +829,10 @@ class ModelBaseService: model_base = ModelBaseRepository.get_by_id(db, model_base_id) if not model_base: raise BusinessException("基础模型不存在", BizCode.MODEL_NOT_FOUND) - + if ModelBaseRepository.check_added_by_tenant(db, model_base_id, tenant_id): raise BusinessException("模型已添加", BizCode.DUPLICATE_NAME) - + model_config_data = { "model_id": model_base_id, "tenant_id": tenant_id, diff --git a/api/app/services/shared_chat_service.py b/api/app/services/shared_chat_service.py index 0d659832..c74604a5 100644 --- a/api/app/services/shared_chat_service.py +++ b/api/app/services/shared_chat_service.py @@ -1,26 +1,24 @@ """基于分享链接的聊天服务""" -import uuid -import time import asyncio +import json +import time +import uuid from typing import Optional, Dict, Any, AsyncGenerator + +from deprecated import deprecated from sqlalchemy.orm import Session -from app.repositories.model_repository import ModelApiKeyRepository -from app.services.memory_konwledges_server import write_rag +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException, ResourceNotFoundException +from app.core.logging_config import get_business_logger +from app.models import MultiAgentConfig from app.models import ReleaseShare, AppRelease, Conversation +from app.repositories import knowledge_repository from app.services.conversation_service import ConversationService from app.services.draft_run_service import create_web_search_tool from app.services.model_service import ModelApiKeyService -from app.services.release_share_service import ReleaseShareService -from app.core.exceptions import BusinessException, ResourceNotFoundException -from app.core.error_codes import BizCode -from app.core.logging_config import get_business_logger from app.services.multi_agent_service import MultiAgentService -from app.models import MultiAgentConfig -from app.repositories import knowledge_repository -import json -from app.services.task_service import get_task_memory_write_result -from app.tasks import write_message_task +from app.services.release_share_service import ReleaseShareService logger = get_business_logger() @@ -118,6 +116,7 @@ class SharedChatService: return conversation + @deprecated("Use the chat method under app_chat_service instead.") async def chat( self, share_token: str, @@ -136,10 +135,7 @@ class SharedChatService: config_id = actual_config_id from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool - from app.services.model_parameter_merger import ModelParameterMerger from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey start_time = time.time() actual_config_id = None @@ -273,11 +269,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ) # 保存消息 @@ -324,6 +315,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def chat_stream( self, share_token: str, @@ -341,8 +333,6 @@ class SharedChatService: from app.core.agent.langchain_agent import LangChainAgent from app.services.draft_run_service import create_knowledge_retrieval_tool, create_long_term_memory_tool from app.schemas.prompt_schema import render_prompt_message, PromptMessageRole - from sqlalchemy import select - from app.models import ModelApiKey import json start_time = time.time() @@ -486,11 +476,6 @@ class SharedChatService: message=message, history=history, context=None, - end_user_id=user_id, - storage_type=storage_type, - user_rag_memory_id=user_rag_memory_id, - config_id=config_id, - memory_flag=memory_flag ): if isinstance(chunk, int): total_tokens = chunk @@ -585,6 +570,7 @@ class SharedChatService: return conversations, total + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat( self, share_token: str, @@ -680,6 +666,7 @@ class SharedChatService: "elapsed_time": elapsed_time } + @deprecated("Use the chat method under app_chat_service instead.") async def multi_agent_chat_stream( self, share_token: str, From 13e35ed1228602744c825f5b46c36e36bef487e7 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 15:15:20 +0800 Subject: [PATCH 007/117] feat(web) workflow edge center add add tool --- .../Workflow/components/Nodes/AddNode.tsx | 4 +- .../components/Nodes/ConditionNode.tsx | 2 +- .../components/Nodes/GroupStartNode.tsx | 2 +- .../Workflow/components/Nodes/LoopNode.tsx | 2 +- .../Workflow/components/Nodes/NormalNode.tsx | 2 +- .../Workflow/components/PortClickHandler.tsx | 87 +++++++++++++++---- web/src/views/Workflow/constant.ts | 70 ++++++++++++++- .../views/Workflow/hooks/useWorkflowGraph.ts | 15 +++- 8 files changed, 159 insertions(+), 25 deletions(-) diff --git a/web/src/views/Workflow/components/Nodes/AddNode.tsx b/web/src/views/Workflow/components/Nodes/AddNode.tsx index 9b9d2236..dd0ab23d 100644 --- a/web/src/views/Workflow/components/Nodes/AddNode.tsx +++ b/web/src/views/Workflow/components/Nodes/AddNode.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-09 18:31:30 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-06 11:43:58 + * @Last Modified time: 2026-03-30 11:55:10 */ import { useState } from 'react'; import { Popover, Flex } from 'antd'; @@ -173,7 +173,7 @@ const AddNode: ReactShapeConfig['component'] = ({ node, graph }) => { align="center" justify="center" gap={4} - className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#DFE4ED] rb:flex rb:items-center rb:justify-center', { + className={clsx('rb:text-[#212332] rb:font-medium rb:text-[12px] rb:cursor-pointer rb:group rb:relative rb:h-full rb:w-full rb:border rb:rounded-lg rb:bg-[#FCFCFD] rb:shadow-[0px_2px_4px_0px_rgba(23,23,25,0.03)] rb:border-[#FCFCFD] rb:flex rb:items-center rb:justify-center', { 'rb:border-orange-500 rb:border-[3px] rb:bg-[#FCFCFD] rb:text-[#475467]': data.isSelected, 'rb:border-[#d1d5db] rb:bg-[#FCFCFD] rb:text-[#374151]': !data.isSelected })} diff --git a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx index 12ae6ca0..516b5125 100644 --- a/web/src/views/Workflow/components/Nodes/ConditionNode.tsx +++ b/web/src/views/Workflow/components/Nodes/ConditionNode.tsx @@ -48,7 +48,7 @@ const ConditionNode: ReactShapeConfig['component'] = ({ node }) => { return (
diff --git a/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx b/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx index 0f963adc..4a29531f 100644 --- a/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx +++ b/web/src/views/Workflow/components/Nodes/GroupStartNode.tsx @@ -3,7 +3,7 @@ import type { ReactShapeConfig } from '@antv/x6-react-shape'; const GroupStartNode: ReactShapeConfig['component'] = () => { return ( -
+
); diff --git a/web/src/views/Workflow/components/Nodes/LoopNode.tsx b/web/src/views/Workflow/components/Nodes/LoopNode.tsx index b8c2ea0c..29c683cc 100644 --- a/web/src/views/Workflow/components/Nodes/LoopNode.tsx +++ b/web/src/views/Workflow/components/Nodes/LoopNode.tsx @@ -122,7 +122,7 @@ const LoopNode: ReactShapeConfig['component'] = ({ node, graph }) => { return (
diff --git a/web/src/views/Workflow/components/Nodes/NormalNode.tsx b/web/src/views/Workflow/components/Nodes/NormalNode.tsx index 340e95dc..12e89cca 100644 --- a/web/src/views/Workflow/components/Nodes/NormalNode.tsx +++ b/web/src/views/Workflow/components/Nodes/NormalNode.tsx @@ -12,7 +12,7 @@ const NormalNode: ReactShapeConfig['component'] = ({ node }) => { return (
diff --git a/web/src/views/Workflow/components/PortClickHandler.tsx b/web/src/views/Workflow/components/PortClickHandler.tsx index 2cc0c3c5..13ad6b98 100644 --- a/web/src/views/Workflow/components/PortClickHandler.tsx +++ b/web/src/views/Workflow/components/PortClickHandler.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-09 18:30:28 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 11:11:56 + * @Last Modified time: 2026-03-30 15:14:02 */ import { useEffect, useState } from 'react'; import { Popover } from 'antd'; @@ -20,13 +20,15 @@ const PortClickHandler: React.FC = ({ graph }) => { const [sourceNode, setSourceNode] = useState(null); const [sourcePort, setSourcePort] = useState(''); const [tempElement, setTempElement] = useState(null); + const [edgeInsertion, setEdgeInsertion] = useState(null); useEffect(() => { const handlePortClick = (event: CustomEvent) => { - const { node, port, element, rect } = event.detail; + const { node, port, element, rect, edgeInsertion } = event.detail; setSourceNode(node); setSourcePort(port); setTempElement(element); + setEdgeInsertion(edgeInsertion || null); setPopoverPosition({ x: rect.left, y: rect.top }); setPopoverVisible(true); }; @@ -72,15 +74,47 @@ const PortClickHandler: React.FC = ({ graph }) => { const sourcePortInfo = sourceNode.getPorts().find((p: any) => p.id === sourcePort); const sourcePortGroup = sourcePortInfo?.group || sourcePort; - // If add-node position exists, use it; otherwise calculate new position + // Calculate new node position let newX, newY; - if (addNodePosition) { + if (edgeInsertion) { + // Edge insertion: place new node on the same row as target, between source and target + const targetBBox = edgeInsertion.targetCell.getBBox(); + const gap = targetBBox.x - (sourceBBox.x + sourceBBox.width); + const requiredSpace = nodeWidth + horizontalSpacing * 4; + + // New node x: right after source + spacing + newX = sourceBBox.x + sourceBBox.width + horizontalSpacing; + // Same row as target node + newY = targetBBox.y + (targetBBox.height - nodeHeight) / 2; + + // If not enough space, shift target and all downstream nodes to the right + if (gap < requiredSpace) { + const shiftX = requiredSpace - gap; + const visited = new Set(); + const shiftDownstream = (cell: any) => { + const cellId = cell.id; + if (visited.has(cellId)) return; + visited.add(cellId); + const pos = cell.getPosition(); + cell.setPosition(pos.x + shiftX, pos.y); + // Recursively shift nodes connected from right ports + graph.getConnectedEdges(cell, { outgoing: true }).forEach((e: any) => { + const tId = e.getTargetCellId(); + if (tId && !visited.has(tId)) { + const tCell = graph.getCellById(tId); + if (tCell?.isNode()) shiftDownstream(tCell); + } + }); + }; + shiftDownstream(edgeInsertion.targetCell); + } + } else if (addNodePosition) { newX = addNodePosition.x; newY = addNodePosition.y; } else { // Determine node placement direction based on port position if (sourcePortGroup === 'left') { - // Left port: add node to the left + // Left port: add node to the left newX = sourceBBox.x - nodeWidth*2 - horizontalSpacing; newY = sourceBBox.y; } else { @@ -91,7 +125,7 @@ const PortClickHandler: React.FC = ({ graph }) => { // Check if position overlaps with existing nodes (only consider connected nodes) const checkOverlap = (x: number, y: number) => { - // Get nodes connected to the source node + // Get nodes connected to the source node const connectedNodes = new Set(); graph.getConnectedEdges(sourceNode).forEach((edge: any) => { const sourceId = edge.getSourceCellId(); @@ -108,7 +142,7 @@ const PortClickHandler: React.FC = ({ graph }) => { y + nodeHeight < bbox.y || y > bbox.y + bbox.height); }); }; - + // If position is occupied, search downward for empty space while (checkOverlap(newX, newY)) { newY += nodeHeight + verticalSpacing; @@ -140,28 +174,51 @@ const PortClickHandler: React.FC = ({ graph }) => { } } + // Edge insertion: remove old edge immediately before creating new edges + if (edgeInsertion) { + const { edge: oldEdge } = edgeInsertion; + if (oldEdge.id && graph.getCellById(oldEdge.id)) { + graph.removeCell(oldEdge.id); + } else { + graph.removeEdge(oldEdge); + } + } + // Create edge connection setTimeout(() => { - const targetPorts = newNode.getPorts(); - let targetPort; - - if (sourcePortGroup === 'left') { + const newPorts = newNode.getPorts(); + + if (edgeInsertion) { + // Edge insertion: create source→new and new→target edges + const { targetCell, targetPort: origTargetPort } = edgeInsertion; + const newLeftPort = newPorts.find((p: any) => p.group === 'left')?.id || 'left'; + const newRightPort = newPorts.find((p: any) => p.group === 'right')?.id || 'right'; + graph.addEdge({ + source: { cell: sourceNode.id, port: sourcePort }, + target: { cell: newNode.id, port: newLeftPort }, + ...edgeAttrs + }); + graph.addEdge({ + source: { cell: newNode.id, port: newRightPort }, + target: { cell: targetCell.id, port: origTargetPort }, + ...edgeAttrs + }); + setEdgeInsertion(null); + } else if (sourcePortGroup === 'left') { // Connect from left port to new node's right side - targetPort = targetPorts.find((port: any) => port.group === 'right')?.id || 'right'; + const targetPort = newPorts.find((port: any) => port.group === 'right')?.id || 'right'; graph.addEdge({ source: { cell: newNode.id, port: targetPort }, target: { cell: sourceNode.id, port: sourcePort }, ...edgeAttrs - // zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0 }); } else { // Connect from right port to new node's left side - targetPort = targetPorts.find((port: any) => port.group === 'left')?.id || 'left'; + const targetPort = newPorts.find((port: any) => port.group === 'left')?.id || 'left'; graph.addEdge({ source: { cell: sourceNode.id, port: sourcePort }, target: { cell: newNode.id, port: targetPort }, ...edgeAttrs - // zIndex: sourceNodeData.cycle && sourceNodeType == 'cycle-start' ? 1 : sourceNodeData.cycle ? 2 : 0 }); } diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index d62ef06f..50b92696 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:06:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 18:30:52 + * @Last Modified time: 2026-03-30 15:11:56 */ import LoopNode from './components/Nodes/LoopNode'; import NormalNode from './components/Nodes/NormalNode'; @@ -642,6 +642,8 @@ interface NodeConfig { /** Edge color for normal state */ export const edge_color = '#D4D5D9'; +/** Edge color for hover state */ +export const edge_hover_color = '#2E90FA'; /** Edge color for selected state */ export const edge_selected_color = '#171719' export const edge_width = 2; @@ -884,4 +886,70 @@ export const edgeAttrs = { }, }, }, +} + +/** + * Edge hover tool: circular "+" button shown at midpoint on hover + */ +export const edgeHoverTool = { + name: 'button', + args: { + markup: [ + { + tagName: 'circle', + selector: 'button', + attrs: { + r: 6, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + cursor: 'pointer', + }, + }, + { + tagName: 'text', + textContent: '+', + selector: 'icon', + attrs: { + fontSize: 12, + fontWeight: 'bold', + fill: '#FFFFFF', + textAnchor: 'middle', + textVerticalAnchor: 'middle', + pointerEvents: 'none', + y: '0.3em', + }, + }, + ], + distance: 0.5, + offset: { x: 0, y: 0 }, + onClick({ e, cell: edge }: any) { + e.stopPropagation(); + const graph = edge.model?.graph; + if (!graph) return; + const sourceCell = graph.getCellById(edge.getSourceCellId()); + const targetCell = graph.getCellById(edge.getTargetCellId()); + const sourcePort = edge.getSourcePortId(); + const targetPort = edge.getTargetPortId(); + if (!sourceCell || !targetCell) return; + const rect = (e.target as HTMLElement).getBoundingClientRect(); + const tempDiv = document.createElement('div'); + tempDiv.style.position = 'fixed'; + tempDiv.style.left = rect.left + 'px'; + tempDiv.style.top = rect.top + 'px'; + tempDiv.style.width = '1px'; + tempDiv.style.height = '1px'; + tempDiv.style.zIndex = '9999'; + document.body.appendChild(tempDiv); + window.dispatchEvent(new CustomEvent('port:click', { + detail: { + node: sourceCell, + port: sourcePort, + element: tempDiv, + rect, + edgeInsertion: { edge, sourceCell, targetCell, sourcePort, targetPort } + } + })); + }, + }, } \ No newline at end of file diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 626165da..4059c264 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 18:14:38 + * @Last Modified time: 2026-03-30 15:08:14 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -12,7 +12,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from ' import { register } from '@antv/x6-react-shape'; import type { PortMetadata } from '@antv/x6/lib/model/port'; -import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; +import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edgeHoverTool, edge_color, edge_hover_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types'; import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' import { useUser } from '@/store/user'; @@ -881,12 +881,21 @@ export const useWorkflowGraph = ({ }); // Use plugins setupPlugins(); - // Listen to edge mouseleave event + // Listen to edge mouseenter event: show hover style and add button + graphRef.current.on('edge:mouseenter', ({ edge }: { edge: Edge }) => { + if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { + edge.setAttrByPath('line/stroke', edge_hover_color); + edge.setAttrByPath('line/strokeWidth', edge_width); + } + edge.addTools([edgeHoverTool]); + }); + // Listen to edge mouseleave event: revert style and remove add button graphRef.current.on('edge:mouseleave', ({ edge }: { edge: Edge }) => { if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { edge.setAttrByPath('line/stroke', edge_color); edge.setAttrByPath('line/strokeWidth', edge_width); } + edge.removeTools(); }); // Listen to node selection event graphRef.current.on('node:click', nodeClick); From e9ad13504ae4fbcc78ed3d81a39b0f660468653f Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Mon, 30 Mar 2026 16:00:49 +0800 Subject: [PATCH 008/117] fix(memory,task): add Redis fair lock for ordered memory writes --- .../core/memory/llm_tools/openai_client.py | 2 +- api/app/tasks.py | 40 ++++-- api/app/utils/redis_lock.py | 133 +++++++++++++++--- 3 files changed, 145 insertions(+), 30 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index 43c2b445..c70fef5f 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.info(f"OpenAI 客户端初始化完成: type={type_}") + logger.debug(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index d5f09a29..0e909fcc 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,5 +1,4 @@ import asyncio -import hashlib import os import re import shutil @@ -38,12 +37,10 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema -from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService -from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisLock +from app.utils.redis_lock import RedisFairLock logger = get_logger(__name__) @@ -1148,8 +1145,28 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result + redis_client = get_sync_redis_client() + lock = None + if redis_client is not None: + lock = RedisFairLock( + key=f"memory_write:{end_user_id}", + redis_client=redis_client, + expire=120, + timeout=300, + auto_renewal=True, + ) + if not lock.acquire(): + logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") + return { + "status": "SKIPPED", + "error": "acquire lock timeout", + "end_user_id": end_user_id, + "config_id": str(config_id), + "elapsed_time": time.time() - start_time, + "task_id": self.request.id, + } + try: - # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1158,7 +1175,6 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") - # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1199,9 +1215,12 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } - - -# unused task + finally: + if lock is not None: + try: + lock.release() + except Exception as e: + logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. @@ -2879,3 +2898,6 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } + + +# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index 99f62d84..a86ba46e 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,6 +1,7 @@ import redis import uuid import time +import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -10,45 +11,136 @@ else end """ +RENEW_SCRIPT = """ +if redis.call("get", KEYS[1]) == ARGV[1] then + return redis.call("expire", KEYS[1], ARGV[2]) +else + return 0 +end +""" -class RedisLock: +CLEANUP_DEAD_HEAD_SCRIPT = """ +local queue_key = KEYS[1] +local lock_key = KEYS[2] + +local first = redis.call("lindex", queue_key, 0) +if not first then + return 0 +end + +if redis.call("exists", lock_key) == 1 then + return 0 +end + +redis.call("lpop", queue_key) +return 1 +""" + +SAFE_RELEASE_QUEUE_SCRIPT = """ +local queue_key = KEYS[1] +local value = ARGV[1] + +local first = redis.call("lindex", queue_key, 0) +if first == value then + redis.call("lpop", queue_key) + return 1 +end +return 0 +""" + + +def _ensure_str(val): + """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" + if val is None: + return None + if isinstance(val, bytes): + return val.decode("utf-8") + return str(val) + + +class RedisFairLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 60, - retry_interval: float = 0.1, - timeout: float = 30 - + expire: int = 30, + retry_interval: float = 0.05, + timeout: float = 600, + auto_renewal: bool = True ): self.key = key - self.expire = expire + self.queue_key = f"{key}:queue" self.value = str(uuid.uuid4()) - self._locked = False + self.expire = expire self.retry_interval = retry_interval self.timeout = timeout - self.redis_client = redis_client + self.redis = redis_client + self._locked = False + self.auto_renewal = auto_renewal + self._renew_thread = None + self._stop_renew = threading.Event() - def acquire(self) -> bool: + def acquire(self): start = time.time() + + self.redis.rpush(self.queue_key, self.value) + while True: - ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) - if ok: - self._locked = True - return True - if time.time() - start >= self.timeout: + first = _ensure_str(self.redis.lindex(self.queue_key, 0)) + + if first == self.value: + ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) + if ok: + self._locked = True + + if self.auto_renewal: + self._start_renewal() + return True + + if first: + self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) + + if time.time() - start > self.timeout: + self.redis.lrem(self.queue_key, 0, self.value) return False + time.sleep(self.retry_interval) + def _renewal_loop(self): + while not self._stop_renew.is_set(): + time.sleep(self.expire / 3) + if self._stop_renew.is_set(): + break + + self.redis.eval( + RENEW_SCRIPT, + 1, + self.key, + self.value, + str(self.expire) + ) + + def _start_renewal(self): + self._stop_renew = threading.Event() + self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) + self._renew_thread.start() + + def _stop_renewal(self): + self._stop_renew.set() + if self._renew_thread: + self._renew_thread.join(timeout=1) + def release(self): if not self._locked: return - self.redis_client.eval( - UNLOCK_SCRIPT, - 1, - self.key, - self.value - ) + + if self.auto_renewal: + self._stop_renewal() + + self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) + + self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) + self._locked = False def __enter__(self): @@ -59,3 +151,4 @@ class RedisLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() + From 891cfc27047a6befcdeb05699180499985537448 Mon Sep 17 00:00:00 2001 From: Ke Sun Date: Mon, 30 Mar 2026 16:50:56 +0800 Subject: [PATCH 009/117] feat(end-user-api): add authenticated API endpoint for end user creation - Should be merged after v0.2.9 - Create new end_user_api_controller.py with POST /end_user/create endpoint - Implement API Key authentication requirement with memory scope - Add support for optional memory_config_id parameter with workspace default fallback - Update memory_api_schema.py to remove workspace_id from request (now derived from API key auth) - Add memory_config_id field to CreateEndUserResponse schema - Register end_user_api_controller router in service module - Migrate end user creation from unauthenticated to authenticated API flow --- api/app/controllers/service/__init__.py | 3 +- .../service/end_user_api_controller.py | 92 +++++++++++++++++++ api/app/schemas/memory_api_schema.py | 14 +-- 3 files changed, 98 insertions(+), 11 deletions(-) create mode 100644 api/app/controllers/service/end_user_api_controller.py diff --git a/api/app/controllers/service/__init__.py b/api/app/controllers/service/__init__.py index 8c679c1f..96da0949 100644 --- a/api/app/controllers/service/__init__.py +++ b/api/app/controllers/service/__init__.py @@ -4,7 +4,7 @@ 认证方式: API Key """ from fastapi import APIRouter -from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller +from . import app_api_controller, rag_api_knowledge_controller, rag_api_document_controller, rag_api_file_controller, rag_api_chunk_controller, memory_api_controller, end_user_api_controller # 创建 V1 API 路由器 service_router = APIRouter() @@ -16,5 +16,6 @@ service_router.include_router(rag_api_document_controller.router) service_router.include_router(rag_api_file_controller.router) service_router.include_router(rag_api_chunk_controller.router) service_router.include_router(memory_api_controller.router) +service_router.include_router(end_user_api_controller.router) __all__ = ["service_router"] diff --git a/api/app/controllers/service/end_user_api_controller.py b/api/app/controllers/service/end_user_api_controller.py new file mode 100644 index 00000000..9d410bd2 --- /dev/null +++ b/api/app/controllers/service/end_user_api_controller.py @@ -0,0 +1,92 @@ +"""End User 服务接口 - 基于 API Key 认证""" + +import uuid + +from fastapi import APIRouter, Body, Depends, Request +from sqlalchemy.orm import Session + +from app.core.api_key_auth import require_api_key +from app.core.error_codes import BizCode +from app.core.exceptions import BusinessException +from app.core.logging_config import get_business_logger +from app.core.response_utils import success +from app.db import get_db +from app.repositories.end_user_repository import EndUserRepository +from app.schemas.api_key_schema import ApiKeyAuth +from app.schemas.memory_api_schema import CreateEndUserRequest, CreateEndUserResponse +from app.services.memory_config_service import MemoryConfigService + +router = APIRouter(prefix="/end_user", tags=["V1 - End User API"]) +logger = get_business_logger() + + +@router.post("/create") +@require_api_key(scopes=["memory"]) +async def create_end_user( + request: Request, + api_key_auth: ApiKeyAuth = None, + db: Session = Depends(get_db), + message: str = Body(..., description="Request body"), +): + """ + Create or retrieve an end user for the workspace. + + Creates a new end user and connects it to a memory configuration. + If an end user with the same other_id already exists in the workspace, + returns the existing one. + + Optionally accepts a memory_config_id to connect the end user to a specific + memory configuration. If not provided, falls back to the workspace default config. + """ + body = await request.json() + payload = CreateEndUserRequest(**body) + workspace_id = api_key_auth.workspace_id + + logger.info(f"Create end user request - other_id: {payload.other_id}, workspace_id: {workspace_id}") + + # Resolve memory_config_id: explicit > workspace default + memory_config_id = None + config_service = MemoryConfigService(db) + + if payload.memory_config_id: + try: + memory_config_id = uuid.UUID(payload.memory_config_id) + except ValueError: + raise BusinessException( + f"Invalid memory_config_id format: {payload.memory_config_id}", + BizCode.INVALID_PARAMETER + ) + config = config_service.get_config_with_fallback(memory_config_id, workspace_id) + if not config: + raise BusinessException( + f"Memory config not found: {payload.memory_config_id}", + BizCode.MEMORY_CONFIG_NOT_FOUND + ) + memory_config_id = config.config_id + else: + default_config = config_service.get_workspace_default_config(workspace_id) + if default_config: + memory_config_id = default_config.config_id + logger.info(f"Using workspace default memory config: {memory_config_id}") + else: + logger.warning(f"No default memory config found for workspace: {workspace_id}") + + end_user_repo = EndUserRepository(db) + end_user = end_user_repo.get_or_create_end_user_with_config( + app_id=api_key_auth.resource_id, + workspace_id=workspace_id, + other_id=payload.other_id, + memory_config_id=memory_config_id, + ) + + logger.info(f"End user ready: {end_user.id}") + + result = { + "id": str(end_user.id), + "other_id": end_user.other_id or "", + "other_name": end_user.other_name or "", + "workspace_id": str(end_user.workspace_id), + "memory_config_id": str(end_user.memory_config_id) if end_user.memory_config_id else None, + } + + return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully") diff --git a/api/app/schemas/memory_api_schema.py b/api/app/schemas/memory_api_schema.py index 84a34e8a..ff62355f 100644 --- a/api/app/schemas/memory_api_schema.py +++ b/api/app/schemas/memory_api_schema.py @@ -138,21 +138,13 @@ class CreateEndUserRequest(BaseModel): """Request schema for creating an end user. Attributes: - workspace_id: Workspace ID (required) other_id: External user identifier (required) other_name: Display name for the end user + memory_config_id: Optional memory config ID. If not provided, uses workspace default. """ - workspace_id: str = Field(..., description="Workspace ID (required)") other_id: str = Field(..., description="External user identifier (required)") other_name: Optional[str] = Field("", description="Display name") - - @field_validator("workspace_id") - @classmethod - def validate_workspace_id(cls, v: str) -> str: - """Validate that workspace_id is not empty.""" - if not v or not v.strip(): - raise ValueError("workspace_id is required and cannot be empty") - return v.strip() + memory_config_id: Optional[str] = Field(None, description="Memory config ID. Falls back to workspace default if not provided.") @field_validator("other_id") @classmethod @@ -171,11 +163,13 @@ class CreateEndUserResponse(BaseModel): other_id: External user identifier other_name: Display name workspace_id: Workspace the user belongs to + memory_config_id: Connected memory config ID """ id: str = Field(..., description="End user UUID") other_id: str = Field(..., description="External user identifier") other_name: str = Field("", description="Display name") workspace_id: str = Field(..., description="Workspace ID") + memory_config_id: Optional[str] = Field(None, description="Connected memory config ID") class MemoryConfigItem(BaseModel): From 533000030fbd4b409bc9557b3ebbf83cb4535a48 Mon Sep 17 00:00:00 2001 From: Ke Sun <33739460+keeees@users.noreply.github.com> Date: Mon, 30 Mar 2026 17:16:14 +0800 Subject: [PATCH 010/117] Revert "fix(memory,task): add Redis fair lock for ordered memory writes" --- .../core/memory/llm_tools/openai_client.py | 2 +- api/app/tasks.py | 40 ++---- api/app/utils/redis_lock.py | 133 +++--------------- 3 files changed, 30 insertions(+), 145 deletions(-) diff --git a/api/app/core/memory/llm_tools/openai_client.py b/api/app/core/memory/llm_tools/openai_client.py index c70fef5f..43c2b445 100644 --- a/api/app/core/memory/llm_tools/openai_client.py +++ b/api/app/core/memory/llm_tools/openai_client.py @@ -65,7 +65,7 @@ class OpenAIClient(LLMClient): type=type_ ) - logger.debug(f"OpenAI 客户端初始化完成: type={type_}") + logger.info(f"OpenAI 客户端初始化完成: type={type_}") async def chat(self, messages: List[Dict[str, str]], **kwargs) -> Any: """ diff --git a/api/app/tasks.py b/api/app/tasks.py index 0e909fcc..d5f09a29 100644 --- a/api/app/tasks.py +++ b/api/app/tasks.py @@ -1,4 +1,5 @@ import asyncio +import hashlib import os import re import shutil @@ -37,10 +38,12 @@ from app.db import get_db, get_db_context from app.models import Document, File, Knowledge from app.models.end_user_model import EndUser from app.schemas import document_schema, file_schema +from app.schemas.model_schema import ModelInfo from app.services.memory_agent_service import MemoryAgentService, get_end_user_connected_config from app.services.memory_forget_service import MemoryForgetService +from app.services.memory_perceptual_service import MemoryPerceptualService from app.utils.config_utils import resolve_config_id -from app.utils.redis_lock import RedisFairLock +from app.utils.redis_lock import RedisLock logger = get_logger(__name__) @@ -1145,28 +1148,8 @@ def write_message_task( logger.info(f"[CELERY WRITE] Write completed successfully: {result}") return result - redis_client = get_sync_redis_client() - lock = None - if redis_client is not None: - lock = RedisFairLock( - key=f"memory_write:{end_user_id}", - redis_client=redis_client, - expire=120, - timeout=300, - auto_renewal=True, - ) - if not lock.acquire(): - logger.warning(f"[CELERY WRITE] 获取锁超时,跳过本次写入: end_user_id={end_user_id}") - return { - "status": "SKIPPED", - "error": "acquire lock timeout", - "end_user_id": end_user_id, - "config_id": str(config_id), - "elapsed_time": time.time() - start_time, - "task_id": self.request.id, - } - try: + # 尝试获取现有事件循环,如果不存在则创建新的 loop = set_asyncio_event_loop() result = loop.run_until_complete(_run()) @@ -1175,6 +1158,7 @@ def write_message_task( logger.info(f"[CELERY WRITE] Task completed successfully " f"- elapsed_time={elapsed_time:.2f}s, task_id={self.request.id}") + # 记录该用户最后一次 write_message 成功的时间,供时间轴筛选使用 try: _r = get_sync_redis_client() if _r is not None: @@ -1215,12 +1199,9 @@ def write_message_task( "elapsed_time": elapsed_time, "task_id": self.request.id } - finally: - if lock is not None: - try: - lock.release() - except Exception as e: - logger.warning(f"[CELERY WRITE] 释放锁失败: {e}") + + +# unused task # @celery_app.task(name="app.core.memory.agent.health.check_read_service") # def check_read_service_task() -> Dict[str, str]: # """Call read_service and write latest status to Redis. @@ -2898,6 +2879,3 @@ def init_community_clustering_for_users(self, end_user_ids: List[str], workspace "elapsed_time": time.time() - start_time, "task_id": self.request.id, } - - -# unused task \ No newline at end of file diff --git a/api/app/utils/redis_lock.py b/api/app/utils/redis_lock.py index a86ba46e..99f62d84 100644 --- a/api/app/utils/redis_lock.py +++ b/api/app/utils/redis_lock.py @@ -1,7 +1,6 @@ import redis import uuid import time -import threading UNLOCK_SCRIPT = """ if redis.call("get", KEYS[1]) == ARGV[1] then @@ -11,136 +10,45 @@ else end """ -RENEW_SCRIPT = """ -if redis.call("get", KEYS[1]) == ARGV[1] then - return redis.call("expire", KEYS[1], ARGV[2]) -else - return 0 -end -""" -CLEANUP_DEAD_HEAD_SCRIPT = """ -local queue_key = KEYS[1] -local lock_key = KEYS[2] - -local first = redis.call("lindex", queue_key, 0) -if not first then - return 0 -end - -if redis.call("exists", lock_key) == 1 then - return 0 -end - -redis.call("lpop", queue_key) -return 1 -""" - -SAFE_RELEASE_QUEUE_SCRIPT = """ -local queue_key = KEYS[1] -local value = ARGV[1] - -local first = redis.call("lindex", queue_key, 0) -if first == value then - redis.call("lpop", queue_key) - return 1 -end -return 0 -""" - - -def _ensure_str(val): - """统一将 Redis 返回值转为 str,兼容 decode_responses=True/False""" - if val is None: - return None - if isinstance(val, bytes): - return val.decode("utf-8") - return str(val) - - -class RedisFairLock: +class RedisLock: def __init__( self, key: str, redis_client: redis.StrictRedis, - expire: int = 30, - retry_interval: float = 0.05, - timeout: float = 600, - auto_renewal: bool = True + expire: int = 60, + retry_interval: float = 0.1, + timeout: float = 30 + ): self.key = key - self.queue_key = f"{key}:queue" - self.value = str(uuid.uuid4()) self.expire = expire + self.value = str(uuid.uuid4()) + self._locked = False self.retry_interval = retry_interval self.timeout = timeout - self.redis = redis_client - self._locked = False - self.auto_renewal = auto_renewal - self._renew_thread = None - self._stop_renew = threading.Event() + self.redis_client = redis_client - def acquire(self): + def acquire(self) -> bool: start = time.time() - - self.redis.rpush(self.queue_key, self.value) - while True: - first = _ensure_str(self.redis.lindex(self.queue_key, 0)) - - if first == self.value: - ok = self.redis.set(self.key, self.value, nx=True, ex=self.expire) - if ok: - self._locked = True - - if self.auto_renewal: - self._start_renewal() - return True - - if first: - self.redis.eval(CLEANUP_DEAD_HEAD_SCRIPT, 2, self.queue_key, self.key) - - if time.time() - start > self.timeout: - self.redis.lrem(self.queue_key, 0, self.value) + ok = self.redis_client.set(self.key, self.value, ex=self.expire, nx=True) + if ok: + self._locked = True + return True + if time.time() - start >= self.timeout: return False - time.sleep(self.retry_interval) - def _renewal_loop(self): - while not self._stop_renew.is_set(): - time.sleep(self.expire / 3) - if self._stop_renew.is_set(): - break - - self.redis.eval( - RENEW_SCRIPT, - 1, - self.key, - self.value, - str(self.expire) - ) - - def _start_renewal(self): - self._stop_renew = threading.Event() - self._renew_thread = threading.Thread(target=self._renewal_loop, daemon=True) - self._renew_thread.start() - - def _stop_renewal(self): - self._stop_renew.set() - if self._renew_thread: - self._renew_thread.join(timeout=1) - def release(self): if not self._locked: return - - if self.auto_renewal: - self._stop_renewal() - - self.redis.eval(UNLOCK_SCRIPT, 1, self.key, self.value) - - self.redis.eval(SAFE_RELEASE_QUEUE_SCRIPT, 1, self.queue_key, self.value) - + self.redis_client.eval( + UNLOCK_SCRIPT, + 1, + self.key, + self.value + ) self._locked = False def __enter__(self): @@ -151,4 +59,3 @@ class RedisFairLock: def __exit__(self, exc_type, exc_val, exc_tb): self.release() - From 3cca35a74f364715bf0028917ec01730b6867b5c Mon Sep 17 00:00:00 2001 From: zhaoying Date: Mon, 30 Mar 2026 17:21:09 +0800 Subject: [PATCH 011/117] feat(web): workflow node port view update --- web/src/views/Workflow/constant.ts | 65 +++++++++++++++++-- .../views/Workflow/hooks/useWorkflowGraph.ts | 56 +++++++++++++++- 2 files changed, 113 insertions(+), 8 deletions(-) diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index 50b92696..06eb4d99 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:06:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-30 15:11:56 + * @Last Modified time: 2026-03-30 16:52:54 */ import LoopNode from './components/Nodes/LoopNode'; import NormalNode from './components/Nodes/NormalNode'; @@ -695,8 +695,59 @@ export const portArgs = { x: nodeWidth, y: portItemArgsY } const defaultPortGroup = { position: { name: 'absolute' }, - markup: portMarkup, - attrs: portAttrs + markup: [ + { tagName: 'rect', selector: 'body' }, + { tagName: 'circle', selector: 'hoverBody' }, + { tagName: 'text', selector: 'label' }, + ], + attrs: { + body: { + width: 1, + height: 8, + x: -1, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + }, + hoverBody: { + r: 6, + cy: 2, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + opacity: 0, + }, + label: { + text: '+', + fontSize: 12, + fontWeight: 'bold', + fill: '#FFFFFF', + textAnchor: 'middle', + textVerticalAnchor: 'middle', + pointerEvents: 'none', + y: '0.15em', + opacity: 0, + }, + }, +} + +const leftPortGroup = { + position: { name: 'absolute' }, + markup: [{ tagName: 'rect', selector: 'body' }], + attrs: { + body: { + width: 1, + height: 8, + x: -1, + y: -4, + magnet: true, + stroke: port_color, + strokeWidth: edge_width, + fill: port_color, + }, + }, } /** @@ -705,7 +756,7 @@ const defaultPortGroup = { */ export const defaultAbsolutePortGroups = { right: defaultPortGroup, - left: defaultPortGroup, + left: leftPortGroup, } /** * Default port items for standard nodes @@ -799,7 +850,7 @@ export const graphNodeLibrary: Record = { height: 28, shape: 'add-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [{ group: 'left', args: { x: 0, y: 18 }}], }, }, @@ -826,7 +877,7 @@ export const graphNodeLibrary: Record = { height: 28, shape: 'add-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [{ group: 'left', args: { x: 0, y: 14 } }], }, }, @@ -835,7 +886,7 @@ export const graphNodeLibrary: Record = { height: 76, shape: 'normal-node', ports: { - groups: { left: defaultPortGroup }, + groups: { left: leftPortGroup }, items: [defaultPortItems[0]], }, }, diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index 4059c264..dd6f6eb7 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-30 15:08:14 + * @Last Modified time: 2026-03-30 17:18:11 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -441,6 +441,7 @@ export const useWorkflowGraph = ({ setTimeout(() => { if (graphRef.current) { graphRef.current.centerContent() + graphRef.current.getNodes().forEach(node => node.toFront()); } }, 200) } @@ -719,6 +720,7 @@ export const useWorkflowGraph = ({ }; const nodePortClickEvent = ({ e, node, port }: { e: MouseEvent, node: Node, port: string }) => { e.stopPropagation(); + e.preventDefault(); const portElement = e.target as HTMLElement; const rect = portElement.getBoundingClientRect(); @@ -903,13 +905,65 @@ export const useWorkflowGraph = ({ graphRef.current.on('edge:click', edgeClick); // Listen to port click event graphRef.current.on('node:port:click', nodePortClickEvent); + // Port hover: show circle style on right ports + graphRef.current.on('node:port:mouseenter', ({ node, port }) => { + if (!port) return; + const portData = node.getPort(port); + if (portData?.group !== 'right') return; + node.toFront(); + node.setPortProp(port, 'attrs/body/opacity', 0); + node.setPortProp(port, 'attrs/hoverBody/opacity', 1); + node.setPortProp(port, 'attrs/label/opacity', 1); + }); + graphRef.current.on('node:port:mouseleave', ({ node, port }) => { + if (!port) return; + const portData = node.getPort(port); + if (portData?.group !== 'right') return; + node.setPortProp(port, 'attrs/body/opacity', 1); + node.setPortProp(port, 'attrs/hoverBody/opacity', 0); + node.setPortProp(port, 'attrs/label/opacity', 0); + }); // Listen to canvas click event, cancel selection graphRef.current.on('blank:click', blankClick); + // Node hover: highlight connected edges + graphRef.current.on('node:mouseenter', ({ node }) => { + graphRef.current?.getEdges().forEach(edge => { + const view = graphRef.current?.findViewByCell(edge); + view?.removeTools(); + if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { + edge.setAttrByPath('line/stroke', edge_color); + } + }); + graphRef.current?.getConnectedEdges(node).forEach(edge => { + edge.setAttrByPath('line/stroke', edge_hover_color); + }); + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 0); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 1); + node.setPortProp(p.id!, 'attrs/label/opacity', 1); + }); + }); + graphRef.current.on('node:mouseleave', ({ node }) => { + graphRef.current?.getConnectedEdges(node).forEach(edge => { + if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { + edge.setAttrByPath('line/stroke', edge_color); + } + }); + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 1); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 0); + node.setPortProp(p.id!, 'attrs/label/opacity', 0); + }); + }); // Listen to zoom event graphRef.current.on('scale', scaleEvent); // Listen to node move event graphRef.current.on('node:moved', nodeMoved); graphRef.current.on('node:removed', blankClick) + // When edge changes, bring connected nodes' ports to front + graphRef.current.on('edge:change', () => { + graphRef.current?.getNodes().forEach(node => node.toFront()); + }); // Listen to copy keyboard event graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent); // Listen to paste keyboard event From 876c39b1b0a8e4bed9a0f8339222839fb6027118 Mon Sep 17 00:00:00 2001 From: Timebomb2018 <18868801967@163.com> Date: Mon, 30 Mar 2026 18:37:09 +0800 Subject: [PATCH 012/117] fix(app): 1. Token consumption of the omni model; 2. Token consumption of the cluster includes sub-agents --- api/app/core/agent/langchain_agent.py | 42 +++++++++++++++----- api/app/core/models/base.py | 12 +++++- api/app/core/tools/mcp/client.py | 8 ++-- api/app/services/app_chat_service.py | 6 +-- api/app/services/master_agent_router.py | 11 +++++ api/app/services/multi_agent_orchestrator.py | 34 ++++++++++++++-- 6 files changed, 92 insertions(+), 21 deletions(-) diff --git a/api/app/core/agent/langchain_agent.py b/api/app/core/agent/langchain_agent.py index 464a668a..9776cc29 100644 --- a/api/app/core/agent/langchain_agent.py +++ b/api/app/core/agent/langchain_agent.py @@ -254,6 +254,33 @@ class LangChainAgent: return messages + @staticmethod + def _extract_tokens_from_message(msg) -> int: + """从 AIMessage 或类似对象中提取 total_tokens,兼容多种 provider 格式 + + 支持的格式: + - response_metadata.token_usage.total_tokens (OpenAI/ChatOpenAI) + - response_metadata.usage.total_tokens (部分 provider) + - usage_metadata.total_tokens (LangChain 新版) + """ + total = 0 + # 1. response_metadata + response_meta = getattr(msg, "response_metadata", None) + if response_meta and isinstance(response_meta, dict): + # 尝试 token_usage 路径 + token_usage = response_meta.get("token_usage") or response_meta.get("usage", {}) + if isinstance(token_usage, dict): + total = token_usage.get("total_tokens", 0) + # 2. usage_metadata(LangChain 新版 AIMessage 属性) + if not total: + usage_meta = getattr(msg, "usage_metadata", None) + if usage_meta: + if isinstance(usage_meta, dict): + total = usage_meta.get("total_tokens", 0) + else: + total = getattr(usage_meta, "total_tokens", 0) + return total or 0 + def _build_multimodal_content(self, text: str, files: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """ 构建多模态消息内容 @@ -412,8 +439,7 @@ class LangChainAgent: else: content = str(msg.content) logger.debug(f"转换为字符串: {content[:100]}...") - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get("total_tokens", 0) if response_meta else 0 + total_tokens = self._extract_tokens_from_message(msg) break logger.info(f"最终提取的内容长度: {len(content)}") @@ -458,7 +484,7 @@ class LangChainAgent: user_rag_memory_id: Optional[str] = None, memory_flag: Optional[bool] = True, files: Optional[List[Dict[str, Any]]] = None # 新增:多模态文件 - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[str | int, None]: """执行流式对话 Args: @@ -594,15 +620,13 @@ class LangChainAgent: logger.debug(f"Agent 流式完成,共 {chunk_count} 个事件") # 统计token消耗 + # 统计 token 消耗:优先使用流式过程中捕获的值,回退到最后 event 的 messages output_messages = event.get("data", {}).get("output", {}).get("messages", []) for msg in reversed(output_messages): if isinstance(msg, AIMessage): - response_meta = msg.response_metadata if hasattr(msg, 'response_metadata') else None - total_tokens = response_meta.get("token_usage", {}).get( - "total_tokens", - 0 - ) if response_meta else 0 - yield total_tokens + stream_total_tokens = self._extract_tokens_from_message(msg) + logger.info(f"流式 token 统计: total_tokens={stream_total_tokens}") + yield stream_total_tokens break if memory_flag: await write_long_term(storage_type, end_user_id, message_chat, full_content, user_rag_memory_id, diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 80117f27..a4dbc092 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -58,7 +58,7 @@ class RedBearModelFactory: write=60.0, pool=10.0, ) - return { + params = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -66,6 +66,10 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + if config.extra_params.get("streaming"): + params["stream_usage"] = True + return params if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA, ModelProvider.VOLCANO]: # 使用 httpx.Timeout 对象来设置详细的超时配置 @@ -78,7 +82,7 @@ class RedBearModelFactory: write=60.0, # 写入超时:60秒 pool=10.0, # 连接池超时:10秒 ) - return { + params = { "model": config.model_name, "base_url": config.base_url, "api_key": config.api_key, @@ -86,6 +90,10 @@ class RedBearModelFactory: "max_retries": config.max_retries, **config.extra_params } + # 流式模式下启用 stream_usage 以获取 token 统计 + if config.extra_params.get("streaming"): + params["stream_usage"] = True + return params elif provider == ModelProvider.DASHSCOPE: # DashScope (通义千问) 使用自己的参数格式 # 注意: DashScopeEmbeddings 不支持 timeout 和 base_url 参数 diff --git a/api/app/core/tools/mcp/client.py b/api/app/core/tools/mcp/client.py index 6df6df51..3539d33a 100644 --- a/api/app/core/tools/mcp/client.py +++ b/api/app/core/tools/mcp/client.py @@ -99,7 +99,7 @@ class SimpleMCPClient: # 建立 SSE 连接 response = await self._session.get(self.server_url) - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"SSE 连接失败 {response.status}: {error_text}") @@ -190,7 +190,7 @@ class SimpleMCPClient: try: async with self._session.post(self._endpoint_url, json=request) as response: - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"请求失败 {response.status}: {error_text}") @@ -205,7 +205,7 @@ class SimpleMCPClient: raise MCPConnectionError("endpoint URL 未初始化") async with self._session.post(self._endpoint_url, json=notification) as response: - if response.status != 200: + if not (200 <= response.status < 300): logger.warning(f"通知发送失败: {response.status}") async def _initialize_modelscope_session(self): @@ -223,7 +223,7 @@ class SimpleMCPClient: try: async with self._session.post(self.server_url, json=init_request) as response: - if response.status != 200: + if not (200 <= response.status < 300): error_text = await response.text() raise MCPConnectionError(f"初始化失败 {response.status}: {error_text}") diff --git a/api/app/services/app_chat_service.py b/api/app/services/app_chat_service.py index 90474428..b5f9f194 100644 --- a/api/app/services/app_chat_service.py +++ b/api/app/services/app_chat_service.py @@ -631,13 +631,13 @@ class AppChatService: storage_type=storage_type, user_rag_memory_id=user_rag_memory_id ): - if "sub_usage" in event: + # 拦截 sub_usage 事件,累加 token + if "event: sub_usage" in event: if "data:" in event: try: data_line = event.split("data: ", 1)[1].strip() data = json.loads(data_line) - if "total_tokens" in data: - total_tokens += data["total_tokens"] + total_tokens += data.get("total_tokens", 0) except: pass else: diff --git a/api/app/services/master_agent_router.py b/api/app/services/master_agent_router.py index b0f43b51..954d3b2b 100644 --- a/api/app/services/master_agent_router.py +++ b/api/app/services/master_agent_router.py @@ -403,6 +403,17 @@ class MasterAgentRouter: response = await llm.ainvoke(prompt) ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取 token 消耗 + self._last_routing_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + self._last_routing_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + self._last_routing_tokens = token_usage.get("total_tokens", 0) + logger.info(f"Master Agent 路由 token 消耗: {self._last_routing_tokens}") + # 提取响应内容 if hasattr(response, 'content'): return response.content diff --git a/api/app/services/multi_agent_orchestrator.py b/api/app/services/multi_agent_orchestrator.py index 60a3b5b8..1330caad 100644 --- a/api/app/services/multi_agent_orchestrator.py +++ b/api/app/services/multi_agent_orchestrator.py @@ -287,6 +287,11 @@ class MultiAgentOrchestrator: sub_conversation_id = None total_tokens = 0 + # 累加 Master Agent 路由决策消耗的 token + total_tokens += task_analysis.get("routing_tokens", 0) + # 累加 Master Agent 整合消耗的 token + total_tokens += getattr(self, '_last_merge_tokens', 0) + if isinstance(results, dict): sub_conversation_id = results.get("conversation_id") or results.get("result", {}).get("conversation_id") # 提取 token 信息 @@ -358,12 +363,16 @@ class MultiAgentOrchestrator: variables=variables ) + # 获取路由决策消耗的 token + routing_tokens = getattr(self.router, '_last_routing_tokens', 0) + logger.info( "Master Agent 分析完成", extra={ "selected_agent": routing_decision.get("selected_agent_id"), "confidence": routing_decision.get("confidence"), - "strategy": routing_decision.get("strategy") + "strategy": routing_decision.get("strategy"), + "routing_tokens": routing_tokens } ) @@ -372,7 +381,8 @@ class MultiAgentOrchestrator: "variables": variables or {}, "sub_agents": self.config.sub_agents, "initial_context": variables or {}, - "routing_decision": routing_decision + "routing_decision": routing_decision, + "routing_tokens": routing_tokens } async def _execute_sequential( @@ -1032,6 +1042,11 @@ class MultiAgentOrchestrator: # 5. 流式执行子 Agent sub_conversation_id = None + # Master Agent 路由决策消耗的 token,通过 sub_usage 事件发送给上层 + routing_tokens = task_analysis.get("routing_tokens", 0) + if routing_tokens > 0: + yield self._format_sse_event("sub_usage", {"total_tokens": routing_tokens}) + async for event in self._execute_sub_agent_stream( agent_data["config"], message, @@ -1054,6 +1069,7 @@ class MultiAgentOrchestrator: except: pass + # 直接透传所有事件(包括 sub_usage),累加统一由上层处理 yield event # 6. 如果有会话 ID,发送一个包含它的事件 @@ -2612,6 +2628,17 @@ class MultiAgentOrchestrator: ModelApiKeyService.record_api_key_usage(self.db, api_key_config.id) + # 提取整合消耗的 token + merge_tokens = 0 + if hasattr(response, 'usage_metadata') and response.usage_metadata: + um = response.usage_metadata + merge_tokens = um.get("total_tokens", 0) if isinstance(um, dict) else getattr(um, "total_tokens", 0) + elif hasattr(response, 'response_metadata') and response.response_metadata: + token_usage = response.response_metadata.get("token_usage") or response.response_metadata.get("usage", {}) + if isinstance(token_usage, dict): + merge_tokens = token_usage.get("total_tokens", 0) + self._last_merge_tokens = merge_tokens + # 提取响应内容 if hasattr(response, 'content'): merged_response = response.content @@ -2621,7 +2648,8 @@ class MultiAgentOrchestrator: logger.info( "Master Agent 整合完成", extra={ - "merged_length": len(merged_response) + "merged_length": len(merged_response), + "merge_tokens": merge_tokens } ) From db8b3416a645fe98d538df0220aa25ab213cfa29 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 11:38:42 +0800 Subject: [PATCH 013/117] feat(web): workflow edge ui --- web/src/views/Workflow/constant.ts | 11 +- .../views/Workflow/hooks/useWorkflowGraph.ts | 121 +++++++++++++++--- 2 files changed, 106 insertions(+), 26 deletions(-) diff --git a/web/src/views/Workflow/constant.ts b/web/src/views/Workflow/constant.ts index 06eb4d99..92773191 100644 --- a/web/src/views/Workflow/constant.ts +++ b/web/src/views/Workflow/constant.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:06:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-30 16:52:54 + * @Last Modified time: 2026-03-31 10:08:26 */ import LoopNode from './components/Nodes/LoopNode'; import NormalNode from './components/Nodes/NormalNode'; @@ -642,8 +642,6 @@ interface NodeConfig { /** Edge color for normal state */ export const edge_color = '#D4D5D9'; -/** Edge color for hover state */ -export const edge_hover_color = '#2E90FA'; /** Edge color for selected state */ export const edge_selected_color = '#171719' export const edge_width = 2; @@ -930,11 +928,8 @@ export const edgeAttrs = { line: { stroke: edge_color, strokeWidth: edge_width, - targetMarker: { - name: 'block', - width: 4, - height: 4, - }, + targetMarker: null, + sourceMarker: null, }, }, } diff --git a/web/src/views/Workflow/hooks/useWorkflowGraph.ts b/web/src/views/Workflow/hooks/useWorkflowGraph.ts index dd6f6eb7..c427788b 100644 --- a/web/src/views/Workflow/hooks/useWorkflowGraph.ts +++ b/web/src/views/Workflow/hooks/useWorkflowGraph.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 15:17:48 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-30 17:18:11 + * @Last Modified time: 2026-03-31 11:13:23 */ import { useRef, useEffect, useState } from 'react'; import { useParams } from 'react-router-dom'; @@ -12,7 +12,7 @@ import { Graph, Node, MiniMap, Snapline, Clipboard, Keyboard, type Edge } from ' import { register } from '@antv/x6-react-shape'; import type { PortMetadata } from '@antv/x6/lib/model/port'; -import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edgeHoverTool, edge_color, edge_hover_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; +import { nodeRegisterLibrary, graphNodeLibrary, nodeLibrary, portMarkup, portAttrs, edgeAttrs, edgeHoverTool, edge_color, edge_selected_color, portTextAttrs, defaultAbsolutePortGroups, nodeWidth, unknownNode, defaultPortItems, portItemArgsY, edge_width, conditionNodePortItemArgsY, conditionNodeItemHeight, conditionNodeHeight, notesConfig } from '../constant'; import type { WorkflowConfig, NodeProperties, ChatVariable } from '../types'; import { getWorkflowConfig, saveWorkflowConfig } from '@/api/application' import { useUser } from '@/store/user'; @@ -523,7 +523,9 @@ export const useWorkflowGraph = ({ * @param edge - Clicked edge */ const edgeClick = ({ edge }: { edge: Edge }) => { + clearEdgeSelect(); edge.setAttrByPath('line/stroke', edge_selected_color); + edge.setData({ ...edge.getData(), isSelected: true }); clearNodeSelect(); }; /** @@ -548,6 +550,7 @@ export const useWorkflowGraph = ({ */ const clearEdgeSelect = () => { graphRef.current?.getEdges().forEach(e => { + e.setData({ ...e.getData(), isSelected: false, isNodeHover: false }); e.setAttrByPath('line/stroke', edge_color); e.setAttrByPath('line/strokeWidth', edge_width); }); @@ -840,15 +843,25 @@ export const useWorkflowGraph = ({ // 1. If both nodes have parent IDs, they must be same to connect // 2. If both have no parent ID, can connect normally // 3. If one has parent, one doesn't, cannot connect - console.log('sourceParentId', sourceParentId, targetParentId) if (sourceParentId && targetParentId) { // Child nodes under same parent can connect to each other - return sourceParentId === targetParentId; + if (sourceParentId !== targetParentId) return false; } else if (sourceParentId || targetParentId) { // One has parent, one doesn't, cannot connect return false; } - + + // Prevent duplicate connections between same ports + const sourcePortId = sourceMagnet?.getAttribute('port') ?? sourceMagnet?.closest('[port]')?.getAttribute('port'); + const targetPortId = targetMagnet?.getAttribute('port') ?? targetMagnet?.closest('[port]')?.getAttribute('port'); + const duplicate = graphRef.current?.getEdges().some(e => + e.getSourceCellId() === sourceCell?.id && + e.getTargetCellId() === targetCell?.id && + e.getSourcePortId() === sourcePortId && + e.getTargetPortId() === targetPortId + ); + if (duplicate) return false; + return true; }, }, @@ -885,17 +898,20 @@ export const useWorkflowGraph = ({ setupPlugins(); // Listen to edge mouseenter event: show hover style and add button graphRef.current.on('edge:mouseenter', ({ edge }: { edge: Edge }) => { - if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { - edge.setAttrByPath('line/stroke', edge_hover_color); - edge.setAttrByPath('line/strokeWidth', edge_width); - } - edge.addTools([edgeHoverTool]); + setTimeout(() => { + edge.addTools([edgeHoverTool]); + }, 0) }); // Listen to edge mouseleave event: revert style and remove add button graphRef.current.on('edge:mouseleave', ({ edge }: { edge: Edge }) => { - if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { - edge.setAttrByPath('line/stroke', edge_color); - edge.setAttrByPath('line/strokeWidth', edge_width); + const data = edge.getData(); + if (!data?.isSelected) { + if (data?.isNodeHover) { + edge.setAttrByPath('line/stroke', edge_selected_color); + } else { + edge.setAttrByPath('line/stroke', edge_color); + edge.setAttrByPath('line/strokeWidth', edge_width); + } } edge.removeTools(); }); @@ -907,6 +923,7 @@ export const useWorkflowGraph = ({ graphRef.current.on('node:port:click', nodePortClickEvent); // Port hover: show circle style on right ports graphRef.current.on('node:port:mouseenter', ({ node, port }) => { + console.log('node:port:mouseenter', port) if (!port) return; const portData = node.getPort(port); if (portData?.group !== 'right') return; @@ -930,12 +947,15 @@ export const useWorkflowGraph = ({ graphRef.current?.getEdges().forEach(edge => { const view = graphRef.current?.findViewByCell(edge); view?.removeTools(); - if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { + if (!edge.getData()?.isSelected && edge.getAttrByPath('line/stroke') === edge_selected_color) { edge.setAttrByPath('line/stroke', edge_color); } }); graphRef.current?.getConnectedEdges(node).forEach(edge => { - edge.setAttrByPath('line/stroke', edge_hover_color); + if (!edge.getData()?.isSelected) { + edge.setAttrByPath('line/stroke', edge_selected_color); + edge.setData({ ...edge.getData(), isNodeHover: true }); + } }); node.getPorts().filter(p => p.group === 'right').forEach(p => { node.setPortProp(p.id!, 'attrs/body/opacity', 0); @@ -945,8 +965,9 @@ export const useWorkflowGraph = ({ }); graphRef.current.on('node:mouseleave', ({ node }) => { graphRef.current?.getConnectedEdges(node).forEach(edge => { - if (edge.getAttrByPath('line/stroke') !== edge_selected_color) { + if (!edge.getData()?.isSelected) { edge.setAttrByPath('line/stroke', edge_color); + edge.setData({ ...edge.getData(), isNodeHover: false }); } }); node.getPorts().filter(p => p.group === 'right').forEach(p => { @@ -960,9 +981,73 @@ export const useWorkflowGraph = ({ // Listen to node move event graphRef.current.on('node:moved', nodeMoved); graphRef.current.on('node:removed', blankClick) - // When edge changes, bring connected nodes' ports to front - graphRef.current.on('edge:change', () => { + // When edge connected, bring connected nodes' ports to front + graphRef.current.on('edge:connected', ({ isNew }) => { graphRef.current?.getNodes().forEach(node => node.toFront()); + // Reset any port hover state left from dragging + if (isNew) { + graphRef.current?.getNodes().forEach(node => { + node.getPorts().filter(p => p.group === 'right').forEach(p => { + node.setPortProp(p.id!, 'attrs/body/opacity', 1); + node.setPortProp(p.id!, 'attrs/hoverBody/opacity', 0); + node.setPortProp(p.id!, 'attrs/label/opacity', 0); + }); + }); + } + }); + + // During edge dragging, manually detect port hover since the dragging edge blocks mouse events + let lastHoveredPort: { node: Node; portId: string } | null = null; + graphRef.current.on('edge:mousemove', ({ e }: { e: MouseEvent }) => { + if (!graphRef.current) return; + const { clientX, clientY } = e; + let found: { node: Node; portId: string } | null = null; + + for (const node of graphRef.current.getNodes()) { + for (const port of node.getPorts().filter(p => p.group === 'right')) { + const portView = graphRef.current.findViewByCell(node); + if (!portView) continue; + const portEl = (portView as any).findPortElem(port.id!, 'body') as SVGElement | null; + if (!portEl) continue; + const rect = portEl.getBoundingClientRect(); + const hitRadius = 16; + const cx = rect.left + rect.width / 2; + const cy = rect.top + rect.height / 2; + if (Math.abs(clientX - cx) <= hitRadius && Math.abs(clientY - cy) <= hitRadius) { + found = { node, portId: port.id! }; + break; + } + } + if (found) break; + } + + if (found?.node.id !== lastHoveredPort?.node.id || found?.portId !== lastHoveredPort?.portId) { + // Leave previous + if (lastHoveredPort) { + const { node, portId } = lastHoveredPort; + node.setPortProp(portId, 'attrs/body/opacity', 1); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 0); + node.setPortProp(portId, 'attrs/label/opacity', 0); + } + // Enter new + if (found) { + const { node, portId } = found; + node.toFront(); + node.setPortProp(portId, 'attrs/body/opacity', 0); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 1); + node.setPortProp(portId, 'attrs/label/opacity', 1); + } + lastHoveredPort = found; + } + }); + graphRef.current.on('edge:mouseup', () => { + if (lastHoveredPort) { + const { node, portId } = lastHoveredPort; + node.setPortProp(portId, 'attrs/body/opacity', 1); + node.setPortProp(portId, 'attrs/hoverBody/opacity', 0); + node.setPortProp(portId, 'attrs/label/opacity', 0); + lastHoveredPort = null; + } }); // Listen to copy keyboard event graphRef.current.bindKey(['ctrl+c', 'cmd+c'], copyEvent); From 02660c7c971e373701244a55015b1fcc1339a2af Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 12:26:17 +0800 Subject: [PATCH 014/117] feat(web): end user list support page --- web/src/api/memory.ts | 6 +- web/src/components/DebounceSelect/index.tsx | 106 ++++++++++++++++ web/src/views/MemoryConversation/index.tsx | 39 +++--- web/src/views/UserMemory/index.tsx | 130 ++++++++------------ 4 files changed, 176 insertions(+), 105 deletions(-) create mode 100644 web/src/components/DebounceSelect/index.tsx diff --git a/web/src/api/memory.ts b/web/src/api/memory.ts index 1ec2d7dc..ee71bea8 100644 --- a/web/src/api/memory.ts +++ b/web/src/api/memory.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 14:00:06 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-24 17:48:01 + * @Last Modified time: 2026-03-31 12:25:53 */ import { request } from '@/utils/request' import type { AxiosRequestConfig } from 'axios' @@ -63,8 +63,8 @@ export const getDashboardData = () => { /****************** User Memory APIs *******************************/ export const userMemoryListUrl = '/dashboard/end_users' -export const getUserMemoryList = () => { - return request.get(userMemoryListUrl) +export const getUserMemoryList = (query?: { keyword?: string }) => { + return request.get(userMemoryListUrl, query) } // User Memory - Total end users export const getTotalEndUsers = () => { diff --git a/web/src/components/DebounceSelect/index.tsx b/web/src/components/DebounceSelect/index.tsx new file mode 100644 index 00000000..ab8379ad --- /dev/null +++ b/web/src/components/DebounceSelect/index.tsx @@ -0,0 +1,106 @@ +import { useRef, useState, useCallback, useEffect, type FC } from 'react'; +import { Select, Spin, Avatar } from 'antd'; +import type { SelectProps, DefaultOptionType } from 'antd/es/select'; + +import { request } from '@/utils/request'; + +interface OptionType { + [key: string]: any; +} + +interface ApiResponse { + items?: T[]; +} + +export interface DebounceSelectProps extends Omit { + /** API endpoint URL — mutually exclusive with fetchOptions */ + url?: string; + /** Extra query params merged with the search keyword */ + params?: Record; + /** Key used as option value */ + valueKey?: string; + /** Key used as option label */ + labelKey?: string; + /** Key name sent to the API for the search keyword */ + searchKey?: string; + /** Custom fetch function — mutually exclusive with url */ + fetchOptions?: (search: string | null) => Promise; + /** Transform raw API items before rendering */ + format?: (items: OptionType[]) => OptionType[]; + debounceTimeout?: number; +} + +const DebounceSelect: FC = ({ + url, + params = { page: 1, pagesize: 20 }, + valueKey = 'value', + labelKey = 'label', + searchKey = 'search', + fetchOptions, + format, + debounceTimeout = 300, + ...props +}) => { + const [fetching, setFetching] = useState(false); + const [options, setOptions] = useState([]); + const fetchRef = useRef(0); + + const timerRef = useRef>(); + + // Load initial options on mount + useEffect(() => { + debounceFetcher(null); + }, []); + + const debounceFetcher = useCallback((keyword: string | null) => { + clearTimeout(timerRef.current); + timerRef.current = setTimeout(() => { + fetchRef.current += 1; + const fetchId = fetchRef.current; + setOptions([]); + setFetching(true); + + const promise: Promise = fetchOptions + ? fetchOptions(keyword) + : request + .get>(url!, { ...params, [searchKey]: keyword }) + .then((res) => { + const data: OptionType[] = Array.isArray(res) ? res : res?.items || []; + const formatted = format ? format(data) : data.map((item) => ({ + label: item[labelKey], + value: item[valueKey], + avatar: item.avatar, + raw: item, + })); + return formatted; + }); + + promise + .then((newOptions) => { + if (fetchId !== fetchRef.current) return; + setOptions(newOptions); + setFetching(false); + }) + .catch(() => setFetching(false)); + }, debounceTimeout); + }, [url, params, searchKey, fetchOptions, format, valueKey, labelKey, debounceTimeout]); + + return ( + ({ + (items as Data[]).map(item => ({ + ...item, + 'end_user.id': item.end_user?.id, + label: item.end_user?.other_name || item.end_user?.id, value: item.end_user?.id, - label: item?.name, }))} - filterOption={(inputValue, option) => option?.label?.toLowerCase().indexOf(inputValue.toLowerCase()) !== -1} - showSearch={true} - // filterOption={(inputValue, option) => option.label?.toLowerCase().indexOf(inputValue.toLowerCase()) !== -1} placeholder={t('memoryConversation.searchPlaceholder')} style={{ width: '100%', marginBottom: '16px' }} - onChange={setUserId} + onChange={(opt: DefaultOptionType) => setUserId(opt?.value as string)} variant="borderless" className="rb:bg-white rb:rounded-lg" + showSearch /> diff --git a/web/src/views/UserMemory/index.tsx b/web/src/views/UserMemory/index.tsx index 96da9dec..7d5dbdfa 100644 --- a/web/src/views/UserMemory/index.tsx +++ b/web/src/views/UserMemory/index.tsx @@ -2,51 +2,36 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:53:44 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-26 14:58:48 + * @Last Modified time: 2026-03-31 12:15:59 */ /** * User Memory Page * Displays list of end users with their memory statistics and configuration */ -import { useEffect, useState, useMemo } from 'react'; +import { useRef } from 'react'; import { useTranslation } from 'react-i18next'; import { useNavigate } from 'react-router-dom' import { Row, Col, Form, Flex, Tooltip } from 'antd'; import type { Data } from './types' -import { getUserMemoryList } from '@/api/memory'; +import { userMemoryListUrl } from '@/api/memory'; import { useUser } from '@/store/user' import RbCard from '@/components/RbCard/Card' import SearchInput from '@/components/SearchInput'; import RbStatistic from '@/components/RbStatistic'; -import BodyWrapper from '@/components/Empty/BodyWrapper' +import PageScrollList, { type PageScrollListRef } from '@/components/PageScrollList' export default function UserMemory() { const { t } = useTranslation(); const navigate = useNavigate() const { storageType } = useUser() - const [loading, setLoading] = useState(false); - const [data, setData] = useState([]); const [form] = Form.useForm() - const search = Form.useWatch(['search'], form) + const keyword = Form.useWatch(['keyword'], form) - /** Fetch user memory list */ - useEffect(() => { - getData() - }, []); + const scrollListRef = useRef(null) - /** Get data from API */ - const getData = () => { - setLoading(true) - getUserMemoryList().then((res) => { - setData(res as Data[] || []) - }) - .finally(() => { - setLoading(false) - }) - } /** Navigate to user memory detail */ const handleViewDetail = (id: string | number) => { switch (storageType) { @@ -64,25 +49,12 @@ export default function UserMemory() { navigate(`/memory`) } - /** Filter data by search term */ - const filterData = useMemo(() => { - if (search && search.trim() !== '') { - return data.filter((item) => { - const { end_user } = item as Data; - const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id - return name?.includes(search) - }) - } - - return data - }, [search, data]) - return (
- +
- - - {filterData.map((item, index) => { - const { end_user, memory_num, memory_config } = item as Data; - const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id - return ( - - -
{name[0]}
-
{name || '-'}
- } - headerType="border" - headerClassName="rb:h-[48px]! rb:mx-4!" - bodyClassName="rb:py-3! rb:px-4!" - className="rb:cursor-pointer" - onClick={() => handleViewDetail(end_user.id)} - > - - - - - - - - + + + ref={scrollListRef} + url={userMemoryListUrl} + query={{ keyword }} + column={3} + renderItem={(item) => { + const { end_user, memory_num, memory_config } = item as Data; + const name = end_user?.other_name && end_user?.other_name !== '' ? end_user?.other_name : end_user?.id + return ( + +
{name[0]}
-
- - {t('userMemory.memory_config_name')} -
-
-
{memory_config?.memory_config_name || '-'}
-
-
- - ) - })} -
-
+
{name || '-'}
+ } + headerType="border" + headerClassName="rb:h-[48px]! rb:mx-4!" + bodyClassName="rb:py-3! rb:px-4!" + className="rb:cursor-pointer" + onClick={() => handleViewDetail(end_user.id)} + > + + + + + + + + + +
+ + {t('userMemory.memory_config_name')} +
+
+
{memory_config?.memory_config_name || '-'}
+
+ + ) + }} + />
); } \ No newline at end of file From 2ad25c48d282916e3c855170031999fc844c9f6d Mon Sep 17 00:00:00 2001 From: Eternity <1533512157@qq.com> Date: Tue, 31 Mar 2026 13:52:41 +0800 Subject: [PATCH 015/117] refactor(memory_agent_service, memory_perceptual_service): Simplify audit logger import and usage - Removed try-except block for importing `audit_logger` and directly imported it. - Removed redundant checks for `audit_logger` being `None` before logging operations. - Added a check in `MemoryPerceptualService` to return `None` if `model_config` or `llm` is `None`. --- api/app/services/memory_agent_service.py | 108 ++++++++---------- api/app/services/memory_perceptual_service.py | 13 ++- 2 files changed, 60 insertions(+), 61 deletions(-) diff --git a/api/app/services/memory_agent_service.py b/api/app/services/memory_agent_service.py index 289fd74c..c27a75be 100644 --- a/api/app/services/memory_agent_service.py +++ b/api/app/services/memory_agent_service.py @@ -37,6 +37,7 @@ from app.core.memory.agent.utils.type_classifier import status_typle from app.core.memory.agent.utils.write_tools import write as write_neo4j from app.core.memory.analytics.hot_memory_tags import get_interest_distribution from app.core.memory.utils.llm.llm_utils import MemoryClientFactory +from app.core.memory.utils.log.audit_logger import audit_logger from app.db import get_db_context from app.models.knowledge_model import Knowledge, KnowledgeType from app.repositories.neo4j.neo4j_connector import Neo4jConnector @@ -49,10 +50,6 @@ from app.services.memory_konwledges_server import ( ) from app.services.memory_perceptual_service import MemoryPerceptualService -try: - from app.core.memory.utils.log.audit_logger import audit_logger -except ImportError: - audit_logger = None logger = get_logger(__name__) config_logger = get_config_logger() @@ -68,24 +65,22 @@ class MemoryAgentService: if str(messages) == 'success': logger.info(f"Write operation successful for group {end_user_id} with config_id {config_id}") # 记录成功的操作 - if audit_logger: - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=True, - duration=duration, details={"message_length": len(message)}) + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=True, + duration=duration, details={"message_length": len(message)}) return context else: logger.warning(f"Write operation failed for group {end_user_id}") # 记录失败的操作 - if audit_logger: - audit_logger.log_operation( - operation="WRITE", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=f"写入失败: {messages[:100]}" - ) + audit_logger.log_operation( + operation="WRITE", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=f"写入失败: {messages[:100]}" + ) raise ValueError(f"写入失败: {messages}") @@ -338,10 +333,9 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) @@ -401,10 +395,10 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Write operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, - success=False, duration=duration, error=error_msg) + + duration = time.time() - start_time + audit_logger.log_operation(operation="WRITE", config_id=config_id, end_user_id=end_user_id, + success=False, duration=duration, error=error_msg) raise ValueError(error_msg) async def read_memory( @@ -469,10 +463,9 @@ class MemoryAgentService: logger.info(f"Read operation for group {end_user_id} with config_id {config_id}") # 导入审计日志记录器 - try: - from app.core.memory.utils.log.audit_logger import audit_logger - except ImportError: - audit_logger = None + + + config_load_start = time.time() try: @@ -492,16 +485,15 @@ class MemoryAgentService: logger.error(error_msg) # Log failed operation - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) @@ -633,15 +625,15 @@ class MemoryAgentService: total_time = time.time() - start_time logger.info( f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)") - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=True, - duration=duration - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=True, + duration=duration + ) return { "answer": summary, @@ -651,16 +643,16 @@ class MemoryAgentService: # Ensure proper error handling and logging error_msg = f"Read operation failed: {str(e)}" logger.error(error_msg) - if audit_logger: - duration = time.time() - start_time - audit_logger.log_operation( - operation="READ", - config_id=config_id, - end_user_id=end_user_id, - success=False, - duration=duration, - error=error_msg - ) + + duration = time.time() - start_time + audit_logger.log_operation( + operation="READ", + config_id=config_id, + end_user_id=end_user_id, + success=False, + duration=duration, + error=error_msg + ) raise ValueError(error_msg) def get_messages_list(self, user_input: Write_UserInput) -> list[dict]: diff --git a/api/app/services/memory_perceptual_service.py b/api/app/services/memory_perceptual_service.py index 5c838fc0..7cf94a1a 100644 --- a/api/app/services/memory_perceptual_service.py +++ b/api/app/services/memory_perceptual_service.py @@ -244,6 +244,8 @@ class MemoryPerceptualService: file: FileInput ): llm, model_config = self._get_mutlimodal_client(file.type, memory_config) + if model_config is None or llm is None: + return None multimodel_service = MultimodalService(self.db, ModelInfo( model_name=model_config.model_name, provider=model_config.provider, @@ -265,15 +267,20 @@ class MemoryPerceptualService: with open(os.path.join(prompt_path, 'perceptual_summary_system.jinja2'), 'r', encoding='utf-8') as f: opt_system_prompt = f.read() rendered_system_message = Template(opt_system_prompt).render(file_type=file.type, language='zh') - except FileNotFoundError: - raise BusinessException(message="System prompt template not found", code=BizCode.NOT_FOUND) + except FileNotFoundError as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None messages = [ {"role": RoleType.SYSTEM.value, "content": [{"type": "text", "text": rendered_system_message}]}, {"role": RoleType.USER.value, "content": [ {"type": "text", "text": "Summarize the following file"}, file_message ]} ] - result = await llm.ainvoke(messages) + try: + result = await llm.ainvoke(messages) + except Exception as e: + business_logger.error(f"Failed to generate perceptual memory: {str(e)}") + return None content = result.content final_output = "" if isinstance(content, list): From 52ae914e1711923086fd694c9832605a72220612 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 15:43:18 +0800 Subject: [PATCH 016/117] feat(web): rag content api --- web/src/assets/images/conversation/ai.png | Bin 0 -> 6364 bytes web/src/assets/images/conversation/user.png | Bin 0 -> 7990 bytes web/src/components/PageScrollList/index.tsx | 16 ++++-- web/src/components/RbModal/index.css | 3 ++ web/src/i18n/en.ts | 7 +++ web/src/i18n/zh.ts | 6 +++ web/src/views/UserMemoryDetail/Rag.tsx | 6 +-- .../components/ConversationMemory.tsx | 48 ++++++++++++++---- 8 files changed, 68 insertions(+), 18 deletions(-) create mode 100644 web/src/assets/images/conversation/ai.png create mode 100644 web/src/assets/images/conversation/user.png diff --git a/web/src/assets/images/conversation/ai.png b/web/src/assets/images/conversation/ai.png new file mode 100644 index 0000000000000000000000000000000000000000..3783a5436736da0e92bf5c19a27081878ca1ec33 GIT binary patch literal 6364 zcmV<27$fJ2P)@F_t-bd-=bn4()`KdGvUmtZ6(AzT%L-sOrXvANh_;W!lsdXQjop%w9!dX^ z&KTV>w2g_KL_#_bF;SBt$itW*rfn3XI06cS#>gumZzzfaZq=>(IM2P-n*GPxd!K#I zy;W58NMFX`oI20F=9-Uhe)C(xk^f~$lDthst^iIKW4C~AF=D30SP&2d5dn$!@u2|4 zy$XVYk=oY^peTUq=xGI2MbUV#<N@Jd#Dp5xi%! zwf3XHZV_1{QZCXUfEa@qQ+%0>9*baV3ylT<5z7D7Zsp?qm=RPp8lr9l)ygmcAt-17 z4Je=?psI_i`UUU($AEo@Tbw=2YYk(}Z;8mwB60!HkZNb+g^K|Z6Q3*oqqgJYyHDXI zfS47FuY1F!(Z=5*h*myRT59yUXut+=o-yVl*4kzW;k$r88V}fPt-VcEcU2bxnA+|} zH#dl}(bW&<{>92i1{z=J=u)I`Z9fVFN-qe6P`ooJo=|y-YJf2Efan>a{^j?1@87kW zH<(oxf2N2$FuC~heud>kSC6ggbG0{*Eq?MAfKV989CmxO_TrJmu&VAd#ykML>9ssy zvxqzZY^}R+Wv$W!Omtym%vcB@f|$|nkGZ}`VQ*b!_g7w0ndAU8j3+}?MHE8m1tFlp zR}r8Z2wMC6D6i}ZA)HtDgt`YbMC3ML>*01ZZn_q+EVbP|^mfRY8?5g7llwb%X=V6u zB;Q}p3oHJ%7He#*mIQxGM0VG|w7%$i{0jocV#=7;ng`Ufe--;;YP&yv;UZ>LFA&Vi z@RHgCoiLe9B*Q~&9sTU=rd*357q9* zm(0Ybxru}rpm6^c<**Ek@;R@7FJ2$r#!BkwzN4z!fPLi^eDqL@KN9qyt6O@Nchujv z+RmHm{n3b3-CP+4wSPyC92rgydo2L8jWHiBe=!4gb5wtV$hgB&PxLD;z|?mC(8)qX z*m%rFe)r@5lY74P_x$yj{)P)KxbTqCFNZL+zGl7vRox9_k|fDbJLkSOmgFWpp#JZ; zGZrzmn8&PM0jeLgH5_6N5%I>v;YUiF|d}(V(}5@%4MOdYfLYh2;!MuaU#75xPWsbZZ&7))1-*29CKJukk% z-~R3G3tFvrmF*@I;b;@&Xq>J5+hSD54WOzV)+5A3FR0p6tyO{g zcN|$os6X$$<14rSHG{#B`MDRl_4A+Q*MIHb^YXrzR>14=u5$F%N9>e^0TN*PHE{HV z5p+E3BZoAWOzcjx(PVn2%~WfK_3PHrYE2O|u&}hi;=%%*PM5b|av5u8TU`6eKjVhK z{L{LRPlS@LVpttRfFiPDLnr>GlcOzvEB8$?MJHJPm=v+73TxM{<(#w6<=mZTbM{$h zv17-n9J_f7Nt(o^No=T=yB(f*@<}$mZVT!IgYLmq!l(ADWlu16Z3ZM93I?yfB5M=% z-f_4xI0=Je2`ow>V5CNh6lL&)<4@p&zxV+zzw~lW*|v?Wokk|dfiGIu(N8F11e+Q{ z0I#0;`2`+*$Jgj|J8|JbtJZj|DJsieg#uT$#S>^Se*S0-fEbKO(XEaQ zP<4bNY*@dZ5ANR0RsZpqIriAsfr58MM_2rCs4X00Q1$Wm%qW>#Yq0`B;l6wCa1C{r=rtcikV*YBtFSj)+P_;ieP0(~P@2Lj|u2K?sEl7!iyFtWi>d)C$c+S(B8#qG76?a?d^Y z^6!4>->`4r%P~=)1P##u=&E5b*&eTIfWudT;x&_i2-~-v!k2HmjpL6$k$mV0LF2-l z(DyL#FmTZKLeB}koa*_0fI*1=&OHnjoJQ~P8sdf zG%cy|y!CZ~3pY4gt&AW2;}3c7d)~vr`2z($(1IWxfu(;|1E>gDl5xuyzr>}NUP8Y= zz`2BOSLivz_y6G`KK+^Nc=qYX(clTr#e$~+4UXVle7;DMwWgw&G$nZN81_4;dNg>P zDkok1K6ZWh<4n)4XMHp9&aDHswgQb-lVAOnU*?uuZ;qw0$U%hF>9NWCqlCioB6pm+ zowLt52bX(75C#KbkQ<(U_C>D$^zXChg=cA`4G>7OH8iJ>C2LL5S)8LgzlVPJ0Kw-q zBg6!gWNq5BTWC#hz&pqCffwj^=J3AHQ{TJ^F~Zw^^^aH_7@q1_HZ>fH54`JrKhG_< z-hyako=W~{vXqAqR`dtfT6=W~8xO|_Q9$6dF~d*4?>}+bPrVh5t4 zckcZr%|-)6$eL@}{DyP+*(=_`WtUvY%=+V)oA1)=?8m#oSo31jDb{V;!39704tBr) z3eMcQlZAeR#rZk%;WBBX#r*S+vGvT0nccjFDXSdc%t=J?$y)CI!4u5ydurACU!%cR zRJx<~fVaN?*SYBO50DvU=mNPH21A%%9P-q&Pm&}F7=ty2_Uv(7@s6M5Bkz9;)(Drq z^%RD~A$RV7iQ!-gG+?YHgh0}mVfz{9@*7uO!M5Yp0dV}5^?dB*z4Vvo(6EH>by(Q* zV@^2hY=%A}zuY-yrlxr3f4GL*uJ7{nKYSZi#TdhC8#@V%%#n70BJhSc?&SPS-pw)X z2tmE~c&FsKFc^5coh~AA)L?8zd+kQvwrd;KL`r7;FtG*A(I@c8~RX@QYE5L~Pdn>ViG;$OImq|ro7Qb6|6w7Se? zV~w52N&~DUKt^rbsXH%XdU_LU8{mQ`c#rc+o`WAM-g%5EMrOv6Bneq#d;r~QrX)5& zUCyApKzDhbyx#>eG#lCY=-L_@S&E2_m!XOYp3HzLpt3_rQz7v=8;;%1aVPJHJ&F;; zCK$1CgkOX2*4oMt=13WUBuUtG!m0Foo?|9%o)mZA+NyCg#$rqo@2x3@HwJ6%cn(;N(J4E!w9z1GuVdKv1T|<7 zeA(o9!3CIqWe;<&yg;Mb0u_2oeLjEFMw(f|X(w&wJ3n}uuibGci?2R~23I2lltFin z@85SfpStk`-g(J699-)07q{NUzGol74?Ccd;ms05%R1RBTT(fjf+P1>G~ZXW_- zin+5!&{PsAYD$_sj0|w7#g~u0^DG=(X3gtTmY4geCJ|tS*zg!qn$0P^9}-+f2#&>r zdwA%+JGkb>7nohQo>%w1$ji_Dh+%JjT;PD>@(%l+`5`xcF6Ud{`YtXX@X~XSv3TGa zLKq@quqI);)uNR^I{}GI0M0ARJ*Dd$z2#m^fKq3S>O_e_P>d?7mSnXt*oqP;f_IK* z9=e~6C;w|++V6SO$)ol*c%{+IIPSPpx&Qlj(rnJeVMXY(wEtO_54?yVgb)#+hLvIs z4S{~=0Q;W4pZ$9t22nKRsFXN9kT#lZIsQai2BvM4#)t^ZJ;!1)-3)0 zawJJq2|mXSml*Ws$%h>@nBtn@Wq$n9N zi9`m-4f?$N^!+4_Eu4S$DKxTxa}Lpz#YM}yV>YvX!&ZLugKx6DJRhNNh;TNji*8tn zpNav7*4#5hDU8;7`#kT{U0PtrS-ZIMzx^Q_Hy*=G3hPaeB;+J%!Xr=hxc!TN##0a6 zNbnx-73bq(LvTf+EP^g*Rc2Kcn=DYdFGKjIqlT# zY~Qhi!C-(f0p~LMeK_HyH*&_#wY@96=QzjKbjm6pwn>9!q zOtYIpd5o`SZ`-fY_Ab&|$R2 z-2NFBmMsAxNbK2S6ic0yzDo#B znQpgfw%V*M_&zV?&PPp!F<=s~2GRtqh15bLfkp}qtF%(3VI2+WqrEO61RBkhrNsqq z`26P>=5E!jB6X#VP?4xqwLwUH2qZR*i#4NBQ@QYJ1YQaq%TPhN@$-Md+2@?k_EXPb zaj8e{J2YBdnyn1)69R^y2CoJm3|3hYR+Pj9tOP^?q5+$Gbl@3y zPm&5*nvf(ZpZ(0IdEkNjtIg4pDIT+cDNL`Uc~T{yQe=8&t$OcC(A${3ZU^AN(GRON&JrjhR9Mc9141TI{q| z?mWstb&Y7`@kdG5*mz4*L)OTcYBbsR@?Nh0A-;iY)vSxO}5-=@dFjla}Vy!JkS6V4B(ngAvq+(f& zwc`qv38OU9lsE4@kDs{cUvSPj=QF!*LlukEv$c@|bVV-{B9rs%vrlvHcmAIH?!AX6 zfBblb4&!2|Ey!16w)j*boDPatKV|?eSm-RxuxUzgZb@cl*FSEA?Usux;tsRumngJ>=2)>}n0yxEMOvL4=Jxw!Er9bFZ%#~ukowQn0j}C{!i%a+P z2o_U_4|GibU&gC*d1QUFMijJG{Qg9K&tra2`9j^M{sB~KXF2q=mvl4lkBr8-+RFc=VA)aH~FdIgsYaYK1}BLrVSWM3?*05HV`mPQ?$ z>eRd~ZcKtTu(Y(y!ooaxK8(GFHQ3^8f=a9o6%f-T=MI;sbb>7h0f*^`l@a zfFA*#s6+CVe6E&f?#uK1VqmaB`Pgo^H>>Ic#w1&fmB|VNd1$>GbK?sHVEO-K; zI!0IM1ok-R#tu!bkbbt?*_mkiRk6udF;cX}W~2}=o1&^R6t}7#$HAx=GbXdG-;bT> zD)D-8(Paq^HNj)N_eJumNnVO-P#vwQCDy1EsUgM?_84O>9}b5Pz2?EmcDub<#N4Ks zUB;FQE(4|ph-C6`L>(j8%0Q+ru`4h}C`wT&s@AbKRY%&tdReR+q>|E=hbBklNo0Y; zq!{A9JkM7>IC(g2bz^FJ`oD_me->-ohI;o`oy-~gmQf#Hlqs>}bB3}7)|xs*O~DY= zj=bo&9(aTcLQO7Pwzp9gLbdohn?y(8+Tn2c@iOL*`tWp-7d|Rtb|b8j;&?_BNt;Q; zTAd3%)Ty(>ix=xEu1L+53yp`y!x&~Ts#P1T>S`oQJUpG}`D-4YJ{om>)^4}orm9zn zn9~v6qR1A5OjA3_5z3P!A~hIYtLHj#;POyr&_Xj;ofH}eO;tFBu&kS4FR~lVh^ZJtW8P9) z?4udkE3aj2gE6)-GoSzWedpZe+;i@^_j#WCJoofM=Y9Pjekh5q_Yfgjq(vl~;QA&-XmB z??w2`y7&CI%pcvHL>VuBpwT85Pao|J_}7?&RFhG!&(Ypb*J<4ahSu$t5{8&sKp`tI z$yyBN084v*t>(EZ&@8DMoLCD@^?*p%o@teLIH8bTI0*|yFtTb{6KsqJ$g_BAguq7$h(4f4^+I&^QnTbc&`1^NdZpJcV{BZ<0 zYpFcgF;BdG_o__nC(7h($cKHqv^x6+iQIl@0X#!qFC9t+*N360~N>er%fHG0=NN2?7y3HFGW!=bPUF|NTwBC9;sYHFb7MK)B+O!TqWL818;=b zZqkTd%dFVy9b*qi9DW14546lv+NP(a^_k4x`6A&`jMS5lHw&T|DozQ9#8yW3qOway zx10mnGC%?rc^o;gP#`=aXoejvsSO}|mZgzTp%z}U#*T4qoGr&S@v#mO8&$1Ol}(|M zSm8>Dy!9yGQ}~aAk5YqSsIeMa!a$Qo+ytSR)BnMF(nI}4SF4}+ZPz~ZbTZ$LhP%!) zzIc}wXMKIo#?Oy}sGoKdz6}srEolV9jG3-wcM)49+w`J_lE@ZlJ~ubUPD&x!QIpWU zC!3kWE<^a9xb&|n#n^MJ3IZ7^FjItp&H)_cI^n!MH(g|Ab}cyn6|**REFZkHL_uS5 z^DmEQcJ_9P$@AXNweVB#etdGp5xgR!Lf?w7J{zb8O)TwgPsJ^2pVCMeWiwU+H8Ea& zf46+D76q0-MV90OTpXwn_jKVCA`t|!W%EajO~@+MX+`_~8VR!=X^(0X&&u*tDdl0VtM$--tdH&7WPgywdK-F+(0S~OyN zvt*3vcRXaOkepO+oY3hrLMjmSbAX_rw3Q}*4R!Can9=4#{~^@($!0eSc4g7uu0E?0 zsVkruf6d-k5B2+S`=_5b>6N|I;$l$zL<2SPj=)*YXUY|uv1a#{Uo+Q@4YnzL8p!Pf zDfW~-r@fbuS!-h&04@;ZRUvipVuc$BcUWQ@Eo&7hLJ>BhGXq*I4@*;R;#D|i(*L4R zK;q|zC>ote&BDcvth9%!1Um+;(D5F=+QkfmmkT!}e(tAJm%6FjV13j-LE-*2D4>GMIG+&PM!vDM@a=N0wAJl3u_g13HI14Su|#<%p%!CSBod>@-H>EXb#+; zh**hv*G};L@o8{sx|5ma`osd-ab2reESWXm`>U^pF<~%BAq1NlD}TQq6gzwE!QMvh z)nX6ks*-G!6d~hdB84BGA~>i3b8|s0IQ6ghWwNNz~3a z-2~o zHQm1#0}z=4-n!yZ&sC#RS!tJPSJ?(`A*Q4cI#XnA+%3* z(Dfv&`tLg92197U)PR|`vBPiSw2l6|_U{u|C7EIxx86q|-Aip>jtZry?Ox!Va?dAE zBC{@svmCZJk77(xa38!mBURJW$8}|5@5HW6+Ihydi>g-ff@S5^gne&2D5@zS;c7Gq zeCJC$Ayc57T9QeS_35Y9wb~)^3p?#Z>Qvj3Gii-rI7i9$auYAU)YC9wmcBf{g%3(tC$cBYJEFdG_J`PKigCz& zvOart=W3s!ZuZ3`|0^NaZyX21w1k1M#|i+U2;HxFrG*Uw4%T1-U+ZNxVHmpLpXZ`< zY!?c-T<8>b%d_d(93*5ApMOgI=I-Z0^p|zm*YN#P<1Ka0%&zg`(ebn)qOP_f_%lO# zyfImq%xkh#ldXXTd0(<+IGZxzv_I*ytU=f2RrjqQ$>zI`(~TNNwCT}~?#Rb66o8eo z5H|N>DzEKb@K3Pr*U#P0xf9I_g)-nv$|6`j%{fXl;sfn?_rW-71taAYx3)vSXrvbr zj@a%QBJ>OdiG_rO#jO54Arn0tuI||x927hG#_ihGf5hRoE(Lz@;~~$15vDNbb#FG^_7TGi!!8t&bRw2t$meWe{w<9{qbdzWa4E7>th*7ma92u4EZa;PnEgQyV+dk zBZX-?pZAe#9SF{?un#-$Vx~r(WSu2)oMjSbM|avpciP%lLqlTrv!1Y*CULi1?GrW# zvnz%}hD|ptlV0dZ5m~XQzrhPzsfPB%wN5D~jS1)Mv}fhW<G7nPTQ18SbgG#T2=`|ZFL&QZ(eDO+6|$3Zf!p?4s<~viwk%fR0Ywz zr<2SaT7EWo@$&=ptNQK*0V6yO5xf(JghP1@9%d?VKWgQN)ESh zsNSBMyY_BI8~p=P>g#Lok#82C%rc)bTC1`uRh!)jt6_2fogyS~hDXLbYFhC5^g`6W zYsKb7Sr_{Wk?9#d8(U8AzpiElAA$!<(=U zLdd<)(t2m9(&eoFot-GwqVc}d3#HKy4R^P1B#fDPi)8h z+n}+l?Fa+rXJ6xuotzg=SDuV6MK5CW^L=uBeZqd}sGqc|%>}-hd2L&YA6lidw3+B!>BIHm4S34(Y)d~z6 zacQL}YLBGjaPcTcQ>>0`{a;gq6w%f_D3m{0W^Qv>kT2r&(hw_SurJ6c!Ve6@qs z5kR!DHn&j#fR!=LuwvDWB5@U~067o@A}uEfLksf=r}-@|H4}{YQhGEv*$CS$fLD5rqM7_&ag_kc|ncvxhj& zhD;ag(R%dXA=enM7OtZdJ;RaAW||ukHx^Y+igoHt zQ^rC?V0tdqvL-^N$=gEiF6lyO72Bgm&a&oAaucV&ne=CxB{(W_|O-oHT=4syF=nOSL`TKgv-pPm_;8k_iiMNQQb)x$eqby&7Wl@GH0W4gsOd4 z;~i6*FXL}J5Sb+}5DY@*#V+IsXK(hf-e^;bxJEEO^s=9QUrS`XdmBlyC zX`@{@B zmeo1o@@>#)lIfDsp&5)}bQs~3ueNn;7u4u_v~|VYe;8Hgm`@q=h+g#Et{*do`wF5( zjyVttobk^un}ULDrGES}L&M!wF!JX{29ii7d>HjAjiLz(>2I)ai7s|+e0sc1TZgV1 zI66=o9!{I^YYGV){|OIG9?jRxxmY%?GhyVJ4(Umo{%9}+$)aoOTsgQC`Q6M`s#tw( z<*R?-L_^BOQm;(@G!I!p>-s`bYo~N>xvbh}YPBE>ET+Gtvevs`5BSXuG?mUw5SkfylBtc=Nmv9LZ6)HTqEI5(E0-K9YAYlH z5v~Q^)dRpF`I4f8Mu#$uWtm%lo%sFh{YEq;UGAv0e(ZS6U)1x7^5#KATyO{e!RFhb zqRGj&_VzP#4MQyfcv5}V{7f5vmLx3kYSCb9_zs-1;2ff6I2AL22>USPcTaNRjHQ%v z62#i7u;^(eut4Zp0mbYjOtF%iRBIgoJqH)s9`p`iR8?Fv*frnUQ17vH+&YxOcwJ+B zoVfqmZ=_4-^YWovg@3a6VrY4%kygNnaF5AJ?S63u)1TLnL9|Y>uk9Gd#E015tABk| zpV40ffkGlls=^U!wj{v}W!S*43=hxt@gI1pv6o7OA>=@vk;KIP>!A{ec12qnNCQ@@P3oU()TV)rA)9Dm(*J>4lp1ieCfO1JgL1lrg+4E|!lwY8q;rg!* z2yFYg@5#F0JgY1&k<`u#UhYSW$KQ7HlZ|Qj8d=B3fW5QoaJLKO_4(7 z+6G4+`159d$o1^1C*_c8UT-)L?`bkf&?=EDPk#7XY1?khtY_=1%k;J4V4cN4!xcw` z5MN27Pv6|6d@md^<&(0G$bFk4JQ$5sz z!P#m~YQnAyBzM>%HA>07e%iPCvZ-5>uZ&nM$KqxiTD5qFnLIDgIo<%-B{H~H!8WOy zq~&Y*oi@9oZR<`$W3QxIazTep%}l6d;Nv79qj~r_a6&mA(q4|(8PF)DP2~-Ac%5_2 z^b3mZ+o6ztl|brjop&9G8!z^yX*$M?R}bIUz)T=q4Q4deZ;3)YR5S@(8!)F#|FJ-?h*;b;LRVC{WFP&cs}xE4 z;qCZekN&WXgXy`%<}ey{6c`EQIzriG*a(ToOvIrYqk4iz$Xd=?p_THi>FA_23sO|Q z!LXF&6L|o=hBRSWE^B9Lsy##}sB!dC>E~N6Z^`f95Z9JdrNwaTF%RDO z<4ZufwufN&z_$zD^)P3am=5_ta^5DIA6}t?j5DCQSI%kG+Ia8{=ccohGGIc7glLVR@^o!J~DvD zLXG4c4o|wxWNR}9_!aN|%DhXbN1wiHatB_VF z1uc1=Yd!gTg|n*L{ilVaV$kUs7KU>4<(q`}DCpvshqtZEbvFnp9MiNio`5 z58*6zJz2lgm9W)QY>;rRWHGzCvH{YB5e`U{%@gD`9c~0#LJM0heS*_HYOS51RwZV9 zyy<3uVPOdcVPUS?FPqd}+~God+}19wGma&luxyMZ_9}GdeN$3>HRfQ5 zI5}R(7dOWzAP}{>x^{0hRI0rl-8rx>lDT`R+!x;%wQ7)zn+{=lezEUnOY5)6_T3&%cfx$$7XQ)Y7zfNvUtI(&#JIX3GI~Mq^3_}I;?H^_B7$X|JxdC z#hj7RzTkG=z`&&grBiw%@WsLwx@(h^j7Jy|cZ$p14<5JE-oKId$;1=`T);f3rY0$} z4O~_dqi~VJYU>nkMR~!LnHuc*ULA2MgiQH9!n1qju4m`UpjY^Zi`x$|qBSs?pWa?m zp3$4@t^zrOt1_$KZc5JVA9w5Yi|4d$VrcnH&}MeZm5O$XuRAR%%2i!?3>4%RQD}vyU%#M)pA~Wuw1UbB7MMXuB?`Tw8q-Ew z_6duy3RJ?}NZ@4WS;`82e`)yd4_J}w#G5x$6KyT+iyucYQmj1R+vwIH+OUR!7q1QA z4!QN2OSRbP_&PEA&)3+PZDh=!k&d%Q9h*+wh)I3*9Bi(tp#snZsPeorLw@domVjD2 z%uoz1dk&=BKxzcbAMK6W`Ei^>SMu1*(iIyF`s%k{OTAQ&uzw%lH={{wi(7rD)VTh= zzT(aJ2kOS&)cTGY>}CDk}{|+w{E%26JMyE`m}+&_A?v5Kjy})E2L0L)~4O- z>+7M9N7GwWWIF!^*RyUL7`Z4UCTH@or^qw~$T}j`GYm zHDZ}vQEj)a6hI9H1cH0~1Q>MFxv(W&3p(!H3mx%?RHdrMs}ee$_qX}-Wsre8gF$>F zCz`Aa_)`;=t-#gsHD?ZR_+emUo}){r-=&Z|9h(*9z%Ax%=L&~8dTH1G==FW68x-Gi zJ&dw2FtMWja9q0HX(x1|1zg2lG!u zVw}fhs|2?PwJf~W@4UFO&{1%+v{YlJWE%PAE&O_lDZhOBQ&avX3A3Qc==iY12%C3y zWE{7x)(Eiisb$qyZ}B*G>DhRC);QAe?JXxIi{(Abroo@c8XYN#l7TX(knEg+s``RQ zoYq&(nOr7?ST2*`x<|AK9VtNV3%$DhkH5<=%i!i^GRVku| zPBm4q-78dBIXP;;^0u*41^0UY3+A5txAd$NdthvE%Fz8qeY4&|I$aa*@z%L9AaA~T zVy3;Do?+C)C=f8d3-r9X+&Qm{Zah_-5fuQ6$nU0_88#7yJZjKSBfTH(=QjC zMeaIO*9S4bA%VMJq{#m3 z{ggIyYv*N{bSHCT>cco#E{k{Ne+oNm+V+=uf%V1iAUqu7>kvwd^8@*zB{ncBNJU5p zaCXSA{bC5bp}`wZJ-D!PgwPDFd^!x_P1G`s{ldUM<(;GH8qr0cRSr<)kA9< zWn*kej-`h0F86$=)-_gVU>ktxSn!8^_#D1Ry*eYnZ&crJPQ1zVye;EK+VuvZZI953 z&-Fz882SSre%@m^yV3!BMdxE$fM2t-sQ3#LIqlKf-nV)QJ>eCc!Su-|zM2CI0VJk^ zipx?NS)wt#E~`YRA9TCuRNbV;$VopB=`!pz!V9y>$Su;eW{ytsef7aQ z63A=)D4I0r&^mfnKsg!O-(pl8un-Z--VTjp4Bn>s-r-*+`yS#~m`vPpljBD+OdO2r z49NszG_wCtrVDXI2fO;vnP6xItb%M{V_Vyz=NYBsh-_sZNkj{y#cfAuvk1n~)b0NUAHGz=ib;a= zmo-!A=7VRQ%P`%!MmDr}zZ`lXDC#7sL~Pym;;Ta2oPZ2Wu}44Q4xx*=)Zb2fHGGnA z(-om-evkh;|43!Ap6AOUyE~1A1Aac5k)v)p;p2DaBUC**=l;tH4Pno)cLptXcXvHc z{a-nr{}Z|Y`?py4Ue*!G8~2VqAD}6XwTr|m_m>He-fHxDIrvg1upCo;CYV;5BfvPy zlTnf-x{qD-M5?JuTlI@w4K`jyT!`29Z-Ww%Nj77~?d5smZ}G>>GA2>geJ( z>#pwLz2mDL7lUp4^RdqF?#J3oPiq>#Z=8Lpy9(iDLKnG6|R zAZ40T)0NK*HRWqD2jUhjEHI7O(6-siK3R?@g!o~W$G-1%{> { /** API endpoint URL */ url: string; /** Function to render each list item */ - renderItem: (item: T) => React.ReactNode; + renderItem: (item: T, index: number) => React.ReactNode; /** Query parameters for API request */ query?: Q; /** Number of columns in grid layout */ @@ -57,6 +57,8 @@ interface PageScrollListProps> { className?: string; needLoading?: boolean; heightClass?: string; + gutter?: [number, number] | number; + onTotalChange?: (total: number) => void; } const defaultHeightClass = 'rb:h-[calc(100vh-116px)]!'; @@ -70,6 +72,8 @@ const PageScrollList = forwardRef(>({ className = '', needLoading = true, heightClass, + gutter = [12, 12], + onTotalChange, }: PageScrollListProps, ref: React.Ref) => { /** Expose refresh method to parent component */ useImperativeHandle(ref, () => ({ @@ -88,6 +92,7 @@ const PageScrollList = forwardRef(>({ const pageRef = useRef(1); const loadingRef = useRef(false); const hasMoreRef = useRef(true); + const [total, setTotal] = useState(0); /** Load more data from API with pagination */ const loadMoreData = (reset?: boolean) => { @@ -107,6 +112,9 @@ const PageScrollList = forwardRef(>({ setData(prev => reset ? results : [...prev, ...results]); hasMoreRef.current = response.page?.hasnext; setHasMore(response.page?.hasnext); + const newTotal = response.page?.total || 0; + setTotal(newTotal); + onTotalChange?.(newTotal); }) .catch(() => { hasMoreRef.current = false; @@ -156,11 +164,11 @@ const PageScrollList = forwardRef(>({ {/* Render grid list or empty state */} {data.length > 0 ? ( {data.map((item, index) => ( - {renderItem(item)} + {renderItem(item, index)} ))} diff --git a/web/src/components/RbModal/index.css b/web/src/components/RbModal/index.css index 8dabc4ab..56d95248 100644 --- a/web/src/components/RbModal/index.css +++ b/web/src/components/RbModal/index.css @@ -1,3 +1,6 @@ +.rb-modal { + top: 40px; +} .rb-modal .ant-modal-footer .ant-btn { height: 32px !important; padding: 0 15px !important; diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 2975796a..294b9bae 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -627,6 +627,8 @@ export const en = { vision: 'Vision', audio: 'Audio', video: 'Video', + thinking: 'Deep Thinking', + is_thinking: 'Deep Thinking Support', }, knowledgeBase: { home: 'Home', @@ -1421,6 +1423,7 @@ export const en = { citation: 'Citation and Attribution', citation_desc: 'Display the attribution of source documents and generated content', invalidVariablesTitle: "The following undefined variables are referenced in the conversation opening. Do you want to save the opening configuration?", + deep_thinking: 'Enable Deep Thinking', apps: 'My Apps', sharing: 'Sharing', @@ -1594,6 +1597,8 @@ export const en = { core_entities: 'Core Entities', communityDetailEmptyDesc: 'Click on a community in the chart on the left to view details', communityLoadingTip: 'Generating community graph', + assistant: 'AI Assistant', + totalRagMemory: 'Total number of memories', }, space: { createSpace: 'Create Space', @@ -1828,6 +1833,8 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re memoryTipTitle: 'Are you sure you want to enable conversation memory? Conversations will be saved to the memory store.', stopAudioRecorder: 'Stop Recording', startAudioRecorder: 'Start Recording', + citations: 'Citations', + reasoning_content: 'Deep reasoning Content', }, login: { title: 'Red Bear Memory Science', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 3edd84e3..38b2a76a 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -795,6 +795,7 @@ export const zh = { citation: '引用和归属', citation_desc: '显示源文档和生成内容的归属部分', invalidVariablesTitle: "对话开场白中引用了以下未定义的变量,是否保存开场白配置?", + deep_thinking: '开启深度思考', apps: '我的应用', sharing: '共享', @@ -1274,6 +1275,8 @@ export const zh = { vision: '视觉', audio: '音频', video: '视频', + thinking: '深度思考', + is_thinking: '支持深度思考', }, timezones: { 'Asia/Shanghai': '中国标准时间 (UTC+8)', @@ -1592,6 +1595,8 @@ export const zh = { core_entities: '核心实体', communityDetailEmptyDesc: '点击左侧图表中的社区查看详情', communityLoadingTip: '社区图谱生成中', + assistant: 'AI 助手', + totalRagMemory: '记忆总数', }, space: { createSpace: '创建空间', @@ -1825,6 +1830,7 @@ export const zh = { stopAudioRecorder: '停止录音', startAudioRecorder: '开始录音', citations: '引用', + reasoning_content: '深度思考内容', }, login: { title: '红熊记忆科学', diff --git a/web/src/views/UserMemoryDetail/Rag.tsx b/web/src/views/UserMemoryDetail/Rag.tsx index f770fafc..a11d4295 100644 --- a/web/src/views/UserMemoryDetail/Rag.tsx +++ b/web/src/views/UserMemoryDetail/Rag.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 17:57:11 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 10:26:31 + * @Last Modified time: 2026-03-31 15:29:45 */ /** * RAG User Memory Detail View @@ -114,7 +114,7 @@ const Rag: FC = () => { } return ( - + { - + diff --git a/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx b/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx index f956eca4..c209274b 100644 --- a/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx +++ b/web/src/views/UserMemoryDetail/components/ConversationMemory.tsx @@ -2,38 +2,64 @@ * @Author: ZhaoYing * @Date: 2026-02-03 18:34:04 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 10:28:53 + * @Last Modified time: 2026-03-31 15:35:13 */ -import { type FC } from 'react' +import { type FC, useState } from 'react' import { useTranslation } from 'react-i18next' import { useParams } from 'react-router-dom' +import { Divider, Flex } from 'antd' +import clsx from 'clsx' import RbCard from '@/components/RbCard/Card' import PageScrollList from '@/components/PageScrollList' import Markdown from '@/components/Markdown' import { getRagContentUrl } from '@/api/memory' +interface DataItem { + role: 'user' | 'assistant'; + content: string; +} + const ConversationMemory: FC = () => { const { t } = useTranslation() const { id } = useParams() + const [total, setTotal] = useState(0) return ( {t('userMemory.conversationMemory')}} headerType="borderless" - headerClassName="rb:min-h-[54px]! rb:pt-0! rb:mb-0! rb:font-[MiSans-Bold] rb:font-bold" - bodyClassName="rb:p-4! rb:pt-0! rb:h-[calc(100%-54px)]!" + headerClassName="rb:min-h-[54px]! rb:pt-0! rb:mb-0!" + bodyClassName="rb:p-4! rb:pt-0! rb:pb-1! rb:h-[calc(100%-54px)]!" className="rb:h-full!" + extra={
{t('userMemory.totalRagMemory')}: {total}
} > - + url={getRagContentUrl} query={{ end_user_id: id }} column={1} - renderItem={(item: string) => ( -
- + gutter={0} + onTotalChange={setTotal} + renderItem={(item, index) => ( +
+ {index !== 0 && } + +
+
+
+ {item.role === 'assistant' ? t('userMemory.assistant') : t('userMemory.user')} +
+ +
+
)} className="rb:h-full!" From b40f4829cb20d4091f9735debf4c23f6e9363a06 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 15:48:19 +0800 Subject: [PATCH 017/117] feat(web): custom model add thinking capability --- .../components/CustomModelModal.tsx | 53 ++++++++++++------- web/src/views/ModelManagement/types.ts | 7 +-- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/web/src/views/ModelManagement/components/CustomModelModal.tsx b/web/src/views/ModelManagement/components/CustomModelModal.tsx index abede886..01cc0fd6 100644 --- a/web/src/views/ModelManagement/components/CustomModelModal.tsx +++ b/web/src/views/ModelManagement/components/CustomModelModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:49:28 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 14:07:10 + * @Last Modified time: 2026-03-31 13:56:18 */ /** * Custom Model Modal @@ -11,7 +11,7 @@ */ import { forwardRef, useEffect, useImperativeHandle, useState } from 'react'; -import { Form, Input, App, Checkbox, Button } from 'antd'; +import { Form, Input, App, Checkbox, Button, Row, Col } from 'antd'; import { useTranslation } from 'react-i18next'; import type { CustomModelForm, ModelListItem, CustomModelModalRef, CustomModelModalProps } from '../types'; @@ -72,6 +72,7 @@ const CustomModelModal = forwardRef( is_vision: capability?.includes('vision') || false, is_video: capability?.includes('video') || false, is_audio: capability?.includes('audio') || false, + is_thinking: capability?.includes('thinking') || false, }); } else { setIsEdit(false); @@ -101,7 +102,7 @@ const CustomModelModal = forwardRef( form .validateFields() .then((values) => { - const { logo, type, is_vision, is_video, is_audio, is_omni, ...rest } = values; + const { logo, type, is_vision, is_video, is_audio, is_omni, is_thinking, ...rest } = values; const formData: CustomModelForm = { ...rest, type, @@ -120,6 +121,9 @@ const CustomModelModal = forwardRef( capability.push('video') } } + if (is_thinking) { + capability.push('thinking') + } formData.capability = capability formData.is_omni = is_omni @@ -238,21 +242,34 @@ const CustomModelModal = forwardRef( - {!['embedding', 'rerank'].includes(modelType as string) && - <> - - {t('modelNew.is_omni')} - - - {t('modelNew.is_vision')} - - - {t('modelNew.is_video')} - - - {t('modelNew.is_audio')} - - + {['llm', 'chat'].includes(modelType as string) && + + + + {t('modelNew.is_omni')} + + + + + {t('modelNew.is_vision')} + + + + + {t('modelNew.is_video')} + + + + + {t('modelNew.is_audio')} + + + + + {t('modelNew.is_thinking')} + + + } diff --git a/web/src/views/ModelManagement/types.ts b/web/src/views/ModelManagement/types.ts index 1662775f..cafac4b3 100644 --- a/web/src/views/ModelManagement/types.ts +++ b/web/src/views/ModelManagement/types.ts @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:50:18 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 12:28:10 + * @Last Modified time: 2026-03-31 15:48:02 */ /** * Type definitions for Model Management @@ -295,7 +295,8 @@ export interface CustomModelForm { is_video?: boolean; is_audio?: boolean; is_omni?: boolean; - capability?: string[]; + is_thinking?: boolean; + capability?: Capability[]; } /** @@ -324,7 +325,7 @@ export interface BaseRef { modelListDetailRefresh?: () => void; } -export type Capability = 'vision' | 'audio' | 'video'; +export type Capability = 'vision' | 'audio' | 'video' | 'thinking'; export interface Model { name: string; type: string; From ca255304d919cf7b38408447140b7f426c71bd39 Mon Sep 17 00:00:00 2001 From: zhaoying Date: Tue, 31 Mar 2026 18:07:32 +0800 Subject: [PATCH 018/117] feat(web): agent support deep thinking --- web/src/components/Chat/ChatContent.tsx | 36 ++++++++- web/src/components/Chat/index.tsx | 4 +- web/src/components/Chat/types.ts | 6 +- web/src/i18n/en.ts | 5 ++ web/src/i18n/zh.ts | 5 ++ web/src/views/ApplicationConfig/Agent.tsx | 7 +- .../ApplicationConfig/TestChat/index.tsx | 25 +++++- .../ApplicationConfig/components/Chat.tsx | 36 ++++++++- .../components/ModelConfigModal.tsx | 30 ++++++-- web/src/views/ApplicationConfig/types.ts | 3 +- web/src/views/Conversation/index.tsx | 76 +++++++++++++++---- web/src/views/Conversation/types.ts | 3 +- 12 files changed, 203 insertions(+), 33 deletions(-) diff --git a/web/src/components/Chat/ChatContent.tsx b/web/src/components/Chat/ChatContent.tsx index ddb25838..2a86ad93 100644 --- a/web/src/components/Chat/ChatContent.tsx +++ b/web/src/components/Chat/ChatContent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2025-12-10 16:46:17 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 14:17:38 + * @Last Modified time: 2026-03-31 15:01:53 */ import { type FC, useRef, useEffect, useState } from 'react' import clsx from 'clsx' @@ -38,6 +38,22 @@ const ChatContent: FC = ({ const isScrolledToBottomRef = useRef(true); const audioRef = useRef(null) const [playingIndex, setPlayingIndex] = useState(null) + const [expandedReasoning, setExpandedReasoning] = useState>(new Set()) + const [manualToggledReasoning, setManualToggledReasoning] = useState>(new Set()) + + const toggleReasoning = (index: number) => { + setManualToggledReasoning(prev => new Set(prev).add(index)) + setExpandedReasoning(prev => { + const next = new Set(prev) + next.has(index) ? next.delete(index) : next.add(index) + return next + }) + } + + const isReasoningExpanded = (index: number) => { + if (manualToggledReasoning.has(index)) return expandedReasoning.has(index) + return !data[index]?.content + } const handlePlay = (index: number, audio_url: string, audio_status?: string) => { if (audio_status !== 'completed' && !audio_status) return @@ -120,7 +136,7 @@ const ChatContent: FC = ({ {labelFormat(item)}
} - {item.meta_data?.files && item.meta_data?.files.length > 0 && + {item.meta_data?.files && item.meta_data?.files.length > 0 && {item.meta_data?.files?.map((file) => { if (file.type.includes('image')) { return ( @@ -174,6 +190,22 @@ const ChatContent: FC = ({ 'rb:mt-1.5': labelPosition === 'top', 'rb:mb-1.5': labelPosition === 'bottom', })}> + {item.meta_data?.reasoning_content &&
+ toggleReasoning(index)} + > + {t('memoryConversation.reasoning_content')} +
+
+ {isReasoningExpanded(index) && } +
} {item.status &&
} {item.subContent && renderRuntime && renderRuntime(item, index)} {/* Render message content using Markdown component */} diff --git a/web/src/components/Chat/index.tsx b/web/src/components/Chat/index.tsx index 49feaf33..f7c0f32e 100644 --- a/web/src/components/Chat/index.tsx +++ b/web/src/components/Chat/index.tsx @@ -27,12 +27,14 @@ const Chat: FC = ({ fileList, fileChange, className, - renderRuntime + renderRuntime, + conversationId }) => { return (
{/* Chat content display area */} void; className?: string; renderRuntime?: (item: ChatItem, index: number) => ReactNode; + conversationId?: string | null; } /** diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 294b9bae..57e95d81 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -1787,6 +1787,11 @@ Memory Bear: After the rebellion, regional warlordism intensified for several re vision_id: 'Vision model', audio_id: 'Audio model', video_id: 'Video model', + onlyDelete: 'Only Delete Fill', + semanticFiltering: 'Semantic Filtering', + sceneFocus: 'Scene Focus', + loose: 'Loose', + strict: 'Strict', }, memoryConversation: { searchPlaceholder: 'Enter user ID...', diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index 38b2a76a..39d63399 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1783,6 +1783,11 @@ export const zh = { vision_id: '视觉模型', audio_id: '音频模型', video_id: '视频模型', + onlyDelete: '仅删填充', + semanticFiltering: '语义过滤', + sceneFocus: '场景聚焦', + loose: '宽松', + strict: '严格', }, memoryConversation: { chatEmpty:'有什么我可以帮您的吗?', diff --git a/web/src/views/ApplicationConfig/Agent.tsx b/web/src/views/ApplicationConfig/Agent.tsx index 0cfdde05..07859527 100644 --- a/web/src/views/ApplicationConfig/Agent.tsx +++ b/web/src/views/ApplicationConfig/Agent.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:29:21 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 18:13:51 + * @Last Modified time: 2026-03-31 16:50:10 */ import { useEffect, useRef, useState, forwardRef, useImperativeHandle, useMemo } from 'react'; import { useTranslation } from 'react-i18next' @@ -194,7 +194,7 @@ const Agent = forwardRef { - modelConfigModalRef.current?.handleOpen('model') + modelConfigModalRef.current?.handleOpen('model', { ...defaultModel, model_parameters : values?.model_parameters }) } /** * Clear all debugging chat sessions @@ -287,7 +287,7 @@ const Agent = forwardRef { const opening_statement = form.getFieldValue(['features', 'opening_statement']) - console.log('opening_statement', opening_statement, defaultModel, chatList) if (opening_statement?.enabled && opening_statement?.statement && opening_statement?.statement.trim() !== '') { const assistantMsg: ChatItem = { diff --git a/web/src/views/ApplicationConfig/TestChat/index.tsx b/web/src/views/ApplicationConfig/TestChat/index.tsx index de98b9a7..b3fca33f 100644 --- a/web/src/views/ApplicationConfig/TestChat/index.tsx +++ b/web/src/views/ApplicationConfig/TestChat/index.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-03-13 17:27:52 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-26 15:35:13 + * @Last Modified time: 2026-03-31 16:04:15 */ import { type FC, useState, useRef, useEffect } from 'react' import { useTranslation } from 'react-i18next' @@ -171,6 +171,7 @@ const TestChat: FC = ({ ...lastMsg, content: lastMsg.content + content, meta_data: { + ...(lastMsg.meta_data || {}), audio_url: audio_url || lastMsg.meta_data?.audio_url, audio_status: audio_status || lastMsg.meta_data?.audio_status, citations: citations || lastMsg.meta_data?.citations @@ -180,6 +181,24 @@ const TestChat: FC = ({ return newList }) } + const updateAssistantReasoningMessage = (content: string) => { + if (!content) return + if (streamLoading) setStreamLoading(false) + setChatList(prev => { + const newList = [...prev] + const lastMsg = newList[newList.length - 1] + if (lastMsg?.role === 'assistant') { + newList[newList.length - 1] = { + ...lastMsg, + meta_data: { + ...(lastMsg.meta_data || {}), + reasoning_content: (lastMsg.meta_data?.reasoning_content || '') + content + } + } + } + return newList + }) + } const updateErrorAssistantMessage = (message_length: number) => { if (message_length > 0) return @@ -273,6 +292,10 @@ const TestChat: FC = ({ case 'start': if (conversation_id && conversationId !== conversation_id) setConversationId(conversation_id) break + case 'reasoning': + updateAssistantReasoningMessage(content) + if (conversation_id && conversationId !== conversation_id) setConversationId(conversation_id) + break case 'message': updateAssistantMessage(content) if (conversation_id && conversationId !== conversation_id) setConversationId(conversation_id) diff --git a/web/src/views/ApplicationConfig/components/Chat.tsx b/web/src/views/ApplicationConfig/components/Chat.tsx index 42ae43a9..c2abf17d 100644 --- a/web/src/views/ApplicationConfig/components/Chat.tsx +++ b/web/src/views/ApplicationConfig/components/Chat.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:27:39 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-27 17:59:07 + * @Last Modified time: 2026-03-31 15:02:07 */ /** * Chat debugging component for application testing @@ -141,6 +141,36 @@ const Chat: FC = ({ } } /** Update assistant message with streaming content */ + const updateAssistantReasoningMessage = (content?: string, model_config_id?: string, conversation_id?: string) => { + if (!content || !model_config_id) return + updateChatList(prev => { + const targetIndex = prev.findIndex(item => item.model_config_id === model_config_id); + if (targetIndex !== -1) { + const modelChatList = [...prev] + const curModelChat = modelChatList[targetIndex] + const curChatMsgList = curModelChat.list || [] + const lastMsg = curChatMsgList[curChatMsgList.length - 1] + if (lastMsg && lastMsg.role === 'assistant') { + modelChatList[targetIndex] = { + ...modelChatList[targetIndex], + conversation_id, + list: [ + ...curChatMsgList.slice(0, curChatMsgList.length - 1), + { + ...lastMsg, + meta_data: { + reasoning_content: (lastMsg.meta_data?.reasoning_content || '') + (content || ''), + } + } + ] + } + } + return [...modelChatList] + } + return prev; + }) + } + /** Update assistant message with streaming content */ const updateAssistantMessage = (content?: string, model_config_id?: string, conversation_id?: string, audio_url?: string, citations?: any[]) => { if ((!content && !audio_url && (!citations || citations?.length < 1)) || !model_config_id) return updateChatList(prev => { @@ -160,6 +190,7 @@ const Chat: FC = ({ ...lastMsg, content: lastMsg.content + (content || ''), meta_data: { + ...(lastMsg.meta_data || {}), ...(audio_url !== undefined ? { audio_url, audio_status: 'pending' } : {}), citations: citations || lastMsg.meta_data?.citations } @@ -274,6 +305,9 @@ const Chat: FC = ({ }; switch (item.event) { + case 'model_reasoning': + updateAssistantReasoningMessage(content, model_config_id, conversation_id) + break; case 'model_message': updateAssistantMessage(content, model_config_id, conversation_id, audio_url) break; diff --git a/web/src/views/ApplicationConfig/components/ModelConfigModal.tsx b/web/src/views/ApplicationConfig/components/ModelConfigModal.tsx index 148afd5a..a80a5905 100644 --- a/web/src/views/ApplicationConfig/components/ModelConfigModal.tsx +++ b/web/src/views/ApplicationConfig/components/ModelConfigModal.tsx @@ -2,7 +2,7 @@ * @Author: ZhaoYing * @Date: 2026-02-03 16:28:07 * @Last Modified by: ZhaoYing - * @Last Modified time: 2026-03-25 11:28:02 + * @Last Modified time: 2026-03-31 16:56:57 */ /** * Model Configuration Modal @@ -11,7 +11,7 @@ */ import { forwardRef, useImperativeHandle, useState, useEffect } from 'react'; -import { Form, type SelectProps } from 'antd'; +import { Form, type SelectProps, Checkbox } from 'antd'; import { useTranslation } from 'react-i18next'; import type { ModelConfig, ModelConfigModalRef, Config, Source } from '../types' @@ -70,7 +70,8 @@ const ModelConfigModal = forwardRef( if (source === 'model') { form.setFieldsValue({ ...(data?.model_parameters || {}), - default_model_config_id: data.default_model_config_id || '' + default_model_config_id: data.default_model_config_id || '', + capability: model?.capability || [] }) } else if (source === 'chat' || source === 'multi_agent') { if (model) { @@ -103,9 +104,12 @@ const ModelConfigModal = forwardRef( const handleChange: SelectProps['onChange'] = (_value, option) => { if (source === 'chat') { form.setFieldValue('label', (option as Model).name) - } else { - form.setFieldValue('capability', (option as Model).capability) } + + form.setFieldsValue({ + capability: (option as Model).capability, + deep_thinking: false, + }) } /** Expose methods to parent component */ @@ -115,8 +119,12 @@ const ModelConfigModal = forwardRef( })); useEffect(() => { - form.setFieldsValue({...(data?.model_parameters || {})}) + const { deep_thinking: _, ...rest } = data?.model_parameters || {} + form.setFieldsValue(rest) }, [values?.default_model_config_id]) + + + console.log('handleChange values', values) return ( ( /> } - {source === 'model' &&