Merge branch 'refs/heads/develop' into fix/memory_bug_fix

# Conflicts:
#	api/app/services/memory_agent_service.py
This commit is contained in:
lixinyue
2026-01-23 14:57:25 +08:00
38 changed files with 685 additions and 353 deletions

View File

@@ -261,9 +261,7 @@ async def read_server(
""" """
config_id = user_input.config_id config_id = user_input.config_id
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
api_logger.info(f"Read service: workspace_id={workspace_id}, config_id={config_id}")
# 获取 storage_type如果为 None 则使用默认值
storage_type = workspace_service.get_workspace_storage_type( storage_type = workspace_service.get_workspace_storage_type(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
@@ -300,12 +298,15 @@ async def read_server(
# 调用 memory_agent_service 的方法生成最终答案 # 调用 memory_agent_service 的方法生成最终答案
result['answer'] = await memory_agent_service.generate_summary_from_retrieve( result['answer'] = await memory_agent_service.generate_summary_from_retrieve(
group_id=user_input.group_id,
retrieve_info=retrieve_info, retrieve_info=retrieve_info,
history=history, history=history,
query=query, query=query,
config_id=config_id, config_id=config_id,
db=db db=db
) )
if "信息不足,无法回答" in result['answer']:
result['answer']=retrieve_info
return success(data=result, msg="回复对话消息成功") return success(data=result, msg="回复对话消息成功")
except BaseException as e: except BaseException as e:
# Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup # Handle ExceptionGroup from TaskGroup (Python 3.11+) or BaseExceptionGroup

View File

@@ -49,63 +49,135 @@ async def get_workspace_end_users(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
): ):
""" """
获取工作空间的宿主列表 获取工作空间的宿主列表(高性能优化版本 v2
返回格式与原 memory_list 接口中的 end_users 字段相同, 优化策略:
并包含每个用户的记忆配置信息memory_config_id 和 memory_config_name 1. 批量查询 end_users一次查询而非循环
2. 并发查询所有用户的记忆数量Neo4j
3. RAG 模式使用批量查询(一次 SQL
4. 只返回必要字段减少数据传输
5. 添加短期缓存减少重复查询
6. 并发执行配置查询和记忆数量查询
返回格式:
{
"end_user": {"id": "uuid", "other_name": "名称"},
"memory_num": {"total": 数量},
"memory_config": {"memory_config_id": "id", "memory_config_name": "名称"}
}
""" """
import asyncio
import json
from app.aioRedis import aio_redis_get, aio_redis_set
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 尝试从缓存获取30秒缓存
cache_key = f"end_users:workspace:{workspace_id}"
try:
cached_data = await aio_redis_get(cache_key)
if cached_data:
api_logger.info(f"从缓存获取宿主列表: workspace_id={workspace_id}")
return success(data=json.loads(cached_data), msg="宿主列表获取成功")
except Exception as e:
api_logger.warning(f"Redis 缓存读取失败: {str(e)}")
# 获取当前空间类型 # 获取当前空间类型
current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user) current_workspace_type = memory_dashboard_service.get_current_workspace_type(db, workspace_id, current_user)
api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表") api_logger.info(f"用户 {current_user.username} 请求获取工作空间 {workspace_id} 的宿主列表")
# 获取 end_users已优化为批量查询
end_users = memory_dashboard_service.get_workspace_end_users( end_users = memory_dashboard_service.get_workspace_end_users(
db=db, db=db,
workspace_id=workspace_id, workspace_id=workspace_id,
current_user=current_user current_user=current_user
) )
# 批量获取所有用户的记忆配置信息(优化:一次查询而非 N 次) if not end_users:
end_user_ids = [str(user.id) for user in end_users] api_logger.info("工作空间下没有宿主")
memory_configs_map = {} # 缓存空结果,避免重复查询
if end_user_ids:
try: try:
memory_configs_map = get_end_users_connected_configs_batch(end_user_ids, db) await aio_redis_set(cache_key, json.dumps([]), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
return success(data=[], msg="宿主列表获取成功")
end_user_ids = [str(user.id) for user in end_users]
# 并发执行两个独立的查询任务
async def get_memory_configs():
"""获取记忆配置(在线程池中执行同步查询)"""
try:
return await asyncio.to_thread(
get_end_users_connected_configs_batch,
end_user_ids, db
)
except Exception as e: except Exception as e:
api_logger.error(f"批量获取记忆配置失败: {str(e)}") api_logger.error(f"批量获取记忆配置失败: {str(e)}")
# 失败时使用空字典,不影响其他数据返回 return {}
async def get_memory_nums():
"""获取记忆数量"""
if current_workspace_type == "rag":
# RAG 模式:批量查询
try:
chunk_map = await asyncio.to_thread(
memory_dashboard_service.get_users_total_chunk_batch,
end_user_ids, db, current_user
)
return {uid: {"total": count} for uid, count in chunk_map.items()}
except Exception as e:
api_logger.error(f"批量获取 RAG chunk 数量失败: {str(e)}")
return {uid: {"total": 0} for uid in end_user_ids}
elif current_workspace_type == "neo4j":
# Neo4j 模式:并发查询(带并发限制)
# 使用信号量限制并发数,避免大量用户时压垮 Neo4j
MAX_CONCURRENT_QUERIES = 10
semaphore = asyncio.Semaphore(MAX_CONCURRENT_QUERIES)
async def get_neo4j_memory_num(end_user_id: str):
async with semaphore:
try:
return await memory_storage_service.search_all(end_user_id)
except Exception as e:
api_logger.error(f"获取用户 {end_user_id} Neo4j 记忆数量失败: {str(e)}")
return {"total": 0}
memory_nums_list = await asyncio.gather(*[get_neo4j_memory_num(uid) for uid in end_user_ids])
return {end_user_ids[i]: memory_nums_list[i] for i in range(len(end_user_ids))}
return {uid: {"total": 0} for uid in end_user_ids}
# 并发执行配置查询和记忆数量查询
memory_configs_map, memory_nums_map = await asyncio.gather(
get_memory_configs(),
get_memory_nums()
)
# 构建结果(优化:使用列表推导式)
result = [] result = []
for end_user in end_users: for end_user in end_users:
memory_num = {}
if current_workspace_type == "neo4j":
# EndUser 是 Pydantic 模型,直接访问属性而不是使用 .get()
memory_num = await memory_storage_service.search_all(str(end_user.id))
elif current_workspace_type == "rag":
memory_num = {
"total":memory_dashboard_service.get_current_user_total_chunk(str(end_user.id), db, current_user)
}
# 从批量查询结果中获取配置信息
user_id = str(end_user.id) user_id = str(end_user.id)
memory_config_info = memory_configs_map.get(user_id, { config_info = memory_configs_map.get(user_id, {})
"memory_config_id": None, result.append({
"memory_config_name": None 'end_user': {
}) 'id': user_id,
'other_name': end_user.other_name
# 只保留需要的字段,移除 error 字段(如果有) },
memory_config = { 'memory_num': memory_nums_map.get(user_id, {"total": 0}),
"memory_config_id": memory_config_info.get("memory_config_id"), 'memory_config': {
"memory_config_name": memory_config_info.get("memory_config_name") "memory_config_id": config_info.get("memory_config_id"),
} "memory_config_name": config_info.get("memory_config_name")
result.append(
{
'end_user': end_user,
'memory_num': memory_num,
'memory_config': memory_config
} }
) })
# 写入缓存30秒过期
try:
await aio_redis_set(cache_key, json.dumps(result), expire=30)
except Exception as e:
api_logger.warning(f"Redis 缓存写入失败: {str(e)}")
api_logger.info(f"成功获取 {len(end_users)} 个宿主记录") api_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return success(data=result, msg="宿主列表获取成功") return success(data=result, msg="宿主列表获取成功")

View File

@@ -421,15 +421,95 @@ async def get_hot_memory_tags_api(
db: Session = Depends(get_db), db: Session = Depends(get_db),
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> dict: ) -> dict:
api_logger.info(f"Hot memory tags requested for current_user: {current_user.id}") """
获取热门记忆标签带Redis缓存
缓存策略:
- 缓存键workspace_id + limit
- 过期时间5分钟300秒
- 缓存命中:~50ms
- 缓存未命中:~600-800ms取决于LLM速度
"""
workspace_id = current_user.current_workspace_id
# 构建缓存键
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
api_logger.info(f"Hot memory tags requested for workspace: {workspace_id}, limit: {limit}")
try: try:
# 尝试从Redis缓存获取
from app.aioRedis import aio_redis_get, aio_redis_set
import json
cached_result = await aio_redis_get(cache_key)
if cached_result:
api_logger.info(f"Cache hit for key: {cache_key}")
try:
data = json.loads(cached_result)
return success(data=data, msg="查询成功(缓存)")
except json.JSONDecodeError:
api_logger.warning(f"Failed to parse cached data, will refresh")
# 缓存未命中,执行查询
api_logger.info(f"Cache miss for key: {cache_key}, executing query")
result = await analytics_hot_memory_tags(db, current_user, limit) result = await analytics_hot_memory_tags(db, current_user, limit)
# 写入缓存过期时间5分钟
# 注意result是列表需要转换为JSON字符串
try:
cache_data = json.dumps(result, ensure_ascii=False)
await aio_redis_set(cache_key, cache_data, expire=300)
api_logger.info(f"Cached result for key: {cache_key}")
except Exception as cache_error:
# 缓存写入失败不影响主流程
api_logger.warning(f"Failed to cache result: {str(cache_error)}")
return success(data=result, msg="查询成功") return success(data=result, msg="查询成功")
except Exception as e: except Exception as e:
api_logger.error(f"Hot memory tags failed: {str(e)}") api_logger.error(f"Hot memory tags failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e)) return fail(BizCode.INTERNAL_ERROR, "热门标签查询失败", str(e))
@router.delete("/analytics/hot_memory_tags/cache", response_model=ApiResponse)
async def clear_hot_memory_tags_cache(
current_user: User = Depends(get_current_user),
) -> dict:
"""
清除热门标签缓存
用于:
- 手动刷新数据
- 调试和测试
- 数据更新后立即生效
"""
workspace_id = current_user.current_workspace_id
api_logger.info(f"Clear hot memory tags cache requested for workspace: {workspace_id}")
try:
from app.aioRedis import aio_redis_delete
# 清除所有limit的缓存常见的limit值
cleared_count = 0
for limit in [5, 10, 15, 20, 30, 50]:
cache_key = f"hot_memory_tags:{workspace_id}:{limit}"
result = await aio_redis_delete(cache_key)
if result:
cleared_count += 1
api_logger.info(f"Cleared cache for key: {cache_key}")
return success(
data={"cleared_count": cleared_count},
msg=f"成功清除 {cleared_count} 个缓存"
)
except Exception as e:
api_logger.error(f"Clear cache failed: {str(e)}")
return fail(BizCode.INTERNAL_ERROR, "清除缓存失败", str(e))
@router.get("/analytics/recent_activity_stats", response_model=ApiResponse) @router.get("/analytics/recent_activity_stats", response_model=ApiResponse)
async def get_recent_activity_stats_api( async def get_recent_activity_stats_api(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),

View File

@@ -317,9 +317,12 @@ async def chat(
appid = share.app_id appid = share.app_id
"""获取存储类型和工作空间的ID""" """获取存储类型和工作空间的ID"""
# 直接通过 SQLAlchemy 查询 app # 直接通过 SQLAlchemy 查询 app(仅查询未删除的应用)
from app.models.app_model import App from app.models.app_model import App
app = db.query(App).filter(App.id == appid).first() app = db.query(App).filter(
App.id == appid,
App.is_active.is_(True)
).first()
if not app: if not app:
raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND) raise BusinessException("应用不存在", BizCode.APP_NOT_FOUND)

View File

@@ -54,7 +54,7 @@ async def create_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -214,7 +214,7 @@ async def delete_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -259,7 +259,7 @@ async def validate_workflow_config(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -329,7 +329,7 @@ async def get_workflow_executions(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -389,7 +389,7 @@ async def get_workflow_execution(
app = db.query(App).filter( app = db.query(App).filter(
App.id == execution.app_id, App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -440,7 +440,7 @@ async def run_workflow(
app = db.query(App).filter( app = db.query(App).filter(
App.id == app_id, App.id == app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:
@@ -578,7 +578,7 @@ async def cancel_workflow_execution(
app = db.query(App).filter( app = db.query(App).filter(
App.id == execution.app_id, App.id == execution.app_id,
App.workspace_id == current_user.current_workspace_id, App.workspace_id == current_user.current_workspace_id,
App.is_active == True App.is_active.is_(True)
).first() ).first()
if not app: if not app:

View File

@@ -14,7 +14,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)

View File

@@ -19,7 +19,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.db import get_db from app.db import get_db
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)
db_session = next(get_db()) db_session = next(get_db())

View File

@@ -12,7 +12,7 @@ from app.core.memory.agent.utils.session_tools import SessionService
from app.core.memory.agent.utils.template_tools import TemplateService from app.core.memory.agent.utils.template_tools import TemplateService
from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin from app.core.memory.agent.services.optimized_llm_service import LLMServiceMixin
template_root = os.path.join(PROJECT_ROOT_, 'agent', 'utils', 'prompt') template_root = os.path.join(PROJECT_ROOT_, 'memory', 'agent', 'utils', 'prompt')
db_session = next(get_db()) db_session = next(get_db())
logger = get_agent_logger(__name__) logger = get_agent_logger(__name__)

View File

@@ -1,11 +1,12 @@
import os import os
from collections import defaultdict from collections import defaultdict
from pathlib import Path
from typing import Annotated, TypedDict from typing import Annotated, TypedDict
from langchain_core.messages import AnyMessage from langchain_core.messages import AnyMessage
from langgraph.graph import add_messages from langgraph.graph import add_messages
PROJECT_ROOT_ = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) PROJECT_ROOT_ = str(Path(__file__).resolve().parents[3])
class WriteState(TypedDict): class WriteState(TypedDict):
''' '''

View File

@@ -0,0 +1,61 @@
# 角色
你是一个智能问答助手,基于检索信息和历史对话回答用户问题。
# 任务
根据提供的上下文信息回答用户的问题。
# 输入信息
- 历史对话:{{history}}
- 检索信息:{{retrieve_info}}
# 用户问题
{{query}}
# 回答指南
## 1. 仔细阅读检索信息
- 答案可能直接或间接地出现在检索信息中
- 如果检索信息中提到"小曼会使用Python",说明用户名是"小曼"
- 第三人称描述的偏好、行为通常指用户本人
## 2. 判断信息相关性
**情况A信息匹配问题**
- 直接回答,像自然对话一样
- 例:检索到"小曼会使用Python" → 问"我叫什么" → 答"你叫小曼"
**情况B信息部分相关**
- 先回答已知部分,再自然地询问更多信息
- 例:检索到"用户去过上海的面包店" → 问"我吃过哪家面包" → 答"我记得你去过上海的面包店,但具体是哪家我不太清楚,是哪家呢?"
**情况C信息完全不相关**
- 自然地表达不知道,但可以提及检索到的相关信息,让对话更连贯
- 使用友好的表达:
- "你好像没和我说过...,但是我知道你[检索到的相关信息]"
- "关于这个我不太清楚,不过我记得你[检索到的相关信息],能告诉我更多吗?"
- "我不记得你提到过...,但你[检索到的相关信息]"
- 即使检索信息不直接回答问题,也可以自然地融入对话中
- 避免僵硬的"信息不足,无法回答"
## 3. 回答要求
- 像人类对话一样自然流畅
- 不要提及"检索信息"、"搜索结果"、"根据资料"等技术术语
- 不要解释推理过程或引用信息来源
- 保持友好、乐于助人的语气
- 使用与问题相同的语言回答
# 关键示例
**示例1 - 直接匹配:**
- 检索信息:"小曼会使用Python..."
- 问题:"我叫什么"
- ✓ 正确:"你叫小曼"
- ✗ 错误:"你没有告诉我你的名字"
**示例2 - 间接匹配:**
- 检索信息:"用户很喜欢吃星巴克的甜品"
- 问题:"我喜欢什么"
- ✓ 正确:"你很喜欢吃星巴克的甜品"
- ✗ 错误:"信息不足"
**示例3 - 信息不匹配(推荐做法):**
- 检索信息:"用户只喝拿铁咖啡,认为美式咖啡太苦"
- 问题:"我吃过哪家面包"
- ✓ 最佳:"你好像没和我说过吃过哪家面包,但是我知道你喜欢喝拿铁,能跟我分享一下吗?"
- ✓ 可以:"你好像没和我说过吃过哪家面包,能跟我分享一下吗?"
- ✗ 错误:"用户只喝拿铁咖啡,认为美式咖啡太苦。"(答非所问)
- ✗ 错误:"信息不足,无法回答。"(太僵硬)
# 重要提醒
- 检索信息中描述用户行为/偏好时提到的名字,就是用户的名字
- 信息不匹配时,不要强行回答无关内容,但可以自然地提及检索到的信息,让对话更有温度
- 用对话式语言表达"不知道",而非机械模板
- 检索信息代表你对用户的了解,即使不直接回答问题,也能体现你对用户的记忆

View File

@@ -139,7 +139,8 @@ def parse_api_docs(file_path: str) -> Dict[str, Any]:
def get_default_docs_path() -> str: def get_default_docs_path() -> str:
project_root = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
return os.path.join(project_root, "src", "analytics", "API接口.md") return os.path.join(project_root, "src", "analytics", "API接口.md")

View File

@@ -2,13 +2,16 @@ import os
import re import re
import glob import glob
import json import json
from pathlib import Path
from typing import Tuple from typing import Tuple
try: try:
from app.core.memory.utils.config.definitions import PROJECT_ROOT from app.core.memory.utils.config.definitions import PROJECT_ROOT
except Exception: except Exception:
# Fallback: derive project root from this file location # Fallback: derive project root from this file location
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # 当前文件在 api/app/core/memory/analytics/recent_activity_stats.py
# 需要向上 5 级到达 api/ 目录
PROJECT_ROOT = str(Path(__file__).resolve().parents[4])
def _get_latest_prompt_log_path() -> str | None: def _get_latest_prompt_log_path() -> str | None:
@@ -67,44 +70,43 @@ def parse_stats_from_log(log_path: str) -> dict:
triplet_relations_count = 0 triplet_relations_count = 0
temporal_count = 0 temporal_count = 0
# Patterns # 正则表达式模式 - 匹配当前日志格式
pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===") pat_chunk_render = re.compile(r"===\s*RENDERED\s*STATEMENT\s*EXTRACTION\s*PROMPT\s*===")
pat_triplet_start = re.compile(r"\[Triplet\].*statements_to_process\s*=\s*(\d+)") pat_triplet_started = re.compile(r"\[Triplet\]\s+Started\s+-\s+statement_id=")
pat_triplet_done = re.compile( pat_triplet_completed = re.compile(
r"\[Triplet\].*completed,\s*total_triplets\s*=\s*(\d+),\s*total_entities\s*=\s*(\d+)" r"\[Triplet\]\s+Completed\s+-\s+statement_id=[^,]+,\s+triplets=(\d+),\s+entities=(\d+)"
) )
pat_temporal_done = re.compile( pat_temporal_completed = re.compile(
r"\[Temporal\].*completed,\s*extracted_valid_ranges\s*=\s*(\d+)" r"\[Temporal\]\s+Completed\s+-\s+statement_id=[^,]+,\s+valid_ranges=(\d+)"
) )
with open(log_path, "r", encoding="utf-8", errors="ignore") as f: with open(log_path, "r", encoding="utf-8", errors="ignore") as f:
for line in f: for line in f:
# Chunk prompts count (each chunk triggers one statement-extraction prompt render) # 文本块数量(每个块触发一次陈述提取提示)
if pat_chunk_render.search(line): if pat_chunk_render.search(line):
chunk_count += 1 chunk_count += 1
continue continue
m1 = pat_triplet_start.search(line) # 陈述数量(每个 Triplet Started 代表一个陈述被处理)
if m1: if pat_triplet_started.search(line):
statements_count += 1
continue
# 三元组完成:[Triplet] Completed - statement_id=xxx, triplets=X, entities=Y
m_triplet = pat_triplet_completed.search(line)
if m_triplet:
try: try:
statements_count += int(m1.group(1)) triplet_relations_count += int(m_triplet.group(1))
triplet_entities_count += int(m_triplet.group(2))
except Exception: except Exception:
pass pass
continue continue
m2 = pat_triplet_done.search(line) # 时间信息完成:[Temporal] Completed - statement_id=xxx, valid_ranges=X
if m2: m_temporal = pat_temporal_completed.search(line)
if m_temporal:
try: try:
triplet_relations_count += int(m2.group(1)) temporal_count += int(m_temporal.group(1))
triplet_entities_count += int(m2.group(2))
except Exception:
pass
continue
m3 = pat_temporal_done.search(line)
if m3:
try:
temporal_count += int(m3.group(1))
except Exception: except Exception:
pass pass
continue continue
@@ -120,15 +122,20 @@ def parse_stats_from_log(log_path: str) -> dict:
def get_recent_activity_stats() -> Tuple[dict, str]: def get_recent_activity_stats() -> Tuple[dict, str]:
"""Get aggregated stats from all prompt logs in logs/. """Get stats from the latest prompt log file only.
Returns (stats_dict, message). Returns (stats_dict, message).
""" """
all_logs = _get_all_prompt_logs() # 获取最新的日志文件
# Fallback to recursive search if none found in logs/ latest_log = _get_latest_prompt_log_path()
if not all_logs:
# 如果没有找到,尝试递归搜索
if not latest_log:
all_logs = _get_any_logs_recursive() all_logs = _get_any_logs_recursive()
if not all_logs: if all_logs:
latest_log = all_logs[-1] # 取最新的
if not latest_log:
return ( return (
{ {
"chunk_count": 0, "chunk_count": 0,
@@ -141,24 +148,13 @@ def get_recent_activity_stats() -> Tuple[dict, str]:
"未找到日志文件,请确认已运行过提取流程。", "未找到日志文件,请确认已运行过提取流程。",
) )
agg = { # 只解析最新的日志文件
"chunk_count": 0, stats = parse_stats_from_log(latest_log)
"statements_count": 0,
"triplet_entities_count": 0, # 添加日志文件路径信息
"triplet_relations_count": 0, stats["log_path"] = f"最新:{latest_log}"
"temporal_count": 0,
} return stats, "成功读取最近一次记忆活动统计。"
for path in all_logs:
s = parse_stats_from_log(path)
agg["chunk_count"] += s.get("chunk_count", 0)
agg["statements_count"] += s.get("statements_count", 0)
agg["triplet_entities_count"] += s.get("triplet_entities_count", 0)
agg["triplet_relations_count"] += s.get("triplet_relations_count", 0)
agg["temporal_count"] += s.get("temporal_count", 0)
# Attach a summary of files combined
agg["log_path"] = f"{len(all_logs)} 个日志文件,最新:{all_logs[-1]}"
return agg, "成功汇总 logs 目录中所有提示日志。"
def _format_summary(stats: dict) -> str: def _format_summary(stats: dict) -> str:

View File

@@ -8,13 +8,14 @@ import sys
import time import time
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List from typing import Any, Dict, List
from pathlib import Path
from dotenv import load_dotenv from dotenv import load_dotenv
# 1 # 1
# 添加项目根目录到路径 # 添加项目根目录到路径
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = Path(__file__).resolve().parent
project_root = os.path.dirname(current_dir) project_root = str(current_dir.parent)
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)
# 关键:将 src 目录置于最前,确保从当前仓库加载模块 # 关键:将 src 目录置于最前,确保从当前仓库加载模块

View File

@@ -16,9 +16,10 @@ except Exception:
# 确保可以找到 src 及项目根路径 # 确保可以找到 src 及项目根路径
import sys import sys
from pathlib import Path
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _THIS_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(_THIS_DIR))) _PROJECT_ROOT = str(_THIS_DIR.parents[2])
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") _SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in (_SRC_DIR, _PROJECT_ROOT): for _p in (_SRC_DIR, _PROJECT_ROOT):
if _p not in sys.path: if _p not in sys.path:

View File

@@ -15,9 +15,10 @@ except Exception:
# 路径与模块导入保持与现有评估脚本一致 # 路径与模块导入保持与现有评估脚本一致
import sys import sys
from pathlib import Path
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _THIS_DIR = Path(__file__).resolve().parent
_PROJECT_ROOT = os.path.dirname(os.path.dirname(_THIS_DIR)) _PROJECT_ROOT = str(_THIS_DIR.parents[1])
_SRC_DIR = os.path.join(_PROJECT_ROOT, "src") _SRC_DIR = os.path.join(_PROJECT_ROOT, "src")
for _p in (_SRC_DIR, _PROJECT_ROOT): for _p in (_SRC_DIR, _PROJECT_ROOT):
if _p not in sys.path: if _p not in sys.path:

View File

@@ -15,9 +15,13 @@ class AppRepository:
self.db = db self.db = db
def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]: def get_apps_by_workspace_id(self, workspace_id: uuid.UUID) -> list[App]:
"""根据工作空间ID查询应用""" """根据工作空间ID查询应用(仅返回未删除的应用)"""
try: try:
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() apps = (
self.db.query(App)
.filter(App.workspace_id == workspace_id, App.is_active.is_(True))
.all()
)
db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用") db_logger.info(f"成功查询工作空间 {workspace_id} 下的 {len(apps)} 个应用")
return apps return apps
except Exception as e: except Exception as e:
@@ -26,7 +30,7 @@ class AppRepository:
def get_apps_by_id(self, app_id: uuid.UUID) -> App: def get_apps_by_id(self, app_id: uuid.UUID) -> App:
try: try:
app = self.db.query(App).filter(App.id == app_id, App.is_active == True).first() app = self.db.query(App).filter(App.id == app_id, App.is_active.is_(True)).first()
return app return app
except Exception as e: except Exception as e:
raise raise

View File

@@ -17,24 +17,24 @@ class HomePageRepository:
"""获取模型统计数据""" """获取模型统计数据"""
total_models = db.query(ModelConfig).filter( total_models = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True ModelConfig.is_active.is_(True)
).count() ).count()
total_llm = db.query(ModelConfig).filter( total_llm = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.type == "llm" ModelConfig.type == "llm"
).count() ).count()
total_embedding = db.query(ModelConfig).filter( total_embedding = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.type == "embedding" ModelConfig.type == "embedding"
).count() ).count()
new_models_this_week = db.query(ModelConfig).filter( new_models_this_week = db.query(ModelConfig).filter(
ModelConfig.tenant_id == tenant_id, ModelConfig.tenant_id == tenant_id,
ModelConfig.is_active == True, ModelConfig.is_active.is_(True),
ModelConfig.created_at >= week_start ModelConfig.created_at >= week_start
).count() ).count()
@@ -56,12 +56,12 @@ class HomePageRepository:
"""获取工作空间统计数据""" """获取工作空间统计数据"""
active_workspaces = db.query(Workspace).filter( active_workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).count() ).count()
new_workspaces_this_week = db.query(Workspace).filter( new_workspaces_this_week = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True, Workspace.is_active.is_(True),
Workspace.created_at >= week_start Workspace.created_at >= week_start
).count() ).count()
@@ -83,7 +83,7 @@ class HomePageRepository:
"""获取用户统计数据""" """获取用户统计数据"""
workspace_ids = db.query(Workspace.id).filter( workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).subquery() ).subquery()
total_users = db.query(EndUser).join( total_users = db.query(EndUser).join(
@@ -91,7 +91,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active" App.status == "active"
).count() ).count()
@@ -100,7 +100,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active", App.status == "active",
EndUser.created_at >= week_start EndUser.created_at >= week_start
).count() ).count()
@@ -123,18 +123,18 @@ class HomePageRepository:
"""获取应用统计数据""" """获取应用统计数据"""
workspace_ids = db.query(Workspace.id).filter( workspace_ids = db.query(Workspace.id).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).subquery() ).subquery()
running_apps = db.query(App).filter( running_apps = db.query(App).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active" App.status == "active"
).count() ).count()
new_apps_this_week = db.query(App).filter( new_apps_this_week = db.query(App).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active == True, App.is_active.is_(True),
App.status == "active", App.status == "active",
App.created_at >= week_start App.created_at >= week_start
).count() ).count()
@@ -158,7 +158,7 @@ class HomePageRepository:
# 获取工作空间列表 # 获取工作空间列表
workspaces = db.query(Workspace).filter( workspaces = db.query(Workspace).filter(
Workspace.tenant_id == tenant_id, Workspace.tenant_id == tenant_id,
Workspace.is_active == True Workspace.is_active.is_(True)
).all() ).all()
workspace_ids = [ws.id for ws in workspaces] workspace_ids = [ws.id for ws in workspaces]
@@ -169,7 +169,7 @@ class HomePageRepository:
func.count(App.id).label('count') func.count(App.id).label('count')
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active, App.is_active.is_(True),
App.status == "active" App.status == "active"
).group_by(App.workspace_id).all() ).group_by(App.workspace_id).all()
@@ -184,7 +184,7 @@ class HomePageRepository:
EndUser.app_id == App.id EndUser.app_id == App.id
).filter( ).filter(
App.workspace_id.in_(workspace_ids), App.workspace_id.in_(workspace_ids),
App.is_active, App.is_active.is_(True),
App.status == "active" App.status == "active"
).group_by(App.workspace_id).all() ).group_by(App.workspace_id).all()

