检查项目,修复group_id的遗留问题
This commit is contained in:
@@ -1009,7 +1009,7 @@ async def run_longmemeval_test(
|
||||
kw_fallback = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(search_limit, 5),
|
||||
)
|
||||
fb_dialogs = kw_fallback.get("dialogues", []) or []
|
||||
@@ -1223,7 +1223,7 @@ async def run_longmemeval_test(
|
||||
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
|
||||
},
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
|
||||
@@ -876,7 +876,7 @@ async def run_longmemeval_test(
|
||||
opt_res = await search_graph(
|
||||
connector=connector,
|
||||
q=str(opt),
|
||||
end_user_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(3, search_limit // 2),
|
||||
)
|
||||
if isinstance(opt_res, dict):
|
||||
@@ -971,7 +971,7 @@ async def run_longmemeval_test(
|
||||
kw_fallback = await search_graph(
|
||||
connector=connector,
|
||||
q=question,
|
||||
group_id=group_id,
|
||||
end_user_id=end_user_id,
|
||||
limit=max(search_limit, 5),
|
||||
)
|
||||
fb_dialogs = kw_fallback.get("dialogues", []) or []
|
||||
@@ -1199,7 +1199,7 @@ async def run_longmemeval_test(
|
||||
"count_avg": statistics.mean(per_query_context_counts) if per_query_context_counts else 0.0,
|
||||
},
|
||||
"params": {
|
||||
"group_id": group_id,
|
||||
"end_user_id": end_user_id,
|
||||
"search_limit": search_limit,
|
||||
"context_char_budget": context_char_budget,
|
||||
"search_type": search_type,
|
||||
@@ -1278,7 +1278,7 @@ def main():
|
||||
result = asyncio.run(
|
||||
run_longmemeval_test(
|
||||
sample_size=sample_size,
|
||||
group_id=args.group_id,
|
||||
end_user_id=args.end_user_id,
|
||||
search_limit=args.search_limit,
|
||||
context_char_budget=args.context_char_budget,
|
||||
llm_temperature=args.llm_temperature,
|
||||
|
||||
@@ -17,7 +17,7 @@ class MemoryConfig(Base):
|
||||
|
||||
# 组织信息
|
||||
workspace_id = Column(UUID(as_uuid=True), nullable=True, comment="工作空间ID")
|
||||
group_id = Column(String, nullable=True, comment="组ID")
|
||||
end_user_id = Column(String, nullable=True, comment="组ID")
|
||||
user_id = Column(String, nullable=True, comment="用户ID")
|
||||
apply_id = Column(String, nullable=True, comment="应用ID")
|
||||
|
||||
|
||||
@@ -217,14 +217,14 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
|
||||
async def find_by_content_keywords(
|
||||
self,
|
||||
group_id: str,
|
||||
end_user_id: str,
|
||||
keywords: List[str],
|
||||
limit: int = 100
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Query memory summaries by content keywords
|
||||
|
||||
Args:
|
||||
group_id: Group ID to filter by
|
||||
end_user_id: Group ID to filter by
|
||||
keywords: List of keywords to search for in content
|
||||
limit: Maximum number of results to return
|
||||
|
||||
@@ -233,7 +233,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""
|
||||
# Build keyword search conditions
|
||||
keyword_conditions = []
|
||||
params = {"end_user_id": group_id, "limit": limit}
|
||||
params = {"end_user_id": end_user_id, "limit": limit}
|
||||
|
||||
for i, keyword in enumerate(keywords):
|
||||
keyword_conditions.append(f"toLower(n.content) CONTAINS toLower($keyword_{i})")
|
||||
@@ -257,7 +257,7 @@ class MemorySummaryRepository(BaseNeo4jRepository):
|
||||
"""Get count of memory summaries for a group
|
||||
|
||||
Args:
|
||||
group_id: Group ID to count summaries for
|
||||
end_user_id: Group ID to count summaries for
|
||||
|
||||
Returns:
|
||||
int: Number of memory summaries
|
||||
|
||||
Reference in New Issue
Block a user