Merge branch 'develop' into feature/ui_upgrade_zy

This commit is contained in:
zhaoying
2026-03-23 15:06:53 +08:00
71 changed files with 1766 additions and 763 deletions

2
.gitignore vendored
View File

@@ -25,6 +25,8 @@ examples/
time.log
celerybeat-schedule.db
search_results.json
redbear-mem-metrics/
pitch-deck/
api/migrations/versions
tmp

View File

@@ -13,6 +13,7 @@ from . import (
document_controller,
emotion_config_controller,
emotion_controller,
end_user_controller,
file_controller,
file_storage_controller,
home_page_controller,
@@ -96,5 +97,6 @@ manager_router.include_router(file_storage_controller.router)
manager_router.include_router(ontology_controller.router)
manager_router.include_router(skill_controller.router)
manager_router.include_router(i18n_controller.router)
manager_router.include_router(end_user_controller.router)
__all__ = ["manager_router"]

View File

@@ -0,0 +1,48 @@
"""End User 管理接口 - 无需认证"""
from app.core.logging_config import get_business_logger
from app.core.response_utils import success
from app.db import get_db
from app.repositories.end_user_repository import EndUserRepository
from app.schemas.memory_api_schema import (
CreateEndUserRequest,
CreateEndUserResponse,
)
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
router = APIRouter(prefix="/end_users", tags=["End Users"])
logger = get_business_logger()
@router.post("")
async def create_end_user(
data: CreateEndUserRequest,
db: Session = Depends(get_db),
):
"""
Create an end user.
Creates a new end user for the given workspace.
If an end user with the same other_id already exists in the workspace,
returns the existing one.
"""
logger.info(f"Create end user request - other_id: {data.other_id}, workspace_id: {data.workspace_id}")
end_user_repo = EndUserRepository(db)
end_user = end_user_repo.get_or_create_end_user(
app_id=None,
workspace_id=data.workspace_id,
other_id=data.other_id,
)
logger.info(f"End user ready: {end_user.id}")
result = {
"id": str(end_user.id),
"other_id": end_user.other_id or "",
"other_name": end_user.other_name or "",
"workspace_id": str(end_user.workspace_id),
}
return success(data=CreateEndUserResponse(**result).model_dump(), msg="End user created successfully")

View File

@@ -91,7 +91,7 @@ async def upload_file(
if file_size > settings.MAX_FILE_SIZE:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
status_code=status.HTTP_413_CONTENT_TOO_LARGE,
detail=f"The file size exceeds the {settings.MAX_FILE_SIZE} byte limit"
)
@@ -172,7 +172,6 @@ async def upload_file_with_share_token(
# Get share and release info from share_token
service = ReleaseShareService(db)
share_info = service.get_shared_release_info(share_token=share_data.share_token)
# Get share object to access app_id
share = service.repo.get_by_share_token(share_data.share_token)
@@ -499,6 +498,51 @@ async def get_file_url(
)
@router.get("/files/{file_id}/public-url", response_model=ApiResponse)
async def get_permanent_file_url(
file_id: uuid.UUID,
db: Session = Depends(get_db),
storage_service: FileStorageService = Depends(get_file_storage_service),
):
"""
获取文件的永久公开 URL无过期时间
- 本地存储:返回 API 永久访问地址(基于 FILE_LOCAL_SERVER_URL 配置)
- 远程存储OSS/S3返回 bucket 公读地址(需 bucket 已配置公共读权限)
"""
file_metadata = db.query(FileMetadata).filter(FileMetadata.id == file_id).first()
if not file_metadata:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="The file does not exist")
if file_metadata.status != "completed":
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"File upload not completed, status: {file_metadata.status}")
file_key = file_metadata.file_key
storage = storage_service.storage
try:
if isinstance(storage, LocalStorage):
url = f"{settings.FILE_LOCAL_SERVER_URL}/storage/permanent/{file_id}"
else:
url = await storage.get_permanent_url(file_key)
if not url:
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Permanent URL not supported for current storage backend")
api_logger.info(f"Generated permanent URL: file_id={file_id}")
return success(
data={"url": url, "expires_in": None, "permanent": True, "file_name": file_metadata.file_name},
msg="Permanent file URL generated successfully"
)
except HTTPException:
raise
except Exception as e:
api_logger.error(f"Failed to generate permanent URL: {e}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to generate permanent URL: {str(e)}")
@router.get("/public/{file_id}", response_model=Any)
async def public_download_file(
request: Request,

View File

@@ -195,10 +195,9 @@ async def get_workspace_end_users(
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
# 触发社区聚类补全任务(异步,不阻塞接口响应)
# 对有 ExtractedEntity 但无 Community 节点的存量用户自动补跑全量聚类
try:
from app.tasks import init_community_clustering_for_users
init_community_clustering_for_users.delay(end_user_ids=end_user_ids)
init_community_clustering_for_users.delay(end_user_ids=end_user_ids, workspace_id=str(workspace_id))
api_logger.info(f"已触发社区聚类补全任务,候选用户数: {len(end_user_ids)}")
except Exception as e:
api_logger.warning(f"触发社区聚类补全任务失败(不影响主流程): {str(e)}")

View File

@@ -33,35 +33,47 @@ def get_memory_count(
@router.get("/{end_user_id}/conversations", response_model=ApiResponse)
def get_conversations(
end_user_id: uuid.UUID,
page: int = 1,
pagesize: int = 20,
current_user: User = Depends(get_current_user),
db: Session = Depends(get_db)
):
"""
Retrieve all conversations for the current user in a specific group.
Retrieve conversations for the current user in a specific group with pagination.
Args:
end_user_id (UUID): The group identifier.
page (int): Page number (1-based). Defaults to 1.
pagesize (int): Number of items per page. Defaults to 20.
current_user (User, optional): The authenticated user.
db (Session, optional): SQLAlchemy session.
Returns:
ApiResponse: Contains a list of conversation IDs.
Notes:
- Initializes the ConversationService with the current DB session.
- Returns only conversation IDs for lightweight response.
- Logs can be added to trace requests in production.
ApiResponse: Contains a paginated list of conversations.
"""
page = max(1, page)
page_size = max(1, min(pagesize, 100)) # Limit page size between 1 and 100
conversation_service = ConversationService(db)
conversations = conversation_service.get_user_conversations(
end_user_id
conversations, total = conversation_service.get_user_conversations(
end_user_id,
page=page,
page_size=page_size
)
return success(data=[
{
"id": conversation.id,
"title": conversation.title
} for conversation in conversations
], msg="get conversations success")
return success(data={
"items": [
{
"id": conversation.id,
"title": conversation.title
} for conversation in conversations
],
"total": total,
"page": {
"page": page,
"pagesize": page_size,
"total": total,
"hasnext": (page * page_size) < total
},
}, msg="get conversations success")
@router.get("/{end_user_id}/messages", response_model=ApiResponse)

View File

@@ -6,6 +6,7 @@ from app.core.response_utils import success
from app.db import get_db
from app.schemas.api_key_schema import ApiKeyAuth
from app.schemas.memory_api_schema import (
ListConfigsResponse,
MemoryReadRequest,
MemoryReadResponse,
MemoryWriteRequest,
@@ -31,14 +32,15 @@ async def write_memory_api_service(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
payload: MemoryWriteRequest = Body(..., embed=False),
message: str = Body(..., description="Message content"),
):
"""
Write memory to storage.
Stores memory content for the specified end user using the Memory API Service.
"""
body = await request.json()
payload = MemoryWriteRequest(**body)
logger.info(f"Memory write request - end_user_id: {payload.end_user_id}, workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
@@ -62,13 +64,15 @@ async def read_memory_api_service(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
payload: MemoryReadRequest = Body(..., embed=False),
message: str = Body(..., description="Query message"),
):
"""
Read memory from storage.
Queries and retrieves memories for the specified end user with context-aware responses.
"""
body = await request.json()
payload = MemoryReadRequest(**body)
logger.info(f"Memory read request - end_user_id: {payload.end_user_id}")
memory_api_service = MemoryAPIService(db)
@@ -85,3 +89,27 @@ async def read_memory_api_service(
logger.info(f"Memory read successful for end_user: {payload.end_user_id}")
return success(data=MemoryReadResponse(**result).model_dump(), msg="Memory read successfully")
@router.get("/configs")
@require_api_key(scopes=["memory"])
async def list_memory_configs(
request: Request,
api_key_auth: ApiKeyAuth = None,
db: Session = Depends(get_db),
):
"""
List all memory configs for the workspace.
Returns all available memory configurations associated with the authorized workspace.
"""
logger.info(f"List configs request - workspace_id: {api_key_auth.workspace_id}")
memory_api_service = MemoryAPIService(db)
result = memory_api_service.list_memory_configs(
workspace_id=api_key_auth.workspace_id,
)
logger.info(f"Listed {result['total']} configs for workspace: {api_key_auth.workspace_id}")
return success(data=ListConfigsResponse(**result).model_dump(), msg="Configs listed successfully")

View File

@@ -76,6 +76,8 @@ async def get_tool_methods(
if methods is None:
raise HTTPException(status_code=404, detail="工具不存在")
return success(data=methods, msg="获取工具方法成功")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -121,6 +123,8 @@ async def create_tool(
raise HTTPException(status_code=400, detail=e.message)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -149,6 +153,8 @@ async def update_tool(
return success(msg="工具更新成功")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -191,6 +197,8 @@ async def set_tool_active(
return success(msg=f"工具已{action}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@@ -223,6 +231,8 @@ async def execute_tool(
},
msg="工具执行完成"
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -97,6 +97,7 @@ class Settings:
# File Upload
MAX_FILE_SIZE: int = int(os.getenv("MAX_FILE_SIZE", "52428800"))
MAX_FILE_COUNT: int = int(os.getenv("MAX_FILE_COUNT", "20"))
FILE_PATH: str = os.getenv("FILE_PATH", "/files")
FILE_URL_EXPIRES: int = int(os.getenv("FILE_URL_EXPIRES", "3600"))

View File

@@ -166,15 +166,12 @@ async def write(
statement_entity_edges=all_statement_entity_edges,
entity_edges=all_entity_entity_edges,
connector=neo4j_connector,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
)
if success:
logger.info("Successfully saved all data to Neo4j")
# 写入成功后,异步触发聚类(不阻塞写入响应)
schedule_clustering_after_write(
all_entity_nodes,
config_id=config_id,
llm_model_id=str(memory_config.llm_model_id) if memory_config.llm_model_id else None,
embedding_model_id=str(memory_config.embedding_model_id) if memory_config.embedding_model_id else None,
)

View File

@@ -69,15 +69,15 @@ class LabelPropagationEngine:
def __init__(
self,
connector: Neo4jConnector,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
):
self.connector = connector
self.repo = CommunityRepository(connector)
self.config_id = config_id
self.llm_model_id = llm_model_id
self.embedding_model_id = embedding_model_id
self.embedding_model_id = embedding_model_id
# ──────────────────────────────────────────────────────────────────────────
# 公开接口
@@ -439,15 +439,17 @@ class LabelPropagationEngine:
@staticmethod
def _build_entity_lines(members: List[Dict]) -> List[str]:
"""将实体列表格式化为 prompt 行,包含 name、aliases、description。"""
"""将实体列表格式化为 prompt 行,包含 name、aliases、description、example"""
lines = []
for m in members:
m_name = m.get("name", "")
aliases = m.get("aliases") or []
description = m.get("description") or ""
example = m.get("example") or ""
aliases_str = f"(别名:{''.join(aliases)}" if aliases else ""
desc_str = f"{description}" if description else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}")
example_str = f"(示例:{example}" if example else ""
lines.append(f"- {m_name}{aliases_str}{desc_str}{example_str}")
return lines
async def _generate_community_metadata(
@@ -481,11 +483,24 @@ class LabelPropagationEngine:
core_entities = [m["name"] for m in sorted_members[:CORE_ENTITY_LIMIT] if m.get("name")]
entity_list_str = "\n".join(self._build_entity_lines(members))
# 方案四:注入社区内实体间关系三元组
relationships = await self.repo.get_community_relationships(cid, end_user_id)
rel_lines = [
f"- {r['subject']}{r['predicate']}{r['object']}"
for r in relationships
if r.get("subject") and r.get("predicate") and r.get("object")
]
rel_section = (
f"\n实体间关系:\n" + "\n".join(rel_lines)
if rel_lines else ""
)
prompt = (
f"以下是一组语义相关的实体:\n{entity_list_str}\n\n"
f"以下是一组语义相关的实体:\n{entity_list_str}{rel_section}\n\n"
f"请为这组实体所代表的主题:\n"
f"1. 起一个简洁的中文名称不超过10个字\n"
f"2. 写一句话摘要(不超过50个字\n\n"
f"2. 写一句话摘要(不超过80个字\n\n"
f"严格按以下格式输出,不要有其他内容:\n"
f"名称:<名称>\n摘要:<摘要>"
)

View File

@@ -121,3 +121,18 @@ class StorageBackend(ABC):
URL for accessing the file.
"""
pass
async def get_permanent_url(self, file_key: str) -> Optional[str]:
"""
Get a permanent public URL for the file (no expiration).
Returns None by default; remote storage backends should override this
if the bucket is configured for public read access.
Args:
file_key: Unique identifier for the file in the storage system.
Returns:
A permanent public URL, or None if not supported.
"""
return None

View File

@@ -261,3 +261,13 @@ class OSSStorage(StorageBackend):
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.{self.endpoint.replace('https://', '').replace('http://', '')}/{file_key}"
async def get_permanent_url(self, file_key: str) -> str:
"""
Get a permanent public URL for the file (requires bucket public read).
Returns:
A permanent URL in the format: https://{bucket}.{endpoint}/{file_key}
"""
host = self.endpoint.replace("https://", "").replace("http://", "")
return f"https://{self.bucket_name}.{host}/{file_key}"

View File

@@ -378,3 +378,12 @@ class S3Storage(StorageBackend):
logger.error(f"Failed to generate presigned URL for {file_key}: {e}")
# Return a basic URL format as fallback
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"
async def get_permanent_url(self, file_key: str) -> str:
"""
Get a permanent public URL for the file (requires bucket public read).
Returns:
A permanent URL in the format: https://{bucket}.s3.{region}.amazonaws.com/{file_key}
"""
return f"https://{self.bucket_name}.s3.{self.region}.amazonaws.com/{file_key}"

View File

@@ -20,9 +20,21 @@ from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes import NodeFactory
from app.core.workflow.nodes.enums import NodeType, BRANCH_NODES
from app.core.workflow.utils.expression_evaluator import evaluate_condition
from app.core.workflow.validator import WorkflowValidator
logger = logging.getLogger(__name__)
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
_OUTPUT_PATTERN = re.compile(r'\{\{.*?}}|[^{}]+')
# Strict variable format: {{ node_id.field_name }}
_VARIABLE_PATTERN = re.compile(r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*}}')
class GraphBuilder:
def __init__(
@@ -37,13 +49,13 @@ class GraphBuilder:
self.stream = stream
self.subgraph = subgraph
self.start_node_id = None
self.end_node_ids = []
self.start_node_id: str | None = None
self.node_map = {node["id"]: node for node in self.nodes}
self.end_node_map: dict[str, StreamOutputConfig] = {}
self._find_upstream_branch_node = lru_cache(
self._find_upstream_activation_dep = lru_cache(
maxsize=len(self.nodes) * 2
)(self._find_upstream_branch_node)
)(self._find_upstream_activation_dep)
if variable_pool:
self.variable_pool = variable_pool
else:
@@ -51,10 +63,19 @@ class GraphBuilder:
self.graph = StateGraph(WorkflowState)
self.add_nodes()
self.reachable_nodes = WorkflowValidator.get_reachable_nodes(self.start_node_id, self.edges)
self.end_nodes = [
node
for node in self.nodes
if node.get("type") == "end" and node.get("id") in self.reachable_nodes
]
self.add_edges()
self._analyze_end_node_output()
# EDGES MUST BE ADDED AFTER NODES ARE ADDED.
self._reverse_adj: dict[str, list[dict]] = defaultdict(list)
self._build_reverse_adj()
self._analyze_end_node_output()
@property
def nodes(self) -> list[dict[str, Any]]:
return self.workflow_config.get("nodes", [])
@@ -87,60 +108,50 @@ class GraphBuilder:
result[node[0]].append(node[1])
return result
def _find_upstream_branch_node(self, target_node: str) -> tuple[bool, tuple[tuple[str, str]]]:
"""
Recursively find all upstream branch (control) nodes that influence the execution
of the given target node.
def _build_reverse_adj(self):
for edge in self.edges:
if edge["source"] not in self.reachable_nodes:
continue
self._reverse_adj[edge.get("target")].append({
"id": edge["source"], "branch": edge.get("label")
})
This method walks upstream along the workflow graph starting from `target_node`.
It distinguishes between:
- branch nodes (node types listed in `BRANCH_NODES`)
- non-branch nodes (ordinary processing nodes)
def _find_upstream_activation_dep(
self,
target_node: str
) -> tuple[tuple[tuple[str, str]], tuple[str]]:
"""Find upstream dependencies that affect the activation of a target node.
Traversal rules:
1. For each immediate upstream node:
- If it is a branch node, it is recorded as an affecting control node.
- If it is a non-branch node, the traversal continues recursively upstream.
2. If ANY upstream path reaches a START / CYCLE_START node without encountering
a branch node, the traversal is considered invalid:
- `has_branch` will be False
- no branch nodes are returned.
3. Only when ALL upstream non-branch paths eventually lead to at least one
branch node will `has_branch` be True.
Walks upstream along the workflow graph from the target node, collecting
two types of dependencies:
- Branch control nodes: upstream branch nodes (e.g. if-else) whose
routing outcome determines whether the target node executes.
- Output nodes: upstream END nodes that must complete their output
before the target node can activate.
Special case:
- If `target_node` has no upstream nodes AND its type is START or CYCLE_START,
it is considered directly reachable from the workflow entry, and therefore
has no controlling branch nodes.
The traversal terminates early and returns empty tuples if any upstream
path reaches START/CYCLE_START without encountering a branch or output
node, indicating the target node is directly reachable and should be
activated immediately.
Args:
target_node (str):
The identifier of the node whose upstream control branches
are to be resolved.
target_node: The ID of the node whose upstream activation
dependencies are to be resolved.
Returns:
tuple[bool, tuple[tuple[str, str]]]:
- has_branch (bool):
True if every upstream path from `target_node` encounters
at least one branch node.
False if any path reaches a start node without a branch.
- branch_nodes (tuple[tuple[str, str]]):
A deduplicated tuple of `(branch_node_id, branch_label)` pairs
representing all branch nodes that can influence `target_node`.
Returns an empty tuple if `has_branch` is False.
A tuple of two elements:
- A deduplicated tuple of (branch_node_id, branch_label) pairs
representing upstream branch control dependencies. Empty if
any clean path to START exists.
- A deduplicated tuple of upstream output node IDs that must
complete before this node activates.
"""
source_nodes = [
{
"id": edge.get("source"),
"branch": edge.get("label")
}
for edge in self.edges
if edge.get("target") == target_node
]
source_nodes = self._reverse_adj[target_node]
if not source_nodes and self.get_node_type(target_node) in [NodeType.START, NodeType.CYCLE_START]:
return False, tuple()
return tuple(), tuple()
branch_nodes = []
output_nodes = []
non_branch_nodes = []
for node_info in source_nodes:
@@ -149,19 +160,23 @@ class GraphBuilder:
(node_info["id"], node_info["branch"])
)
else:
if self.get_node_type(node_info["id"]) == NodeType.END:
output_nodes.append(node_info["id"])
non_branch_nodes.append(node_info["id"])
has_branch = True
for node_id in non_branch_nodes:
node_has_branch, nodes = self._find_upstream_branch_node(node_id)
has_branch = has_branch and node_has_branch
if not has_branch:
break
branch_nodes.extend(nodes)
if not has_branch:
branch_nodes = []
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(node_id)
if not upstream_control_nodes:
if not upstream_output_nodes and node_id not in output_nodes:
return tuple(), tuple()
branch_nodes = []
has_branch = False
if has_branch:
branch_nodes.extend(upstream_control_nodes)
output_nodes.extend(upstream_output_nodes)
return has_branch, tuple(set(branch_nodes))
return tuple(set(branch_nodes)), tuple(set(output_nodes))
def _analyze_end_node_output(self):
"""
@@ -182,11 +197,10 @@ class GraphBuilder:
"""
# Collect all End nodes in the workflow
end_nodes = [node for node in self.nodes if node.get("type") == "end"]
logger.info(f"[Prefix Analysis] Found {len(end_nodes)} End nodes")
logger.info(f"[Prefix Analysis] Found {len(self.end_nodes)} End nodes")
# Iterate through each End node to analyze its output
for end_node in end_nodes:
for end_node in self.end_nodes:
end_node_id = end_node.get("id")
config = end_node.get("config", {})
output = config.get("output")
@@ -195,42 +209,33 @@ class GraphBuilder:
if not output:
continue
# Regex to split output into:
# - variable placeholders: {{ ... }}
# - normal literal text
#
# Example:
# "Hello {{user.name}}!" ->
# ["Hello ", "{{user.name}}", "!"]
pattern = r'\{\{.*?\}\}|[^{}]+'
# Strict variable format: {{ node_id.field_name }}
variable_pattern_string = r'\{\{\s*[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\s*\}\}'
variable_pattern = re.compile(variable_pattern_string)
# Split output into ordered segments
output_template = list(re.findall(pattern, output))
output_template = list(_OUTPUT_PATTERN.findall(output))
# Determine whether each segment is literal text
# True -> literal (can be directly output)
# False -> variable placeholder (needs runtime value)
output_flag = [
not bool(variable_pattern.match(item))
not bool(_VARIABLE_PATTERN.match(item))
for item in output_template
]
# Stream mode: output activation depends on upstream branch nodes
if self.stream:
# Find upstream branch nodes that can control this End node
has_branch, control_nodes = self._find_upstream_branch_node(end_node_id)
upstream_control_nodes, upstream_output_nodes = self._find_upstream_activation_dep(end_node_id)
activate = not bool(upstream_control_nodes) and not bool(upstream_output_nodes)
# Build StreamOutputConfig for this End node
self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
# If there is no upstream branch, output is active immediately
activate=not has_branch,
activate=activate,
# Branch nodes that control activation of this End node
control_nodes=self._merge_control_nodes(control_nodes),
control_nodes=self._merge_control_nodes(upstream_control_nodes),
upstream_output_nodes=list(upstream_output_nodes),
control_resolved=not bool(upstream_control_nodes),
output_resolved=not bool(upstream_output_nodes),
# Convert output segments into OutputContent objects
outputs=list(
@@ -249,14 +254,16 @@ class GraphBuilder:
cursor=0
)
logger.info(f"[Stream Analysis] end_id: {end_node_id}, "
f"activate: {not has_branch}, "
f"control_nodes: {control_nodes},"
f"activate: {activate}, "
f"control_nodes: {upstream_control_nodes},"
f"ref_outputs: {upstream_output_nodes},"
f"output: {output_template},"
f"output_activate: {output_flag}")
# Non-stream mode: all outputs are activated by default
else:
self.end_node_map[end_node_id] = StreamOutputConfig(
id=end_node_id,
activate=True,
control_nodes={},
outputs=list(
@@ -269,7 +276,10 @@ class GraphBuilder:
for output_string, activate in zip(output_template, output_flag)
]
),
cursor=0
cursor=0,
upstream_output_nodes=[],
control_resolved=True,
output_resolved=True,
)
def add_nodes(self):
@@ -304,8 +314,6 @@ class GraphBuilder:
# Record start and end node IDs
if node_type in [NodeType.START, NodeType.CYCLE_START]:
self.start_node_id = node_id
elif node_type == NodeType.END:
self.end_node_ids.append(node_id)
# Create node instance (start and end nodes are also created)
# NOTE:Loop node creation automatically removes the nodes and edges of the subgraph from the current graph
@@ -448,7 +456,7 @@ class GraphBuilder:
branch_activate = []
new_state = state.copy()
new_state["activate"] = dict(state.get("activate", {})) # deep copy of activate
node_output = variable_pool.get_node_output(src, defalut=dict(), strict=False)
node_output = variable_pool.get_node_output(src, default=dict(), strict=False)
for label, branch in unique_branch.items():
if node_output and evaluate_condition(
branch["condition"],
@@ -494,9 +502,11 @@ class GraphBuilder:
logger.debug(f"Added waiting edge: {sources} -> {target}")
# Connect End nodes to the global END node
for end_node_id in self.end_node_ids:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
for end_node in self.end_nodes:
end_node_id = end_node.get("id")
if end_node_id:
self.graph.add_edge(end_node_id, END)
logger.debug(f"Added edge: {end_node_id} -> END")
return
def build(self) -> CompiledStateGraph:

View File

@@ -12,6 +12,7 @@ class WorkflowResultBuilder:
variable_pool: VariablePool,
elapsed_time: float,
final_output: str,
success: bool
):
"""Construct the final standardized output of the workflow execution.
@@ -29,6 +30,7 @@ class WorkflowResultBuilder:
elapsed_time (float): Total execution time in seconds.
final_output (Any): The aggregated or final output content of the workflow
(e.g., combined messages from all End nodes).
success (bool): Whether the execution was successful.
Returns:
dict: A dictionary containing the final workflow execution result with keys:
@@ -49,7 +51,7 @@ class WorkflowResultBuilder:
conversation_id = variable_pool.get_value("sys.conversation_id")
return {
"status": "completed",
"status": "completed" if success else "failed",
"output": final_output,
"variables": {
"conv": variable_pool.get_all_conversation_vars(),

View File

@@ -3,6 +3,7 @@
# @Email: 1533512157@qq.com
# @Time : 2026/2/9 15:11
import re
from queue import Queue
from typing import AsyncGenerator
from pydantic import BaseModel, Field, PrivateAttr
@@ -37,8 +38,8 @@ class OutputContent(BaseModel):
activate: bool = Field(
...,
description=(
"Whether this output segment is currently active.\n"
"- True: allowed to be emitted/output\n"
"Whether this output segment is currently active."
"- True: allowed to be emitted/output"
"- False: blocked until activated by branch control"
)
)
@@ -46,8 +47,8 @@ class OutputContent(BaseModel):
is_variable: bool = Field(
...,
description=(
"Whether this segment represents a variable placeholder.\n"
"True -> variable (e.g. {{ node.field }})\n"
"Whether this segment represents a variable placeholder."
"True -> variable (e.g. {{ node.field }})"
"False -> literal text"
)
)
@@ -86,12 +87,16 @@ class StreamOutputConfig(BaseModel):
- which upstream branch/control nodes gate the activation
- how each parsed output segment is streamed and activated
"""
id: str = Field(
...,
description="ID of the End node this configuration belongs to."
)
activate: bool = Field(
...,
description=(
"Global activation flag for the End node output.\n"
"When False, output segments should not be emitted even if available.\n"
"Global activation flag for the End node output."
"When False, output segments should not be emitted even if available."
"This flag typically becomes True once required control branch conditions "
"are satisfied."
)
@@ -100,17 +105,46 @@ class StreamOutputConfig(BaseModel):
control_nodes: dict[str, list[str]] = Field(
...,
description=(
"Control branch conditions for this End node output.\n"
"Mapping of `branch_node_id -> expected_branch_label`.\n"
"Control branch conditions for this End node output."
"Mapping of `branch_node_id -> expected_branch_label`."
"The End node output becomes globally active when a controlling branch node "
"reports a matching completion status."
)
)
upstream_output_nodes: list[str] = Field(
...,
description=(
"Upstream output node dependencies (data flow)."
"Represents END/output nodes that this output depends on."
"These nodes provide data sources required before this output can be activated "
"or streamed."
"Used to ensure correct ordering and dependency resolution in streaming mode."
)
)
control_resolved: bool = Field(
...,
description=(
"Whether all upstream branch control dependencies have been satisfied."
"True if no upstream branch nodes exist or the required branch "
"conditions have been met."
)
)
output_resolved: bool = Field(
...,
description=(
"Whether all upstream output node dependencies have been completed."
"True if no upstream output nodes exist or all upstream output "
"nodes have finished their output."
)
)
outputs: list[OutputContent] = Field(
...,
description=(
"Ordered list of output segments parsed from the output template.\n"
"Ordered list of output segments parsed from the output template."
"Each segment represents either a literal text block or a variable placeholder "
"that may be activated independently."
)
@@ -119,49 +153,97 @@ class StreamOutputConfig(BaseModel):
cursor: int = Field(
...,
description=(
"Streaming cursor index.\n"
"Indicates the next output segment index to be emitted.\n"
"Streaming cursor index."
"Indicates the next output segment index to be emitted."
"Segments with index < cursor are considered already streamed."
)
)
force: bool = Field(
default=False,
description=(
"Force flag for output emission."
"When True, all output segments are emitted regardless of activation state."
"Triggered when this output node has finished execution."
)
)
def update_activate(self, scope: str, status=None):
"""
Update streaming activation state based on an upstream node or special variable.
Update streaming activation state based on upstream events.
Args:
scope (str):
Identifier of the completed upstream entity.
- If a control branch node, it should match a key in `control_nodes`.
- If a variable placeholder (e.g., "sys.xxx"), it may appear in output segments.
- If an upstream output node, it should match an entry in `upstream_output_nodes`.
- If a variable placeholder (e.g., "sys.xxx" or "node_id.field"),
it may appear in output segments.
status (optional):
Completion status of the control branch node.
Required when `scope` refers to a control node.
Behavior:
1. Control branch nodes:
- If `scope` matches a key in `control_nodes` and `status` matches the expected
branch label, the End node output becomes globally active (`activate = True`).
1. Force activation:
- If `self.force` is True, the method returns immediately.
- If `scope == self.id`, the node marks itself as completed:
- `activate = True`
- `force = True`
This is typically used for final flushing when the node finishes execution.
2. Variable output segments:
- For each segment that is a variable (`is_variable=True`):
- If the segment literal references `scope`, mark the segment as active.
- This applies both to regular node variables (e.g., "node_id.field")
and special system variables (e.g., "sys.xxx").
2. Control dependency resolution:
- If `scope` matches a key in `control_nodes`:
- `status` must be provided.
- If `status` matches expected branch labels, mark control as resolved
(`control_resolved = True`).
3. Upstream output dependency resolution:
- If `scope` is in `upstream_output_nodes`,
mark data dependency as resolved (`output_resolved = True`).
4. Global activation condition:
- The node becomes active when BOTH conditions are satisfied:
- control_resolved == True
- output_resolved == True
- Once activated, `activate` remains True.
5. Variable segment activation:
- For each output segment that is a variable (`is_variable=True`):
- If the segment depends on the given `scope`,
mark the segment as active.
- This applies to both node variables (e.g., "node_id.field")
and system variables (e.g., "sys.xxx").
Notes:
- This method does not emit output or advance the streaming cursor.
- It only updates activation flags based on upstream events or special variables.
- This method does NOT emit output or advance the streaming cursor.
- It only updates activation and dependency resolution states.
- Activation is driven by both control flow (branch nodes) and
data flow (upstream output nodes).
"""
if self.force:
return
# Case 1: resolve control branch dependency
if scope == self.id:
self.activate = True
self.force = True
return
# resolve control branch dependency
if scope in self.control_nodes:
if status is None:
raise RuntimeError("[Stream Output] Control node activation status not provided")
if status in self.control_nodes[scope]:
self.activate = True
self.control_resolved = True
# Case 2: activate variable segments related to this node
if scope in self.upstream_output_nodes:
self.upstream_output_nodes.remove(scope)
if not self.upstream_output_nodes:
self.output_resolved = True
self.activate = self.activate or (self.control_resolved and self.output_resolved)
# activate variable segments related to this node
for i in range(len(self.outputs)):
if (
self.outputs[i].is_variable
@@ -174,12 +256,17 @@ class StreamOutputCoordinator:
def __init__(self):
self.end_outputs: dict[str, StreamOutputConfig] = {}
self.activate_end: str | None = None
self.output_queue: Queue = Queue()
self.processed_outputs = []
def initialize_end_outputs(
self,
end_node_map: dict[str, StreamOutputConfig]
):
self.end_outputs = end_node_map
self.processed_outputs = []
self.activate_end = None
self.output_queue = Queue()
@property
def current_activate_end_info(self):
@@ -211,8 +298,11 @@ class StreamOutputCoordinator:
"""
for node in self.end_outputs.keys():
self.end_outputs[node].update_activate(scope, status)
if self.end_outputs[node].activate and self.activate_end is None:
self.activate_end = node
if self.end_outputs[node].activate and node not in self.processed_outputs:
self.output_queue.put(node)
self.processed_outputs.append(node)
if self.activate_end is None and not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
async def emit_activate_chunk(
self,
@@ -256,7 +346,7 @@ class StreamOutputCoordinator:
final_chunk = ''
current_segment = end_info.outputs[end_info.cursor]
if not current_segment.activate and not force:
if not current_segment.activate and not force and not end_info.force:
# Stop processing until this segment becomes active
break
@@ -273,7 +363,7 @@ class StreamOutputCoordinator:
logger.warning(f"[STREAM] Failed to evaluate segment: {current_segment.literal}, error: {e}")
if final_chunk:
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk:{final_chunk}")
logger.info(f"[STREAM] StreamOutput Node:{self.activate_end}, chunk_length:{len(final_chunk)}")
yield {
"event": "message",
"data": {
@@ -285,8 +375,7 @@ class StreamOutputCoordinator:
end_info.cursor += 1
if end_info.cursor >= len(end_info.outputs):
self.end_outputs.pop(self.activate_end)
self.activate_end = None
self.pop_current_activate_end()
async def flush_remaining_chunk(
self,
@@ -325,6 +414,8 @@ class StreamOutputCoordinator:
async for msg_event in self.emit_activate_chunk(variable_pool, force=True):
yield msg_event
if not self.output_queue.empty():
self.activate_end = self.output_queue.get_nowait()
# Move to next active End node if current one is done
if not self.activate_end and self.end_outputs:
self.activate_end = list(self.end_outputs.keys())[0]

View File

@@ -351,12 +351,12 @@ class VariablePool:
}
return runtime_vars
def get_node_output(self, node_id: str, defalut: Any = None, strict: bool = True) -> dict[str, Any] | None:
def get_node_output(self, node_id: str, default: Any = None, strict: bool = True) -> dict[str, Any] | None:
"""获取指定节点的输出(运行时变量)
Args:
node_id: 节点 ID
defalut: 默认值
default: 默认值
strict: 是否严格模式
Returns:
@@ -368,7 +368,7 @@ class VariablePool:
if strict:
raise KeyError(f"node {node_id} output not exist")
else:
return defalut
return default
def copy(self, pool: 'VariablePool'):
self.variables = deepcopy(pool.variables)

View File

@@ -128,89 +128,100 @@ class WorkflowExecutor:
- token_usage: aggregated token usage if available
- error: error message if any
"""
logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
start_time = datetime.datetime.now()
# Execute the workflow
try:
# Build the workflow graph
graph = self.build_graph()
# Initialize the variable pool with input data
await self.variable_initializer.initialize(
variable_pool=self.variable_pool,
input_data=input_data,
execution_context=self.execution_context
)
initial_state = self.state_manager.create_initial_state(
workflow_config=self.workflow_config,
input_data=input_data,
execution_context=self.execution_context,
start_node_id=self.start_node_id
)
result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
# Aggregate output from all End nodes
full_content = ''
for end_id in self.stream_coordinator.end_outputs.keys():
full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
# Append messages for user and assistant
if input_data.get("files"):
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "user",
"content": input_data.get("files")
},
{
"role": "assistant",
"content": full_content
}
]
)
else:
result["messages"].extend(
[
{
"role": "user",
"content": input_data.get("message", '')
},
{
"role": "assistant",
"content": full_content
}
]
)
# Calculate elapsed time
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.info(
f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
except Exception as e:
end_time = datetime.datetime.now()
elapsed_time = (end_time - start_time).total_seconds()
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True)
return {
"status": "failed",
"error": str(e),
"output": None,
"node_outputs": {},
"elapsed_time": elapsed_time,
"token_usage": None
}
start = datetime.datetime.now()
async for event in self.execute_stream(input_data):
if event.get("event") == "workflow_end":
return event.get("data")
return self.result_builder.build_final_output(
{"error": "Workflow execution did not end as expected"},
self.variable_pool,
(datetime.datetime.now() - start).total_seconds(),
"",
success=False
)
# logger.info(f"Starting workflow execution: execution_id={self.execution_context.execution_id}")
#
# start_time = datetime.datetime.now()
#
# # Execute the workflow
# try:
# # Build the workflow graph
# graph = self.build_graph()
#
# # Initialize the variable pool with input data
# await self.variable_initializer.initialize(
# variable_pool=self.variable_pool,
# input_data=input_data,
# execution_context=self.execution_context
# )
# initial_state = self.state_manager.create_initial_state(
# workflow_config=self.workflow_config,
# input_data=input_data,
# execution_context=self.execution_context,
# start_node_id=self.start_node_id
# )
#
# result = await graph.ainvoke(initial_state, config=self.execution_context.checkpoint_config)
#
# # Aggregate output from all End nodes
# full_content = ''
# for end_id in self.stream_coordinator.end_outputs.keys():
# full_content += self.variable_pool.get_value(f"{end_id}.output", default="", strict=False)
#
# # Append messages for user and assistant
# if input_data.get("files"):
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "user",
# "content": input_data.get("files")
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# else:
# result["messages"].extend(
# [
# {
# "role": "user",
# "content": input_data.get("message", '')
# },
# {
# "role": "assistant",
# "content": full_content
# }
# ]
# )
# # Calculate elapsed time
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.info(
# f"Workflow execution completed: execution_id={self.execution_context.execution_id}, elapsed_time={elapsed_time:.2f}ms")
#
# return self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
#
# except Exception as e:
# end_time = datetime.datetime.now()
# elapsed_time = (end_time - start_time).total_seconds()
#
# logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
# exc_info=True)
# return {
# "status": "failed",
# "error": str(e),
# "output": None,
# "node_outputs": {},
# "elapsed_time": elapsed_time,
# "token_usage": None
# }
async def execute_stream(
self,
@@ -248,7 +259,8 @@ class WorkflowExecutor:
"timestamp": int(start_time.timestamp() * 1000)
}
}
result = None
full_content = ''
try:
# Build the workflow graph in streaming mode
graph = self.build_graph(stream=True)
@@ -266,7 +278,6 @@ class WorkflowExecutor:
start_node_id=self.start_node_id
)
full_content = ''
self.stream_coordinator.update_scope_activation("sys")
# Execute the workflow with streaming
@@ -363,7 +374,12 @@ class WorkflowExecutor:
yield {
"event": "workflow_end",
"data": self.result_builder.build_final_output(result, self.variable_pool, elapsed_time, full_content)
"data": self.result_builder.build_final_output(
result,
self.variable_pool,
elapsed_time,
full_content,
success=True)
}
except Exception as e:
@@ -372,16 +388,19 @@ class WorkflowExecutor:
logger.error(f"Workflow execution failed: execution_id={self.execution_context.execution_id}, error={e}",
exc_info=True)
if result is None:
result = {"error": str(e)}
else:
result["error"] = str(e)
yield {
"event": "workflow_end",
"data": {
"execution_id": self.execution_context.execution_id,
"status": "failed",
"error": str(e),
"elapsed_time": elapsed_time,
"timestamp": end_time.isoformat()
}
"data": self.result_builder.build_final_output(
result,
self.variable_pool,
elapsed_time,
full_content,
success=False
)
}

View File

@@ -128,7 +128,7 @@ class CodeNode(BaseNode):
else:
raise ValueError(f"Unsupported language: {self.typed_config.language}")
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient(timeout=60) as client:
response = await client.post(
"http://sandbox:8194/v1/sandbox/run",
headers={

View File

@@ -51,7 +51,7 @@ class ConditionDetail(BaseModel):
)
right: Any = Field(
...,
default=None,
description="Right-hand operand of the comparison expression"
)

View File

@@ -158,7 +158,7 @@ class LoopRuntime:
self.variable_pool.variables["conv"].update(
self.child_variable_pool.variables["conv"]
)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, defalut={}, strict=False)
loop_vars = self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False)
loopstate["node_outputs"][self.node_id] = loop_vars
def evaluate_conditional(self) -> bool:
@@ -261,4 +261,4 @@ class LoopRuntime:
idx += 1
logger.info(f"loop node {self.node_id}: execution completed")
return self.child_variable_pool.get_node_output(self.node_id) | {"__child_state": child_state}
return self.child_variable_pool.get_node_output(self.node_id, default={}, strict=False) | {"__child_state": child_state}

View File

@@ -18,7 +18,7 @@ class ConditionDetail(BaseModel):
)
right: Any = Field(
...,
default=None,
description="Value to compare with"
)

View File

@@ -31,13 +31,13 @@ class IfElseNode(BaseNode):
expressions.append({
"left": self.get_variable(expression.left, variable_pool, strict=False),
"right": expression.right
if expression.input_type == ValueInputType.CONSTANT
if expression.input_type == ValueInputType.CONSTANT or expression.right is None
else self.get_variable(expression.right, variable_pool, strict=False),
"operator": expression.operator,
"operator": str(expression.operator),
})
result.append({
"expressions": expressions,
"logical_operator": case.logical_operator,
"logical_operator": str(case.logical_operator),
})
return {
"cases": result

View File

@@ -5,7 +5,7 @@ from typing import Any
from app.core.error_codes import BizCode
from app.core.exceptions import BusinessException
from app.core.models import RedBearRerank, RedBearModelConfig
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory
from app.core.rag.vdb.elasticsearch.elasticsearch_vector import ElasticSearchVectorFactory, ElasticSearchVector
from app.core.workflow.engine.state_manager import WorkflowState
from app.core.workflow.engine.variable_pool import VariablePool
from app.core.workflow.nodes.base_node import BaseNode
@@ -24,6 +24,7 @@ class KnowledgeRetrievalNode(BaseNode):
def __init__(self, node_config: dict[str, Any], workflow_config: dict[str, Any]):
super().__init__(node_config, workflow_config)
self.typed_config: KnowledgeRetrievalNodeConfig | None = None
self.vector_service: ElasticSearchVector | None = None
def _output_types(self) -> dict[str, VariableType]:
return {
@@ -163,6 +164,50 @@ class KnowledgeRetrievalNode(BaseNode):
)
return reranker
def knowledge_retrieval(self, db, query, rs, db_knowledge, kb_config):
if db_knowledge.type == knowledge_model.KnowledgeType.FOLDER:
children = knowledge_repository.get_knowledges_by_parent_id(db=db, parent_id=db_knowledge.id)
for child in children:
if not (child and child.chunk_num > 0 and child.status == 1):
continue
kb_config.kb_id = child.id
self.knowledge_retrieval(db, query, rs, child, kb_config)
return
self.vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = self.vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = self.vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hybrid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
return
if self.typed_config.reranker_id:
self.vector_service.reranker = self.get_reranker_model()
rs.extend(self.vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
else:
rs.extend(sorted(
unique_rs,
key=lambda d: d.metadata.get("score", 0),
reverse=True
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
async def execute(self, state: WorkflowState, variable_pool: VariablePool) -> Any:
"""
Execute the knowledge retrieval workflow node.
@@ -191,56 +236,19 @@ class KnowledgeRetrievalNode(BaseNode):
query = self._render_template(self.typed_config.query, variable_pool)
with get_db_read() as db:
knowledge_bases = self.typed_config.knowledge_bases
existing_ids = self._get_existing_kb_ids(db, [kb.kb_id for kb in knowledge_bases])
if not existing_ids:
raise RuntimeError("Knowledge base retrieval failed: the knowledge base does not exist.")
rs = []
for kb_config in knowledge_bases:
db_knowledge = knowledge_repository.get_knowledge_by_id(db=db, knowledge_id=kb_config.kb_id)
if not db_knowledge:
raise RuntimeError("The knowledge base does not exist or access is denied.")
self.knowledge_retrieval(db, query, rs, db_knowledge, kb_config)
vector_service = ElasticSearchVectorFactory().init_vector(knowledge=db_knowledge)
indices = f"Vector_index_{kb_config.kb_id}_Node".lower()
match kb_config.retrieve_type:
case RetrieveType.PARTICIPLE:
rs.extend(vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold))
case RetrieveType.SEMANTIC:
rs.extend(vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight))
case RetrieveType.HYBRID:
rs1 = vector_service.search_by_vector(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.vector_similarity_weight)
rs2 = vector_service.search_by_full_text(query=query, top_k=kb_config.top_k,
indices=indices,
score_threshold=kb_config.similarity_threshold)
# Deduplicate hy brid retrieval results
unique_rs = self._deduplicate_docs(rs1, rs2)
if not unique_rs:
continue
if self.typed_config.reranker_id:
vector_service.reranker = self.get_reranker_model()
rs.extend(vector_service.rerank(query=query, docs=unique_rs, top_k=kb_config.top_k))
else:
rs.extend(sorted(
unique_rs,
key=lambda d: d.metadata.get("score", 0),
reverse=True
)[:kb_config.top_k])
case _:
raise RuntimeError("Unknown retrieval type")
if not rs:
return []
if self.typed_config.reranker_id:
vector_service.reranker = self.get_reranker_model()
final_rs = vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
self.vector_service.reranker = self.get_reranker_model()
final_rs = self.vector_service.rerank(query=query, docs=rs, top_k=self.typed_config.reranker_top_k)
else:
final_rs = sorted(
rs,

View File

@@ -250,6 +250,8 @@ class ConditionBase(ABC):
self.type_limit = getattr(self, "type_limit", None)
def resolve_right_literal_value(self):
if self.right_selector is None:
return None
if self.input_type == ValueInputType.VARIABLE:
pattern = r"\{\{\s*(.*?)\s*\}\}"
right_expression = re.sub(pattern, r"\1", self.right_selector).strip()

View File

@@ -170,7 +170,7 @@ class WorkflowValidator:
# 仅在发布时验证所有节点可达
# 6. 验证所有节点可达(从 start 节点出发)
if start_nodes and not errors: # 只有在前面验证通过时才检查可达性
reachable = WorkflowValidator._get_reachable_nodes(
reachable = WorkflowValidator.get_reachable_nodes(
start_nodes[0]["id"],
edges
)
@@ -194,7 +194,7 @@ class WorkflowValidator:
return len(errors) == 0, errors
@staticmethod
def _get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
def get_reachable_nodes(start_id: str, edges: list[dict]) -> set[str]:
"""获取从 start 节点可达的所有节点
Args:

View File

@@ -2,7 +2,7 @@ from enum import StrEnum
from abc import abstractmethod, ABC
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from app.schemas import FileType
@@ -41,10 +41,10 @@ class VariableType(StrEnum):
"""
if isinstance(var, str):
return cls.STRING
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, bool):
return cls.BOOLEAN
elif isinstance(var, (int, float)):
return cls.NUMBER
elif isinstance(var, FileObject) or (isinstance(var, dict) and var.get('is_file')):
return cls.FILE
elif isinstance(var, dict):
@@ -116,7 +116,7 @@ class FileObject(BaseModel):
content_cache: dict = Field(default_factory=dict)
is_file: bool
_byte_content: bytes | None = None
_byte_content: bytes | None = PrivateAttr(default=None)
def get_content(self):
return self._byte_content

View File

@@ -10,6 +10,7 @@ T = TypeVar("T", bound=BaseVariable)
class StringVariable(BaseVariable):
value: str
type = 'str'
def valid_value(self, value) -> str:
@@ -22,6 +23,7 @@ class StringVariable(BaseVariable):
class NumberVariable(BaseVariable):
value: int | float
type = 'number'
def valid_value(self, value) -> int | float:
@@ -34,6 +36,7 @@ class NumberVariable(BaseVariable):
class BooleanVariable(BaseVariable):
value: bool
type = 'boolean'
def valid_value(self, value) -> bool:
@@ -46,6 +49,7 @@ class BooleanVariable(BaseVariable):
class DictVariable(BaseVariable):
value: dict
type = 'object'
def valid_value(self, value) -> dict:
@@ -58,6 +62,7 @@ class DictVariable(BaseVariable):
class FileVariable(BaseVariable):
value: FileObject
type = 'file'
def valid_value(self, value) -> FileObject:
@@ -102,6 +107,7 @@ class FileVariable(BaseVariable):
class ArrayVariable(BaseVariable, Generic[T]):
value: list[T]
type = 'array'
def __init__(self, child_type: Type[T], value: list[Any]):
@@ -129,6 +135,7 @@ class ArrayVariable(BaseVariable, Generic[T]):
class NestedArrayVariable(BaseVariable):
value: list[ArrayVariable]
type = 'array_nest'
def valid_value(self, value: list[T]) -> list[T]:
@@ -153,6 +160,7 @@ class NestedArrayVariable(BaseVariable):
category=RuntimeWarning
)
class AnyVariable(BaseVariable):
value: Any
type = 'any'
def valid_value(self, value: Any) -> Any:

View File

@@ -65,6 +65,7 @@ def get_db_read() -> Generator[Session, None, None]:
yield db
finally:
db.rollback() # 只读任务无需 commit
db.close()
def get_pool_status():

View File

@@ -506,10 +506,13 @@ async def http_exception_handler(request: Request, exc: HTTPException):
404: "errors.common.not_found",
405: "errors.common.method_not_allowed",
409: "errors.common.conflict",
413: "errors.common.payload_too_large",
422: "errors.common.validation_failed",
429: "errors.common.too_many_requests",
500: "errors.common.internal_error",
502: "errors.common.bad_gateway",
503: "errors.common.service_unavailable",
504: "errors.common.gateway_timeout",
}
# 如果有对应的翻译键,使用翻译
@@ -534,7 +537,7 @@ async def http_exception_handler(request: Request, exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=fail(code=exc.status_code, msg=translated_message, error=translated_message)
content=fail(code=exc.status_code, msg=translated_message, error=exc.detail)
)

View File

@@ -90,27 +90,27 @@ class ConversationRepository:
self,
user_id: uuid.UUID,
workspace_id: uuid.UUID = None,
limit: int = 10,
is_activate: bool = True
) -> list[Conversation]:
is_activate: bool = True,
page: int = 1,
page_size: int = 20
) -> tuple[list[Conversation], int]:
"""
Retrieve recent conversations for a specific user.
Retrieve recent conversations for a specific user with pagination.
This method queries conversations associated with the given user ID,
optionally scoped to a specific workspace. Results are ordered by the
most recently updated conversations and limited to a fixed number.
most recently updated conversations.
Args:
user_id (uuid.UUID): Unique identifier of the user.
workspace_id (uuid.UUID, optional): Workspace scope for the query.
If provided, only conversations under this workspace will be returned.
limit (int): Maximum number of conversations to return.
Defaults to 10.
is_activate (bool): Convsersation State limit
is_activate (bool): Conversation State limit.
page (int): Page number (1-based). Defaults to 1.
page_size (int): Number of items per page. Defaults to 20.
Returns:
list[Conversation]: A list of conversation entities ordered by
last updated time (descending).
tuple[list[Conversation], int]: A list of conversation entities and total count.
"""
logger.info(f"Fetching conversation by user_id: {user_id}")
@@ -122,18 +122,25 @@ class ConversationRepository:
if workspace_id:
stmt = stmt.where(Conversation.workspace_id == workspace_id)
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.limit(limit)
# Calculate total count
total = int(self.db.execute(
select(func.count()).select_from(stmt.subquery())
).scalar_one())
convsersations = list(self.db.scalars(stmt).all())
# Apply ordering and pagination
stmt = stmt.order_by(desc(Conversation.updated_at))
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
conversations = list(self.db.scalars(stmt).all())
logger.info(
"Conversation fetched successfully",
extra={
"user_id": str(user_id),
"workspace_id": str(workspace_id),
"total": total,
}
)
return convsersations
return conversations, total
def list_conversations(
self,

View File

@@ -17,12 +17,17 @@ from app.repositories.neo4j.cypher_queries import (
GET_ALL_ENTITY_IDS_FOR_USER,
GET_ENTITIES_PAGE,
GET_COMMUNITY_MEMBERS,
GET_COMMUNITY_RELATIONSHIPS,
GET_ALL_COMMUNITY_MEMBERS_BATCH,
GET_ALL_ENTITY_NEIGHBORS_BATCH,
GET_ENTITY_NEIGHBORS_BATCH_FOR_IDS,
CHECK_USER_HAS_COMMUNITIES,
UPDATE_COMMUNITY_MEMBER_COUNT,
UPDATE_COMMUNITY_METADATA,
GET_INCOMPLETE_COMMUNITIES,
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING,
CHECK_COMMUNITY_IS_COMPLETE,
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING,
BATCH_UPDATE_COMMUNITY_METADATA,
)
@@ -177,7 +182,7 @@ class CommunityRepository:
async def get_community_members(
self, community_id: str, end_user_id: str
) -> List[Dict]:
"""查询社区成员列表。"""
"""查询社区成员列表(含 example 字段)"""
try:
return await self.connector.execute_query(
GET_COMMUNITY_MEMBERS,
@@ -188,6 +193,20 @@ class CommunityRepository:
logger.error(f"get_community_members failed: {e}")
return []
async def get_community_relationships(
self, community_id: str, end_user_id: str
) -> List[Dict]:
"""查询社区内实体间的关系三元组subject, predicate, object"""
try:
return await self.connector.execute_query(
GET_COMMUNITY_RELATIONSHIPS,
community_id=community_id,
end_user_id=end_user_id,
)
except Exception as e:
logger.error(f"get_community_relationships failed: {e}")
return []
async def get_all_community_members_batch(
self, community_ids: List[str], end_user_id: str
) -> Dict[str, List[Dict]]:
@@ -234,6 +253,31 @@ class CommunityRepository:
logger.error(f"refresh_member_count failed: {e}")
return 0
async def get_incomplete_communities(self, end_user_id: str, check_embedding: bool = False) -> List[str]:
"""查询该用户下属性不完整的 Community 节点 ID 列表。
Args:
end_user_id: 用户 ID
check_embedding: 为 True 时额外检查 summary_embedding 是否缺失(仅当用户有 embedding 模型配置时传 True
"""
try:
query = GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING if check_embedding else GET_INCOMPLETE_COMMUNITIES
result = await self.connector.execute_query(query, end_user_id=end_user_id)
return [row["community_id"] for row in result]
except Exception as e:
logger.error(f"get_incomplete_communities failed: {e}")
return []
async def is_community_complete(self, community_id: str, end_user_id: str, check_embedding: bool = False) -> bool:
"""检查单个社区节点的属性是否完整。"""
try:
query = CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING if check_embedding else CHECK_COMMUNITY_IS_COMPLETE
result = await self.connector.execute_query(query, community_id=community_id, end_user_id=end_user_id)
return result[0]["is_complete"] if result else False
except Exception as e:
logger.error(f"is_community_complete failed: {e}")
return False
async def update_community_metadata(
self,
community_id: str,
@@ -243,7 +287,7 @@ class CommunityRepository:
core_entities: List[str],
summary_embedding: Optional[List[float]] = None,
) -> bool:
"""更新社区的名称、摘要、核心实体列表和摘要向量"""
"""更新社区的名称、摘要、核心实体列表及 summary_embedding"""
try:
result = await self.connector.execute_query(
UPDATE_COMMUNITY_METADATA,

View File

@@ -1137,10 +1137,20 @@ MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(
RETURN e.id AS id, e.name AS name, e.entity_type AS entity_type,
e.importance_score AS importance_score, e.activation_value AS activation_value,
e.name_embedding AS name_embedding,
e.aliases AS aliases, e.description AS description
e.aliases AS aliases, e.description AS description,
e.example AS example
ORDER BY coalesce(e.activation_value, 0) DESC
"""
GET_COMMUNITY_RELATIONSHIPS = """
MATCH (e1:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community {community_id: $community_id})
MATCH (e2:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c)
MATCH (e1)-[r:EXTRACTED_RELATIONSHIP]->(e2)
RETURN e1.name AS subject, r.predicate AS predicate, e2.name AS object
ORDER BY e1.name, r.predicate, e2.name
LIMIT 20
"""
GET_ALL_COMMUNITY_MEMBERS_BATCH = """
MATCH (e:ExtractedEntity {end_user_id: $end_user_id})-[:BELONGS_TO_COMMUNITY]->(c:Community)
RETURN c.community_id AS community_id,
@@ -1316,3 +1326,38 @@ RETURN s.statement AS statement,
ORDER BY COALESCE(s.activation_value, 0) DESC
LIMIT $limit
"""
CHECK_COMMUNITY_IS_COMPLETE = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL
) AS is_complete
"""
CHECK_COMMUNITY_IS_COMPLETE_WITH_EMBEDDING = """
MATCH (c:Community {community_id: $community_id, end_user_id: $end_user_id})
RETURN (
c.name IS NOT NULL AND c.name <> '' AND
c.summary IS NOT NULL AND c.summary <> '' AND
c.core_entities IS NOT NULL AND
c.summary_embedding IS NOT NULL
) AS is_complete
"""
GET_INCOMPLETE_COMMUNITIES = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.summary IS NULL OR c.core_entities IS NULL
OR c.name = '' OR c.summary = ''
RETURN c.community_id AS community_id
"""
GET_INCOMPLETE_COMMUNITIES_WITH_EMBEDDING = """
MATCH (c:Community {end_user_id: $end_user_id})
WHERE c.name IS NULL OR c.name = ''
OR c.summary IS NULL OR c.summary = ''
OR c.core_entities IS NULL
OR (c.summary_embedding IS NULL AND c.summary IS NOT NULL AND c.summary <> '(empty)')
RETURN c.community_id AS community_id
"""

View File

@@ -1,4 +1,5 @@
import asyncio
import os
from typing import List, Optional
# 使用新的仓储层
@@ -304,7 +305,6 @@ async def save_dialog_and_statements_to_neo4j(
def schedule_clustering_after_write(
entity_nodes: List,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
@@ -325,13 +325,12 @@ def schedule_clustering_after_write(
end_user_id = entity_nodes[0].end_user_id
new_entity_ids = [e.id for e in entity_nodes]
logger.info(f"[Clustering] 准备触发聚类,实体数: {len(new_entity_ids)}, end_user_id: {end_user_id}")
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
asyncio.create_task(_trigger_clustering(new_entity_ids, end_user_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id))
async def _trigger_clustering(
new_entity_ids: List[str],
end_user_id: str,
config_id: Optional[str] = None,
llm_model_id: Optional[str] = None,
embedding_model_id: Optional[str] = None,
) -> None:
@@ -343,7 +342,7 @@ async def _trigger_clustering(
from app.core.memory.storage_services.clustering_engine import LabelPropagationEngine
logger.info(f"[Clustering] 开始聚类end_user_id={end_user_id}, 实体数={len(new_entity_ids)}")
connector = Neo4jConnector()
engine = LabelPropagationEngine(connector, config_id=config_id, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
engine = LabelPropagationEngine(connector, llm_model_id=llm_model_id, embedding_model_id=embedding_model_id)
await engine.run(end_user_id=end_user_id, new_entity_ids=new_entity_ids)
logger.info(f"[Clustering] 聚类完成end_user_id={end_user_id}")
except Exception as e:

View File

@@ -43,6 +43,7 @@ class WorkflowConfigRepository:
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
features: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None
) -> WorkflowConfig:
"""创建或更新工作流配置
@@ -53,6 +54,7 @@ class WorkflowConfigRepository:
edges: 边列表
variables: 变量列表
execution_config: 执行配置
features: 功能特性
triggers: 触发器列表
Returns:
@@ -82,6 +84,7 @@ class WorkflowConfigRepository:
edges=edges,
variables=variables or [],
execution_config=execution_config or {},
features=features or {},
triggers=triggers or []
)
self.db.add(config)

View File

@@ -149,18 +149,26 @@ class FileUploadConfig(BaseModel):
)
# 通用文件PDF/DOCX/XLSX/TXT/CSV/JSON最大 100MB
document_enabled: bool = Field(default=False)
document_max_size_mb: int = Field(default=100)
document_max_size_mb: int = Field(default=50)
document_allowed_extensions: List[str] = Field(
default=["pdf", "docx", "xlsx", "txt", "csv", "json", "md"]
default=["pdf", "docx", "doc", "xlsx", "xls", "txt", "csv", "json", "md"]
)
# 视频文件MP4/MOV/AVI/WebM最大 500MB
video_enabled: bool = Field(default=False)
video_max_size_mb: int = Field(default=500)
video_max_size_mb: int = Field(default=50)
video_allowed_extensions: List[str] = Field(
default=["mp4", "mov"]
default=["mp4"]
)
# 最大文件数量
max_file_count: int = Field(default=5, ge=1, le=20)
max_file_count: int = Field(default=5, ge=1)
@field_validator("max_file_count")
@classmethod
def validate_max_file_count(cls, v: int) -> int:
from app.core.config import settings
if v > settings.MAX_FILE_COUNT:
raise ValueError(f"max_file_count 不能超过 {settings.MAX_FILE_COUNT}")
return v
class OpeningStatementConfig(BaseModel):

View File

@@ -21,7 +21,7 @@ class MemoryWriteRequest(BaseModel):
"""
end_user_id: str = Field(..., description="End user ID (required)")
message: str = Field(..., description="Message content to store")
config_id: Optional[str] = Field(None, description="Memory configuration ID")
config_id: str = Field(..., description="Memory configuration ID (required)")
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
@@ -68,7 +68,7 @@ class MemoryReadRequest(BaseModel):
"0",
description="Search mode: 0=verify, 1=direct, 2=context"
)
config_id: Optional[str] = Field(None, description="Memory configuration ID")
config_id: str = Field(..., description="Memory configuration ID (required)")
storage_type: str = Field("neo4j", description="Storage type: neo4j or rag")
user_rag_memory_id: Optional[str] = Field(None, description="RAG memory ID")
@@ -132,3 +132,79 @@ class MemoryReadResponse(BaseModel):
description="Intermediate retrieval outputs"
)
end_user_id: str = Field(..., description="End user ID")
class CreateEndUserRequest(BaseModel):
"""Request schema for creating an end user.
Attributes:
workspace_id: Workspace ID (required)
other_id: External user identifier (required)
other_name: Display name for the end user
"""
workspace_id: str = Field(..., description="Workspace ID (required)")
other_id: str = Field(..., description="External user identifier (required)")
other_name: Optional[str] = Field("", description="Display name")
@field_validator("workspace_id")
@classmethod
def validate_workspace_id(cls, v: str) -> str:
"""Validate that workspace_id is not empty."""
if not v or not v.strip():
raise ValueError("workspace_id is required and cannot be empty")
return v.strip()
@field_validator("other_id")
@classmethod
def validate_other_id(cls, v: str) -> str:
"""Validate that other_id is not empty."""
if not v or not v.strip():
raise ValueError("other_id is required and cannot be empty")
return v.strip()
class CreateEndUserResponse(BaseModel):
"""Response schema for end user creation.
Attributes:
id: Created end user UUID
other_id: External user identifier
other_name: Display name
workspace_id: Workspace the user belongs to
"""
id: str = Field(..., description="End user UUID")
other_id: str = Field(..., description="External user identifier")
other_name: str = Field("", description="Display name")
workspace_id: str = Field(..., description="Workspace ID")
class MemoryConfigItem(BaseModel):
"""Schema for a single memory config in the list response.
Attributes:
config_id: Configuration UUID
config_name: Configuration name
config_desc: Configuration description
is_default: Whether this is the workspace default config
scene_name: Associated ontology scene name
created_at: Creation timestamp
updated_at: Last update timestamp
"""
config_id: str = Field(..., description="Configuration ID")
config_name: str = Field(..., description="Configuration name")
config_desc: Optional[str] = Field(None, description="Configuration description")
is_default: bool = Field(False, description="Whether this is the workspace default")
scene_name: Optional[str] = Field(None, description="Associated ontology scene name")
created_at: Optional[str] = Field(None, description="Creation timestamp")
updated_at: Optional[str] = Field(None, description="Last update timestamp")
class ListConfigsResponse(BaseModel):
"""Response schema for listing memory configs.
Attributes:
configs: List of memory config items
total: Total number of configs
"""
configs: List[MemoryConfigItem] = Field(default_factory=list, description="List of configs")
total: int = Field(0, description="Total number of configs")

View File

@@ -118,28 +118,54 @@ class AppChatService:
)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
# 加载历史消息
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=10
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件
processed_files = None
if files:
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
@@ -313,31 +339,54 @@ class AppChatService:
streaming=True
)
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
# 加载历史消息
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=10
)
history = []
memory_config = {"enabled": True, 'max_history': 10}
if memory_config.get("enabled"):
messages = self.conversation_service.get_messages(
conversation_id=conversation_id,
limit=memory_config.get("max_history", 10)
)
history = [
{"role": msg.role, "content": msg.content}
for msg in messages
]
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
history_files = msg.meta_data.get("files", [])
# 使用 MultimodalService 处理文件
multimodal_service = MultimodalService(self.db, api_config=model_info)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in history_files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
history_processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(history_processed_files)
history.append({
"role": msg.role,
"content": content
})
# 处理多模态文件
processed_files = None
if files:
model_info = ModelInfo(
model_name=api_key_obj.model_name,
provider=api_key_obj.provider,
api_key=api_key_obj.api_key,
api_base=api_key_obj.api_base,
capability=api_key_obj.capability,
is_omni=api_key_obj.is_omni,
model_type=ModelType.LLM
)
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
logger.info(f"处理了 {len(processed_files)} 个文件")
@@ -347,8 +396,14 @@ class AppChatService:
total_tokens = 0
text_queue: asyncio.Queue = asyncio.Queue()
api_key_config = {
"model_name": api_key_obj.model_name,
"api_key": api_key_obj.api_key,
"api_base": api_key_obj.api_base,
"provider": api_key_obj.provider,
}
stream_audio_url, tts_task = await self.agent_service._generate_tts_streaming(
features_config, api_key_obj,
features_config, api_key_config,
text_queue=text_queue,
tenant_id=tenant_id, workspace_id=workspace_id
)

View File

@@ -16,6 +16,7 @@ from app.models.app_release_model import AppRelease
from app.models.knowledge_model import Knowledge
from app.models.models_model import ModelConfig
from app.models.tool_model import ToolConfig as ToolConfigModel
from app.models.skill_model import Skill
from app.models.workflow_model import WorkflowConfig
from app.services.workflow_service import WorkflowService
from app.core.workflow.adapters.memory_bear.memory_bear_adapter import MemoryBearAdapter
@@ -84,7 +85,9 @@ class AppDslService:
if "knowledge_retrieval" in cfg:
enriched["knowledge_retrieval"] = self._enrich_knowledge_retrieval(cfg["knowledge_retrieval"])
if "tools" in cfg:
enriched["tools"] = self._enrich_tools(cfg["tools"])
enriched["tools"] = self._enrich_tools(cfg.get("tools"))
if "skills" in cfg:
enriched["skills"] = self._enrich_skills(cfg.get("skills"))
return enriched
if app_type == AppType.MULTI_AGENT:
enriched = {**cfg}
@@ -108,6 +111,7 @@ class AppDslService:
"variables": config.variables if config else [],
"edges": config.edges if config else [],
"nodes": config.nodes if config else [],
"features": config.features if config else {},
"execution_config": config.execution_config if config else {},
"triggers": config.triggers if config else [],
} if config else {}
@@ -123,7 +127,8 @@ class AppDslService:
"memory": config.memory if config else None,
"variables": config.variables if config else [],
"tools": self._enrich_tools(config.tools) if config else [],
"skills": config.skills if config else {},
"skills": self._enrich_skills(config.skills) if config else {},
"features": config.features if config else {}
} if config else {}
dsl = {**meta, "app": app_meta, "agent_config": config_data}
@@ -185,6 +190,22 @@ class AppDslService:
def _enrich_tools(self, tools: list) -> list:
return [{**t, "_ref": self._tool_ref(t.get("tool_id"))} for t in (tools or [])]
def _skill_ref(self, skill_id) -> Optional[dict]:
if not skill_id:
return None
s = self.db.query(Skill).filter(Skill.id == skill_id).first()
return {"id": str(skill_id), "name": s.name} if s else {"id": str(skill_id)}
def _enrich_skills(self, skills: Optional[dict]) -> Optional[dict]:
if not skills:
return skills
skill_ids = skills.get("skill_ids", [])
enriched_ids = [
{"id": sid, "_ref": self._skill_ref(sid)}
for sid in (skill_ids or [])
]
return {**skills, "skill_ids": enriched_ids}
def _agent_ref(self, agent_id) -> Optional[dict]:
if not agent_id:
return None
@@ -249,7 +270,8 @@ class AppDslService:
memory=self._resolve_memory(cfg.get("memory"), workspace_id, warnings),
variables=cfg.get("variables", []),
tools=self._resolve_tools(cfg.get("tools", []), tenant_id, warnings),
skills=cfg.get("skills", {}),
skills=self._resolve_skills(cfg.get("skills", {}), tenant_id, warnings),
features=cfg.get("features", {}),
is_active=True,
created_at=now,
updated_at=now,
@@ -290,6 +312,7 @@ class AppDslService:
edges=[e.model_dump() for e in result.edges],
variables=[v.model_dump() for v in result.variables],
execution_config=wf.get("execution_config", {}),
features=wf.get("features", {}),
triggers=wf.get("triggers", []),
validate=False,
)
@@ -444,6 +467,46 @@ class AppDslService:
return {**memory, "memory_config_id": None, "enabled": False}
return memory
def _resolve_skills(self, skills: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> dict:
if not skills:
return skills or {}
resolved_ids = []
for entry in (skills.get("skill_ids") or []):
# entry 可能是 {"id": "...", "_ref": {...}} 或直接是字符串
if isinstance(entry, dict):
ref = entry.get("_ref") or ({"name": None, "id": entry.get("id")} if entry.get("id") else None)
skill_id = self._resolve_skill(ref, tenant_id, warnings)
else:
skill_id = self._resolve_skill({"id": str(entry)}, tenant_id, warnings)
if skill_id:
resolved_ids.append(str(skill_id))
return {**{k: v for k, v in skills.items() if k != "skill_ids"}, "skill_ids": resolved_ids}
def _resolve_skill(self, ref: Optional[dict], tenant_id: uuid.UUID, warnings: list) -> Optional[str]:
if not ref:
return None
# 先按 id 匹配
if ref.get("id"):
try:
s = self.db.query(Skill).filter(
Skill.id == uuid.UUID(str(ref["id"])),
Skill.tenant_id == tenant_id
).first()
if s:
return str(s.id)
except Exception:
pass
# 再按名称匹配
if ref.get("name"):
s = self.db.query(Skill).filter(
Skill.name == ref["name"],
Skill.tenant_id == tenant_id
).first()
if s:
return str(s.id)
warnings.append(f"未找到技能: {ref}")
return None
def _resolve_tools(self, tools: list, tenant_id: uuid.UUID, warnings: list) -> list:
result = []
for t in (tools or []):

View File

@@ -833,8 +833,6 @@ class AppService:
# 跨工作空间时,获取目标工作空间的 tenant_id 用于判断模型配置是否可用
target_tenant_id = None
available_model_ids: set = set()
available_kb_ids: set = set()
if is_cross_workspace:
target_ws = self.db.get(Workspace, target_workspace_id)
if not target_ws:
@@ -849,28 +847,29 @@ class AppService:
if source_config:
if is_cross_workspace:
# Batch-collect and preload all referenced resources
model_ids, kb_ids = self._collect_resource_ids_from_config(
source_config.default_model_config_id,
source_config.knowledge_retrieval,
source_config.tools
# 跨工作空间model/tools/skills 属于 tenant 级别直接保留,
# knowledge_bases 属于 workspace 级别需过滤memory_config 需清空
_, kb_ids = self._collect_resource_ids_from_config(
None, source_config.knowledge_retrieval
)
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
target_tenant_id, target_workspace_id, model_ids, kb_ids
)
new_model_config_id = self._is_model_available(
source_config.default_model_config_id, available_model_ids
_, available_kb_ids = self._preload_cross_workspace_resources(
target_tenant_id, target_workspace_id, set(), kb_ids
)
new_model_config_id = source_config.default_model_config_id
new_knowledge_retrieval = self._clean_knowledge_retrieval(
source_config.knowledge_retrieval, available_kb_ids
)
new_tools = self._clean_tools(
source_config.tools, available_kb_ids
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
new_memory = self._clean_memory_cross_workspace(
source_config.memory, target_workspace_id
)
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
else:
new_model_config_id = source_config.default_model_config_id
new_knowledge_retrieval = copy.deepcopy(source_config.knowledge_retrieval) if source_config.knowledge_retrieval else None
new_tools = copy.deepcopy(source_config.tools) if source_config.tools else []
new_memory = copy.deepcopy(source_config.memory) if source_config.memory else None
new_skills = copy.deepcopy(source_config.skills) if source_config.skills else {}
new_config = AgentConfig(
id=uuid.uuid4(),
@@ -879,9 +878,11 @@ class AppService:
default_model_config_id=new_model_config_id,
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
knowledge_retrieval=new_knowledge_retrieval,
memory=copy.deepcopy(source_config.memory) if source_config.memory else None,
memory=new_memory,
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
tools=new_tools,
skills=new_skills,
features=copy.deepcopy(source_config.features) if source_config.features else {},
is_active=True,
created_at=now,
updated_at=now,
@@ -894,28 +895,14 @@ class AppService:
).first()
if source_config:
if is_cross_workspace:
model_ids, kb_ids = self._collect_resource_ids_from_workflow_nodes(
source_config.nodes
)
available_model_ids, available_kb_ids = self._preload_cross_workspace_resources(
target_tenant_id, target_workspace_id, model_ids, kb_ids
)
new_nodes = self._clean_workflow_nodes_for_cross_workspace(
source_config.nodes or [],
available_model_ids,
available_kb_ids
)
else:
new_nodes = copy.deepcopy(source_config.nodes) if source_config.nodes else []
new_config = WorkflowConfig(
id=uuid.uuid4(),
app_id=new_app.id,
nodes=new_nodes,
nodes=copy.deepcopy(source_config.nodes) if source_config.nodes else [],
edges=copy.deepcopy(source_config.edges) if source_config.edges else [],
variables=copy.deepcopy(source_config.variables) if source_config.variables else [],
execution_config=copy.deepcopy(source_config.execution_config) if source_config.execution_config else {},
features=copy.deepcopy(source_config.features) if source_config.features else {},
triggers=copy.deepcopy(source_config.triggers) if source_config.triggers else [],
is_active=True,
created_at=now,
@@ -929,24 +916,15 @@ class AppService:
).first()
if source_config:
if is_cross_workspace:
model_ids = {source_config.default_model_config_id} if source_config.default_model_config_id else set()
available_model_ids, _ = self._preload_cross_workspace_resources(
target_tenant_id, target_workspace_id, model_ids, set()
)
new_model_config_id = self._is_model_available(
source_config.default_model_config_id, available_model_ids
)
else:
new_model_config_id = source_config.default_model_config_id
# multi_agent 的 model_config_id/sub_agents/routing_rules 均属于 tenant 级别直接保留
# 跨空间时 master_agent_idAppRelease属于源空间需清空
new_config = MultiAgentConfig(
id=uuid.uuid4(),
app_id=new_app.id,
master_agent_id=source_config.master_agent_id if not is_cross_workspace else None,
master_agent_name=source_config.master_agent_name,
default_model_config_id=new_model_config_id,
model_parameters=source_config.model_parameters,
default_model_config_id=source_config.default_model_config_id,
model_parameters=copy.deepcopy(source_config.model_parameters) if source_config.model_parameters else None,
orchestration_mode=source_config.orchestration_mode,
sub_agents=copy.deepcopy(source_config.sub_agents) if source_config.sub_agents else [],
routing_rules=copy.deepcopy(source_config.routing_rules) if source_config.routing_rules else None,
@@ -1037,8 +1015,7 @@ class AppService:
@staticmethod
def _collect_resource_ids_from_config(
model_config_id: Optional[uuid.UUID],
knowledge_retrieval: Optional[dict],
tools: Optional[list]
knowledge_retrieval: Optional[dict]
) -> tuple:
"""Extract all model config IDs and knowledge base IDs from an app config."""
model_ids: set = set()
@@ -1048,62 +1025,12 @@ class AppService:
model_ids.add(model_config_id)
if knowledge_retrieval and isinstance(knowledge_retrieval, dict):
if "kb_ids" in knowledge_retrieval:
for kid in knowledge_retrieval.get("kb_ids", []):
if kid:
kb_ids.add(str(kid))
if knowledge_retrieval.get("knowledge_id"):
kb_ids.add(str(knowledge_retrieval["knowledge_id"]))
if tools:
for tool in tools:
if isinstance(tool, dict):
kid = tool.get("knowledge_id") or tool.get("kb_id")
if kid:
kb_ids.add(str(kid))
if "knowledge_bases" in knowledge_retrieval:
for kid in knowledge_retrieval.get("knowledge_bases", []):
kb_ids.add(str(kid.get("kb_id")))
return model_ids, kb_ids
@staticmethod
def _collect_resource_ids_from_workflow_nodes(nodes: list) -> tuple:
"""Extract all model config IDs and knowledge base IDs from workflow nodes."""
model_ids: set = set()
kb_ids: set = set()
for node in (nodes or []):
if not isinstance(node, dict):
continue
data = node.get("data", {})
if not isinstance(data, dict):
continue
for key in ("model_config_id", "default_model_config_id"):
val = data.get(key)
if val:
try:
model_ids.add(uuid.UUID(str(val)))
except (ValueError, AttributeError):
pass
kr = data.get("knowledge_retrieval")
if isinstance(kr, dict):
for kid in kr.get("kb_ids", []):
if kid:
kb_ids.add(str(kid))
if kr.get("knowledge_id"):
kb_ids.add(str(kr["knowledge_id"]))
if data.get("knowledge_id"):
kb_ids.add(str(data["knowledge_id"]))
for kid in data.get("kb_ids", []):
if kid:
kb_ids.add(str(kid))
return model_ids, kb_ids
@staticmethod
def _is_model_available(model_config_id: Optional[uuid.UUID], available_model_ids: set) -> Optional[uuid.UUID]:
if not model_config_id:
return None
return model_config_id if model_config_id in available_model_ids else None
@staticmethod
def _is_kb_available(kb_id: Optional[str], available_kb_ids: set) -> Optional[str]:
if not kb_id:
@@ -1124,95 +1051,53 @@ class AppService:
cleaned = copy.deepcopy(knowledge_retrieval)
if "kb_ids" in cleaned and isinstance(cleaned["kb_ids"], list):
cleaned["kb_ids"] = [
kid for kid in cleaned["kb_ids"]
if self._is_kb_available(kid, available_kb_ids)
if "knowledge_bases" in cleaned and isinstance(cleaned["knowledge_bases"], list):
cleaned["knowledge_bases"] = [
kb for kb in cleaned["knowledge_bases"]
if self._is_kb_available(kb.get("kb_id"), available_kb_ids)
]
if "knowledge_id" in cleaned:
cleaned["knowledge_id"] = self._is_kb_available(
cleaned.get("knowledge_id"), available_kb_ids
)
return cleaned
def _clean_tools(
def _clean_memory_cross_workspace(
self,
tools: Optional[list],
available_kb_ids: set
) -> list:
"""Clean tools config, keeping built-in tools and tools with available KBs."""
if not tools:
return []
memory: Optional[dict],
target_workspace_id: uuid.UUID
) -> Optional[dict]:
"""Clear memory_config_id/memory_content if it doesn't belong to target workspace."""
if not memory:
return None
cleaned = []
for tool in tools:
if not isinstance(tool, dict):
cleaned.append(tool)
continue
from app.models.memory_config_model import MemoryConfig
tool_type = tool.get("type", "")
if tool_type in ("builtin", "built_in", "system"):
cleaned.append(copy.deepcopy(tool))
continue
cleaned = copy.deepcopy(memory)
# 兼容旧字段 memory_content 和新字段 memory_config_id
mid = cleaned.get("memory_config_id") or cleaned.get("memory_content")
if mid:
try:
mid_uuid = uuid.UUID(str(mid))
except (ValueError, AttributeError):
exists = self.db.query(MemoryConfig).filter(
MemoryConfig.config_id_old == int(mid),
MemoryConfig.workspace_id == target_workspace_id
).first()
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
return cleaned
kb_id = tool.get("knowledge_id") or tool.get("kb_id")
if kb_id:
if self._is_kb_available(kb_id, available_kb_ids):
cleaned.append(copy.deepcopy(tool))
continue
exists = self.db.query(
self.db.query(MemoryConfig).filter(
MemoryConfig.config_id == mid_uuid,
MemoryConfig.workspace_id == target_workspace_id
).exists()
).scalar()
if not exists:
cleaned["memory_config_id"] = None
cleaned.pop("memory_content", None)
cleaned["enabled"] = False
cleaned.append(copy.deepcopy(tool))
return cleaned
def _clean_workflow_nodes_for_cross_workspace(
self,
nodes: list,
available_model_ids: set,
available_kb_ids: set
) -> list:
"""Clean workflow nodes, using pre-loaded resource sets. Uses deepcopy to avoid mutating source."""
if not nodes:
return []
cleaned = []
for node in nodes:
if not isinstance(node, dict):
cleaned.append(node)
continue
node_copy = copy.deepcopy(node)
data = node_copy.get("data")
if not isinstance(data, dict):
cleaned.append(node_copy)
continue
for key in ("model_config_id", "default_model_config_id"):
if key in data and data[key]:
try:
mid = uuid.UUID(str(data[key]))
except (ValueError, AttributeError):
data[key] = None
continue
data[key] = str(mid) if mid in available_model_ids else None
if "knowledge_retrieval" in data and data["knowledge_retrieval"]:
data["knowledge_retrieval"] = self._clean_knowledge_retrieval(
data["knowledge_retrieval"], available_kb_ids
)
if "knowledge_id" in data:
data["knowledge_id"] = self._is_kb_available(
data.get("knowledge_id"), available_kb_ids
)
if "kb_ids" in data and isinstance(data["kb_ids"], list):
data["kb_ids"] = [
kid for kid in data["kb_ids"]
if self._is_kb_available(kid, available_kb_ids)
]
cleaned.append(node_copy)
return cleaned
def list_apps(

View File

@@ -21,6 +21,7 @@ from app.models.conversation_model import ConversationDetail
from app.models.prompt_optimizer_model import RoleType
from app.repositories.conversation_repository import ConversationRepository, MessageRepository
from app.schemas.conversation_schema import ConversationOut
from app.schemas.model_schema import ModelInfo
from app.services import workspace_service
from app.services.model_service import ModelConfigService
@@ -119,25 +120,27 @@ class ConversationService:
def get_user_conversations(
self,
user_id: uuid.UUID
) -> list[Conversation]:
user_id: uuid.UUID,
page: int = 1,
page_size: int = 20
) -> tuple[list[Conversation], int]:
"""
Retrieve recent conversations for a specific user
This method delegates persistence logic to the repository layer and
applies service-level defaults (e.g. recent conversation limit).
Retrieve recent conversations for a specific user with pagination.
Args:
user_id (uuid.UUID): Unique identifier of the user.
page (int): Page number (1-based). Defaults to 1.
page_size (int): Number of items per page. Defaults to 20.
Returns:
list[Conversation]: A list of recent conversation entities.
tuple[list[Conversation], int]: A list of recent conversation entities and total count.
"""
conversations = self.conversation_repo.get_conversation_by_user_id(
conversations, total = self.conversation_repo.get_conversation_by_user_id(
user_id,
limit=10
page=page,
page_size=page_size
)
return conversations
return conversations, total
def list_conversations(
self,
@@ -267,10 +270,11 @@ class ConversationService:
return messages
def get_conversation_history(
async def get_conversation_history(
self,
conversation_id: uuid.UUID,
max_history: Optional[int] = None
max_history: Optional[int] = None,
api_config: Optional[ModelInfo] = None
) -> List[dict]:
"""
Retrieve historical conversation messages formatted as dictionaries.
@@ -278,6 +282,7 @@ class ConversationService:
Args:
conversation_id (uuid.UUID): Conversation UUID.
max_history (Optional[int]): Maximum number of messages to retrieve.
api_config (Optional[ModelInfo]): Model API configuration for multimodal processing.
Returns:
List[dict]: List of message dictionaries with keys 'role' and 'content'.
@@ -288,13 +293,37 @@ class ConversationService:
)
# 转换为字典格式
history = [
{
history = []
for msg in messages:
content = [{"type": "text", "text": msg.content}]
# 处理 meta_data 中的 files
if msg.meta_data and msg.meta_data.get("files"):
files = msg.meta_data.get("files", [])
if api_config:
# 使用 MultimodalService 处理文件
from app.services.multimodal_service import MultimodalService
multimodal_service = MultimodalService(self.db, api_config=api_config)
# 将 files 转换为 FileInput 格式
file_inputs = []
for file in files:
from app.schemas.app_schema import FileInput, TransferMethod
file_input = FileInput(
type=file.get("type"),
transfer_method=TransferMethod.REMOTE_URL,
url=file.get("url")
)
file_inputs.append(file_input)
processed_files = await multimodal_service.history_process_files(files=file_inputs)
content.extend(processed_files)
history.append({
"role": msg.role,
"content": msg.content
}
for msg in messages
]
"content": content
})
return history
@@ -522,9 +551,18 @@ class ConversationService:
type=ModelType(model_type)
)
conversation_messages = self.get_conversation_history(
conversation_messages = await self.get_conversation_history(
conversation_id=conversation_id,
max_history=20
max_history=20,
api_config=ModelInfo(
model_name=model_name,
provider=provider,
api_key=api_key,
api_base=api_base,
capability=api_config.capability,
is_omni=api_config.is_omni,
model_type=model_type
)
)
if len(conversation_messages) == 0:
return ConversationOut(

View File

@@ -579,9 +579,20 @@ class AgentRunService:
user_id=user_id
)
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=model_config.type
)
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
api_config=model_info,
max_history=10
)
@@ -589,15 +600,6 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
@@ -815,9 +817,20 @@ class AgentRunService:
sub_agent=sub_agent
)
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=model_config.type
)
# 6. 加载历史消息
history = await self._load_conversation_history(
conversation_id=conversation_id,
api_config=model_info,
max_history=memory_config.get("max_history", 10)
)
@@ -825,15 +838,6 @@ class AgentRunService:
processed_files = None
if files:
# 获取 provider 信息
model_info = ModelInfo(
model_name=api_key_config["model_name"],
provider=api_key_config["provider"],
api_key=api_key_config["api_key"],
api_base=api_key_config["api_base"],
capability=api_key_config["capability"],
is_omni=api_key_config["is_omni"],
model_type=ModelType.LLM
)
provider = api_key_config.get("provider", "openai")
multimodal_service = MultimodalService(self.db, model_info)
processed_files = await multimodal_service.process_files(user_id, files)
@@ -1115,6 +1119,7 @@ class AgentRunService:
async def _load_conversation_history(
self,
conversation_id: str,
api_config: ModelInfo | None = None,
max_history: int = 10
) -> List[Dict[str, str]]:
"""加载会话历史消息
@@ -1129,9 +1134,11 @@ class AgentRunService:
try:
conversation_service = ConversationService(self.db)
history = conversation_service.get_conversation_history(
# 获取 API 配置用于多模态处理
history = await conversation_service.get_conversation_history(
conversation_id=uuid.UUID(conversation_id),
max_history=max_history
max_history=max_history,
api_config=api_config
)
logger.debug(

View File

@@ -1179,7 +1179,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
app = db.query(App).filter(App.id == app_id).first()
if not app:
logger.warning(f"App not found: {app_id}")
raise ValueError(f"应用不存在: {app_id}")
# raise ValueError(f"应用不存在: {app_id}")
# TODO: temp fix for draft run
# if not app.current_release_id:
# logger.warning(f"No current release for app: {app_id}")
@@ -1252,17 +1252,15 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
memory_config_service = MemoryConfigService(db)
memory_config = memory_config_service.get_config_with_fallback(
memory_config_id=memory_config_id_to_use,
workspace_id=app.workspace_id
workspace_id=end_user.workspace_id
)
memory_config_id = str(memory_config.config_id) if memory_config else None
result = {
"end_user_id": str(end_user_id),
"app_id": str(app_id),
"release_id": str(app.current_release_id) if app.current_release_id else None,
"memory_config_id": memory_config_id,
"workspace_id": str(app.workspace_id)
"workspace_id": str(end_user.workspace_id)
}
logger.info(

View File

@@ -84,43 +84,65 @@ class MemoryAPIService:
if not app:
logger.warning(f"App not found for end_user: {end_user_id}")
raise ResourceNotFoundException(
resource_type="App",
resource_id=str(end_user.app_id)
)
if app.workspace_id != workspace_id:
logger.warning(
f"End user {end_user_id} belongs to workspace {app.workspace_id}, "
f"not authorized workspace {workspace_id}"
)
raise BusinessException(
message="End user does not belong to authorized workspace",
code=BizCode.FORBIDDEN
)
# raise ResourceNotFoundException(
# resource_type="App",
# resource_id=str(end_user.app_id)
# )
# temporally allow any workspace to access
# if end_user.workspace_id != workspace_id:
# print(f"[DEBUG] end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}")
# logger.warning(
# f"End user {end_user_id} belongs to workspace {end_user.workspace_id}, "
# f"not authorized workspace {workspace_id}"
# )
# raise BusinessException(
# message=f"End user does not belong to authorized workspace. end_user.workspace_id={end_user.workspace_id}, api_key.workspace_id={workspace_id}",
# code=BizCode.FORBIDDEN
# )
logger.info(f"End user {end_user_id} validated successfully")
return end_user
def _update_end_user_config(self, end_user_id: str, config_id: str) -> None:
"""Update the end user's memory_config_id.
Silently updates the config association. Logs warnings on failure
but does not raise, so it won't block the main read/write operation.
Args:
end_user_id: End user identifier
config_id: Memory configuration ID to assign
"""
try:
config_uuid = uuid.UUID(config_id)
from app.repositories.end_user_repository import EndUserRepository
end_user_repo = EndUserRepository(self.db)
end_user_repo.update_memory_config_id(
end_user_id=uuid.UUID(end_user_id),
memory_config_id=config_uuid,
)
except Exception as e:
logger.warning(f"Failed to update memory_config_id for end_user {end_user_id}: {e}")
async def write_memory(
self,
workspace_id: uuid.UUID,
end_user_id: str,
message: str,
config_id: Optional[str] = None,
config_id: str,
storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Write memory with validation.
Validates end_user exists and belongs to workspace, then delegates
to MemoryAgentService.write_memory.
Validates end_user exists and belongs to workspace, updates the end user's
memory_config_id, then delegates to MemoryAgentService.write_memory.
Args:
workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as end_user_id)
message: Message content to store
config_id: Optional memory configuration ID
config_id: Memory configuration ID (required)
storage_type: Storage backend (neo4j or rag)
user_rag_memory_id: Optional RAG memory ID
@@ -136,7 +158,8 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as end_user_id for memory operations
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
try:
# Delegate to MemoryAgentService
@@ -188,21 +211,21 @@ class MemoryAPIService:
end_user_id: str,
message: str,
search_switch: str = "0",
config_id: Optional[str] = None,
config_id: str = "",
storage_type: str = "neo4j",
user_rag_memory_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Read memory with validation.
Validates end_user exists and belongs to workspace, then delegates
to MemoryAgentService.read_memory.
Validates end_user exists and belongs to workspace, updates the end user's
memory_config_id, then delegates to MemoryAgentService.read_memory.
Args:
workspace_id: Workspace ID for resource validation
end_user_id: End user identifier (used as end_user_id)
message: Query message
search_switch: Search mode (0=deep search with verification, 1=deep search, 2=fast search)
config_id: Optional memory configuration ID
config_id: Memory configuration ID (required)
storage_type: Storage backend (neo4j or rag)
user_rag_memory_id: Optional RAG memory ID
@@ -218,7 +241,8 @@ class MemoryAPIService:
# Validate end_user exists and belongs to workspace
self.validate_end_user(end_user_id, workspace_id)
# Use end_user_id as end_user_id for memory operations
# Update end user's memory_config_id
self._update_end_user_config(end_user_id, config_id)
try:
@@ -256,3 +280,50 @@ class MemoryAPIService:
message=f"Memory read failed: {str(e)}",
code=BizCode.MEMORY_READ_FAILED
)
def list_memory_configs(
self,
workspace_id: uuid.UUID,
) -> Dict[str, Any]:
"""List all memory configs for a workspace.
Args:
workspace_id: Workspace ID from API key authorization
Returns:
Dict with configs list and total count
Raises:
BusinessException: If listing fails
"""
logger.info(f"Listing memory configs for workspace: {workspace_id}")
try:
from app.repositories.memory_config_repository import MemoryConfigRepository
results = MemoryConfigRepository.get_all(self.db, workspace_id=workspace_id)
configs = []
for config, scene_name in results:
configs.append({
"config_id": str(config.config_id),
"config_name": config.config_name,
"config_desc": config.config_desc,
"is_default": config.is_default or False,
"scene_name": scene_name,
"created_at": config.created_at.isoformat() if config.created_at else None,
"updated_at": config.updated_at.isoformat() if config.updated_at else None,
})
logger.info(f"Found {len(configs)} memory configs for workspace {workspace_id}")
return {
"configs": configs,
"total": len(configs),
}
except Exception as e:
logger.error(f"Failed to list memory configs for workspace {workspace_id}: {e}")
raise BusinessException(
message=f"Failed to list memory configs: {str(e)}",
code=BizCode.MEMORY_READ_FAILED
)

View File

@@ -619,7 +619,7 @@ class MemoryForgetService:
recent_trends.append({
'date': date_str,
'merged_count': record.merged_count,
'average_activation': record.average_activation_value,
'average_activation': round(record.average_activation_value, 2) if record.average_activation_value is not None else None,
'total_nodes': record.total_nodes,
'execution_time': int(record.execution_time.timestamp() * 1000)
})

View File

@@ -11,6 +11,8 @@
import base64
import io
import uuid
import zipfile
import chardet
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
@@ -42,12 +44,10 @@ PDF_MIME = ['application/pdf']
DOC_MIME = [
'application/msword',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document',
'application/zip'
]
XLSX_MIME = [
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet',
'application/vnd.ms-excel',
'application/zip'
]
CSV_MIME = ['text/csv', 'application/csv']
JSON_MIME = ['application/json']
@@ -418,6 +418,71 @@ class MultimodalService:
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
async def history_process_files(
self,
files: Optional[List[FileInput]],
) -> List[Dict[str, Any]]:
"""
处理文件列表,返回 LLM 可用的格式
Args:
files: 文件输入列表
Returns:
List[Dict]: LLM 可用的内容格式列表(根据 provider 返回不同格式)
"""
if not files:
return []
# 获取对应的策略
# dashscope 的 omni 模型使用 OpenAI 兼容格式
if self.provider == "dashscope" and self.is_omni:
strategy_class = OpenAIFormatStrategy
else:
strategy_class = PROVIDER_STRATEGIES.get(self.provider)
if not strategy_class:
logger.warning(f"未找到 provider '{self.provider}' 的策略,使用默认策略")
strategy_class = DashScopeFormatStrategy
result = []
for idx, file in enumerate(files):
strategy = strategy_class(file)
if not file.url:
file.url = await self.get_file_url(file)
try:
if file.type == FileType.IMAGE and "vision" in self.capability:
is_support, content = await self._process_image(file, strategy)
result.append(content)
elif file.type == FileType.DOCUMENT:
is_support, content = await self._process_document(file, strategy)
result.append(content)
elif file.type == FileType.AUDIO and "audio" in self.capability:
is_support, content = await self._process_audio(file, strategy)
result.append(content)
elif file.type == FileType.VIDEO and "video" in self.capability:
is_support, content = await self._process_video(file, strategy)
result.append(content)
else:
logger.warning(f"不支持的文件类型: {file.type}")
except Exception as e:
logger.error(
f"处理文件失败",
extra={
"file_index": idx,
"file_type": file.type,
"error": str(e)
},
exc_info=True
)
# 继续处理其他文件,不中断整个流程
result.append({
"type": "text",
"text": f"[文件处理失败: {str(e)}]"
})
logger.info(f"成功处理 {len(result)}/{len(files)} 个文件provider={self.provider}")
return result
def write_perceptual_memory(
self,
end_user_id: str,
@@ -588,12 +653,12 @@ class MultimodalService:
file.set_content(file_content)
file_mime_type = magic.from_buffer(file_content, mime=True)
if file_mime_type in TEXT_MIME:
return file_content.decode("utf-8")
return self._decode_text_safe(file_content)
elif file_mime_type in PDF_MIME:
return await self._extract_pdf_text(file_content)
elif file_mime_type in DOC_MIME and file.file_type.endswith(('docx', 'doc')):
elif self._is_word_file(file_content, file_mime_type):
return await self._extract_word_text(file_content)
elif file_mime_type in XLSX_MIME and file.file_type.endswith(("xlsx", "xls")):
elif self._is_excel_file(file_content, file_mime_type):
return await self._extract_xlsx_text(file_content)
elif file_mime_type in CSV_MIME:
return await self._extract_csv_text(file_content)
@@ -622,52 +687,156 @@ class MultimodalService:
@staticmethod
async def _extract_word_text(file_content: bytes) -> str:
"""提取 Word 文档文本"""
"""提取 Word 文档文本(支持 .docx 和旧版 .doc"""
# 先尝试 docxZIP 格式)
if file_content[:2] == b'PK':
try:
word_file = io.BytesIO(file_content)
doc = Document(word_file)
return '\n'.join(p.text for p in doc.paragraphs)
except Exception as e:
logger.error(f"提取 docx 文本失败: {e}")
return f"[docx 提取失败: {str(e)}]"
# 旧版 .docOLE2 格式)
try:
word_file = io.BytesIO(file_content)
doc = Document(word_file)
text_parts = [paragraph.text for paragraph in doc.paragraphs]
return '\n'.join(text_parts)
import olefile
ole = olefile.OleFileIO(io.BytesIO(file_content))
if not ole.exists('WordDocument'):
return "[doc 提取失败: 未找到 WordDocument 流]"
# 读取 WordDocument 流,提取可见 ASCII/Unicode 文本
stream = ole.openstream('WordDocument').read()
# Word Binary Format: 文本在流中以 UTF-16-LE 编码存储
# 简单提取:过滤出可打印字符段
try:
text = stream.decode('utf-16-le', errors='ignore')
except Exception:
text = stream.decode('latin-1', errors='ignore')
# 过滤控制字符,保留可打印内容
import re
text = re.sub(r'[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]', '', text)
text = re.sub(r' +', ' ', text).strip()
ole.close()
return text
except Exception as e:
logger.error(f"提取 Word 文本失败: {e}")
return f"[Word 提取失败: {str(e)}]"
logger.error(f"提取 doc 文本失败: {e}")
return f"[doc 提取失败: {str(e)}]"
@staticmethod
async def _extract_xlsx_text(file_content: bytes) -> str:
"""提取 Excel 文本"""
"""提取 Excel 文本(支持 .xlsx 和旧版 .xls"""
# xlsxZIP 格式)
if file_content[:2] == b'PK':
try:
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
parts = []
for sheet in wb.worksheets:
parts.append(f"[Sheet: {sheet.title}]")
for row in sheet.iter_rows(values_only=True):
parts.append('\t'.join('' if v is None else str(v) for v in row))
return '\n'.join(parts)
except Exception as e:
logger.error(f"提取 xlsx 文本失败: {e}")
return f"[xlsx 提取失败: {str(e)}]"
# xlsOLE2/BIFF 格式)
try:
wb = openpyxl.load_workbook(io.BytesIO(file_content), read_only=True, data_only=True)
import xlrd
wb = xlrd.open_workbook(file_contents=file_content)
parts = []
for sheet in wb.worksheets:
parts.append(f"[Sheet: {sheet.title}]")
for row in sheet.iter_rows(values_only=True):
parts.append('\t'.join('' if v is None else str(v) for v in row))
for sheet in wb.sheets():
parts.append(f"[Sheet: {sheet.name}]")
for row_idx in range(sheet.nrows):
parts.append('\t'.join(str(sheet.cell_value(row_idx, col)) for col in range(sheet.ncols)))
return '\n'.join(parts)
except Exception as e:
logger.error(f"提取 Excel 文本失败: {e}")
return f"[Excel 提取失败: {str(e)}]"
logger.error(f"提取 xls 文本失败: {e}")
return f"[xls 提取失败: {str(e)}]"
@staticmethod
async def _extract_csv_text(file_content: bytes) -> str:
async def _extract_csv_text(self, file_content: bytes) -> str:
"""提取 CSV 文本"""
try:
text = file_content.decode('utf-8-sig')
text = self._decode_text_safe(file_content)
reader = csv.reader(io.StringIO(text))
return '\n'.join('\t'.join(row) for row in reader)
except Exception as e:
logger.error(f"提取 CSV 文本失败: {e}")
return f"[CSV 提取失败: {str(e)}]"
@staticmethod
async def _extract_json_text(file_content: bytes) -> str:
async def _extract_json_text(self, file_content: bytes) -> str:
"""提取 JSON 文本"""
try:
data = json.loads(file_content.decode('utf-8'))
text = self._decode_text_safe(file_content)
data = json.loads(text)
return json.dumps(data, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"提取 JSON 文本失败: {e}")
return f"[JSON 提取失败: {str(e)}]"
def _is_word_file(self, file_content: bytes, mime_type: str) -> bool:
"""判断是不是 Word 文件doc / docx不依赖后缀"""
# 旧版 .doc
if mime_type == 'application/msword':
return True
# 新版 .docxZIP 内部包含 word/document.xml
header = file_content[:4]
if header == b'PK\x03\x04':
try:
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
return "word/document.xml" in zf.namelist()
except:
pass
return False
def _is_excel_file(self, file_content: bytes, mime_type: str) -> bool:
"""判断是不是 Excel 文件xls / xlsx不依赖后缀"""
# 旧版 .xls
if mime_type == 'application/vnd.ms-excel':
return True
# 新版 .xlsxZIP 内部包含 xl/workbook.xml
header = file_content[:4]
if header == b'PK\x03\x04':
try:
with zipfile.ZipFile(io.BytesIO(file_content)) as zf:
return "xl/workbook.xml" in zf.namelist()
except:
pass
return False
@staticmethod
def _decode_text_safe(file_content: bytes) -> str:
"""
【万能文本解码】
自动检测编码,支持 utf-8 / gbk / gb2312 / utf-8-sig / ascii 等
永远不报错,永远不乱码
"""
if not file_content:
return ""
# 1. 自动检测文件编码
detect = chardet.detect(file_content)
encoding = detect.get("encoding") or "utf-8"
encoding = encoding.lower()
# 2. 兼容常见中文编码
compatible_encodings = ["utf-8", "gbk", "gb18030", "gb2312", "ascii", "latin-1"]
# 3. 按优先级尝试解码
for enc in [encoding] + compatible_encodings:
if not enc:
continue
try:
return file_content.decode(enc.strip())
except (UnicodeDecodeError, LookupError):
continue
# 终极兜底
return file_content.decode("utf-8", errors="replace")
def get_multimodal_service(db: Session) -> MultimodalService:
"""获取多模态服务实例(依赖注入)"""

View File

@@ -1408,12 +1408,11 @@ async def analytics_memory_types(
if end_user_id:
try:
conversation_repo = ConversationRepository(db)
conversations = conversation_repo.get_conversation_by_user_id(
conversations, total = conversation_repo.get_conversation_by_user_id(
user_id=uuid.UUID(end_user_id),
limit=100, # 获取更多会话以准确统计
is_activate=True
)
work_count = len(conversations)
work_count = total
logger.debug(f"工作记忆数量(会话数): {work_count} (end_user_id={end_user_id})")
except Exception as e:
logger.warning(f"获取会话数量失败工作记忆数量设为0: {str(e)}")

View File

@@ -25,7 +25,7 @@ from app.repositories.workflow_repository import (
WorkflowExecutionRepository,
WorkflowNodeExecutionRepository
)
from app.schemas import DraftRunRequest, FileInput, FileType
from app.schemas import DraftRunRequest, FileInput
from app.services.conversation_service import ConversationService
from app.services.multi_agent_service import convert_uuids_to_str
from app.services.multimodal_service import MultimodalService
@@ -55,6 +55,7 @@ class WorkflowService:
edges: list[dict[str, Any]],
variables: list[dict[str, Any]] | None = None,
execution_config: dict[str, Any] | None = None,
features: dict[str, Any] | None = None,
triggers: list[dict[str, Any]] | None = None,
validate: bool = True
) -> WorkflowConfig:
@@ -66,6 +67,7 @@ class WorkflowService:
edges: 边列表
variables: 变量列表
execution_config: 执行配置
features: 功能特性
triggers: 触发器列表
validate: 是否验证配置
@@ -81,6 +83,7 @@ class WorkflowService:
"edges": edges,
"variables": variables or [],
"execution_config": execution_config or {},
"features": features or {},
"triggers": triggers or []
}
@@ -101,6 +104,7 @@ class WorkflowService:
edges=edges,
variables=variables,
execution_config=execution_config,
features=features,
triggers=triggers
)

View File

@@ -2675,13 +2675,15 @@ def write_perceptual_memory(
time_limit=7200, # 2小时硬超时
soft_time_limit=6900,
)
def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[str, Any]:
def init_community_clustering_for_users(self, end_user_ids: List[str], workspace_id: Optional[str] = None) -> Dict[str, Any]:
"""触发型任务:检查指定用户列表,对有 ExtractedEntity 但无 Community 节点的用户执行全量聚类。
由 /dashboard/end_users 接口触发,已有社区节点的用户直接跳过。
任务完成且所有用户数据均完整时,写入 Redis 标记,避免下次重复投递。
Args:
end_user_ids: 需要检查的用户 ID 列表
workspace_id: 工作空间 ID用于完成标记
Returns:
包含任务执行结果的字典
@@ -2707,6 +2709,7 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
# 批量预取所有用户的配置(内置兜底:用户配置不可用时自动回退到工作空间默认配置)
user_llm_map: Dict[str, Optional[str]] = {}
user_embedding_map: Dict[str, Optional[str]] = {}
try:
with get_db_context() as db:
from app.services.memory_agent_service import get_end_users_connected_configs_batch
@@ -2718,21 +2721,54 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
try:
cfg = MemoryConfigService(db).load_memory_config(config_id=config_id)
user_llm_map[uid] = str(cfg.llm_model_id) if cfg.llm_model_id else None
user_embedding_map[uid] = str(cfg.embedding_model_id) if cfg.embedding_model_id else None
except Exception as e:
logger.warning(f"[CommunityCluster] 用户 {uid} 加载 LLM 配置失败,将使用 None: {e}")
logger.warning(f"[CommunityCluster] 用户 {uid} 加载配置失败,将使用 None: {e}")
user_llm_map[uid] = None
user_embedding_map[uid] = None
else:
user_llm_map[uid] = None
user_embedding_map[uid] = None
except Exception as e:
logger.warning(f"[CommunityCluster] 批量获取 LLM 配置失败,所有用户将使用 None: {e}")
logger.warning(f"[CommunityCluster] 批量获取配置失败,所有用户将使用 None: {e}")
for end_user_id in end_user_ids:
try:
# 已有社区节点则跳过
# 已有社区节点时,检查是否存在属性不完整的节点
has_communities = await repo.has_communities(end_user_id)
if has_communities:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 已有社区节点,跳过")
llm_model_id = user_llm_map.get(end_user_id)
embedding_model_id = user_embedding_map.get(end_user_id)
incomplete_ids = await repo.get_incomplete_communities(
end_user_id, check_embedding=bool(embedding_model_id)
)
if not incomplete_ids:
skipped += 1
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 社区节点均完整,跳过")
continue
# 对不完整的社区节点逐一补全元数据
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
logger.info(
f"[CommunityCluster] 用户 {end_user_id} 发现 {len(incomplete_ids)} 个属性不完整的社区,开始补全"
)
patch_ok = 0
patch_fail = 0
for cid in incomplete_ids:
try:
await engine._generate_community_metadata(cid, end_user_id)
patch_ok += 1
except Exception as patch_err:
patch_fail += 1
logger.error(f"[CommunityCluster] 社区 {cid} 元数据补全失败: {patch_err}")
logger.info(
f"[CommunityCluster] 用户 {end_user_id} 社区补全完成: 成功={patch_ok}, 失败={patch_fail}"
)
initialized += 1
continue
# 检查是否有 ExtractedEntity 节点
@@ -2742,11 +2778,13 @@ def init_community_clustering_for_users(self, end_user_ids: List[str]) -> Dict[s
logger.debug(f"[CommunityCluster] 用户 {end_user_id} 无实体节点,跳过")
continue
# 每个用户使用自己的 llm_model_id
# 每个用户使用自己的 llm_model_id / embedding_model_id
llm_model_id = user_llm_map.get(end_user_id)
embedding_model_id = user_embedding_map.get(end_user_id)
engine = LabelPropagationEngine(
connector=connector,
llm_model_id=llm_model_id,
embedding_model_id=embedding_model_id,
)
logger.info(f"[CommunityCluster] 用户 {end_user_id}{len(entities)} 个实体开始全量聚类llm_model_id={llm_model_id}")

View File

@@ -1,4 +1,38 @@
{
"v0.2.8": {
"introduction": {
"codeName": "景玉",
"releaseDate": "2026-3-20",
"upgradePosition": "🐻 MemoryBear v0.2.8 社区版全面升级应用共享、多模态交互与平台基础设施,引入语音交互、感知记忆和云端存储,打造更强大的开放 AI 记忆平台",
"coreUpgrades": [
"1. 应用共享与发布<br>* 应用共享Agent、工作流、Agent 集群):全类型应用共享至其他空间<br>* 分享应用默认开启记忆功能:发布分享后记忆默认开启,关闭时提醒<br>* 工作流记忆分享规则:按记忆配置自动控制分享页记忆开关<br>* 分享会话联网搜索修复:恢复分享应用的联网搜索能力",
"2. 多模态与交互 💬<br>* 语音输入:模型接口和应用支持语音输入<br>* 语音回复:应用支持语音回复模态<br>* 多模态感知记忆:记忆系统支持视觉、音频、图片和文件的感知记忆<br>* 对话框文件展示:试运行和体验分享中正确展示上传文件",
"3. 平台与基础设施 ⚙️<br>* i18n 国际化:全面多语言多地区支持<br>* 云端文件存储OSS + S3支持阿里云 OSS 和 S3 云端上传<br>* Flower 容器监控Celery 异步任务监控与管理",
"4. EndUser 身份迁移 🔐<br>* EndUser 从 app_id 迁移至 workspace_id身份从应用级迁移至工作空间级",
"5. 情景记忆 🧠<br>* 情景记忆聚类算法:基于社区图谱的聚类算法,支持老用户图谱生成",
"6. 稳健性与缺陷修复 🔧<br>* MCP 服务删除后工具 404修复删除 MCP 服务后接口报错<br>* 应用导出配置不一致:导出已保存配置而非画布状态<br>* 工作流节点 ID 重复:修复复制节点后 ID 冲突<br>* 条件分支连线错误:修复保存刷新后连线错乱<br>* 回复节点内容丢失:修复点击画布后内容消失<br>* 连接桩规则优化:禁止非法连接方向<br>* 知识库状态列宽度:锁定或自适应宽度<br>* 等待中文档预览:支持未完成解析文档预览<br>* 知识库关联修复:统一修复关联问题<br>* 多模态对话连续性:修复多模态内容后无法继续对话<br>* 时区统一:环境变量统一控制存储和任务时区<br>* 遗忘强度精度:修复小数显示过长",
"<br>",
"v0.2.8 社区版在应用共享和多模态交互方面实现重大升级,感知记忆扩展了平台的认知维度。后续将深化多智能体协作、情景记忆聚类,并持续优化平台稳定性与开放生态。",
"MemoryBear —— 让 AI 拥有记忆 🐻✨"
]
},
"introduction_en": {
"codeName": "JingYu",
"releaseDate": "2026-3-20",
"upgradePosition": "🐻 MemoryBear v0.2.8 Community delivers multimodal interaction, perceptual memory, cloud storage, and workspace-level identity for a more capable open AI memory platform",
"coreUpgrades": [
"1. Application Sharing & Publishing<br>* Application Sharing (Agent, Workflow, Agent Cluster): Full sharing across all app types<br>* Memory Enabled by Default: Memory auto-enabled on shared apps with disable reminder<br>* Workflow Memory Sharing Rules: Auto-controlled based on memory configuration<br>* Shared Session Web Search Fix: Restored web search for shared apps",
"2. Multimodal & Interaction 💬<br>* Voice Input: Model interfaces and apps support voice input<br>* Voice Reply: Apps support voice reply modality<br>* Multimodal Perceptual Memory: Memory system supports visual, audio, image, and file perception<br>* File Display in Chat: Uploaded files display correctly in dry-run and sharing",
"3. Platform & Infrastructure ⚙️<br>* i18n Internationalization: Full multi-language multi-region support<br>* Cloud File Storage (OSS + S3): Alibaba Cloud OSS and S3 cloud uploads<br>* Flower Container Monitoring: Celery async task monitoring and management",
"4. EndUser Identity Migration 🔐<br>* EndUser Migration from app_id to workspace_id: Identity migrated to workspace level",
"5. Episodic Memory 🧠<br>* Episodic Memory Clustering: Community-graph-based clustering with legacy user support",
"6. Robustness & Bug Fixes 🔧<br>* MCP Service Deletion 404: Fixed tool endpoint error after MCP removal<br>* App Export Config Mismatch: Exports saved config instead of canvas state<br>* Workflow Duplicate Node ID: Fixed ID conflict on node duplication<br>* Conditional Branch Wiring: Fixed wiring reset after save/refresh<br>* Reply Node Content Loss: Fixed content disappearing on canvas click<br>* Port Connection Rules: Prohibited invalid connection directions<br>* Knowledge Base Status Width: Locked or adaptive column width<br>* Pending Document Preview: Preview support for unparsed documents<br>* Knowledge Base Association Fixes: Consolidated association fixes<br>* Multimodal Conversation Continuity: Fixed single-round limit after multimodal input<br>* Timezone Unification: Env-var controlled unified timezone<br>* Forgetting Strength Precision: Fixed excessive decimal display",
"<br>",
"v0.2.8 Community delivers major upgrades in application sharing and multimodal interaction, with perceptual memory expanding the platform's cognitive dimensions. Multi-agent collaboration, episodic clustering, and continued platform stability improvements are ahead.",
"MemoryBear — Give AI Memory 🐻✨"
]
}
},
"v0.2.7": {
"introduction": {
"codeName": "武陵",

View File

@@ -303,7 +303,7 @@ async def test_get_node_output_not_exist_with_default():
"""测试获取不存在的节点输出(使用默认值)"""
pool = VariablePool()
result = pool.get_node_output("nonexistent_node", defalut=None, strict=False)
result = pool.get_node_output("nonexistent_node", default=None, strict=False)
assert result is None

View File

@@ -52,6 +52,10 @@ export const getKnowledgeBaseTypeList = async (): Promise<string[]> => {
// 如果不是数组,返回空数组
return [];
};
// 获取文件地址
export const getFileUrl = (fileId: string) => {
return `${apiPrefix}/files/${fileId}`;
};
// 知识库文档解析类型
export const getKnowledgeBaseDocumentParseTypeList = async () => {
const response = await request.get(`${apiPrefix}/knowledges/parsertype`);

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-03 14:00:06
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-13 10:48:41
* @Last Modified time: 2026-03-19 18:35:10
*/
import { request } from '@/utils/request'
import type { AxiosRequestConfig } from 'axios'
@@ -218,8 +218,8 @@ export const getExplicitMemory = (end_user_id: string) => {
export const getExplicitMemoryDetails = (data: { end_user_id: string, memory_id: string; }) => {
return request.post(`/memory/explicit-memory/details`, data)
}
export const getConversations = (end_user_id: string) => {
return request.get(`/memory/work/${end_user_id}/conversations`)
export const getConversations = (end_user_id: string, page = 1, pagesize = 20) => {
return request.get(`/memory/work/${end_user_id}/conversations`, { page, pagesize })
}
export const getConversationMessages = (end_user_id: string, conversation_id: string) => {
return request.get(`/memory/work/${end_user_id}/messages`, { conversation_id })

View File

@@ -143,15 +143,20 @@ const ChatContent: FC<ChatContentProps> = ({
}
return (
<div key={file.url || file.uid} className="rb:relative rb:rounded-lg rb:bg-[#F0F3F8] rb:p-1! rb:cursor-pointer" onClick={() => handleDownload(file)}>
{(file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document')) && <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
></div>}
{(file.type.includes('pdf')) && <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
></div>}
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv')) && <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
></div>}
{(file.type.includes('excel') || file.type.includes('spreadsheetml.sheet') || file.type.includes('csv'))
? <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/excel.svg')]"
></div>
:(file.type.includes('pdf'))
? <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/pdf.svg')]"
></div>
: (file.type.includes('doc') || file.type.includes('docx') || file.type.includes('word') || file.type.includes('wordprocessingml.document'))
? <div
className="rb:size-10 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/conversation/word.svg')]"
></div>
: null
}
</div>
)
})}

View File

@@ -49,6 +49,7 @@ interface FormValues {
memory?: boolean;
}
const max_file_count = 1;
const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
features,
leftExtra,
@@ -86,10 +87,16 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
// Append newly uploaded file to the file list when upload is complete
const fileChange = (file?: any) => {
if (file?.status !== 'done') return
const files = [...(queryValues?.files || []), file]
form.setFieldValue('files', files)
onFilesChange?.(files)
console.log('file', file)
const lastFiles = form.getFieldValue('files') || [];
const index = lastFiles.findIndex((item: any) => item.uid === file.uid)
if (index > -1) {
lastFiles[index] = file
} else {
lastFiles.push(file)
}
form.setFieldValue('files', [...lastFiles])
onFilesChange?.([...lastFiles])
}
// Append recorded audio file to the file list and notify parent
@@ -129,8 +136,8 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
key: 'url',
label: t('memoryConversation.addRemoteFile'),
onClick: () => {
if ((queryValues?.files?.length || 0) >= file_upload.max_file_count) {
messageApi.warning(t('common.fileNumTip', { num: file_upload.max_file_count }))
if ((queryValues?.files?.length || 0) >= max_file_count) {
messageApi.warning(t('common.fileNumTip', { num: max_file_count }))
return
}
uploadFileListModalRef.current?.handleOpen()
@@ -146,7 +153,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
onChange={fileChange}
requestConfig={uploadRequestConfig}
featureConfig={file_upload}
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
disabled={(queryValues?.files?.length || 0) >= max_file_count}
/>
)
})
@@ -184,7 +191,7 @@ const ChatToolbar = forwardRef<ChatToolbarRef, ChatToolbarProps>(({
{rightExtra}
{file_upload?.audio_enabled && file_upload?.allowed_transfer_methods?.includes('local_file') &&
<AudioRecorder
disabled={(queryValues?.files?.length || 0) >= file_upload.max_file_count}
disabled={(queryValues?.files?.length || 0) >= max_file_count}
action={uploadAction}
requestConfig={uploadRequestConfig}
onRecordingComplete={handleRecordingComplete}

View File

@@ -4,7 +4,7 @@
* @Author: yujiangping
* @Date: 2026-03-16 19:01:12
* @LastEditors: yujiangping
* @LastEditTime: 2026-03-18 18:35:53
* @LastEditTime: 2026-03-20 12:12:20
*/
import { useState, useEffect, useRef, useCallback, type FC } from 'react';
import { Spin, Alert, Button, Table, InputNumber, Image } from 'antd';
@@ -309,23 +309,64 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
}
};
const [csvTruncated, setCsvTruncated] = useState(false);
const isCsvFile = () => getFileExtension() === '.csv';
// CSV 预览大小限制1MB
const CSV_PREVIEW_SIZE = 1 * 1024 * 1024;
// 最大预览行数
const MAX_PREVIEW_ROWS = 500;
const fetchFileBufferWithLimit = async (url: string, maxBytes?: number): Promise<ArrayBuffer> => {
const requestUrl = getRequestUrl(url);
const headers: Record<string, string> = {
'Authorization': `Bearer ${cookieUtils.get('authToken') || ''}`,
};
if (maxBytes) {
headers['Range'] = `bytes=0-${maxBytes - 1}`;
}
const response = await fetch(requestUrl, {
credentials: 'include',
headers,
});
if (!response.ok && response.status !== 206) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`);
}
return response.arrayBuffer();
};
const loadExcelFile = async () => {
setLoading(true);
setError(false);
setErrorMessage('');
setCsvTruncated(false);
try {
const arrayBuffer = await fetchFileBuffer(fileUrl);
// CSV 文件需要处理编码问题(可能是 GBK/GB2312
// CSV 文件需要处理编码问题(可能是 GBK/GB2312且大文件只取前 1MB
if (isCsvFile()) {
let arrayBuffer: ArrayBuffer;
let truncated = false;
try {
// 先尝试 Range 请求只取前 1MB
arrayBuffer = await fetchFileBufferWithLimit(fileUrl, CSV_PREVIEW_SIZE);
// 如果返回的数据刚好等于限制大小,说明可能被截断了
if (arrayBuffer.byteLength >= CSV_PREVIEW_SIZE) {
truncated = true;
}
} catch {
// Range 请求不支持时,全量获取后截断
const fullBuffer = await fetchFileBuffer(fileUrl);
if (fullBuffer.byteLength > CSV_PREVIEW_SIZE) {
arrayBuffer = fullBuffer.slice(0, CSV_PREVIEW_SIZE);
truncated = true;
} else {
arrayBuffer = fullBuffer;
}
}
let csvText: string;
// 先尝试 UTF-8 解码
const utf8Text = new TextDecoder('utf-8').decode(arrayBuffer);
// 检测是否有乱码特征(常见的 GBK 被错误解析为 UTF-8 的替换字符)
if (utf8Text.includes('\uFFFD') || /[\x80-\xff]/.test(utf8Text.slice(0, 200))) {
// 尝试 GBK 解码
try {
csvText = new TextDecoder('gbk').decode(arrayBuffer);
} catch {
@@ -334,19 +375,35 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
} else {
csvText = utf8Text;
}
// 如果被截断,去掉最后一行不完整的数据
if (truncated) {
const lastNewline = csvText.lastIndexOf('\n');
if (lastNewline > 0) {
csvText = csvText.substring(0, lastNewline);
}
}
const workbook = XLSX.read(csvText, { type: 'string' });
const sheets = workbook.SheetNames.map(sheetName => {
const worksheet = workbook.Sheets[sheetName];
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
let data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
// 限制最大行数
if (data.length > MAX_PREVIEW_ROWS + 1) {
data = data.slice(0, MAX_PREVIEW_ROWS + 1); // +1 保留表头
truncated = true;
}
return { sheetName, data };
});
setCsvTruncated(truncated);
setExcelData(sheets);
setLoading(false);
return;
}
const arrayBuffer = await fetchFileBuffer(fileUrl);
const workbook = XLSX.read(arrayBuffer, { type: 'array' });
const sheets = workbook.SheetNames.map(sheetName => {
const sheets = workbook.SheetNames.map((sheetName: string) => {
const worksheet = workbook.Sheets[sheetName];
const data = XLSX.utils.sheet_to_json(worksheet, { header: 1 }) as any[][];
return { sheetName, data };
@@ -522,9 +579,14 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
)
)}
{/* Excel 预览 */}
{/* Excel/CSV 预览 */}
{isExcelFile() && !error && !loading && (
<div className="rb:w-full rb:flex-1 rb:overflow-auto rb:bg-white rb:p-4 rb:rounded rb:border rb:border-gray-200">
{csvTruncated && (
<div className="rb:mb-3 rb:px-3 rb:py-2 rb:bg-yellow-50 rb:border rb:border-yellow-200 rb:rounded rb:text-sm rb:text-yellow-700">
{MAX_PREVIEW_ROWS}
</div>
)}
{excelData.map((sheet, index) => (
<div key={index} className="rb:mb-6">
<h3 className="rb:text-lg rb:font-semibold rb:mb-3">{sheet.sheetName}</h3>
@@ -541,6 +603,7 @@ const DocumentPreview: FC<DocumentPreviewProps> = ({
scroll={{ x: 'max-content' }}
size="small"
bordered
virtual
/>
)}
</div>

View File

@@ -469,6 +469,7 @@ export const en = {
download: 'Download',
view: 'View',
updated_at: 'Updated At',
callbackUrlInvalid: 'Please enter a valid URL',
},
model: {
searchPlaceholder: 'search model…',

View File

@@ -1106,6 +1106,7 @@ export const zh = {
download: '下载',
view: '查看',
updated_at: '更新时间',
callbackUrlInvalid: '请输入有效的 URL',
},
model: {
searchPlaceholder: '搜索模型…',

View File

@@ -183,7 +183,7 @@ const TestChat: FC<TestChatProps> = ({
const handleSend = () => {
if (loading || !application || !message || !message?.trim()) return
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
const variables = toolbarRef.current?.getVariables() || []
const { isCanSend, params } = buildVariableParams(variables)
if (!isCanSend) return
@@ -235,7 +235,7 @@ const TestChat: FC<TestChatProps> = ({
const handleWorkflowSend = () => {
if (loading || !application || !message || !message?.trim()) return
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
const variables = toolbarRef.current?.getVariables() || []
const { isCanSend, params } = buildVariableParams(variables)
if (!isCanSend) return

View File

@@ -189,7 +189,7 @@ const Chat: FC<ChatProps> = ({
.then(() => {
const message = msg
if (!message?.trim()) return
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
// Validate required variables before sending
let isCanSend = true
const params: Record<string, any> = {}
@@ -350,7 +350,7 @@ const Chat: FC<ChatProps> = ({
.then(() => {
const message = msg
if (!message || message.trim() === '') return
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
addUserMessage(message, files)
setMessage(undefined)
toolbarRef.current?.setFiles([])

View File

@@ -24,7 +24,7 @@ interface FeaturesConfigModalProps {
refresh: (value: FeaturesConfigForm) => void;
source?: Application['type'];
}
const max_file_count = 1;
/**
* Modal for copying applications
*/
@@ -133,7 +133,7 @@ const FeaturesConfigModal = forwardRef<FeaturesConfigModalRef, FeaturesConfigMod
</div>
<div>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:py-1">{t('application.maxCount')}</div>
{fu.max_file_count} {t('application.unix')}
{max_file_count} {t('application.unix')}
</div>
</Flex>
<Button block onClick={handleOpenSettings}>{t('application.setting')}</Button>

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-03-05
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-19 15:18:20
* @Last Modified time: 2026-03-19 20:19:14
*/
import { forwardRef, useImperativeHandle, useState } from 'react';
import { Form, InputNumber, Flex, Switch, Row, Col, Radio } from 'antd';
@@ -82,28 +82,27 @@ const defaultValues: FileUpload = {
"mp3",
"wav",
"m4a",
"ogg",
"flac"
],
document_enabled: false,
document_max_size_mb: 100,
document_allowed_extensions: [
"pdf",
"docx",
"doc",
"xlsx",
"xls",
"txt",
"csv",
"json"
"json",
"md",
],
video_enabled: false,
video_max_size_mb: 100,
video_allowed_extensions: [
"mp4",
"mov",
"avi",
"webm"
],
max_file_count: 5,
max_file_count: 1,
allowed_transfer_methods: 'both'
}
@@ -168,8 +167,8 @@ const FileUploadSettingModal = forwardRef<FileUploadSettingModalRef, FileUploadS
</Radio.Group>
</Form.Item>
<div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div>
<Form.Item label={t('application.maxCount')} name="max_file_count">
{/* <div className="rb:text-[12px] rb:text-[#5B6167] rb:mb-1">{t('application.maxCount')}</div> */}
<Form.Item label={t('application.maxCount')} name="max_file_count" hidden>
<InputNumber min={1} max={20} precision={0} className="rb:w-full!" placeholder={t('common.pleaseEnter')} />
</Form.Item>

View File

@@ -23,7 +23,7 @@
import { useState, useEffect, forwardRef, useImperativeHandle, useMemo } from 'react';
import { Upload, Progress, App, Flex } from 'antd';
import type { UploadProps, UploadFile } from 'antd';
import type { UploadProps as RcUploadProps } from 'antd/es/upload/interface';
import type { UploadProps as RcUploadProps, RcFile, UploadFileStatus } from 'antd/es/upload/interface';
import { useTranslation } from 'react-i18next';
import { request } from '@/utils/request'
@@ -221,17 +221,29 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
*/
const handleCustomRequest: RcUploadProps['customRequest'] = async (options) => {
const { file, onSuccess, onError } = options;
try {
const formData = new FormData();
formData.append('file', file);
const response = await request.uploadFile(action, formData, requestConfig);
onSuccess?.({data: response});
} catch (error) {
onError?.(error as Error);
if (typeof file === 'string') return;
const rcFile = file as RcFile;
const formData = new FormData();
formData.append('file', rcFile);
const fileVo: UploadFile = {
uid: rcFile.uid,
name: rcFile.name,
status: 'uploading' as UploadFileStatus,
percent: 0,
type: rcFile.type,
originFileObj: rcFile,
thumbUrl: URL.createObjectURL(rcFile)
}
onChange?.(fileVo)
request.uploadFile(action, formData, requestConfig)
.then(res => {
onSuccess?.({ data: res });
})
.catch((error) => {
onError?.(error as Error);
fileVo.status = 'error'
onChange?.(fileVo)
})
};
/**

View File

@@ -2,7 +2,7 @@
* @Author: ZhaoYing
* @Date: 2026-02-06 21:09:47
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-18 21:10:01
* @Last Modified time: 2026-03-19 20:32:32
*/
/**
* Upload File List Modal Component
@@ -19,7 +19,10 @@
* @component
*/
import { forwardRef, useImperativeHandle, useState, useMemo } from 'react';
import { Form, Input, Select, Button, Flex } from 'antd';
import { Form, Input, Select,
// Button,
Flex
} from 'antd';
import { useTranslation } from 'react-i18next';
import type { UploadFileListModalRef } from '../types'
@@ -105,9 +108,11 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
onOk={handleSave}
confirmLoading={loading}
>
<Form form={form} layout="vertical">
<Form form={form} layout="vertical" initialValues={{ files: [{ type: undefined, url: undefined }] }}>
<Form.List name="files">
{(fields, { add, remove }) => (
{(fields,
// { add, remove }
) => (
<>
{/* Render each file entry with type selector and URL input */}
{fields.map(({ key, name, ...restField }) => (
@@ -116,6 +121,9 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
{...restField}
name={[name, 'type']}
className="rb:mb-0!"
rules={[
{ required: true, message: t('common.pleaseSelect') }
]}
>
<Select
placeholder={t('memoryConversation.fileType')}
@@ -126,22 +134,25 @@ const UploadFileListModal = forwardRef<UploadFileListModalRef, UploadFileListMod
<FormItem
{...restField}
name={[name, 'url']}
rules={[{ required: true, message: t('common.pleaseEnter') }]}
rules={[
{ required: true, message: t('common.pleaseEnter') },
{ type: 'url', message: t('common.callbackUrlInvalid') },
]}
className="rb:mb-0! rb:flex-1!"
>
<Input placeholder={t('memoryConversation.fileUrl')} />
</FormItem>
<div
{/* <div
className="rb:w-5 rb:h-5 rb:cursor-pointer rb:bg-cover rb:bg-[url('@/assets/images/delete.svg')] rb:hover:bg-[url('@/assets/images/delete_hover.svg')]"
onClick={() => remove(name)}
></div>
></div> */}
</Flex>
))}
<Form.Item noStyle>
{/* <Form.Item noStyle>
<Button type="dashed" onClick={() => add()} block>
+ {t('common.add')}
</Button>
</Form.Item>
</Form.Item> */}
</>
)}
</Form.List>

View File

@@ -194,7 +194,7 @@ const Conversation: FC = () => {
/** Send message and handle streaming response */
const handleSend = () => {
if (!token || !shareToken) return
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
const variables = toolbarRef.current?.getVariables() || []
let isCanSend = true
const params: Record<string, any> = {}

View File

@@ -11,7 +11,7 @@ import { useNavigate, useParams, useLocation } from 'react-router-dom';
import { useTranslation } from 'react-i18next';
import { useBreadcrumbManager, type BreadcrumbPath } from '@/hooks/useBreadcrumbManager';
import { Button, Spin, message, Switch } from 'antd';
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk } from '@/api/knowledgeBase';
import { getDocumentDetail, getDocumentChunkList, downloadFile, updateDocument, updateDocumentChunk, createDocumentChunk, getFileUrl } from '@/api/knowledgeBase';
import type { KnowledgeBaseDocumentData, RecallTestData } from '@/views/KnowledgeBase/types';
import { formatDateTime } from '@/utils/format';
import InfoPanel, { type InfoItem } from '../components/InfoPanel';
@@ -138,7 +138,7 @@ const DocumentDetails: FC = () => {
const response = await getDocumentDetail(documentId);
setDocument(response);
setInfoItems(formatDocumentInfo(response));
const url = `${imagePath}/api/files/${response.file_id}`
const url = `${window.location.origin}/api/files/${response.file_id}`;
setFileUrl(url);
setParserMode(response?.parser_config?.auto_questions || 0)
// ChunkList will be called automatically in useEffect based on document.progress

View File

@@ -191,24 +191,28 @@ const RelationshipNetwork: FC = () => {
})}>
{(selectedNode as RawCommunityNode).properties.community_id
? <div>
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
{(selectedNode as RawCommunityNode).properties.name}
</div>
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
{(selectedNode as RawCommunityNode).properties.summary}
</div>
<Flex align="center" justify="space-between" className="rb:mt-5!">
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
</Flex>
<div className="rb:font-medium rb:text-[#212332] rb:text-[16px] rb:leading-5.5 rb:pl-1">
{(selectedNode as RawCommunityNode).properties.name || selectedNode.id}
</div>
{(selectedNode as RawCommunityNode).properties.summary && <>
<div className="rb:mt-3 rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.summary')}</div>
<div className="rb:bg-[#F6F6F6] rb:rounded-xl rb:px-3 rb:py-2.5 rb:mt-2">
{(selectedNode as RawCommunityNode).properties.summary}
</div>
</>}
<Flex align="center" justify="space-between" className="rb:mt-5!">
<span className="rb:text-[#5B6167] rb:font-regular rb:pl-1">{t('userMemory.member_count')}</span>
<span className="rb:font-medium">{(selectedNode as RawCommunityNode).properties.member_count}{t('userMemory.member_count_desc')}</span>
</Flex>
<Divider className='rb:my-2.5!' />
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
{(selectedNode as RawCommunityNode).properties.core_entities.map((entity, index) => <li key={index}>{entity}</li>)}
</ul>
</div>
{(selectedNode as RawCommunityNode).properties.core_entities && <>
<Divider className='rb:my-2.5!' />
<div className="rb:font-medium rb:leading-5 rb:pl-1">{t('userMemory.core_entities')}</div>
<ul className="rb:list-disc rb:pl-4 rb:text-[#5B6167] rb:mt-2">
{(selectedNode as RawCommunityNode).properties.core_entities?.map((entity, index) => <li key={index}>{entity}</li>)}
</ul>
</>}
</div>
: <>
{(selectedNode as Node).name &&
<div className="rb:font-medium rb:text-[16px] rb:text-[#212332] rb:leading-5.5 rb:mb-3">

View File

@@ -4,12 +4,14 @@
* @Last Modified by: ZhaoYing
* @Last Modified time: 2026-03-16 15:10:17
*/
import { type FC, useEffect, useState, useMemo } from 'react'
import { type FC, useEffect, useState, useMemo, useRef } from 'react'
import clsx from 'clsx'
import { useTranslation } from 'react-i18next'
import { useParams } from 'react-router-dom'
import { Row, Col, Skeleton, Button, Divider, Tooltip, Flex } from 'antd'
import InfiniteScroll from 'react-infinite-scroll-component'
import RbCard from '@/components/RbCard/Card'
import {
getConversations,
@@ -61,6 +63,8 @@ const WorkingDetail: FC = () => {
const { id } = useParams()
const [loading, setLoading] = useState<boolean>(false)
const [data, setData] = useState<Conversation[]>([])
const [hasMore, setHasMore] = useState<boolean>(true)
const pageRef = useRef<number>(1)
const [messagesLoading, setMessagesLoading] = useState<boolean>(false)
const [messages, setMessages] = useState<ChatItem[]>([])
const [detailLoading, setDetailLoading] = useState<boolean>(false)
@@ -80,17 +84,30 @@ const WorkingDetail: FC = () => {
setSelected(null)
setDetail(null)
setData([])
getConversations(id).then((res) => {
const response = res as Conversation[]
setData(response)
setSelected(response[0] || null)
setHasMore(true)
pageRef.current = 1
getConversations(id, 1).then((res) => {
const response = res as { items: Conversation[], page: { hasnext: boolean } }
setData(response.items)
setSelected(response.items[0] || null)
setHasMore(response.page.hasnext)
})
.finally(() => {
setLoading(false)
})
}
/* Load messages and AI insight whenever the selected conversation changes. */
const loadMore = () => {
if (!id) return
const nextPage = pageRef.current + 1
getConversations(id, nextPage).then((res) => {
const response = res as {items: Conversation[], page: { hasnext: boolean }}
setData(prev => [...prev, ...response.items])
pageRef.current = nextPage
setHasMore(response.page.hasnext)
})
}
useEffect(() => {
if (!id || !selected || !selected.id) return
getDetail(selected.id)
@@ -138,16 +155,16 @@ const WorkingDetail: FC = () => {
: data.length === 0
? <Empty />
:(
<Row gutter={16} className="rb:h-full">
<Col flex='360px' className="rb:h-full">
<RbCard
title={t('workingDetail.conversation')}
headerType="borderless"
headerClassName="rb:min-h-[54px]! rb:font-[MiSans-Bold] rb:font-bold"
bodyClassName='rb:p-3! rb:pt-0! rb:h-[calc(100%-54px)]'
className="rb:h-full!"
>
<Flex gap={8} vertical>
<Row gutter={16}>
<Col span={5}>
<div id="conversation-list" className="rb:h-[calc(100vh-76px)]! rb:border-r rb:border-[#EAECEE] rb:py-3 rb:px-4 rb:overflow-y-auto">
<InfiniteScroll
dataLength={data.length}
next={loadMore}
hasMore={hasMore}
loader={null}
scrollableTarget="conversation-list"
>
{data.map(item => (
<Flex
key={item.id}
@@ -166,8 +183,8 @@ const WorkingDetail: FC = () => {
</Tooltip>
</Flex>
))}
</Flex>
</RbCard>
</InfiniteScroll>
</div>
</Col>
{selected && <>
<Col flex="auto" className="rb:h-full">

View File

@@ -151,7 +151,7 @@ const Chat = forwardRef<ChatRef, { appId: string; graphRef: GraphRef; data: Work
setLoading(true)
const message = msg
const files = toolbarRef.current?.getFiles() || []
const files = (toolbarRef.current?.getFiles() || []).filter(item => !['uploading', 'error'].includes(item.status))
setChatList(prev => [...prev, {
role: 'user',
content: message,

View File

@@ -18,8 +18,8 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
const isUserInputRef = useRef(false);
useEffect(() => {
// 监听编辑器变化,标记是否为用户输入
const removeListener = editor.registerUpdateListener(({ editorState }) => {
const removeListener = editor.registerUpdateListener(({ editorState, tags }) => {
if (tags.has('programmatic')) return;
editorState.read(() => {
const root = $getRoot();
const textContent = root.getTextContent();
@@ -107,7 +107,7 @@ const InitialValuePlugin: React.FC<InitialValuePluginProps> = ({ value, options
});
root.append(paragraph);
}
}, { discrete: true });
}, { discrete: true, tag: 'programmatic' });
});
}