View File

@@ -68,7 +68,7 @@ class UserRepository:
db_logger.debug("查询超级用户") db_logger.debug("查询超级用户")
try: try:
user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).first() user = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active.is_(True)).filter(User.is_superuser.is_(True)).first()
if user: if user:
db_logger.debug(f"超级用户查询成功: {user.username}") db_logger.debug(f"超级用户查询成功: {user.username}")
else: else:
@@ -82,7 +82,7 @@ class UserRepository:
db_logger.debug("检查是否只有一个超级用户") db_logger.debug("检查是否只有一个超级用户")
try: try:
count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active == True).filter(User.is_superuser == True).count() count = self.db.query(User).options(joinedload(User.tenant)).filter(User.is_active.is_(True)).filter(User.is_superuser.is_(True)).count()
return count == 1 return count == 1
except Exception as e: except Exception as e:
db_logger.error(f"检查超级用户数量失败: {str(e)}") db_logger.error(f"检查超级用户数量失败: {str(e)}")

View File

@@ -33,7 +33,7 @@ class WorkflowConfigRepository:
""" """
return self.db.query(WorkflowConfig).filter( return self.db.query(WorkflowConfig).filter(
WorkflowConfig.app_id == app_id, WorkflowConfig.app_id == app_id,
WorkflowConfig.is_active == True WorkflowConfig.is_active.is_(True)
).first() ).first()
def create_or_update( def create_or_update(

View File

@@ -103,7 +103,7 @@ class WorkspaceRepository:
workspaces = ( workspaces = (
self.db.query(Workspace) self.db.query(Workspace)
.filter(Workspace.tenant_id == user.tenant_id) .filter(Workspace.tenant_id == user.tenant_id)
.filter(Workspace.is_active == True) .filter(Workspace.is_active.is_(True))
.order_by(Workspace.updated_at.desc()) .order_by(Workspace.updated_at.desc())
.all() .all()
) )
@@ -115,7 +115,7 @@ class WorkspaceRepository:
self.db.query(Workspace) self.db.query(Workspace)
.join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id) .join(WorkspaceMember, Workspace.id == WorkspaceMember.workspace_id)
.filter(WorkspaceMember.user_id == user_id) .filter(WorkspaceMember.user_id == user_id)
.filter(Workspace.is_active == True) .filter(Workspace.is_active.is_(True))
.order_by(Workspace.updated_at.desc()) .order_by(Workspace.updated_at.desc())
.all() .all()
) )
@@ -134,7 +134,7 @@ class WorkspaceRepository:
workspaces = ( workspaces = (
self.db.query(Workspace) self.db.query(Workspace)
.filter(Workspace.tenant_id == tenant_id) .filter(Workspace.tenant_id == tenant_id)
.filter(Workspace.is_active == True) .filter(Workspace.is_active.is_(True))
.all() .all()
) )
db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}") db_logger.debug(f"租户工作空间查询成功: tenant_id={tenant_id}, 数量={len(workspaces)}")
@@ -169,7 +169,7 @@ class WorkspaceRepository:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.user_id == user_id, WorkspaceMember.user_id == user_id,
WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.is_active == True, WorkspaceMember.is_active.is_(True),
).first() ).first()
if member: if member:
db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}") db_logger.debug(f"工作空间成员查询成功: user_id={user_id}, workspace_id={workspace_id}, role={member.role}")
@@ -189,8 +189,8 @@ class WorkspaceRepository:
.join(User, WorkspaceMember.user_id == User.id) .join(User, WorkspaceMember.user_id == User.id)
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
.filter(WorkspaceMember.workspace_id == workspace_id) .filter(WorkspaceMember.workspace_id == workspace_id)
.filter(WorkspaceMember.is_active == True) .filter(WorkspaceMember.is_active.is_(True))
.filter(User.is_active == True) .filter(User.is_active.is_(True))
.all() .all()
) )
db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}") db_logger.debug(f"成员列表查询成功: workspace_id={workspace_id}, 数量={len(members)}")
@@ -208,8 +208,8 @@ class WorkspaceRepository:
.join(User, WorkspaceMember.user_id == User.id) .join(User, WorkspaceMember.user_id == User.id)
.options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace)) .options(joinedload(WorkspaceMember.user), joinedload(WorkspaceMember.workspace))
.filter(WorkspaceMember.id == member_id) .filter(WorkspaceMember.id == member_id)
.filter(WorkspaceMember.is_active == True) .filter(WorkspaceMember.is_active.is_(True))
.filter(User.is_active == True) .filter(User.is_active.is_(True))
.first() .first()
) )
if member: if member:
@@ -226,7 +226,7 @@ class WorkspaceRepository:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.user_id == user_id, WorkspaceMember.user_id == user_id,
WorkspaceMember.is_active == True, WorkspaceMember.is_active.is_(True),
).first() ).first()
if not member: if not member:
return None return None
@@ -243,7 +243,7 @@ class WorkspaceRepository:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.workspace_id == workspace_id, WorkspaceMember.workspace_id == workspace_id,
WorkspaceMember.user_id == user_id, WorkspaceMember.user_id == user_id,
WorkspaceMember.is_active == True, WorkspaceMember.is_active.is_(True),
).first() ).first()
if not member: if not member:
return None return None
@@ -259,7 +259,7 @@ class WorkspaceRepository:
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.id == member_id, WorkspaceMember.id == member_id,
WorkspaceMember.is_active == True, WorkspaceMember.is_active.is_(True),
).first() ).first()
if not member: if not member:
return None return None
@@ -275,7 +275,7 @@ class WorkspaceRepository:
try: try:
member = self.db.query(WorkspaceMember).filter( member = self.db.query(WorkspaceMember).filter(
WorkspaceMember.id == id, WorkspaceMember.id == id,
WorkspaceMember.is_active == True, WorkspaceMember.is_active.is_(True),
).first() ).first()
if not member: if not member:
return None return None

