[changes]Ensure that there are sufficient labels for LLM to process, and control the number of label returns.
This commit is contained in:
@@ -142,11 +142,11 @@ async def get_raw_tags_from_db(
|
||||
async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool = False) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
获取原始标签,然后使用LLM进行筛选,返回最终的热门标签列表。
|
||||
查询更多的标签(limit=10)给LLM提供更丰富的上下文进行筛选。
|
||||
查询更多的标签(40条)给LLM提供更丰富的上下文进行筛选,但最终返回数量由limit参数控制。
|
||||
|
||||
Args:
|
||||
end_user_id: 必需参数。如果by_user=False,则为end_user_id;如果by_user=True,则为user_id
|
||||
limit: 返回的标签数量限制
|
||||
limit: 最终返回的标签数量限制(默认10)
|
||||
by_user: 是否按user_id查询(默认False,按end_user_id查询)
|
||||
|
||||
Raises:
|
||||
@@ -161,8 +161,9 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
||||
# 使用项目的Neo4jConnector
|
||||
connector = Neo4jConnector()
|
||||
try:
|
||||
# 1. 从数据库获取原始排名靠前的标签
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, limit, by_user=by_user)
|
||||
# 1. 从数据库获取原始排名靠前的标签(查询40条给LLM提供更丰富的上下文)
|
||||
query_limit = 40
|
||||
raw_tags_with_freq = await get_raw_tags_from_db(connector, end_user_id, query_limit, by_user=by_user)
|
||||
if not raw_tags_with_freq:
|
||||
return []
|
||||
|
||||
@@ -177,7 +178,8 @@ async def get_hot_memory_tags(end_user_id: str, limit: int = 10, by_user: bool =
|
||||
if tag in meaningful_tag_names:
|
||||
final_tags.append((tag, freq))
|
||||
|
||||
return final_tags
|
||||
# 4. 限制返回的标签数量
|
||||
return final_tags[:limit]
|
||||
finally:
|
||||
# 确保关闭连接
|
||||
await connector.close()
|
||||
|
||||
Reference in New Issue
Block a user