View File

@@ -55,8 +55,8 @@ class AgentRegistry:
""" """
# 构建查询 # 构建查询
stmt = select(AgentConfig).join(App).where( stmt = select(AgentConfig).join(App).where(
AgentConfig.is_active == True, AgentConfig.is_active.is_(True),
App.is_active == True App.is_active.is_(True)
) )
# 工作空间过滤(同工作空间或公开) # 工作空间过滤(同工作空间或公开)

View File

@@ -758,7 +758,7 @@ class AppService:
) )
# 构建查询条件 # 构建查询条件
filters = [App.is_active == True] filters = [App.is_active.is_(True)]
if type: if type:
filters.append(App.type == type) filters.append(App.type == type)
if visibility: if visibility:
@@ -873,7 +873,7 @@ class AppService:
self._validate_workspace_access(app, workspace_id) self._validate_workspace_access(app, workspace_id)
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
AgentConfig.updated_at.desc()) AgentConfig.updated_at.desc())
agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first() agent_cfg: Optional[AgentConfig] = self.db.scalars(stmt).first()
now = datetime.datetime.now() now = datetime.datetime.now()
@@ -1204,7 +1204,7 @@ class AppService:
default_model_config_id = None default_model_config_id = None
if app.type == AppType.AGENT: if app.type == AppType.AGENT:
stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active == True).order_by( stmt = select(AgentConfig).where(AgentConfig.app_id == app_id, AgentConfig.is_active.is_(True)).order_by(
AgentConfig.updated_at.desc()) AgentConfig.updated_at.desc())
agent_cfg = self.db.scalars(stmt).first() agent_cfg = self.db.scalars(stmt).first()
if not agent_cfg: if not agent_cfg:
@@ -1226,7 +1226,7 @@ class AppService:
select(MultiAgentConfig) select(MultiAgentConfig)
.where( .where(
MultiAgentConfig.app_id == app_id, MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True MultiAgentConfig.is_active.is_(True)
) )
.order_by(MultiAgentConfig.updated_at.desc()) .order_by(MultiAgentConfig.updated_at.desc())
) )
@@ -1380,7 +1380,7 @@ class AppService:
stmt = ( stmt = (
select(AppRelease) select(AppRelease)
.where(AppRelease.app_id == app_id, AppRelease.is_active == True) .where(AppRelease.app_id == app_id, AppRelease.is_active.is_(True))
.order_by(AppRelease.version.desc()) .order_by(AppRelease.version.desc())
) )
return list(self.db.scalars(stmt).all()) return list(self.db.scalars(stmt).all())

View File

@@ -728,7 +728,7 @@ class DraftRunService:
select(ModelApiKey) select(ModelApiKey)
.where( .where(
ModelApiKey.model_config_id == model_config_id, ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True ModelApiKey.is_active.is_(True)
) )
.order_by(ModelApiKey.priority.desc()) .order_by(ModelApiKey.priority.desc())
.limit(1) .limit(1)

View File

@@ -175,10 +175,9 @@ class MemoryAgentService:
""" """
logger.info("Reading log file") logger.info("Reading log file")
# Get log file path - use project root directory
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py from pathlib import Path
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = str(Path(__file__).resolve().parents[2]) # api directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
log_path = os.path.join(project_root, "logs", "agent_service.log") log_path = os.path.join(project_root, "logs", "agent_service.log")
summer = '' summer = ''
@@ -217,9 +216,8 @@ class MemoryAgentService:
logger.info("Starting log content streaming") logger.info("Starting log content streaming")
# Get log file path - use project root directory # Get log file path - use project root directory
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py from pathlib import Path
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = str(Path(__file__).resolve().parents[2]) # api directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
log_path = os.path.join(project_root, "logs", "agent_service.log") log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream # Check if file exists before starting stream
@@ -431,13 +429,15 @@ class MemoryAgentService:
audit_logger = None audit_logger = None
config_load_start = time.time()
try: try:
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
config_id=config_id, config_id=config_id,
service_name="MemoryAgentService" service_name="MemoryAgentService"
) )
logger.info(f"Configuration loaded successfully: {memory_config.config_name}") config_load_time = time.time() - config_load_start
logger.info(f"[PERF] Configuration loaded in {config_load_time:.4f}s: {memory_config.config_name}")
except ConfigurationError as e: except ConfigurationError as e:
error_msg = f"Failed to load configuration for config_id: {config_id}: {e}" error_msg = f"Failed to load configuration for config_id: {config_id}: {e}"
logger.error(error_msg) logger.error(error_msg)
@@ -578,6 +578,8 @@ class MemoryAgentService:
logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True) logger.error(f"保存短期记忆失败: {str(save_error)}", exc_info=True)
# Log successful operation # Log successful operation
total_time = time.time() - start_time
logger.info(f"[PERF] read_memory completed successfully in {total_time:.4f}s (config: {config_load_time:.4f}s, graph: {graph_exec_time:.4f}s)")
if audit_logger: if audit_logger:
duration = time.time() - start_time duration = time.time() - start_time
audit_logger.log_operation( audit_logger.log_operation(
@@ -668,6 +670,8 @@ class MemoryAgentService:
""" """
logger.info("Classifying message type") logger.info("Classifying message type")
# Load configuration to get LLM model ID # Load configuration to get LLM model ID
config_service = MemoryConfigService(db) config_service = MemoryConfigService(db)
memory_config = config_service.load_memory_config( memory_config = config_service.load_memory_config(
@@ -681,10 +685,11 @@ class MemoryAgentService:
async def generate_summary_from_retrieve( async def generate_summary_from_retrieve(
self, self,
end_user_id: str,
retrieve_info: str, retrieve_info: str,
history: List[Dict], history: List[Dict],
query: str, query: str,
config_id: UUID, config_id: str,
db: Session db: Session
) -> str: ) -> str:
""" """
@@ -702,6 +707,18 @@ class MemoryAgentService:
Returns: Returns:
生成的答案文本 生成的答案文本
""" """
if config_id is None:
try:
config_id = get_end_user_connected_config(end_user_id, db)
config_id = config_id.get('memory_config_id')
if config_id is None:
raise ValueError(
f"No memory configuration found for end_user {end_user_id}. Please ensure the user has a connected memory configuration.")
except Exception as e:
if "No memory configuration found" in str(e):
raise # Re-raise our specific error
logger.error(f"Failed to get connected config for end_user {end_user_id}: {e}")
raise ValueError(f"Unable to determine memory configuration for end_user {group_id}: {e}")
logger.info(f"Generating summary from retrieve info for query: {query[:50]}...") logger.info(f"Generating summary from retrieve info for query: {query[:50]}...")
try: try:
@@ -727,7 +744,7 @@ class MemoryAgentService:
state=state, state=state,
history=history, history=history,
retrieve_info=retrieve_info, retrieve_info=retrieve_info,
template_name='Retrieve_Summary_prompt.jinja2', template_name='direct_summary_prompt.jinja2',
operation_name='retrieve_summary', operation_name='retrieve_summary',
response_model=RetrieveSummaryResponse, response_model=RetrieveSummaryResponse,
search_mode="1" search_mode="1"
@@ -1075,9 +1092,8 @@ class MemoryAgentService:
logger.info("Starting log content streaming") logger.info("Starting log content streaming")
# Get log file path - use project root directory # Get log file path - use project root directory
current_file = os.path.abspath(__file__) # app/services/memory_agent_service.py from pathlib import Path
app_dir = os.path.dirname(os.path.dirname(current_file)) # app directory project_root = str(Path(__file__).resolve().parents[2]) # api directory
project_root = os.path.dirname(app_dir) # redbear-mem directory
log_path = os.path.join(project_root, "logs", "agent_service.log") log_path = os.path.join(project_root, "logs", "agent_service.log")
# Check if file exists before starting stream # Check if file exists before starting stream
@@ -1175,7 +1191,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
# 3. 从 config 中提取 memory_config_id # 3. 从 config 中提取 memory_config_id
config = latest_release.config or {} config = latest_release.config or {}
# 如果 config 是字符串,解析为字典 # 如果 config 是字符串,解析为字典
if isinstance(config, str): if isinstance(config, str):
import json import json
@@ -1184,7 +1200,7 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
except json.JSONDecodeError: except json.JSONDecodeError:
logger.warning(f"Failed to parse config JSON for release {latest_release.id}") logger.warning(f"Failed to parse config JSON for release {latest_release.id}")
config = {} config = {}
memory_obj = config.get('memory', {}) memory_obj = config.get('memory', {})
memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None memory_config_id = memory_obj.get('memory_content') if isinstance(memory_obj, dict) else None
@@ -1196,10 +1212,6 @@ def get_end_user_connected_config(end_user_id: str, db: Session) -> Dict[str, An
"memory_config_id": memory_config_id "memory_config_id": memory_config_id
} }
print(188*'*')
print(result)
print(188 * '*')
logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}") logger.info(f"Successfully retrieved connected config: memory_config_id={memory_config_id}")
return result return result

View File

@@ -77,7 +77,10 @@ class MemoryAPIService:
) )
# Verify end_user belongs to the workspace via App relationship # Verify end_user belongs to the workspace via App relationship
app = self.db.query(App).filter(App.id == end_user.app_id).first() app = self.db.query(App).filter(
App.id == end_user.app_id,
App.is_active.is_(True)
).first()
if not app: if not app:
logger.warning(f"App not found for end_user: {end_user_id}") logger.warning(f"App not found for end_user: {end_user_id}")

View File

@@ -53,18 +53,28 @@ def get_workspace_end_users(
workspace_id: uuid.UUID, workspace_id: uuid.UUID,
current_user: User current_user: User
) -> List[EndUser]: ) -> List[EndUser]:
"""获取工作空间的所有宿主""" """获取工作空间的所有宿主(优化版本:减少数据库查询次数)"""
business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}") business_logger.info(f"获取工作空间宿主列表: workspace_id={workspace_id}, 操作者: {current_user.username}")
try: try:
# 查询应用ORM并转换为 Pydantic 模型 # 查询应用ORM
apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id) apps_orm = app_repository.get_apps_by_workspace_id(db, workspace_id)
apps = [AppSchema.model_validate(h) for h in apps_orm]
app_ids = [app.id for app in apps] if not apps_orm:
end_users = [] business_logger.info("工作空间下没有应用")
for app_id in app_ids: return []
end_user_orm_list = end_user_repository.get_end_users_by_app_id(db, app_id)
end_users.extend([EndUserSchema.model_validate(h) for h in end_user_orm_list]) # 提取所有 app_id
app_ids = [app.id for app in apps_orm]
# 批量查询所有 end_users一次查询而非循环查询
from app.models.end_user_model import EndUser as EndUserModel
end_users_orm = db.query(EndUserModel).filter(
EndUserModel.app_id.in_(app_ids)
).all()
# 转换为 Pydantic 模型(只在需要时转换)
end_users = [EndUserSchema.model_validate(eu) for eu in end_users_orm]
business_logger.info(f"成功获取 {len(end_users)} 个宿主记录") business_logger.info(f"成功获取 {len(end_users)} 个宿主记录")
return end_users return end_users
@@ -414,6 +424,67 @@ def get_current_user_total_chunk(
business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}") business_logger.error(f"获取用户总chunk数失败: end_user_id={end_user_id} - {str(e)}")
raise raise
def get_users_total_chunk_batch(
end_user_ids: List[str],
db: Session,
current_user: User
) -> dict:
"""
批量获取多个用户的总chunk数性能优化版本
Args:
end_user_ids: 用户ID列表
db: 数据库会话
current_user: 当前用户
Returns:
字典key为end_user_idvalue为chunk总数
格式: {"user_id_1": 100, "user_id_2": 50, ...}
"""
business_logger.info(f"批量获取 {len(end_user_ids)} 个用户的总chunk数, 操作者: {current_user.username}")
try:
from app.models.document_model import Document
from sqlalchemy import func, case
if not end_user_ids:
return {}
# 构造所有文件名
file_names = [f"{user_id}.txt" for user_id in end_user_ids]
# 一次查询获取所有用户的chunk总数
# 使用 GROUP BY file_name 来分组统计
results = db.query(
Document.file_name,
func.sum(Document.chunk_num).label('total_chunk')
).filter(
Document.file_name.in_(file_names)
).group_by(
Document.file_name
).all()
# 构建结果字典
chunk_map = {}
for file_name, total_chunk in results:
# 从文件名中提取 end_user_id (去掉 .txt 后缀)
user_id = file_name.replace('.txt', '')
chunk_map[user_id] = int(total_chunk or 0)
# 对于没有记录的用户设置为0
for user_id in end_user_ids:
if user_id not in chunk_map:
chunk_map[user_id] = 0
business_logger.info(f"成功批量获取 {len(chunk_map)} 个用户的总chunk数")
return chunk_map
except Exception as e:
business_logger.error(f"批量获取用户总chunk数失败: {str(e)}")
raise
def get_rag_content( def get_rag_content(
end_user_id: str, end_user_id: str,
limit: int, limit: int,

View File

@@ -38,7 +38,10 @@ class WorkspaceAppService:
Returns: Returns:
Dictionary containing detailed application information Dictionary containing detailed application information
""" """
apps = self.db.query(App).filter(App.workspace_id == workspace_id).all() apps = self.db.query(App).filter(
App.workspace_id == workspace_id,
App.is_active.is_(True)
).all()
app_ids = [str(app.id) for app in apps] app_ids = [str(app.id) for app in apps]
apps_detailed_info = [] apps_detailed_info = []

View File

@@ -12,7 +12,11 @@ from datetime import datetime
from typing import Any, AsyncGenerator, Dict, List, Optional from typing import Any, AsyncGenerator, Dict, List, Optional
from app.core.logging_config import get_config_logger, get_logger from app.core.logging_config import get_config_logger, get_logger
from app.core.memory.analytics.hot_memory_tags import get_hot_memory_tags from app.core.memory.analytics.hot_memory_tags import (
get_hot_memory_tags,
get_raw_tags_from_db,
filter_tags_with_llm,
)
from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats from app.core.memory.analytics.recent_activity_stats import get_recent_activity_stats
from app.models.user_model import User from app.models.user_model import User
from app.repositories.memory_config_repository import MemoryConfigRepository from app.repositories.memory_config_repository import MemoryConfigRepository
@@ -237,7 +241,8 @@ class DataConfigService: # 数据配置服务类PostgreSQL
ValueError: 当配置无效或参数缺失时 ValueError: 当配置无效或参数缺失时
RuntimeError: 当管线执行失败时 RuntimeError: 当管线执行失败时
""" """
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from pathlib import Path
project_root = str(Path(__file__).resolve().parents[2])
try: try:
# 发出初始进度事件 # 发出初始进度事件
@@ -512,27 +517,79 @@ async def analytics_hot_memory_tags(
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
""" """
获取热门记忆标签按数量排序并返回前N个 获取热门记忆标签按数量排序并返回前N个
优化策略:
1. 先从所有用户收集原始标签不调用LLM
2. 聚合并合并相同标签的频率
3. 排序后取前N个
4. 只调用一次LLM进行筛选
""" """
workspace_id = current_user.current_workspace_id workspace_id = current_user.current_workspace_id
# 获取更多标签供LLM筛选获取limit*4个标签 # 获取更多标签供LLM筛选获取limit*4个标签
raw_limit = limit * 4 raw_limit = limit * 4
from app.services.memory_dashboard_service import get_workspace_end_users from app.services.memory_dashboard_service import get_workspace_end_users
end_users = get_workspace_end_users(db, workspace_id, current_user) # 使用 asyncio.to_thread 避免阻塞事件循环
end_users = await asyncio.to_thread(get_workspace_end_users, db, workspace_id, current_user)
tags = [] if not end_users:
for end_user in end_users: return []
tag = await get_hot_memory_tags(str(end_user.id), limit=raw_limit)
if tag:
# 将每个用户的标签列表展平到总列表中
tags.extend(tag)
# 按频率降序排序(虽然数据库已经排序,但为了确保正确性再次排序)
sorted_tags = sorted(tags, key=lambda x: x[1], reverse=True)
# 只返回前limit个 # 步骤1: 收集所有用户的原始标签不调用LLM
top_tags = sorted_tags[:limit] connector = Neo4jConnector()
try:
return [{"name": t, "frequency": f} for t, f in top_tags] all_raw_tags = []
for end_user in end_users:
raw_tags = await get_raw_tags_from_db(
connector,
str(end_user.id),
limit=raw_limit,
by_user=False
)
if raw_tags:
all_raw_tags.extend(raw_tags)
if not all_raw_tags:
return []
# 步骤2: 聚合相同标签的频率
tag_frequency_map = {}
for tag_name, frequency in all_raw_tags:
if tag_name in tag_frequency_map:
tag_frequency_map[tag_name] += frequency
else:
tag_frequency_map[tag_name] = frequency
# 步骤3: 按频率降序排序取前raw_limit个
sorted_tags = sorted(
tag_frequency_map.items(),
key=lambda x: x[1],
reverse=True
)[:raw_limit]
if not sorted_tags:
return []
# 步骤4: 只调用一次LLM进行筛选
tag_names = [tag for tag, _ in sorted_tags]
# 使用第一个用户的group_id来获取LLM配置
# 因为同一工作空间下的用户应该使用相同的配置
first_end_user_id = str(end_users[0].id)
filtered_tag_names = await filter_tags_with_llm(tag_names, first_end_user_id)
# 步骤5: 根据LLM筛选结果构建最终列表保留频率
final_tags = []
for tag, freq in sorted_tags:
if tag in filtered_tag_names:
final_tags.append((tag, freq))
# 步骤6: 只返回前limit个
top_tags = final_tags[:limit]
return [{"name": t, "frequency": f} for t, f in top_tags]
finally:
await connector.close()
async def analytics_recent_activity_stats() -> Dict[str, Any]: async def analytics_recent_activity_stats() -> Dict[str, Any]:

View File

@@ -2548,7 +2548,7 @@ class MultiAgentOrchestrator:
# 获取 API Key 配置 # 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter( api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == default_model_config_id, ModelApiKey.model_config_id == default_model_config_id,
ModelApiKey.is_active == True ModelApiKey.is_active.is_(True)
).first() ).first()
if not api_key_config: if not api_key_config:
@@ -2705,7 +2705,7 @@ class MultiAgentOrchestrator:
# 获取 API Key 配置 # 获取 API Key 配置
api_key_config = self.db.query(ModelApiKey).filter( api_key_config = self.db.query(ModelApiKey).filter(
ModelApiKey.model_config_id == default_model_config_id, ModelApiKey.model_config_id == default_model_config_id,
ModelApiKey.is_active == True ModelApiKey.is_active.is_(True)
).first() ).first()
if not api_key_config: if not api_key_config:

View File

@@ -74,7 +74,7 @@ class MultiAgentService:
select(MultiAgentConfig) select(MultiAgentConfig)
.where( .where(
MultiAgentConfig.app_id == app_id, MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True MultiAgentConfig.is_active.is_(True)
) )
.order_by(MultiAgentConfig.updated_at.desc()) .order_by(MultiAgentConfig.updated_at.desc())
).first() ).first()
@@ -144,7 +144,7 @@ class MultiAgentService:
select(MultiAgentConfig) select(MultiAgentConfig)
.where( .where(
MultiAgentConfig.app_id == app_id, MultiAgentConfig.app_id == app_id,
MultiAgentConfig.is_active == True MultiAgentConfig.is_active.is_(True)
) )
.order_by(MultiAgentConfig.updated_at.desc()) .order_by(MultiAgentConfig.updated_at.desc())
).first() ).first()

View File

@@ -168,7 +168,7 @@ class SharedChatService:
select(ModelApiKey) select(ModelApiKey)
.where( .where(
ModelApiKey.model_config_id == model_config_id, ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True ModelApiKey.is_active.is_(True)
) )
.order_by(ModelApiKey.priority.desc()) .order_by(ModelApiKey.priority.desc())
.limit(1) .limit(1)
@@ -362,7 +362,7 @@ class SharedChatService:
select(ModelApiKey) select(ModelApiKey)
.where( .where(
ModelApiKey.model_config_id == model_config_id, ModelApiKey.model_config_id == model_config_id,
ModelApiKey.is_active == True ModelApiKey.is_active.is_(True)
) )
.order_by(ModelApiKey.priority.desc()) .order_by(ModelApiKey.priority.desc())
.limit(1) .limit(1)
@@ -598,7 +598,7 @@ class SharedChatService:
# 获取多 Agent 配置 # 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter( multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True MultiAgentConfig.is_active.is_(True)
).first() ).first()
if not multi_agent_config: if not multi_agent_config:
@@ -695,7 +695,7 @@ class SharedChatService:
# 获取多 Agent 配置 # 获取多 Agent 配置
multi_agent_config = self.db.query(MultiAgentConfig).filter( multi_agent_config = self.db.query(MultiAgentConfig).filter(
MultiAgentConfig.app_id == release.app_id, MultiAgentConfig.app_id == release.app_id,
MultiAgentConfig.is_active == True MultiAgentConfig.is_active.is_(True)
).first() ).first()
if not multi_agent_config: if not multi_agent_config:

View File

@@ -761,7 +761,10 @@ class WorkflowService:
# 4. 获取工作空间 ID从 app 获取) # 4. 获取工作空间 ID从 app 获取)
from app.models import App from app.models import App
app = self.db.query(App).filter(App.id == app_id).first() app = self.db.query(App).filter(
App.id == app_id,
App.is_active.is_(True)
).first()
if not app: if not app:
raise BusinessException( raise BusinessException(
code=BizCode.NOT_FOUND, code=BizCode.NOT_FOUND,

View File

@@ -690,8 +690,11 @@ def write_total_memory_task(workspace_id: str) -> Dict[str, Any]:
try: try:
workspace_uuid = uuid.UUID(workspace_id) workspace_uuid = uuid.UUID(workspace_id)
# 1. 查询当前workspace下的所有app # 1. 查询当前workspace下的所有app(仅未删除的)
apps = db.query(App).filter(App.workspace_id == workspace_uuid).all() apps = db.query(App).filter(
App.workspace_id == workspace_uuid,
App.is_active.is_(True)
).all()
if not apps: if not apps:
# 如果没有app总量为0 # 如果没有app总量为0

View File

@@ -46,7 +46,8 @@ def import_all_models_from_package(package_name: str):
# Add the project root to sys.path if not already there # Add the project root to sys.path if not already there
# This is crucial for relative imports like 'app.db' to work # This is crucial for relative imports like 'app.db' to work
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) from pathlib import Path
project_root = str(Path(__file__).resolve().parent.parent)
if project_root not in sys.path: if project_root not in sys.path:
sys.path.insert(0, project_root) sys.path.insert(0, project_root)

View File

@@ -89,21 +89,25 @@ const UploadFiles = forwardRef<UploadFilesRef, UploadFilesProps>(({
// 处理文件移除 // 处理文件移除
const handleRemove = (file: UploadFile) => { const handleRemove = (file: UploadFile) => {
// 如果有自定义的 onRemove 回调,先执行它 // 显示确认弹窗
if (customOnRemove) {
const result = customOnRemove(file);
// 如果返回 false阻止移除
if (result === false) {
return false;
}
}
confirm({ confirm({
title: `${t('common.confirmRemoveFile')}`, title: `${t('common.confirmRemoveFile')}`,
okText: `${t('common.confirm')}`, okText: `${t('common.confirm')}`,
okType: 'danger', okType: 'danger',
cancelText: `${t('common.cancel')}`, cancelText: `${t('common.cancel')}`,
onOk: () => { onOk: async () => {
// 如果有自定义的 onRemove 回调,在确认后执行
if (customOnRemove) {
const result = customOnRemove(file);
// 等待 Promise 结果
const finalResult = result instanceof Promise ? await result : result;
// 如果返回 false阻止移除
if (finalResult === false) {
return;
}
}
// 移除文件
const newFileList = fileList.filter((item) => item.uid !== file.uid); const newFileList = fileList.filter((item) => item.uid !== file.uid);
setFileList(newFileList); setFileList(newFileList);
onChange?.(newFileList); onChange?.(newFileList);

View File

@@ -91,15 +91,11 @@ const VariableSelect: FC<VariableSelectProps> = ({
onChange={handleChange} onChange={handleChange}
showSearch showSearch
allowClear={allowClear} allowClear={allowClear}
optionFilterProp="value"
filterOption={(input, option) => { filterOption={(input, option) => {
if (input === '/') return true; if (input === '/') return true;
if (option?.options) { const value = 'value' in option! ? option.value as string : '';
return option.label?.toLowerCase().includes(input.toLowerCase()) || return value.toLowerCase().includes(input.toLowerCase());
option.options.some((opt: any) =>
opt.value.toLowerCase().includes(input.toLowerCase())
);
}
return option?.label?.toLowerCase().includes(input.toLowerCase()) ?? false;
}} }}
/> />
) )

View File

@@ -135,6 +135,78 @@ export const getCurrentNodeVariables = (nodeData: any, values: any): Suggestion[
return nodeData.type === 'var-aggregator' && !nodeData.config.group.defaultValue ? [] : list; return nodeData.type === 'var-aggregator' && !nodeData.config.group.defaultValue ? [] : list;
}; };
export const getChildNodeVariables = (
selectedNode: Node,
graphRef: React.MutableRefObject<Graph | undefined>
): Suggestion[] => {
const graph = graphRef.current;
if (!graph) return [];
const list: Suggestion[] = [];
const nodes = graph.getNodes();
const edges = graph.getEdges();
const keys = new Set<string>();
const childNodes = nodes.filter(node => node.getData()?.cycle === selectedNode.id);
const getConnectedNodes = (nodeId: string, visited = new Set<string>()): string[] => {
if (visited.has(nodeId)) return [];
visited.add(nodeId);
const prev = edges.filter(e => e.getTargetCellId() === nodeId).map(e => e.getSourceCellId());
return [...prev, ...prev.flatMap(id => getConnectedNodes(id, visited))];
};
const relevantIds = new Set<string>();
childNodes.forEach(child => {
relevantIds.add(child.id);
getConnectedNodes(child.id).forEach(id => relevantIds.add(id));
});
relevantIds.forEach(id => {
const node = nodes.find(n => n.id === id);
if (!node) return;
const nodeData = node.getData();
const nodeId = nodeData.id;
const { type } = nodeData;
if (type in NODE_VARIABLES) {
NODE_VARIABLES[type as keyof typeof NODE_VARIABLES].forEach(({ label, dataType, field }) => {
const varKey = `${nodeId}_${label}`;
if (!keys.has(varKey)) {
keys.add(varKey);
list.push({
key: varKey,
label,
type: 'variable',
dataType,
value: `${nodeId}.${field}`,
nodeData,
});
}
});
}
if (type === 'parameter-extractor') {
(nodeData.config?.params?.defaultValue || []).forEach((p: any) => {
if (p?.name && !keys.has(`${nodeId}_${p.name}`)) {
keys.add(`${nodeId}_${p.name}`);
list.push({
key: `${nodeId}_${p.name}`,
label: p.name,
type: 'variable',
dataType: p.type || 'string',
value: `${nodeId}.${p.name}`,
nodeData,
});
}
});
}
});
return list;
};
export const useVariableList = ( export const useVariableList = (
selectedNode: Node | null | undefined, selectedNode: Node | null | undefined,
graphRef: React.MutableRefObject<Graph | undefined>, graphRef: React.MutableRefObject<Graph | undefined>,
@@ -187,13 +259,13 @@ export const useVariableList = (
} else if (pd.type === 'iteration' && pd.config.input.defaultValue) { } else if (pd.type === 'iteration' && pd.config.input.defaultValue) {
let itemType = 'object'; let itemType = 'object';
const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue);
if (iv?.dataType.startsWith('array[')) itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1'); if (iv?.dataType.startsWith('array[')) {itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1');}
addVariable(list, keys, `${pid}_item`, 'item', itemType, `${pid}.item`, pd); addVariable(list, keys, `${pid}_item`, 'item', itemType, `${pid}.item`, pd);
addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd);
} else if (pd.type === 'iteration' && !pd.config.input.defaultValue) { } else if (pd.type === 'iteration' && !pd.config.input.defaultValue) {
let itemType = 'object'; let itemType = 'object';
const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue); const iv = list.find(v => `{{${v.value}}}` === pd.config.input.defaultValue);
if (iv?.dataType.startsWith('array[')) itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1'); if (iv?.dataType.startsWith('array[')) {itemType = iv.dataType.replace(/^array\[(.+)\]$/, '$1');}
addVariable(list, keys, `${pid}_item`, 'item', 'string', `${pid}.item`, pd); addVariable(list, keys, `${pid}_item`, 'item', 'string', `${pid}.item`, pd);
addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd); addVariable(list, keys, `${pid}_index`, 'index', 'number', `${pid}.index`, pd);
} }

View File

@@ -24,7 +24,7 @@ import AssignmentList from './AssignmentList'
import ToolConfig from './ToolConfig' import ToolConfig from './ToolConfig'
import MemoryConfig from './MemoryConfig' import MemoryConfig from './MemoryConfig'
import VariableList from './VariableList' import VariableList from './VariableList'
import { useVariableList, getCurrentNodeVariables } from './hooks/useVariableList' import { useVariableList, getCurrentNodeVariables, getChildNodeVariables } from './hooks/useVariableList'
import styles from './properties.module.css' import styles from './properties.module.css'
import Editor from "../Editor"; import Editor from "../Editor";
import RbSlider from './RbSlider' import RbSlider from './RbSlider'
@@ -290,141 +290,26 @@ const Properties: FC<PropertiesProps> = ({
let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string' || variable.dataType === 'number'); let filteredList = addParentIterationVars(variableList).filter(variable => variable.dataType === 'string' || variable.dataType === 'number');
return filteredList; return filteredList;
} }
if (nodeType === 'iteration' && key === 'output') { if (nodeType === 'iteration' && key === 'output' || nodeType === 'loop' && key === 'condition') {
let filteredList = variableList.filter(variable => variable.value.includes('sys.')); if (!selectedNode) return [];
// Add child node output variables for loop nodes let filteredList = nodeType === 'iteration'
if (selectedNode) { ? variableList.filter(variable => variable.value.includes('sys.'))
const graph = graphRef.current; : addParentIterationVars(variableList).filter(variable => variable.nodeData.type !== 'loop');
if (graph) {
const nodes = graph.getNodes(); const childVariables = getChildNodeVariables(selectedNode, graphRef);
const childNodes = nodes.filter(node => { const existingKeys = new Set(filteredList.map(v => v.key));
const nodeData = node.getData(); childVariables.forEach(v => {
return nodeData?.cycle === selectedNode.id; if (!existingKeys.has(v.key)) {
}); filteredList.push(v);
existingKeys.add(v.key);
// Add output variables from child nodes
childNodes.forEach(childNode => {
const childData = childNode.getData();
const childNodeId = childData.id;
// Add child node output variables based on their type
switch (childData.type) {
case 'llm':
case 'jinja-render':
case 'tool':
const outputKey = `${childNodeId}_output`;
const existingOutput = filteredList.find(v => v.key === outputKey);
if (!existingOutput) {
filteredList.push({
key: outputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.output`,
nodeData: childData,
});
}
break;
case 'http-request':
const bodyKey = `${childNodeId}_body`;
const statusKey = `${childNodeId}_status_code`;
if (!filteredList.find(v => v.key === bodyKey)) {
filteredList.push({
key: bodyKey,
label: 'body',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.body`,
nodeData: childData,
});
}
if (!filteredList.find(v => v.key === statusKey)) {
filteredList.push({
key: statusKey,
label: 'status_code',
type: 'variable',
dataType: 'number',
value: `${childNodeId}.status_code`,
nodeData: childData,
});
}
break;
}
});
} }
} });
return filteredList; return filteredList;
} }
if (nodeType === 'iteration') { if (nodeType === 'iteration') {
return variableList.filter(variable => variable.dataType.includes('array')); return variableList.filter(variable => variable.dataType.includes('array'));
} }
if (nodeType === 'loop' && key === 'condition') {
let filteredList = addParentIterationVars(variableList).filter(variable => variable.nodeData.type !== 'loop');
// Add child node output variables for loop nodes
if (selectedNode) {
const graph = graphRef.current;
if (graph) {
const nodes = graph.getNodes();
const childNodes = nodes.filter(node => {
const nodeData = node.getData();
return nodeData?.cycle === selectedNode.id;
});
// Add output variables from child nodes
childNodes.forEach(childNode => {
const childData = childNode.getData();
const childNodeId = childData.id;
// Add child node output variables based on their type
switch(childData.type) {
case 'llm':
case 'jinja-render':
case 'tool':
const outputKey = `${childNodeId}_output`;
const existingOutput = filteredList.find(v => v.key === outputKey);
if (!existingOutput) {
filteredList.push({
key: outputKey,
label: 'output',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.output`,
nodeData: childData,
});
}
break;
case 'http-request':
const bodyKey = `${childNodeId}_body`;
const statusKey = `${childNodeId}_status_code`;
if (!filteredList.find(v => v.key === bodyKey)) {
filteredList.push({
key: bodyKey,
label: 'body',
type: 'variable',
dataType: 'string',
value: `${childNodeId}.body`,
nodeData: childData,
});
}
if (!filteredList.find(v => v.key === statusKey)) {
filteredList.push({
key: statusKey,
label: 'status_code',
type: 'variable',
dataType: 'number',
value: `${childNodeId}.status_code`,
nodeData: childData,
});
}
break;
}
});
}
}
return filteredList;
}
// For all other node types, add parent iteration variables if applicable // For all other node types, add parent iteration variables if applicable
let baseList = variableList; let baseList = variableList;